Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ documentation = "https://docs.rs/xgboost"
readme = "README.md"

[dependencies]
xgboost-sys = "0.1.2"
xgboost-sys = { path="./xgboost-sys" }
libc = "0.2"
derive_builder = "0.5"
log = "0.4"
Expand Down
8 changes: 4 additions & 4 deletions examples/basic/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ fn main() {
let mut data = Vec::new();

let reader = BufReader::new(File::open("../../xgboost-sys/xgboost/demo/data/agaricus.txt.train").unwrap());
let mut current_row = 0;
let mut current_row: u64 = 0;
for line in reader.lines() {
let line = line.unwrap();
let sample: Vec<&str> = line.split_whitespace().collect();
Expand All @@ -106,16 +106,16 @@ fn main() {
for entry in &sample[1..] {
let pair: Vec<&str> = entry.split(':').collect();
rows.push(current_row);
cols.push(pair[0].parse::<usize>().unwrap());
cols.push(pair[0].parse::<u64>().unwrap());
data.push(pair[1].parse::<f32>().unwrap());
}

current_row += 1;
}

// work out size of sparse matrix from max row/col values
let shape = (*rows.iter().max().unwrap() + 1 as usize,
*cols.iter().max().unwrap() + 1 as usize);
let shape = ((*rows.iter().max().unwrap() + 1) as usize,
(*cols.iter().max().unwrap() + 1) as usize);
let triplet_mat = sprs::TriMatBase::from_triplets(shape, rows, cols, data);
let csr_mat = triplet_mat.to_csr();

Expand Down
2 changes: 1 addition & 1 deletion examples/multiclass_classification/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,4 @@ publish = false
xgboost = { path = "../../" }
log = "0.4"
env_logger = "0.5"
reqwest = "0.8"
reqwest = { version = "0.11", features = ["blocking"] }
2 changes: 1 addition & 1 deletion examples/multiclass_classification/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ fn download_dataset<P: AsRef<Path>>(dst: P) {
}

debug!("Fetching training dataset from {}", url);
let mut response = reqwest::get(url).expect("failed to download training set data");
let mut response = reqwest::blocking::get(url).expect("failed to download training set data");

let file = File::create(dst).expect(&format!("failed to create file {}", dst.display()));
let mut writer = BufWriter::new(file);
Expand Down
6 changes: 3 additions & 3 deletions src/booster.rs
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ impl Booster {
for (dmat, dmat_name) in eval_sets {
let margin = bst.predict_margin(dmat)?;
let eval_result = eval_fn(&margin, dmat);
let mut eval_results = dmat_eval_results.entry(eval_name.to_string())
let eval_results = dmat_eval_results.entry(eval_name.to_string())
.or_insert_with(IndexMap::new);
eval_results.insert(dmat_name.to_string(), eval_result);
}
Expand All @@ -188,7 +188,7 @@ impl Booster {
let mut eval_dmat_results = BTreeMap::new();
for (dmat_name, eval_results) in &dmat_eval_results {
for (eval_name, result) in eval_results {
let mut dmat_results = eval_dmat_results.entry(eval_name).or_insert_with(BTreeMap::new);
let dmat_results = eval_dmat_results.entry(eval_name).or_insert_with(BTreeMap::new);
dmat_results.insert(dmat_name, result);
}
}
Expand Down Expand Up @@ -548,7 +548,7 @@ impl Booster {
let score = metric_parts[1].parse::<f32>()
.unwrap_or_else(|_| panic!("Unable to parse XGBoost metrics output: {}", eval));

let mut metric_map = result.entry(evname.to_string()).or_insert_with(IndexMap::new);
let metric_map = result.entry(evname.to_string()).or_insert_with(IndexMap::new);
metric_map.insert(metric.to_owned(), score);
}
}
Expand Down
17 changes: 9 additions & 8 deletions src/dmatrix.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use std::{slice, ffi, ptr, path::Path};
use libc::{c_uint, c_float};
use std::os::unix::ffi::OsStrExt;
use std::convert::TryInto;

use xgboost_sys;

Expand Down Expand Up @@ -123,17 +124,17 @@ impl DMatrix {
/// `data[indptr[i]:indptr[i+1]`.
///
/// If `num_cols` is set to None, number of columns will be inferred from given data.
pub fn from_csr(indptr: &[usize], indices: &[usize], data: &[f32], num_cols: Option<usize>) -> XGBResult<Self> {
pub fn from_csr(indptr: &[u64], indices: &[usize], data: &[f32], num_cols: Option<usize>) -> XGBResult<Self> {
assert_eq!(indices.len(), data.len());
let mut handle = ptr::null_mut();
let indices: Vec<u32> = indices.iter().map(|x| *x as u32).collect();
let num_cols = num_cols.unwrap_or(0); // infer from data if 0
xgb_call!(xgboost_sys::XGDMatrixCreateFromCSREx(indptr.as_ptr(),
indices.as_ptr(),
data.as_ptr(),
indptr.len(),
data.len(),
num_cols,
indptr.len().try_into().unwrap(),
data.len().try_into().unwrap(),
num_cols.try_into().unwrap(),
&mut handle))?;
Ok(DMatrix::new(handle)?)
}
Expand All @@ -146,17 +147,17 @@ impl DMatrix {
/// `data[indptr[i]:indptr[i+1]`.
///
/// If `num_rows` is set to None, number of rows will be inferred from given data.
pub fn from_csc(indptr: &[usize], indices: &[usize], data: &[f32], num_rows: Option<usize>) -> XGBResult<Self> {
pub fn from_csc(indptr: &[u64], indices: &[usize], data: &[f32], num_rows: Option<usize>) -> XGBResult<Self> {
assert_eq!(indices.len(), data.len());
let mut handle = ptr::null_mut();
let indices: Vec<u32> = indices.iter().map(|x| *x as u32).collect();
let num_rows = num_rows.unwrap_or(0); // infer from data if 0
xgb_call!(xgboost_sys::XGDMatrixCreateFromCSCEx(indptr.as_ptr(),
indices.as_ptr(),
data.as_ptr(),
indptr.len(),
data.len(),
num_rows,
indptr.len().try_into().unwrap(),
data.len().try_into().unwrap(),
num_rows.try_into().unwrap(),
&mut handle))?;
Ok(DMatrix::new(handle)?)
}
Expand Down
2 changes: 1 addition & 1 deletion src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ impl XGBError {
match ret_val {
0 => Ok(()),
-1 => Err(XGBError::from_xgboost()),
_ => panic!(format!("unexpected return value '{}', expected 0 or -1", ret_val)),
_ => panic!("unexpected return value '{}', expected 0 or -1", ret_val),
}
}

Expand Down
1 change: 0 additions & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,6 @@ extern crate libc;
extern crate tempfile;
extern crate indexmap;

#[macro_use]
macro_rules! xgb_call {
($x:expr) => {
XGBError::check_return_value(unsafe { $x })
Expand Down
2 changes: 1 addition & 1 deletion xgboost-sys/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,4 @@ readme = "README.md"
libc = "0.2"

[build-dependencies]
bindgen = "0.36"
bindgen = "0.59"