Skip to content

Commit 770b345

Browse files
committed
Merge pull request #907 from tqchen/master
[DIST] Enable multiple thread make rabit and xgboost threadsafe
2 parents 12dc92f + 04bdbca commit 770b345

File tree

18 files changed

+368
-112
lines changed

18 files changed

+368
-112
lines changed

.gitignore

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,5 +78,4 @@ tags
7878
*.iml
7979
*.class
8080
target
81-
8281
*.swp

NEWS.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,11 @@ This file records the changes in xgboost library in reverse chronological order.
2222
- The windows version is still blocked due to Rtools do not support ```std::thread```.
2323
* rabit and dmlc-core are maintained through git submodule
2424
- Anyone can open PR to update these dependencies now.
25+
* Improvements
26+
- Rabit and xgboost libs are not thread-safe and use thread local PRNGs
27+
- This could fix some of the previous problem which runs xgboost on multiple threads.
28+
* JVM Package
29+
- Enable xgboost4j for java and scala
2530

2631
## v0.47 (2016.01.14)
2732

jvm-packages/.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
tracker.py
2+
build.sh

jvm-packages/create_jni.sh

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@ fi
2727

2828
rm -f xgboost4j/src/main/resources/lib/libxgboost4j.${dl}
2929
mv lib/libxgboost4j.so xgboost4j/src/main/resources/lib/libxgboost4j.${dl}
30+
# copy python to native resources
31+
cp ../dmlc-core/tracker/dmlc_tracker/tracker.py xgboost4j/src/main/resources/tracker.py
3032

3133
popd > /dev/null
3234
echo "complete"

jvm-packages/test_distributed.sh

100644100755
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
#!/bin/bash
22
# Simple script to test distributed version, to be deleted later.
33
cd xgboost4j-demo
4-
../../dmlc-core/tracker/dmlc-submit --cluster=local --num-workers=3 java -cp target/xgboost4j-demo-0.1-jar-with-dependencies.jar ml.dmlc.xgboost4j.demo.DistTrain
4+
java -XX:OnError="gdb - %p" -cp target/xgboost4j-demo-0.1-jar-with-dependencies.jar ml.dmlc.xgboost4j.demo.DistTrain 4
55
cd ..

jvm-packages/xgboost4j-demo/src/main/java/ml/dmlc/xgboost4j/demo/DistTrain.java

Lines changed: 57 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -2,42 +2,79 @@
22

33
import java.io.IOException;
44
import java.util.HashMap;
5+
import java.util.Map;
6+
7+
import org.apache.commons.logging.Log;
8+
import org.apache.commons.logging.LogFactory;
59

610
import ml.dmlc.xgboost4j.*;
711

12+
813
/**
914
* Distributed training example, used to quick test distributed training.
1015
*
1116
* @author tqchen
1217
*/
1318
public class DistTrain {
19+
private static final Log logger = LogFactory.getLog(DistTrain.class);
20+
private Map<String, String> envs = null;
21+
22+
private class Worker implements Runnable {
23+
private final int workerId;
24+
25+
Worker(int workerId) {
26+
this.workerId = workerId;
27+
}
1428

15-
public static void main(String[] args) throws IOException, XGBoostError {
16-
// always initialize rabit module before training.
17-
Rabit.init(new HashMap<String, String>());
29+
public void run() {
30+
try {
31+
Map<String, String> worker_env = new HashMap<String, String>(envs);
1832

19-
// load file from text file, also binary buffer generated by xgboost4j
20-
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
21-
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");
33+
worker_env.put("DMLC_TASK_ID", String.valueOf(workerId));
34+
// always initialize rabit module before training.
35+
Rabit.init(worker_env);
2236

23-
HashMap<String, Object> params = new HashMap<String, Object>();
24-
params.put("eta", 1.0);
25-
params.put("max_depth", 2);
26-
params.put("silent", 1);
27-
params.put("objective", "binary:logistic");
37+
// load file from text file, also binary buffer generated by xgboost4j
38+
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
39+
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");
2840

41+
HashMap<String, Object> params = new HashMap<String, Object>();
42+
params.put("eta", 1.0);
43+
params.put("max_depth", 2);
44+
params.put("silent", 1);
45+
params.put("nthread", 2);
46+
params.put("objective", "binary:logistic");
2947

30-
HashMap<String, DMatrix> watches = new HashMap<String, DMatrix>();
31-
watches.put("train", trainMat);
32-
watches.put("test", testMat);
48+
HashMap<String, DMatrix> watches = new HashMap<String, DMatrix>();
49+
watches.put("train", trainMat);
50+
watches.put("test", testMat);
3351

34-
//set round
35-
int round = 2;
52+
//set round
53+
int round = 2;
3654

37-
//train a boost model
38-
Booster booster = XGBoost.train(params, trainMat, round, watches, null, null);
55+
//train a boost model
56+
Booster booster = XGBoost.train(params, trainMat, round, watches, null, null);
57+
58+
// always shutdown rabit module after training.
59+
Rabit.shutdown();
60+
} catch (Exception ex){
61+
logger.error(ex);
62+
}
63+
}
64+
}
65+
66+
void start(int nWorkers) throws IOException, XGBoostError, InterruptedException {
67+
RabitTracker tracker = new RabitTracker(nWorkers);
68+
if (tracker.start()) {
69+
envs = tracker.getWorkerEnvs();
70+
for (int i = 0; i < nWorkers; ++i) {
71+
new Thread(new Worker(i)).start();
72+
}
73+
tracker.waitFor();
74+
}
75+
}
3976

40-
// always shutdown rabit module after training.
41-
Rabit.shutdown();
77+
public static void main(String[] args) throws IOException, XGBoostError, InterruptedException {
78+
new DistTrain().start(Integer.parseInt(args[0]));
4279
}
4380
}

jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/Booster.java

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
package ml.dmlc.xgboost4j;
22

33
import java.io.IOException;
4+
import java.io.Serializable;
45
import java.util.Map;
56

6-
public interface Booster {
7+
public interface Booster extends Serializable {
78

89
/**
910
* set parameter
@@ -109,12 +110,25 @@ public interface Booster {
109110
float[][] predict(DMatrix data, int treeLimit, boolean predLeaf) throws XGBoostError;
110111

111112
/**
112-
* save model to modelPath
113-
*
113+
* save model to modelPath, the model path support depends on the path support
114+
* in libxgboost. For example, if we want to save to hdfs, libxgboost need to be
115+
* compiled with HDFS support.
116+
* See also toByteArray
114117
* @param modelPath model path
115118
*/
116119
void saveModel(String modelPath) throws XGBoostError;
117120

121+
/**
122+
* Save the model as byte array representation.
123+
* Write these bytes to a file will give compatible format with other xgboost bindings.
124+
*
125+
* If java natively support HDFS file API, use toByteArray and write the ByteArray,
126+
*
127+
* @return the saved byte array.
128+
* @throws XGBoostError
129+
*/
130+
byte[] toByteArray() throws XGBoostError;
131+
118132
/**
119133
* Dump model into a text file.
120134
*

jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/JavaBoosterImpl.java

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,6 @@ class JavaBoosterImpl implements Booster {
5757
setParams(params);
5858
}
5959

60-
6160
/**
6261
* load model from modelPath
6362
*
@@ -440,6 +439,22 @@ public Map<String, Integer> getFeatureScore(String featureMap) throws XGBoostErr
440439
return featureScore;
441440
}
442441

442+
/**
443+
* Save the model as byte array representation.
444+
* Write these bytes to a file will give compatible format with other xgboost bindings.
445+
*
446+
* If java natively support HDFS file API, use toByteArray and write the ByteArray,
447+
*
448+
* @return the saved byte array.
449+
* @throws XGBoostError
450+
*/
451+
public byte[] toByteArray() throws XGBoostError {
452+
byte[][] bytes = new byte[1][];
453+
JNIErrorHandle.checkCall(XgboostJNI.XGBoosterGetModelRaw(this.handle, bytes));
454+
return bytes[0];
455+
}
456+
457+
443458
/**
444459
* Load the booster model from thread-local rabit checkpoint.
445460
* This is only used in distributed training.
@@ -475,6 +490,27 @@ private static long[] dmatrixsToHandles(DMatrix[] dmatrixs) {
475490
return handles;
476491
}
477492

493+
// making Booster serializable
494+
private void writeObject(java.io.ObjectOutputStream out)
495+
throws IOException {
496+
try {
497+
out.writeObject(this.toByteArray());
498+
} catch (XGBoostError ex) {
499+
throw new IOException(ex.toString());
500+
}
501+
}
502+
503+
private void readObject(java.io.ObjectInputStream in)
504+
throws IOException, ClassNotFoundException {
505+
try {
506+
this.init(null);
507+
byte[] bytes = (byte[])in.readObject();
508+
JNIErrorHandle.checkCall(XgboostJNI.XGBoosterLoadModelFromBuffer(this.handle, bytes));
509+
} catch (XGBoostError ex) {
510+
throw new IOException(ex.toString());
511+
}
512+
}
513+
478514
@Override
479515
protected void finalize() throws Throwable {
480516
super.finalize();

jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/NativeLibLoader.java

Lines changed: 26 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import org.apache.commons.logging.Log;
2222
import org.apache.commons.logging.LogFactory;
2323

24+
2425
/**
2526
* class to load native library
2627
*
@@ -61,12 +62,32 @@ public static synchronized void initXgBoost() throws IOException {
6162
* three characters
6263
*/
6364
private static void loadLibraryFromJar(String path) throws IOException, IllegalArgumentException{
65+
String temp = createTempFileFromResource(path);
66+
// Finally, load the library
67+
System.load(temp);
68+
}
6469

70+
/**
71+
* Create a temp file that copies the resource from current JAR archive
72+
* <p/>
73+
* The file from JAR is copied into system temp file.
74+
* The temporary file is deleted after exiting.
75+
* Method uses String as filename because the pathname is "abstract", not system-dependent.
76+
* <p/>
77+
* The restrictions of {@link File#createTempFile(java.lang.String, java.lang.String)} apply to
78+
* {@code path}.
79+
* @param path Path to the resources in the jar
80+
* @return The created temp file.
81+
* @throws IOException
82+
* @throws IllegalArgumentException
83+
*/
84+
static String createTempFileFromResource(String path) throws
85+
IOException, IllegalArgumentException {
86+
// Obtain filename from path
6587
if (!path.startsWith("/")) {
6688
throw new IllegalArgumentException("The path has to be absolute (start with '/').");
6789
}
6890

69-
// Obtain filename from path
7091
String[] parts = path.split("/");
7192
String filename = (parts.length > 1) ? parts[parts.length - 1] : null;
7293

@@ -83,7 +104,6 @@ private static void loadLibraryFromJar(String path) throws IOException, IllegalA
83104
if (filename == null || prefix.length() < 3) {
84105
throw new IllegalArgumentException("The filename has to be at least 3 characters long.");
85106
}
86-
87107
// Prepare temporary file
88108
File temp = File.createTempFile(prefix, suffix);
89109
temp.deleteOnExit();
@@ -113,9 +133,7 @@ private static void loadLibraryFromJar(String path) throws IOException, IllegalA
113133
os.close();
114134
is.close();
115135
}
116-
117-
// Finally, load the library
118-
System.load(temp.getAbsolutePath());
136+
return temp.getAbsolutePath();
119137
}
120138

121139
/**
@@ -133,8 +151,9 @@ private static void smartLoad(String libName) throws IOException {
133151
try {
134152
String libraryFromJar = nativeResourcePath + System.mapLibraryName(libName);
135153
loadLibraryFromJar(libraryFromJar);
136-
} catch (IOException e1) {
137-
throw e1;
154+
} catch (IOException ioe) {
155+
logger.error("failed to load library from both native path and jar");
156+
throw ioe;
138157
}
139158
}
140159
}

0 commit comments

Comments
 (0)