22
33import java .io .IOException ;
44import java .util .HashMap ;
5+ import java .util .Map ;
6+
7+ import org .apache .commons .logging .Log ;
8+ import org .apache .commons .logging .LogFactory ;
59
610import ml .dmlc .xgboost4j .*;
711
12+
813/**
914 * Distributed training example, used to quick test distributed training.
1015 *
1116 * @author tqchen
1217 */
1318public class DistTrain {
19+ private static final Log logger = LogFactory .getLog (DistTrain .class );
20+ private Map <String , String > envs = null ;
21+
22+ private class Worker implements Runnable {
23+ private final int workerId ;
24+
25+ Worker (int workerId ) {
26+ this .workerId = workerId ;
27+ }
1428
15- public static void main ( String [] args ) throws IOException , XGBoostError {
16- // always initialize rabit module before training.
17- Rabit . init ( new HashMap <String , String >() );
29+ public void run () {
30+ try {
31+ Map < String , String > worker_env = new HashMap <String , String >(envs );
1832
19- // load file from text file, also binary buffer generated by xgboost4j
20- DMatrix trainMat = new DMatrix ( "../../demo/data/agaricus.txt.train" );
21- DMatrix testMat = new DMatrix ( "../../demo/data/agaricus.txt.test" );
33+ worker_env . put ( "DMLC_TASK_ID" , String . valueOf ( workerId ));
34+ // always initialize rabit module before training.
35+ Rabit . init ( worker_env );
2236
23- HashMap <String , Object > params = new HashMap <String , Object >();
24- params .put ("eta" , 1.0 );
25- params .put ("max_depth" , 2 );
26- params .put ("silent" , 1 );
27- params .put ("objective" , "binary:logistic" );
37+ // load file from text file, also binary buffer generated by xgboost4j
38+ DMatrix trainMat = new DMatrix ("../../demo/data/agaricus.txt.train" );
39+ DMatrix testMat = new DMatrix ("../../demo/data/agaricus.txt.test" );
2840
41+ HashMap <String , Object > params = new HashMap <String , Object >();
42+ params .put ("eta" , 1.0 );
43+ params .put ("max_depth" , 2 );
44+ params .put ("silent" , 1 );
45+ params .put ("nthread" , 2 );
46+ params .put ("objective" , "binary:logistic" );
2947
30- HashMap <String , DMatrix > watches = new HashMap <String , DMatrix >();
31- watches .put ("train" , trainMat );
32- watches .put ("test" , testMat );
48+ HashMap <String , DMatrix > watches = new HashMap <String , DMatrix >();
49+ watches .put ("train" , trainMat );
50+ watches .put ("test" , testMat );
3351
34- //set round
35- int round = 2 ;
52+ //set round
53+ int round = 2 ;
3654
37- //train a boost model
38- Booster booster = XGBoost .train (params , trainMat , round , watches , null , null );
55+ //train a boost model
56+ Booster booster = XGBoost .train (params , trainMat , round , watches , null , null );
57+
58+ // always shutdown rabit module after training.
59+ Rabit .shutdown ();
60+ } catch (Exception ex ){
61+ logger .error (ex );
62+ }
63+ }
64+ }
65+
66+ void start (int nWorkers ) throws IOException , XGBoostError , InterruptedException {
67+ RabitTracker tracker = new RabitTracker (nWorkers );
68+ if (tracker .start ()) {
69+ envs = tracker .getWorkerEnvs ();
70+ for (int i = 0 ; i < nWorkers ; ++i ) {
71+ new Thread (new Worker (i )).start ();
72+ }
73+ tracker .waitFor ();
74+ }
75+ }
3976
40- // always shutdown rabit module after training.
41- Rabit . shutdown ( );
77+ public static void main ( String [] args ) throws IOException , XGBoostError , InterruptedException {
78+ new DistTrain (). start ( Integer . parseInt ( args [ 0 ]) );
4279 }
4380}
0 commit comments