Sign Up
Log In
Log In
or
Sign Up
Places
All Projects
Status Monitor
Collapse sidebar
openSUSE:Factory:Rebuild
xgboost
no-hadoop.patch
Overview
Repositories
Revisions
Requests
Users
Attributes
Meta
File no-hadoop.patch of Package xgboost
diff --git a/jvm-packages/xgboost4j/pom.xml b/jvm-packages/xgboost4j/pom.xml index 8d4f2c05..7649df65 100644 --- a/jvm-packages/xgboost4j/pom.xml +++ b/jvm-packages/xgboost4j/pom.xml @@ -29,18 +29,6 @@ <artifactId>scala-collection-compat_${scala.binary.version}</artifactId> <version>${scala-collection-compat.version}</version> </dependency> - <dependency> - <groupId>org.apache.hadoop</groupId> - <artifactId>hadoop-hdfs</artifactId> - <version>${hadoop.version}</version> - <scope>provided</scope> - </dependency> - <dependency> - <groupId>org.apache.hadoop</groupId> - <artifactId>hadoop-common</artifactId> - <version>${hadoop.version}</version> - <scope>provided</scope> - </dependency> <dependency> <groupId>junit</groupId> <artifactId>junit</artifactId> diff --git a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/ExternalCheckpointManager.java b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/ExternalCheckpointManager.java deleted file mode 100644 index 655b9902..00000000 --- a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/ExternalCheckpointManager.java +++ /dev/null @@ -1,117 +0,0 @@ -package ml.dmlc.xgboost4j.java; - -import java.io.IOException; -import java.io.InputStream; -import java.io.OutputStream; -import java.util.*; -import java.util.stream.Collectors; - -import org.apache.commons.logging.Log; -import org.apache.commons.logging.LogFactory; -import org.apache.hadoop.fs.FileSystem; -import org.apache.hadoop.fs.Path; - -public class ExternalCheckpointManager { - - private Log logger = LogFactory.getLog("ExternalCheckpointManager"); - private String modelSuffix = ".model"; - private Path checkpointPath; - private FileSystem fs; - - public ExternalCheckpointManager(String checkpointPath, FileSystem fs) throws XGBoostError { - if (checkpointPath == null || checkpointPath.isEmpty()) { - throw new XGBoostError("cannot create ExternalCheckpointManager with null or" + - " empty checkpoint path"); - } - this.checkpointPath = new Path(checkpointPath); - this.fs = fs; - } - - private String getPath(int version) { - return checkpointPath.toUri().getPath() + "/" + version + modelSuffix; - } - - private List<Integer> getExistingVersions() throws IOException { - if (!fs.exists(checkpointPath)) { - return new ArrayList<>(); - } else { - return Arrays.stream(fs.listStatus(checkpointPath)) - .map(path -> path.getPath().getName()) - .filter(fileName -> fileName.endsWith(modelSuffix)) - .map(fileName -> Integer.valueOf( - fileName.substring(0, fileName.length() - modelSuffix.length()))) - .collect(Collectors.toList()); - } - } - - public void cleanPath() throws IOException { - fs.delete(checkpointPath, true); - } - - public Booster loadCheckpointAsBooster() throws IOException, XGBoostError { - List<Integer> versions = getExistingVersions(); - if (versions.size() > 0) { - int latestVersion = versions.stream().max(Comparator.comparing(Integer::valueOf)).get(); - String checkpointPath = getPath(latestVersion); - InputStream in = fs.open(new Path(checkpointPath)); - logger.info("loaded checkpoint from " + checkpointPath); - Booster booster = XGBoost.loadModel(in); - booster.setVersion(latestVersion); - return booster; - } else { - return null; - } - } - - public void updateCheckpoint(Booster boosterToCheckpoint) throws IOException, XGBoostError { - List<String> prevModelPaths = getExistingVersions().stream() - .map(this::getPath).collect(Collectors.toList()); - String eventualPath = getPath(boosterToCheckpoint.getVersion()); - String tempPath = eventualPath + "-" + UUID.randomUUID(); - try (OutputStream out = fs.create(new Path(tempPath), true)) { - boosterToCheckpoint.saveModel(out); - fs.rename(new Path(tempPath), new Path(eventualPath)); - logger.info("saving checkpoint with version " + boosterToCheckpoint.getVersion()); - prevModelPaths.stream().forEach(path -> { - try { - fs.delete(new Path(path), true); - } catch (IOException e) { - logger.error("failed to delete outdated checkpoint at " + path, e); - } - }); - } - } - - public void cleanUpHigherVersions(int currentRound) throws IOException { - getExistingVersions().stream().filter(v -> v / 2 >= currentRound).forEach(v -> { - try { - fs.delete(new Path(getPath(v)), true); - } catch (IOException e) { - logger.error("failed to clean checkpoint from other training instance", e); - } - }); - } - - public List<Integer> getCheckpointRounds(int checkpointInterval, int numOfRounds) - throws IOException { - if (checkpointInterval > 0) { - List<Integer> prevRounds = - getExistingVersions().stream().map(v -> v / 2).collect(Collectors.toList()); - prevRounds.add(0); - int firstCheckpointRound = prevRounds.stream() - .max(Comparator.comparing(Integer::valueOf)).get() + checkpointInterval; - List<Integer> arr = new ArrayList<>(); - for (int i = firstCheckpointRound; i <= numOfRounds; i += checkpointInterval) { - arr.add(i); - } - arr.add(numOfRounds); - return arr; - } else if (checkpointInterval <= 0) { - List<Integer> l = new ArrayList<Integer>(); - l.add(numOfRounds); - return l; - } else { - throw new IllegalArgumentException("parameters \"checkpoint_path\" should also be set."); - } - } -} diff --git a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/XGBoost.java b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/XGBoost.java index bcd0b1b1..3e23c15f 100644 --- a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/XGBoost.java +++ b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/XGBoost.java @@ -22,7 +22,6 @@ import java.util.regex.Pattern; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; -import org.apache.hadoop.fs.FileSystem; /** * trainer for xgboost @@ -134,34 +133,35 @@ public class XGBoost { return train(dtrain, params, round, watches, metrics, obj, eval, earlyStoppingRound, null); } - private static void saveCheckpoint( - Booster booster, - int iter, - Set<Integer> checkpointIterations, - ExternalCheckpointManager ecm) throws XGBoostError { - try { - if (checkpointIterations.contains(iter)) { - ecm.updateCheckpoint(booster); - } - } catch (Exception e) { - logger.error("failed to save checkpoint in XGBoost4J at iteration " + iter, e); - throw new XGBoostError("failed to save checkpoint in XGBoost4J at iteration" + iter, e); - } - } - - public static Booster trainAndSaveCheckpoint( + /** + * Train a booster given parameters. + * + * @param dtrain Data to be trained. + * @param params Parameters. + * @param round Number of boosting iterations. + * @param watches a group of items to be evaluated during training, this allows user to watch + * performance on the validation set. + * @param metrics array containing the evaluation metrics for each matrix in watches for each + * iteration + * @param earlyStoppingRounds if non-zero, training would be stopped + * after a specified number of consecutive + * goes to the unexpected direction in any evaluation metric. + * @param obj customized objective + * @param eval customized evaluation + * @param booster train from scratch if set to null; train from an existing booster if not null. + * @return The trained booster. + */ + public static Booster train( DMatrix dtrain, Map<String, Object> params, - int numRounds, + int round, Map<String, DMatrix> watches, float[][] metrics, IObjective obj, IEvaluation eval, int earlyStoppingRounds, - Booster booster, - int checkpointInterval, - String checkpointPath, - FileSystem fs) throws XGBoostError, IOException { + Booster booster) throws XGBoostError { + //collect eval matrixs String[] evalNames; DMatrix[] evalMats; @@ -169,11 +169,6 @@ public class XGBoost { int bestIteration; List<String> names = new ArrayList<String>(); List<DMatrix> mats = new ArrayList<DMatrix>(); - Set<Integer> checkpointIterations = new HashSet<>(); - ExternalCheckpointManager ecm = null; - if (checkpointPath != null) { - ecm = new ExternalCheckpointManager(checkpointPath, fs); - } for (Map.Entry<String, DMatrix> evalEntry : watches.entrySet()) { names.add(evalEntry.getKey()); @@ -184,7 +179,7 @@ public class XGBoost { evalMats = mats.toArray(new DMatrix[mats.size()]); bestIteration = 0; - metrics = metrics == null ? new float[evalNames.length][numRounds] : metrics; + metrics = metrics == null ? new float[evalNames.length][round] : metrics; //collect all data matrixs DMatrix[] allMats; @@ -209,22 +204,17 @@ public class XGBoost { booster.setParams(params); } - if (ecm != null) { - checkpointIterations = new HashSet<>(ecm.getCheckpointRounds(checkpointInterval, numRounds)); - } - boolean initial_best_score_flag = false; boolean max_direction = false; // begin to train - for (int iter = booster.getVersion() / 2; iter < numRounds; iter++) { + for (int iter = booster.getVersion() / 2; iter < round; iter++) { if (booster.getVersion() % 2 == 0) { if (obj != null) { booster.update(dtrain, obj); } else { booster.update(dtrain, iter); } - saveCheckpoint(booster, iter, checkpointIterations, ecm); booster.saveRabitCheckpoint(); } @@ -290,44 +280,6 @@ public class XGBoost { return booster; } - /** - * Train a booster given parameters. - * - * @param dtrain Data to be trained. - * @param params Parameters. - * @param round Number of boosting iterations. - * @param watches a group of items to be evaluated during training, this allows user to watch - * performance on the validation set. - * @param metrics array containing the evaluation metrics for each matrix in watches for each - * iteration - * @param earlyStoppingRounds if non-zero, training would be stopped - * after a specified number of consecutive - * goes to the unexpected direction in any evaluation metric. - * @param obj customized objective - * @param eval customized evaluation - * @param booster train from scratch if set to null; train from an existing booster if not null. - * @return The trained booster. - */ - public static Booster train( - DMatrix dtrain, - Map<String, Object> params, - int round, - Map<String, DMatrix> watches, - float[][] metrics, - IObjective obj, - IEvaluation eval, - int earlyStoppingRounds, - Booster booster) throws XGBoostError { - try { - return trainAndSaveCheckpoint(dtrain, params, round, watches, metrics, obj, eval, - earlyStoppingRounds, booster, - -1, null, null); - } catch (IOException e) { - logger.error("training failed in xgboost4j", e); - throw new XGBoostError("training failed in xgboost4j ", e); - } - } - private static Integer tryGetIntFromObject(Object o) { if (o instanceof Integer) { return (int)o; diff --git a/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/ExternalCheckpointManager.scala b/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/ExternalCheckpointManager.scala deleted file mode 100644 index 240c2387..00000000 --- a/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/ExternalCheckpointManager.scala +++ /dev/null @@ -1,37 +0,0 @@ -/* - Copyright (c) 2014 by Contributors - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - */ - -package ml.dmlc.xgboost4j.scala - -import ml.dmlc.xgboost4j.java.{ExternalCheckpointManager => JavaECM} -import org.apache.hadoop.fs.FileSystem - -class ExternalCheckpointManager(checkpointPath: String, fs: FileSystem) - extends JavaECM(checkpointPath, fs) { - - def updateCheckpoint(booster: Booster): Unit = { - super.updateCheckpoint(booster.booster) - } - - def loadCheckpointAsScalaBooster(): Booster = { - val loadedBooster = super.loadCheckpointAsBooster() - if (loadedBooster == null) { - null - } else { - new Booster(loadedBooster) - } - } -} diff --git a/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/XGBoost.scala b/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/XGBoost.scala index 50d86c89..49ce29d8 100644 --- a/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/XGBoost.scala +++ b/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/XGBoost.scala @@ -20,61 +20,12 @@ import java.io.InputStream import ml.dmlc.xgboost4j.java.{XGBoostError, XGBoost => JXGBoost} import scala.jdk.CollectionConverters._ -import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.Path /** * XGBoost Scala Training function. */ object XGBoost { - private[scala] def trainAndSaveCheckpoint( - dtrain: DMatrix, - params: Map[String, Any], - numRounds: Int, - watches: Map[String, DMatrix] = Map(), - metrics: Array[Array[Float]] = null, - obj: ObjectiveTrait = null, - eval: EvalTrait = null, - earlyStoppingRound: Int = 0, - prevBooster: Booster, - checkpointParams: Option[ExternalCheckpointParams]): Booster = { - - // we have to filter null value for customized obj and eval - val jParams: java.util.Map[String, AnyRef] = - params.filter(_._2 != null).mapValues(_.toString.asInstanceOf[AnyRef]).toMap.asJava - - val jWatches = watches.mapValues(_.jDMatrix).toMap.asJava - val jBooster = if (prevBooster == null) { - null - } else { - prevBooster.booster - } - - val xgboostInJava = checkpointParams. - map(cp => { - JXGBoost.trainAndSaveCheckpoint( - dtrain.jDMatrix, - jParams, - numRounds, jWatches, metrics, obj, eval, earlyStoppingRound, jBooster, - cp.checkpointInterval, - cp.checkpointPath, - new Path(cp.checkpointPath).getFileSystem(new Configuration())) - }). - getOrElse( - JXGBoost.train( - dtrain.jDMatrix, - jParams, - numRounds, jWatches, metrics, obj, eval, earlyStoppingRound, jBooster) - ) - if (prevBooster == null) { - new Booster(xgboostInJava) - } else { - // Avoid creating a new SBooster with the same JBooster - prevBooster - } - } - /** * Train a booster given parameters. * @@ -104,8 +55,23 @@ object XGBoost { eval: EvalTrait = null, earlyStoppingRound: Int = 0, booster: Booster = null): Booster = { - trainAndSaveCheckpoint(dtrain, params, round, watches, metrics, obj, eval, earlyStoppingRound, - booster, None) + val jWatches = watches.mapValues(_.jDMatrix).toMap.asJava + val jBooster = if (booster == null) { + null + } else { + booster.booster + } + val xgboostInJava = JXGBoost.train( + dtrain.jDMatrix, + // we have to filter null value for customized obj and eval + params.filter(_._2 != null).mapValues(_.toString.asInstanceOf[AnyRef]).toMap.asJava, + round, jWatches, metrics, obj, eval, earlyStoppingRound, jBooster) + if (booster == null) { + new Booster(xgboostInJava) + } else { + // Avoid creating a new SBooster with the same JBooster + booster + } } /** @@ -160,41 +126,3 @@ object XGBoost { new Booster(xgboostInJava) } } - -private[scala] case class ExternalCheckpointParams( - checkpointInterval: Int, - checkpointPath: String, - skipCleanCheckpoint: Boolean) - -private[scala] object ExternalCheckpointParams { - - def extractParams(params: Map[String, Any]): Option[ExternalCheckpointParams] = { - val checkpointPath: String = params.get("checkpoint_path") match { - case None | Some(null) | Some("") => null - case Some(path: String) => path - case _ => throw new IllegalArgumentException("parameter \"checkpoint_path\" must be" + - s" an instance of String, but current value is ${params("checkpoint_path")}") - } - - val checkpointInterval: Int = params.get("checkpoint_interval") match { - case None => 0 - case Some(freq: Int) => freq - case _ => throw new IllegalArgumentException("parameter \"checkpoint_interval\" must be" + - " an instance of Int.") - } - - val skipCleanCheckpointFile: Boolean = params.get("skip_clean_checkpoint") match { - case None => false - case Some(skipCleanCheckpoint: Boolean) => skipCleanCheckpoint - case _ => throw new IllegalArgumentException("parameter \"skip_clean_checkpoint\" must be" + - " an instance of Boolean") - } - if (checkpointPath == null || checkpointInterval == 0) { - None - } else { - Some(ExternalCheckpointParams(checkpointInterval, checkpointPath, skipCleanCheckpointFile)) - } - } -} - -
Locations
Projects
Search
Status Monitor
Help
OpenBuildService.org
Documentation
API Documentation
Code of Conduct
Contact
Support
@OBShq
Terms
openSUSE Build Service is sponsored by
The Open Build Service is an
openSUSE project
.
Sign Up
Log In
Places
Places
All Projects
Status Monitor