Skip to content

Commit

Permalink
hack: make xgboost test finish quicker
Browse files Browse the repository at this point in the history
  • Loading branch information
MarWeUMR committed Oct 7, 2022
1 parent 325ca07 commit f08a495
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 137 deletions.
66 changes: 0 additions & 66 deletions operators/src/pro/ml/bindings/dmatrix.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,57 +16,6 @@ static KEY_BASE_MARGIN: &str = "base_margin";
///
/// It's used as a container for both features (i.e. a row for every instance), and an optional true label for that
/// instance (as an `f32` value).
///
/// Can be created files, or from dense or sparse
/// ([CSR](https://en.wikipedia.org/wiki/Sparse_matrix#Compressed_sparse_row_(CSR,_CRS_or_Yale_format))
/// or [CSC](https://en.wikipedia.org/wiki/Sparse_matrix#Compressed_sparse_column_(CSC_or_CCS))) matrices.
///
/// # Examples
///
/// ## Load from file
///
/// Load matrix from file in [LIBSVM](https://www.csie.ntu.edu.tw/~cjlin/libsvm/) or binary format.
///
/// ```should_panic
/// use geoengine_operators::pro::ml::bindings::DMatrix;
///
/// let dmat = DMatrix::load("somefile.txt").unwrap();
/// ```
///
/// ## Create from dense array
///
/// ```
/// use geoengine_operators::pro::ml::bindings::DMatrix;
///
/// let data = &[1.0, 0.5, 0.2, 0.2,
/// 0.7, 1.0, 0.1, 0.1,
/// 0.2, 0.0, 0.0, 1.0];
/// let num_rows = 3;
/// let mut dmat = DMatrix::from_dense(data, num_rows).unwrap();
/// assert_eq!(dmat.shape(), (3, 4));
///
/// // set true labels for each row
/// dmat.set_labels(&[1.0, 0.0, 1.0]);
/// ```
///
/// ## Create from sparse CSR matrix
///
/// Create from sparse representation of
/// ```text
/// [[1.0, 0.0, 2.0],
/// [0.0, 0.0, 3.0],
/// [4.0, 5.0, 6.0]]
/// ```
///
/// ```
/// use geoengine_operators::pro::ml::bindings::DMatrix;
///
/// let indptr = &[0, 2, 3, 6];
/// let indices = &[0, 2, 2, 0, 1, 2];
/// let data = &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
/// let dmat = DMatrix::from_csr(indptr, indices, data, None).unwrap();
/// assert_eq!(dmat.shape(), (3, 0));
/// ```
#[derive(Debug)]
pub struct DMatrix {
pub(super) handle: xgboost_sys::DMatrixHandle,
Expand Down Expand Up @@ -149,21 +98,6 @@ impl DMatrix {

/// Create a new `DMatrix` from dense array in row-major order.
///
/// E.g. the matrix
/// ```text
/// [[1.0, 2.0],
/// [3.0, 4.0],
/// [5.0, 6.0]]
/// ```
/// would be represented converted into a `DMatrix` with
/// ```
/// use geoengine_operators::pro::ml::bindings::DMatrix;
///
/// let data = &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
/// let num_rows = 3;
/// let dmat = DMatrix::from_dense(data, num_rows).unwrap();
/// ```
///
/// # Panics
///
/// Will panic, if the matrix creation fails with an error not coming from `XGBoost`.
Expand Down
48 changes: 0 additions & 48 deletions operators/src/pro/ml/bindings/mod.rs
Original file line number Diff line number Diff line change
@@ -1,51 +1,3 @@
//! # Basic usage example
//!
//! ```
//! use geoengine_operators::pro::ml::bindings::DMatrix;
//! use geoengine_operators::pro::ml::bindings::Booster;
//! use geoengine_operators::pro::ml::bindings::parameters::TrainingParametersBuilder;
//!
//! // training matrix with 5 training examples and 3 features
//! let x_train = &[1.0, 1.0, 1.0,
//! 1.0, 1.0, 0.0,
//! 1.0, 1.0, 1.0,
//! 0.0, 0.0, 0.0,
//! 1.0, 1.0, 1.0];
//! let num_rows = 5;
//! let y_train = &[1.0, 1.0, 1.0, 0.0, 1.0];
//!
//! // convert training data into XGBoost's matrix format
//! let mut dtrain = DMatrix::from_dense(x_train, num_rows).unwrap();
//!
//! // set ground truth labels for the training matrix
//! dtrain.set_labels(y_train).unwrap();
//!
//! // test matrix with 1 row
//! let x_test = &[0.7, 0.9, 0.6];
//! let num_rows = 1;
//! let y_test = &[1.0];
//! let mut dtest = DMatrix::from_dense(x_test, num_rows).unwrap();
//! dtest.set_labels(y_test).unwrap();
//!
//! // specify datasets to evaluate against during training
//! let evaluation_sets = &[(&dtrain, "train"), (&dtest, "test")];
//!
//! // specify overall training setup
//! let training_params = TrainingParametersBuilder::default()
//! .dtrain(&dtrain)
//! .evaluation_sets(Some(evaluation_sets))
//! .build()
//! .unwrap();
//!
//! // train model, and print evaluation data
//! let bst = Booster::train(Some(evaluation_sets), &dtrain, std::collections::HashMap::new(), None).unwrap();
//!
//! println!("{:?}", bst.predict(&dtest).unwrap());
//! ```
//!
//! See the [examples](https://github.com/davechallis/rust-xgboost/tree/master/examples) directory for
//! more detailed examples of different features.
//!
extern crate indexmap;
extern crate libc;
extern crate tempfile;
Expand Down
21 changes: 0 additions & 21 deletions operators/src/pro/ml/bindings/parameters/booster.rs
Original file line number Diff line number Diff line change
@@ -1,24 +1,3 @@
//! `BoosterParameters` for specifying the type of booster that is used when training a model.
//!
//! # Example
//!
//! ```
//! use geoengine_operators::pro::ml::bindings::DMatrix;
//! use geoengine_operators::pro::ml::bindings::parameters::BoosterParametersBuilder;
//! use geoengine_operators::pro::ml::bindings::parameters::BoosterType;
//! use geoengine_operators::pro::ml::bindings::parameters::tree::TreeBoosterParametersBuilder;
//!
//! let tree_params = TreeBoosterParametersBuilder::default()
//! .eta(0.2)
//! .gamma(3.0)
//! .subsample(0.75)
//! .build()
//! .unwrap();
//! let booster_params = BoosterParametersBuilder::default()
//! .booster_type(BoosterType::Tree(tree_params))
//! .build()
//! .unwrap();
//! ```
use std::default::Default;

use super::{linear, tree};
Expand Down
10 changes: 8 additions & 2 deletions operators/src/pro/ml/xgboost.rs
Original file line number Diff line number Diff line change
Expand Up @@ -476,6 +476,8 @@ mod tests {

#[tokio::test]
async fn xg_op_test() {
// TODO: This test needs a much smaller test data set.

// setup data to predict
let paths = vec![
"s2_10m_de_marburg/b02.tiff",
Expand Down Expand Up @@ -535,11 +537,15 @@ mod tests {

let ctx = MockQueryContext::test_default();
let result_stream = processor.query(qry_rectangle, &ctx).await.unwrap();
let result: Vec<crate::util::Result<RasterTile2D<f32>>> = result_stream.collect().await;

// TODO: .take(1) for a quick test run
let result: Vec<crate::util::Result<RasterTile2D<f32>>> =
result_stream.take(1).collect().await;

let mut all_pixels = Vec::new();

for tile in result {
// TODO: .take(1) for a quick test run
for tile in result.into_iter().take(1) {
let data_of_tile = tile
.unwrap()
.into_materialized_tile()
Expand Down

0 comments on commit f08a495

Please sign in to comment.