Skip to content

Commit 98604a0

Browse files
author
Montana Low
committed
rabit checkpoints removed upstream
1 parent ec745a4 commit 98604a0

File tree

1 file changed

+2
-37
lines changed

1 file changed

+2
-37
lines changed

src/booster.rs

Lines changed: 2 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -148,29 +148,8 @@ impl Booster {
148148
dmats
149149
};
150150

151-
let mut bst = Booster::new_with_cached_dmats(&params.booster_params, &cached_dmats)?;
152-
// load distributed code checkpoint from rabit
153-
let mut version = bst.load_rabit_checkpoint()?;
154-
debug!("Loaded Rabit checkpoint: version={}", version);
155-
assert!(unsafe { xgboost_sys::RabitGetWorldSize() != 1 || version == 0 });
156-
let start_iteration = version / 2;
157-
for i in start_iteration..params.boost_rounds as i32 {
158-
// distributed code: need to resume to this point
159-
// skip first update if a recovery step
160-
if version % 2 == 0 {
161-
if let Some(objective_fn) = params.custom_objective_fn {
162-
debug!("Boosting in round: {}", i);
163-
bst.update_custom(params.dtrain, objective_fn)?;
164-
} else {
165-
debug!("Updating in round: {}", i);
166-
bst.update(params.dtrain, i)?;
167-
}
168-
let _ = bst.save_rabit_checkpoint()?;
169-
version += 1;
170-
}
171-
172-
assert!(unsafe { xgboost_sys::RabitGetWorldSize() == 1 || version == xgboost_sys::RabitVersionNumber() });
173-
151+
let bst = Booster::new_with_cached_dmats(&params.booster_params, &cached_dmats)?;
152+
for i in 0..params.boost_rounds as i32 {
174153
if let Some(eval_sets) = params.evaluation_sets {
175154
let mut dmat_eval_results = bst.eval_set(eval_sets, i)?;
176155

@@ -203,10 +182,6 @@ impl Booster {
203182
}
204183
println!();
205184
}
206-
207-
// do checkpoint after evaluation, in case evaluation also updates booster.
208-
let _ = bst.save_rabit_checkpoint();
209-
version += 1;
210185
}
211186

212187
Ok(bst)
@@ -568,16 +543,6 @@ impl Booster {
568543
}
569544
}
570545

571-
pub(crate) fn load_rabit_checkpoint(&self) -> XGBResult<i32> {
572-
let mut version = 0;
573-
xgb_call!(xgboost_sys::XGBoosterLoadRabitCheckpoint(self.handle, &mut version))?;
574-
Ok(version)
575-
}
576-
577-
pub(crate) fn save_rabit_checkpoint(&self) -> XGBResult<()> {
578-
xgb_call!(xgboost_sys::XGBoosterSaveRabitCheckpoint(self.handle))
579-
}
580-
581546
pub fn set_param(&mut self, name: &str, value: &str) -> XGBResult<()> {
582547
let name = ffi::CString::new(name).unwrap();
583548
let value = ffi::CString::new(value).unwrap();

0 commit comments

Comments
 (0)