1- use dmatrix:: DMatrix ;
2- use error:: XGBError ;
1+ use crate :: dmatrix:: DMatrix ;
2+ use crate :: error:: XGBError ;
33use libc;
44use std:: collections:: { BTreeMap , HashMap } ;
55use std:: io:: { self , BufRead , BufReader , Write } ;
@@ -13,7 +13,7 @@ use tempfile;
1313use xgboost_sys;
1414
1515use super :: XGBResult ;
16- use parameters:: { BoosterParameters , TrainingParameters } ;
16+ use crate :: parameters:: { BoosterParameters , TrainingParameters } ;
1717
1818pub type CustomObjective = fn ( & [ f32 ] , & DMatrix ) -> ( Vec < f32 > , Vec < f32 > ) ;
1919
@@ -148,29 +148,8 @@ impl Booster {
148148 dmats
149149 } ;
150150
151- let mut bst = Booster :: new_with_cached_dmats ( & params. booster_params , & cached_dmats) ?;
152- // load distributed code checkpoint from rabit
153- let mut version = bst. load_rabit_checkpoint ( ) ?;
154- debug ! ( "Loaded Rabit checkpoint: version={}" , version) ;
155- assert ! ( unsafe { xgboost_sys:: RabitGetWorldSize ( ) != 1 || version == 0 } ) ;
156- let start_iteration = version / 2 ;
157- for i in start_iteration..params. boost_rounds as i32 {
158- // distributed code: need to resume to this point
159- // skip first update if a recovery step
160- if version % 2 == 0 {
161- if let Some ( objective_fn) = params. custom_objective_fn {
162- debug ! ( "Boosting in round: {}" , i) ;
163- bst. update_custom ( params. dtrain , objective_fn) ?;
164- } else {
165- debug ! ( "Updating in round: {}" , i) ;
166- bst. update ( params. dtrain , i) ?;
167- }
168- let _ = bst. save_rabit_checkpoint ( ) ?;
169- version += 1 ;
170- }
171-
172- assert ! ( unsafe { xgboost_sys:: RabitGetWorldSize ( ) == 1 || version == xgboost_sys:: RabitVersionNumber ( ) } ) ;
173-
151+ let bst = Booster :: new_with_cached_dmats ( & params. booster_params , & cached_dmats) ?;
152+ for i in 0 ..params. boost_rounds as i32 {
174153 if let Some ( eval_sets) = params. evaluation_sets {
175154 let mut dmat_eval_results = bst. eval_set ( eval_sets, i) ?;
176155
@@ -203,10 +182,6 @@ impl Booster {
203182 }
204183 println ! ( ) ;
205184 }
206-
207- // do checkpoint after evaluation, in case evaluation also updates booster.
208- let _ = bst. save_rabit_checkpoint ( ) ;
209- version += 1 ;
210185 }
211186
212187 Ok ( bst)
@@ -365,13 +340,16 @@ impl Booster {
365340 let mut out_len = 0 ;
366341 let mut out = ptr:: null_mut ( ) ;
367342 xgb_call ! ( xgboost_sys:: XGBoosterGetAttrNames ( self . handle, & mut out_len, & mut out) ) ?;
368-
369- let out_ptr_slice = unsafe { slice:: from_raw_parts ( out, out_len as usize ) } ;
370- let out_vec = out_ptr_slice
371- . iter ( )
372- . map ( |str_ptr| unsafe { ffi:: CStr :: from_ptr ( * str_ptr) . to_str ( ) . unwrap ( ) . to_owned ( ) } )
373- . collect ( ) ;
374- Ok ( out_vec)
343+ if out_len > 0 {
344+ let out_ptr_slice = unsafe { slice:: from_raw_parts ( out, out_len as usize ) } ;
345+ let out_vec = out_ptr_slice
346+ . iter ( )
347+ . map ( |str_ptr| unsafe { ffi:: CStr :: from_ptr ( * str_ptr) . to_str ( ) . unwrap ( ) . to_owned ( ) } )
348+ . collect ( ) ;
349+ Ok ( out_vec)
350+ } else {
351+ Ok ( Vec :: new ( ) )
352+ }
375353 }
376354
377355 /// Predict results for given data.
@@ -517,7 +495,7 @@ impl Booster {
517495 Err ( err) => return Err ( XGBError :: new ( err. to_string ( ) ) ) ,
518496 } ;
519497
520- let file_path = tmp_dir. path ( ) . join ( "fmap.txt " ) ;
498+ let file_path = tmp_dir. path ( ) . join ( "fmap.json " ) ;
521499 let mut file: File = match File :: create ( & file_path) {
522500 Ok ( f) => f,
523501 Err ( err) => return Err ( XGBError :: new ( err. to_string ( ) ) ) ,
@@ -551,24 +529,18 @@ impl Booster {
551529 & mut out_dump_array
552530 ) ) ?;
553531
554- let out_ptr_slice = unsafe { slice:: from_raw_parts ( out_dump_array, out_len as usize ) } ;
555- let out_vec: Vec < String > = out_ptr_slice
556- . iter ( )
557- . map ( |str_ptr| unsafe { ffi:: CStr :: from_ptr ( * str_ptr) . to_str ( ) . unwrap ( ) . to_owned ( ) } )
558- . collect ( ) ;
532+ if out_len > 0 {
533+ let out_ptr_slice = unsafe { slice:: from_raw_parts ( out_dump_array, out_len as usize ) } ;
534+ let out_vec: Vec < String > = out_ptr_slice
535+ . iter ( )
536+ . map ( |str_ptr| unsafe { ffi:: CStr :: from_ptr ( * str_ptr) . to_str ( ) . unwrap ( ) . to_owned ( ) } )
537+ . collect ( ) ;
559538
560- assert_eq ! ( out_len as usize , out_vec. len( ) ) ;
561- Ok ( out_vec. join ( "\n " ) )
562- }
563-
564- pub ( crate ) fn load_rabit_checkpoint ( & self ) -> XGBResult < i32 > {
565- let mut version = 0 ;
566- xgb_call ! ( xgboost_sys:: XGBoosterLoadRabitCheckpoint ( self . handle, & mut version) ) ?;
567- Ok ( version)
568- }
569-
570- pub ( crate ) fn save_rabit_checkpoint ( & self ) -> XGBResult < ( ) > {
571- xgb_call ! ( xgboost_sys:: XGBoosterSaveRabitCheckpoint ( self . handle) )
539+ assert_eq ! ( out_len as usize , out_vec. len( ) ) ;
540+ Ok ( out_vec. join ( "\n " ) )
541+ } else {
542+ Ok ( String :: new ( ) )
543+ }
572544 }
573545
574546 pub fn set_param ( & mut self , name : & str , value : & str ) -> XGBResult < ( ) > {
@@ -721,7 +693,7 @@ impl fmt::Display for FeatureType {
721693#[ cfg( test) ]
722694mod tests {
723695 use super :: * ;
724- use parameters:: { self , learning, tree} ;
696+ use crate :: parameters:: { self , learning, tree} ;
725697
726698 fn read_train_matrix ( ) -> XGBResult < DMatrix > {
727699 DMatrix :: load ( r#"{"uri": "xgboost-sys/xgboost/demo/data/agaricus.txt.train?format=libsvm"}"# )
@@ -739,7 +711,6 @@ mod tests {
739711 assert ! ( res. is_ok( ) ) ;
740712 }
741713
742-
743714 #[ test]
744715 fn get_set_attr ( ) {
745716 let mut booster = load_test_booster ( ) ;
0 commit comments