Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,15 @@ homepage = "https://github.com/davechallis/rust-xgboost"
description = "Machine learning using XGBoost"
documentation = "https://docs.rs/xgboost"
readme = "README.md"
edition = "2021"

[dependencies]
xgboost-sys = { path = "xgboost-sys" }
libc = "0.2"
derive_builder = "0.12"
derive_builder = "0.20"
log = "0.4"
tempfile = "3.9"
indexmap = "2.1"
tempfile = "3.15"
indexmap = "2.7"

[features]
cuda = ["xgboost-sys/cuda"]
85 changes: 28 additions & 57 deletions src/booster.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use dmatrix::DMatrix;
use error::XGBError;
use crate::dmatrix::DMatrix;
use crate::error::XGBError;
use libc;
use std::collections::{BTreeMap, HashMap};
use std::io::{self, BufRead, BufReader, Write};
Expand All @@ -13,7 +13,7 @@ use tempfile;
use xgboost_sys;

use super::XGBResult;
use parameters::{BoosterParameters, TrainingParameters};
use crate::parameters::{BoosterParameters, TrainingParameters};

pub type CustomObjective = fn(&[f32], &DMatrix) -> (Vec<f32>, Vec<f32>);

Expand Down Expand Up @@ -148,29 +148,8 @@ impl Booster {
dmats
};

let mut bst = Booster::new_with_cached_dmats(&params.booster_params, &cached_dmats)?;
// load distributed code checkpoint from rabit
let mut version = bst.load_rabit_checkpoint()?;
debug!("Loaded Rabit checkpoint: version={}", version);
assert!(unsafe { xgboost_sys::RabitGetWorldSize() != 1 || version == 0 });
let start_iteration = version / 2;
for i in start_iteration..params.boost_rounds as i32 {
// distributed code: need to resume to this point
// skip first update if a recovery step
if version % 2 == 0 {
if let Some(objective_fn) = params.custom_objective_fn {
debug!("Boosting in round: {}", i);
bst.update_custom(params.dtrain, objective_fn)?;
} else {
debug!("Updating in round: {}", i);
bst.update(params.dtrain, i)?;
}
let _ = bst.save_rabit_checkpoint()?;
version += 1;
}

assert!(unsafe { xgboost_sys::RabitGetWorldSize() == 1 || version == xgboost_sys::RabitVersionNumber() });

let bst = Booster::new_with_cached_dmats(&params.booster_params, &cached_dmats)?;
for i in 0..params.boost_rounds as i32 {
if let Some(eval_sets) = params.evaluation_sets {
let mut dmat_eval_results = bst.eval_set(eval_sets, i)?;

Expand Down Expand Up @@ -203,10 +182,6 @@ impl Booster {
}
println!();
}

// do checkpoint after evaluation, in case evaluation also updates booster.
let _ = bst.save_rabit_checkpoint();
version += 1;
}

Ok(bst)
Expand Down Expand Up @@ -365,13 +340,16 @@ impl Booster {
let mut out_len = 0;
let mut out = ptr::null_mut();
xgb_call!(xgboost_sys::XGBoosterGetAttrNames(self.handle, &mut out_len, &mut out))?;

let out_ptr_slice = unsafe { slice::from_raw_parts(out, out_len as usize) };
let out_vec = out_ptr_slice
.iter()
.map(|str_ptr| unsafe { ffi::CStr::from_ptr(*str_ptr).to_str().unwrap().to_owned() })
.collect();
Ok(out_vec)
if out_len > 0 {
let out_ptr_slice = unsafe { slice::from_raw_parts(out, out_len as usize) };
let out_vec = out_ptr_slice
.iter()
.map(|str_ptr| unsafe { ffi::CStr::from_ptr(*str_ptr).to_str().unwrap().to_owned() })
.collect();
Ok(out_vec)
} else {
Ok(Vec::new())
}
}

/// Predict results for given data.
Expand Down Expand Up @@ -517,7 +495,7 @@ impl Booster {
Err(err) => return Err(XGBError::new(err.to_string())),
};

let file_path = tmp_dir.path().join("fmap.txt");
let file_path = tmp_dir.path().join("fmap.json");
let mut file: File = match File::create(&file_path) {
Ok(f) => f,
Err(err) => return Err(XGBError::new(err.to_string())),
Expand Down Expand Up @@ -551,24 +529,18 @@ impl Booster {
&mut out_dump_array
))?;

let out_ptr_slice = unsafe { slice::from_raw_parts(out_dump_array, out_len as usize) };
let out_vec: Vec<String> = out_ptr_slice
.iter()
.map(|str_ptr| unsafe { ffi::CStr::from_ptr(*str_ptr).to_str().unwrap().to_owned() })
.collect();
if out_len > 0 {
let out_ptr_slice = unsafe { slice::from_raw_parts(out_dump_array, out_len as usize) };
let out_vec: Vec<String> = out_ptr_slice
.iter()
.map(|str_ptr| unsafe { ffi::CStr::from_ptr(*str_ptr).to_str().unwrap().to_owned() })
.collect();

assert_eq!(out_len as usize, out_vec.len());
Ok(out_vec.join("\n"))
}

pub(crate) fn load_rabit_checkpoint(&self) -> XGBResult<i32> {
let mut version = 0;
xgb_call!(xgboost_sys::XGBoosterLoadRabitCheckpoint(self.handle, &mut version))?;
Ok(version)
}

pub(crate) fn save_rabit_checkpoint(&self) -> XGBResult<()> {
xgb_call!(xgboost_sys::XGBoosterSaveRabitCheckpoint(self.handle))
assert_eq!(out_len as usize, out_vec.len());
Ok(out_vec.join("\n"))
} else {
Ok(String::new())
}
}

pub fn set_param(&mut self, name: &str, value: &str) -> XGBResult<()> {
Expand Down Expand Up @@ -721,7 +693,7 @@ impl fmt::Display for FeatureType {
#[cfg(test)]
mod tests {
use super::*;
use parameters::{self, learning, tree};
use crate::parameters::{self, learning, tree};

fn read_train_matrix() -> XGBResult<DMatrix> {
DMatrix::load(r#"{"uri": "xgboost-sys/xgboost/demo/data/agaricus.txt.train?format=libsvm"}"#)
Expand All @@ -739,7 +711,6 @@ mod tests {
assert!(res.is_ok());
}


#[test]
fn get_set_attr() {
let mut booster = load_test_booster();
Expand Down
6 changes: 5 additions & 1 deletion src/dmatrix.rs
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,11 @@ impl DMatrix {
&mut out_dptr
))?;

Ok(unsafe { slice::from_raw_parts(out_dptr as *mut c_float, out_len as usize) })
if out_len > 0 {
Ok(unsafe { slice::from_raw_parts(out_dptr as *mut c_float, out_len as usize) })
} else {
Err(XGBError::new("error"))
}
}

fn set_float_info(&mut self, field: &str, array: &[f32]) -> XGBResult<()> {
Expand Down
File renamed without changes.
3 changes: 2 additions & 1 deletion xgboost-sys/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,13 @@ license = "MIT"
repository = "https://github.com/davechallis/rust-xgboost"
description = "Native bindings to the xgboost library"
readme = "README.md"
edition = "2021"

[dependencies]
libc = "0.2"

[build-dependencies]
bindgen = "0.69"
bindgen = "0.71"
cmake = "0.1"

[features]
Expand Down
15 changes: 11 additions & 4 deletions xgboost-sys/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,16 +25,21 @@ fn main() {
dst.define("BUILD_STATIC_LIB", "ON").define("CMAKE_CXX_STANDARD", "17");

// CMake
let mut dst = Config::new(&xgb_root);
let mut dst = dst.define("BUILD_STATIC_LIB", "ON");

#[cfg(feature = "cuda")]
dst.define("USE_CUDA", "ON")
let mut dst = dst
.define("USE_CUDA", "ON")
.define("BUILD_WITH_CUDA", "ON")
.define("BUILD_WITH_CUDA_CUB", "ON");

#[cfg(target_os = "macos")]
{
let path = PathBuf::from("/opt/homebrew/"); // check for m1 vs intel config
if let Ok(_dir) = std::fs::read_dir(&path) {
dst.define("CMAKE_C_COMPILER", "/opt/homebrew/opt/llvm/bin/clang")
dst = dst
.define("CMAKE_C_COMPILER", "/opt/homebrew/opt/llvm/bin/clang")
.define("CMAKE_CXX_COMPILER", "/opt/homebrew/opt/llvm/bin/clang++")
.define("OPENMP_LIBRARIES", "/opt/homebrew/opt/llvm/lib")
.define("OPENMP_INCLUDES", "/opt/homebrew/opt/llvm/include");
Expand All @@ -54,9 +59,11 @@ fn main() {

#[cfg(feature = "cuda")]
let bindings = bindings.clang_arg("-I/usr/local/cuda/include");
let bindings = bindings.generate().expect("Unable to generate bindings.");
let bindings = bindings
.generate()
.expect("Unable to generate bindings.");

let out_path = PathBuf::from(env::var("OUT_DIR").unwrap());
let out_path = PathBuf::from(out_dir);
bindings
.write_to_file(out_path.join("bindings.rs"))
.expect("Couldn't write bindings.");
Expand Down
2 changes: 1 addition & 1 deletion xgboost-sys/xgboost
Submodule xgboost updated 797 files