From bfc8ea6f38dda6f6c9e1769950bd015dc0d9d64a Mon Sep 17 00:00:00 2001 From: Marcus Weber Date: Thu, 1 Dec 2022 14:04:55 +0100 Subject: [PATCH 1/3] add xgboost operator --- CHANGELOG.md | 4 + Settings-default.toml | 3 + datatypes/Cargo.toml | 4 +- datatypes/src/error.rs | 18 +- datatypes/src/util/mod.rs | 42 + operators/Cargo.toml | 2 + operators/src/engine/execution_context.rs | 35 + operators/src/error.rs | 20 + operators/src/pro/ml/mod.rs | 1 + operators/src/pro/ml/xgboost.rs | 920 ++++++++++++++++++ operators/src/pro/mod.rs | 1 + services/Cargo.toml | 1 + services/src/contexts/mod.rs | 133 +++ .../src/datasets/external/netcdfcf/mod.rs | 2 +- .../datasets/external/netcdfcf/overviews.rs | 3 +- services/src/handlers/session.rs | 4 +- services/src/pro/contexts/mod.rs | 131 +++ services/src/pro/handlers/drone_mapping.rs | 4 +- services/src/pro/handlers/users.rs | 7 +- services/src/util/config.rs | 9 + services/src/util/mod.rs | 36 - .../ml/xgboost/s2_10m_de_marburg/model.json | 1 + 22 files changed, 1337 insertions(+), 44 deletions(-) create mode 100644 operators/src/pro/ml/mod.rs create mode 100644 operators/src/pro/ml/xgboost.rs create mode 100644 test_data/pro/ml/xgboost/s2_10m_de_marburg/model.json diff --git a/CHANGELOG.md b/CHANGELOG.md index 49a63b0ff..4525c7fc3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added +- Added a new operator `XGBoost`. This operator allows to use a pretrained model in order to make predictions based on some set of raster tile data. + + - https://github.com/geo-engine/geoengine/pull/639 + - Added a handler (`/available`) to the API to check if the service is available. - https://github.com/geo-engine/geoengine/pull/681 diff --git a/Settings-default.toml b/Settings-default.toml index d3d81be15..103d8af95 100644 --- a/Settings-default.toml +++ b/Settings-default.toml @@ -77,6 +77,9 @@ provider_defs_path = "./test_data/provider_defs" layer_defs_path = "./test_data/layer_defs" layer_collection_defs_path = "./test_data/layer_collection_defs" +[machinelearning] +model_defs_path = "./test_data/pro/ml" + [gdal] # TODO: find good default # Use 0 for `ALL_CPUS` option or a number >0 for a specific number of threads. diff --git a/datatypes/Cargo.toml b/datatypes/Cargo.toml index 5e4513b43..baf6db636 100644 --- a/datatypes/Cargo.toml +++ b/datatypes/Cargo.toml @@ -19,6 +19,7 @@ async-trait = "0.1" chrono = "0.4" float-cmp = "0.9" gdal = "0.14" +geo-types = "0.7" geo = "0.23" geojson = "0.24" image = "0.24" @@ -37,6 +38,7 @@ uuid = { version = "1.1", features = ["serde", "v4", "v5"] } [dev-dependencies] criterion = "0.4" +tempfile = "3.1" [[bench]] name = "multi_point_collection" @@ -52,4 +54,4 @@ harness = false [[bench]] name = "masked_grid_mapping" -harness = false \ No newline at end of file +harness = false diff --git a/datatypes/src/error.rs b/datatypes/src/error.rs index b1c8439d6..099839387 100644 --- a/datatypes/src/error.rs +++ b/datatypes/src/error.rs @@ -5,8 +5,9 @@ use crate::{ spatial_reference::SpatialReference, }; use snafu::{prelude::*, AsErrorSource, ErrorCompat, IntoError}; -use std::{any::Any, convert::Infallible, sync::Arc}; +use std::{any::Any, convert::Infallible, path::PathBuf, sync::Arc}; +pub type Result = std::result::Result; pub trait ErrorSource: std::error::Error + Send + Sync + Any + 'static + AsErrorSource { fn boxed(self) -> Box where @@ -313,6 +314,15 @@ pub enum Error { a: SpatialReference, b: SpatialReference, }, + + Io { + source: std::io::Error, + }, + + SubPathMustNotEscapeBasePath { + base: PathBuf, + sub_path: PathBuf, + }, } impl From for Error { @@ -338,3 +348,9 @@ impl From for Error { Self::Gdal { source: gdal_error } } } + +impl From for Error { + fn from(e: std::io::Error) -> Self { + Self::Io { source: e } + } +} diff --git a/datatypes/src/util/mod.rs b/datatypes/src/util/mod.rs index b37ceaa6b..a108df012 100644 --- a/datatypes/src/util/mod.rs +++ b/datatypes/src/util/mod.rs @@ -8,6 +8,48 @@ mod result; pub mod well_known_data; pub mod test; +use std::path::{Path, PathBuf}; + pub use self::identifiers::Identifier; pub use any::{AsAny, AsAnyArc}; pub use result::Result; + +/// Canonicalize `base`/`sub_path` and ensure the `sub_path` doesn't escape the `base` +/// returns an error if the `sub_path` escapes the `base` +/// +/// This only works if the `Path` you are referring to actually exists. +/// +pub fn canonicalize_subpath(base: &Path, sub_path: &Path) -> crate::error::Result { + let base = base.canonicalize()?; + let path = base.join(sub_path).canonicalize()?; + + if path.starts_with(&base) { + Ok(path) + } else { + Err(crate::error::Error::SubPathMustNotEscapeBasePath { + base, + sub_path: sub_path.into(), + }) + } +} + +#[cfg(test)] +mod mod_tests { + use super::*; + #[test] + fn it_doesnt_escape_base_path() { + let tmp_dir = tempfile::tempdir().unwrap(); + let tmp_path = tmp_dir.path(); + std::fs::create_dir_all(tmp_path.join("foo/bar/foobar")).unwrap(); + std::fs::create_dir_all(tmp_path.join("foo/barfoo")).unwrap(); + + assert_eq!( + canonicalize_subpath(&tmp_path.join("foo/bar"), Path::new("foobar")) + .unwrap() + .to_string_lossy(), + tmp_path.join("foo/bar/foobar").to_string_lossy() + ); + + assert!(canonicalize_subpath(&tmp_path.join("foo/bar"), Path::new("../barfoo")).is_err()); + } +} diff --git a/operators/Cargo.toml b/operators/Cargo.toml index 8bdce6d89..503bdf9a1 100644 --- a/operators/Cargo.toml +++ b/operators/Cargo.toml @@ -50,11 +50,13 @@ tracing = "0.1" tracing-opentelemetry = "0.18" typetag = "0.2" uuid = { version = "1.1", features = ["serde", "v4", "v5"] } +xgboost-rs = "0.1.2" [dev-dependencies] async-stream = "0.3" geo-rand = { git = "https://github.com/lelongg/geo-rand", tag = "v0.3.0" } rand = "0.8" +ndarray = "0.15.6" [[bench]] diff --git a/operators/src/engine/execution_context.rs b/operators/src/engine/execution_context.rs index da579d70f..f6acf36dd 100644 --- a/operators/src/engine/execution_context.rs +++ b/operators/src/engine/execution_context.rs @@ -21,9 +21,11 @@ use std::any::Any; use std::collections::HashMap; use std::fmt::Debug; use std::marker::PhantomData; +use std::path::PathBuf; use std::sync::Arc; /// A context that provides certain utility access during operator initialization +#[async_trait::async_trait] pub trait ExecutionContext: Send + Sync + MetaDataProvider @@ -50,6 +52,10 @@ pub trait ExecutionContext: Send op: Box, span: CreateSpan, ) -> Box; + + async fn read_ml_model(&self, path: PathBuf) -> Result; + + async fn write_ml_model(&mut self, path: PathBuf, ml_model_str: String) -> Result<()>; } #[async_trait] @@ -84,6 +90,7 @@ pub struct MockExecutionContext { pub thread_pool: Arc, pub meta_data: HashMap>, pub tiling_specification: TilingSpecification, + pub ml_models: HashMap, } impl TestDefault for MockExecutionContext { @@ -92,6 +99,7 @@ impl TestDefault for MockExecutionContext { thread_pool: create_rayon_thread_pool(0), meta_data: HashMap::default(), tiling_specification: TilingSpecification::test_default(), + ml_models: HashMap::default(), } } } @@ -102,6 +110,7 @@ impl MockExecutionContext { thread_pool: create_rayon_thread_pool(0), meta_data: HashMap::default(), tiling_specification, + ml_models: HashMap::default(), } } @@ -113,6 +122,7 @@ impl MockExecutionContext { thread_pool: create_rayon_thread_pool(num_threads), meta_data: HashMap::default(), tiling_specification, + ml_models: HashMap::default(), } } @@ -136,8 +146,17 @@ impl MockExecutionContext { abort_trigger: Some(abort_trigger), } } + + pub fn initialize_ml_model(&mut self, model_path: PathBuf) -> Result<()> { + let model = std::fs::read_to_string(&model_path)?; + + self.ml_models.insert(model_path, model); + + Ok(()) + } } +#[async_trait::async_trait] impl ExecutionContext for MockExecutionContext { fn thread_pool(&self) -> &Arc { &self.thread_pool @@ -170,6 +189,22 @@ impl ExecutionContext for MockExecutionContext { ) -> Box { op } + + async fn read_ml_model(&self, path: PathBuf) -> Result { + let res = self + .ml_models + .get(&path) + .ok_or(Error::MachineLearningModelNotFound)? + .clone(); + + Ok(res) + } + + async fn write_ml_model(&mut self, path: PathBuf, ml_model_str: String) -> Result<()> { + self.ml_models.insert(path, ml_model_str); + + Ok(()) + } } #[async_trait] diff --git a/operators/src/error.rs b/operators/src/error.rs index 2df07d59b..45188563a 100644 --- a/operators/src/error.rs +++ b/operators/src/error.rs @@ -314,6 +314,26 @@ pub enum Error { QueryCanceled, AbortTriggerAlreadyUsed, + + SubPathMustNotEscapeBasePath { + base: PathBuf, + sub_path: PathBuf, + }, + + InvalidDataProviderConfig, + + InvalidMachineLearningConfig, + + MachineLearningModelNotFound, + + InvalidMlModelPath, + CouldNotGetMlModelDirectory, + + #[cfg(feature = "pro")] + #[snafu(context(false))] + XGBoost { + source: crate::pro::ml::xgboost::XGBoostModuleError, + }, } impl From for Error { diff --git a/operators/src/pro/ml/mod.rs b/operators/src/pro/ml/mod.rs new file mode 100644 index 000000000..f503bf645 --- /dev/null +++ b/operators/src/pro/ml/mod.rs @@ -0,0 +1 @@ +pub mod xgboost; diff --git a/operators/src/pro/ml/xgboost.rs b/operators/src/pro/ml/xgboost.rs new file mode 100644 index 000000000..0e984fd01 --- /dev/null +++ b/operators/src/pro/ml/xgboost.rs @@ -0,0 +1,920 @@ +use std::mem; +use std::path::PathBuf; +use std::sync::Arc; + +use async_trait::async_trait; +use futures::{future, StreamExt}; +use geoengine_datatypes::primitives::{ + partitions_extent, time_interval_extent, Measurement, RasterQueryRectangle, SpatialPartition2D, + SpatialResolution, +}; +use geoengine_datatypes::raster::{ + BaseTile, Grid2D, GridOrEmpty, GridShape, GridShapeAccess, GridSize, RasterDataType, + RasterTile2D, +}; +use rayon::prelude::ParallelIterator; +use rayon::slice::ParallelSlice; +use rayon::ThreadPool; +use serde::{Deserialize, Serialize}; +use snafu::{ensure, OptionExt, ResultExt}; +use xgboost_rs::{Booster, DMatrix, XGBError}; + +use crate::engine::{ + CreateSpan, ExecutionContext, InitializedRasterOperator, MultipleRasterSources, Operator, + OperatorName, QueryContext, QueryProcessor, RasterOperator, RasterQueryProcessor, + RasterResultDescriptor, TypedRasterQueryProcessor, +}; +use crate::util::stream_zip::StreamVectorZip; +use crate::util::Result; +use futures::stream::BoxStream; +use RasterDataType::F32 as RasterOut; + +use snafu::Snafu; +use tracing::{span, Level}; +use TypedRasterQueryProcessor::F32 as QueryProcessorOut; + +#[derive(Debug, Snafu)] +#[snafu(visibility(pub(crate)), context(suffix(false)), module(error))] +pub enum XGBoostModuleError { + #[snafu(display("The XGBoost library could not complete the operation successfully.",))] + LibraryError { source: XGBError }, + + #[snafu(display("Couldn't parse the model file contents.",))] + ModelFileParsingError { source: std::io::Error }, + + #[snafu(display("Couldn't create a booster instance from the content of the model file.",))] + LoadBoosterFromModelError { source: XGBError }, + + #[snafu(display("Couldn't generate a xgboost dmatrix from the given data.",))] + CreateDMatrixError { source: XGBError }, + + #[snafu(display("Couldn't calculate predictions from the given data.",))] + PredictionError { source: XGBError }, + + #[snafu(display("Couldn't get a base tile.",))] + BaseTileError, + + #[snafu(display("No input data error. At least one raster is required.",))] + NoInputData, + + #[snafu(display("There was an error with the creation of a new grid.",))] + DataTypesError { + source: geoengine_datatypes::error::Error, + }, + + #[snafu(display("There was an error with the joining of tokio tasks.",))] + TokioJoinError { source: tokio::task::JoinError }, +} + +impl From for XGBoostModuleError { + fn from(source: std::io::Error) -> Self { + Self::ModelFileParsingError { source } + } +} + +impl From for XGBoostModuleError { + fn from(source: XGBError) -> Self { + Self::LibraryError { source } + } +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +#[serde(rename_all = "camelCase")] +pub struct XgboostParams { + model_sub_path: PathBuf, + no_data_value: f32, +} + +pub type XgboostOperator = Operator; + +impl OperatorName for XgboostOperator { + const TYPE_NAME: &'static str = "XgboostOperator"; +} + +pub struct InitializedXgboostOperator { + result_descriptor: RasterResultDescriptor, + sources: Vec>, + model: String, + no_data_value: f32, +} + +type PixelOut = f32; + +#[typetag::serde] +#[async_trait] +impl RasterOperator for XgboostOperator { + async fn _initialize( + self: Box, + context: &dyn ExecutionContext, + ) -> Result> { + let model = context.read_ml_model(self.params.model_sub_path).await?; + + let init_rasters = future::try_join_all( + self.sources + .rasters + .iter() + .map(|raster| raster.clone().initialize(context)), + ) + .await?; + + let input = init_rasters.get(0).context(self::error::NoInputData)?; + + let spatial_reference = input.result_descriptor().spatial_reference; + + let in_descriptors = init_rasters + .iter() + .map(InitializedRasterOperator::result_descriptor) + .collect::>(); + + for other_spatial_reference in in_descriptors.iter().skip(1).map(|rd| rd.spatial_reference) + { + ensure!( + spatial_reference == other_spatial_reference, + crate::error::InvalidSpatialReference { + expected: spatial_reference, + found: other_spatial_reference, + } + ); + } + + let time = time_interval_extent(in_descriptors.iter().map(|d| d.time)); + let bbox = partitions_extent(in_descriptors.iter().map(|d| d.bbox)); + + let resolution = in_descriptors + .iter() + .map(|d| d.resolution) + .reduce(|a, b| match (a, b) { + (Some(a), Some(b)) => { + Some(SpatialResolution::new_unchecked(a.x.min(b.x), a.y.min(b.y))) + } + _ => None, + }) + .flatten(); + + let out_desc = RasterResultDescriptor { + data_type: RasterOut, + time, + bbox, + resolution, + spatial_reference, + measurement: Measurement::Unitless, + }; + + let initialized_operator = InitializedXgboostOperator { + result_descriptor: out_desc, + sources: init_rasters, + model, + no_data_value: self.params.no_data_value, + }; + + Ok(initialized_operator.boxed()) + } + + span_fn!(XgboostOperator); +} + +impl InitializedRasterOperator for InitializedXgboostOperator { + fn result_descriptor(&self) -> &RasterResultDescriptor { + &self.result_descriptor + } + + fn query_processor(&self) -> Result { + let vec_of_rqps: Vec>> = self + .sources + .iter() + .map( + |init_raster| -> Result< + Box>, + crate::error::Error, + > { + let typed_raster_qp = init_raster + .query_processor() + .map_err(|_| crate::error::Error::QueryProcessor)?; + let converted = typed_raster_qp.into_f32(); + Ok(converted) + }, + ) + .collect::, _>>()?; + + Ok(QueryProcessorOut(Box::new(XgboostProcessor::new( + vec_of_rqps, + self.model.clone(), + self.no_data_value, + )))) + } +} + +struct XgboostProcessor +where + Q: RasterQueryProcessor, +{ + sources: Vec, + model: String, + no_data_value: f32, +} + +impl XgboostProcessor +where + Q: RasterQueryProcessor, +{ + pub fn new(sources: Vec, model_file_content: String, no_data_value: f32) -> Self { + Self { + sources, + model: model_file_content, + no_data_value, + } + } + + async fn process_tile_async( + &self, + bands_of_tile: Vec>, + model_content: Arc, + pool: Arc, + ) -> Result, f32>>, XGBoostModuleError> { + let tile = bands_of_tile.get(0).context(self::error::BaseTile)?; + + // gather the data + let grid_shape = tile.grid_shape(); + let n_bands = bands_of_tile.len() as i32; + let props = tile.properties.clone(); // = &tile.properties; + let time = tile.time; + let tile_position = tile.tile_position; + let global_geo_transform = tile.global_geo_transform; + let ndv = self.no_data_value; + + let predicted_grid = crate::util::spawn_blocking(move || { + process_tile( + bands_of_tile, + &pool, + &model_content, + grid_shape, + n_bands as usize, + ndv, + ) + }) + .await + .context(error::TokioJoin)??; + + let rt: BaseTile, f32>> = + RasterTile2D::new_with_properties( + time, + tile_position, + global_geo_transform, + predicted_grid.into(), + props.clone(), + ); + + Ok(rt) + } +} + +fn process_tile( + bands_of_tile: Vec>, + pool: &ThreadPool, + model_file: &Arc, + grid_shape: GridShape<[usize; 2]>, + n_bands: usize, + nan_val: f32, +) -> Result, f32>, XGBoostModuleError> { + pool.install(|| { + // get the actual tile data + let band_data: Vec<_> = bands_of_tile + .into_iter() + .map(|band| { + let mat_tile = band.into_materialized_tile(); + mat_tile.grid_array.inner_grid.data + }) + .collect(); + + let n_rows = grid_shape.axis_size_x() * grid_shape.axis_size_y(); + // we need to reshape the data and take the i-th element from each band. + let i = (0..n_rows).map(|row_idx| band_data.iter().flatten().skip(row_idx).step_by(n_rows)); + let pixels: Vec<_> = i.flatten().copied().collect::>(); + + // TODO: clarify: as of right now, this is not doing anything + // because xgboost seems to be the fastest, when the chunks are big. + let chunk_size = grid_shape.number_of_elements() * n_bands; + + let res: Vec> = pixels + .par_chunks(chunk_size) + .map(|elem| { + // get xgboost style matrices + let xg_matrix = DMatrix::from_col_major_f32( + elem, + mem::size_of::() * n_bands, + mem::size_of::(), + grid_shape.number_of_elements(), + n_bands, + -1, // TODO: add this to settings.toml: # of threads for xgboost to use + nan_val, + ) + .context(self::error::CreateDMatrix)?; + + let mut out_dim: u64 = 0; + + let bst = Booster::load_buffer(model_file.as_bytes()) + .context(self::error::LoadBoosterFromModel)?; + + // measure time for prediction + let predictions: Result, XGBError> = bst.predict_from_dmat( + &xg_matrix, + &[grid_shape.number_of_elements() as u64, n_bands as u64], + &mut out_dim, + ); + + predictions.map_err(|xg_err| XGBoostModuleError::PredictionError { source: xg_err }) + }) + .collect::, _>>()?; + + let predictions_flat: Vec = res.into_iter().flatten().collect(); + + Grid2D::new(grid_shape, predictions_flat).context(self::error::DataTypes) + }) +} + +#[async_trait] +impl QueryProcessor for XgboostProcessor +where + Q: QueryProcessor, SpatialBounds = SpatialPartition2D>, +{ + type Output = RasterTile2D; + type SpatialBounds = SpatialPartition2D; + + async fn _query<'a>( + &'a self, + query: RasterQueryRectangle, + ctx: &'a dyn QueryContext, + ) -> Result>> { + let model_content = Arc::new(self.model.clone()); + + let mut band_buffer = Vec::new(); + + for band in &self.sources { + let stream = band.query(query, ctx).await?; + band_buffer.push(stream); + } + + let rs = StreamVectorZip::new(band_buffer).then(move |tile_vec| { + let arc_clone = Arc::clone(&model_content); + async move { + let tile_vec = tile_vec.into_iter().collect::>()?; + let processed_tile = + self.process_tile_async(tile_vec, arc_clone, ctx.thread_pool().clone()); + Ok(processed_tile.await?) + } + }); + + Ok(rs.boxed()) + } +} + +#[cfg(test)] +mod tests { + use crate::engine::{ + MockExecutionContext, MockQueryContext, MultipleRasterSources, QueryProcessor, + RasterOperator, RasterResultDescriptor, + }; + use crate::mock::{MockRasterSource, MockRasterSourceParams}; + + use futures::StreamExt; + use geoengine_datatypes::primitives::{ + Measurement, RasterQueryRectangle, SpatialPartition2D, SpatialResolution, TimeInterval, + }; + + use geoengine_datatypes::raster::{ + Grid2D, GridOrEmpty, RasterDataType, RasterTile2D, TileInformation, TilingSpecification, + }; + use geoengine_datatypes::spatial_reference::SpatialReference; + use geoengine_datatypes::test_data; + use geoengine_datatypes::util::test::TestDefault; + use ndarray::{arr2, ArrayBase, Dim, OwnedRepr}; + use xgboost_rs::{Booster, DMatrix}; + + use crate::util::Result; + use std::collections::HashMap; + use std::path::PathBuf; + + use super::{XgboostOperator, XgboostParams}; + + /// Just a helper method to make the code less cluttery. + fn zip_bands(b1: &[f32], b2: &[f32], b3: &[f32]) -> Vec<[f32; 3]> { + b1.iter() + .zip(b2.iter()) + .zip(b3.iter()) + .map(|((x, y), z)| [*x, *y, *z]) + .collect::>() + } + + /// Helper method to generate a raster with two tiles. + fn make_double_raster(t1: Vec, t2: Vec) -> Box { + let raster_tiles = vec![ + RasterTile2D::::new_with_tile_info( + TimeInterval::new_unchecked(0, 1), + TileInformation { + global_tile_position: [-1, 0].into(), + tile_size_in_pixels: [5, 5].into(), + global_geo_transform: TestDefault::test_default(), + }, + GridOrEmpty::from(Grid2D::new([5, 5].into(), t1).unwrap()), + ), + RasterTile2D::::new_with_tile_info( + TimeInterval::new_unchecked(0, 1), + TileInformation { + global_tile_position: [-1, 1].into(), + tile_size_in_pixels: [5, 5].into(), + global_geo_transform: TestDefault::test_default(), + }, + GridOrEmpty::from(Grid2D::new([5, 5].into(), t2).unwrap()), + ), + ]; + + MockRasterSource { + params: MockRasterSourceParams { + data: raster_tiles, + result_descriptor: RasterResultDescriptor { + data_type: RasterDataType::U8, + spatial_reference: SpatialReference::epsg_4326().into(), + measurement: Measurement::Unitless, + time: None, + bbox: None, + resolution: None, + }, + }, + } + .boxed() + } + + fn make_raster(tile: Vec) -> Box { + // green raster: + // 255, 255, 0, 0, 0 + // 255, 255, 0, 0, 0 + // 0, 0, 0, 0, 0 + // 0, 0, 0, 255, 255 + // 0, 0, 0, 255, 255 + let raster_tiles = vec![RasterTile2D::::new_with_tile_info( + TimeInterval::new_unchecked(0, 1), + TileInformation { + global_tile_position: [-1, 0].into(), + tile_size_in_pixels: [5, 5].into(), + global_geo_transform: TestDefault::test_default(), + }, + GridOrEmpty::from(Grid2D::new([5, 5].into(), tile).unwrap()), + )]; + + MockRasterSource { + params: MockRasterSourceParams { + data: raster_tiles, + result_descriptor: RasterResultDescriptor { + data_type: RasterDataType::U8, + spatial_reference: SpatialReference::epsg_4326().into(), + measurement: Measurement::Unitless, + time: None, + bbox: None, + resolution: None, + }, + }, + } + .boxed() + } + + // Just a helper method to extract xgboost matrix generation code. + fn make_xg_matrix(data_arr_2d: &ArrayBase, Dim<[usize; 2]>>) -> DMatrix { + // define information needed for xgboost + let strides_ax_0 = data_arr_2d.strides()[0] as usize; + let strides_ax_1 = data_arr_2d.strides()[1] as usize; + let byte_size_ax_0 = std::mem::size_of::() * strides_ax_0; + let byte_size_ax_1 = std::mem::size_of::() * strides_ax_1; + + // get xgboost style matrices + let xg_matrix = DMatrix::from_col_major_f32( + data_arr_2d.as_slice_memory_order().unwrap(), + byte_size_ax_0, + byte_size_ax_1, + 25, + 3, + -1, + f32::NAN, + ) + .unwrap(); + + xg_matrix + } + + #[tokio::test] + async fn multi_tile_test() { + // 255, 255, 0, 0, 0 || 0, 0, 0, 255, 255 + // 255, 255, 0, 0, 0 || 0, 0, 0, 255, 255 + // 255, 255, 0, 0, 0 || 0, 0, 0, 255, 255 + // 255, 255, 0, 0, 0 || 0, 0, 0, 255, 255 + // 255, 255, 0, 0, 0 || 0, 0, 0, 255, 255 + let green = make_double_raster( + vec![ + 255, 255, 0, 0, 0, 255, 255, 0, 0, 0, 255, 255, 0, 0, 0, 255, 255, 0, 0, 0, 255, + 255, 0, 0, 0, + ], + vec![ + 0, 0, 0, 255, 255, 0, 0, 0, 255, 255, 0, 0, 0, 255, 255, 0, 0, 0, 255, 255, 0, 0, + 0, 255, 255, + ], + ); + + // 0, 0, 0, 255, 255 || 255, 255, 0, 0, 0 + // 0, 0, 0, 255, 255 || 255, 255, 0, 0, 0 + // 0, 0, 0, 0, 0 || 0, 0, 0, 0, 0 + // 0, 0, 0, 255, 255 || 255, 255, 0, 0, 0 + // 0, 0, 0, 255, 255 || 255, 255, 0, 0, 0 + let blue = make_double_raster( + vec![ + 0, 0, 0, 255, 255, 0, 0, 0, 255, 255, 0, 0, 0, 0, 0, 0, 0, 0, 255, 255, 0, 0, 0, + 255, 255, + ], + vec![ + 255, 255, 0, 0, 0, 255, 255, 0, 0, 0, 0, 0, 0, 0, 0, 255, 255, 0, 0, 0, 255, 255, + 0, 0, 0, + ], + ); + + // 0, 0, 255, 0, 0 || 0, 0, 255, 0, 0 + // 0, 0, 255, 0, 0 || 0, 0, 255, 0, 0 + // 0, 0, 255, 255, 255 || 255, 255, 0, 0, 0 + // 0, 0, 255, 0, 0 || 0, 0, 255, 0, 0 + // 0, 0, 255, 0, 0 || 0, 0, 255, 0, 0 + let temp = make_double_raster( + vec![ + 0, 0, 255, 0, 0, 0, 0, 255, 0, 0, 0, 0, 255, 255, 255, 0, 0, 255, 0, 0, 0, 0, 255, + 0, 0, + ], + vec![ + 0, 0, 255, 0, 0, 0, 0, 255, 0, 0, 255, 255, 255, 0, 0, 0, 0, 255, 0, 0, 0, 0, 255, + 0, 0, + ], + ); + + let srcs = vec![green, blue, temp]; + + let model_path = PathBuf::from(test_data!("pro/ml/xgboost/s2_10m_de_marburg/model.json")); + + let xg = XgboostOperator { + params: XgboostParams { + model_sub_path: model_path.clone(), + no_data_value: f32::NAN, + }, + sources: MultipleRasterSources { rasters: srcs }, + }; + + let mut exe_ctx = MockExecutionContext::new_with_tiling_spec(TilingSpecification::new( + (0., 0.).into(), + [5, 5].into(), + )); + + exe_ctx + .initialize_ml_model(model_path) + .expect("The model file should be available."); + + let op = RasterOperator::boxed(xg) + .initialize(&exe_ctx) + .await + .unwrap(); + + let processor = op.query_processor().unwrap().get_f32().unwrap(); + + let query_rect = RasterQueryRectangle { + spatial_bounds: SpatialPartition2D::new((0., 5.).into(), (10., 0.).into()).unwrap(), + time_interval: TimeInterval::new_unchecked(0, 1), + spatial_resolution: SpatialResolution::one(), + }; + + let query_ctx = MockQueryContext::test_default(); + + let result_stream = processor.query(query_rect, &query_ctx).await.unwrap(); + + let result: Vec>> = result_stream.collect().await; + let result = result.into_iter().collect::>>().unwrap(); + + let mut all_pixels = Vec::new(); + + for tile in result { + let data_of_tile = tile.into_materialized_tile().grid_array.inner_grid.data; + for pixel in &data_of_tile { + all_pixels.push(pixel.round()); + } + } + + // expected result tiles + // tile 1 || tile 2 + // 0, 0, 2, 1, 1 || 1, 1, 2, 0, 0 + // 0, 0, 2, 1, 1 || 1, 1, 2, 0, 0 + // 0, 0, 2, 2, 2 || 2, 2, 2, 2, 2 + // 0, 0, 2, 1, 1 || 1, 1, 2, 0, 0 + // 0, 0, 2, 1, 1 || 1, 1, 2, 0, 0 + + let expected = vec![ + 0.0, 0.0, 2.0, 1.0, 1.0, 0.0, 0.0, 2.0, 1.0, 1.0, 0.0, 0.0, 2.0, 2.0, 2.0, 0.0, 0.0, + 2.0, 1.0, 1.0, 0.0, 0.0, 2.0, 1.0, 1.0, 1.0, 1.0, 2.0, 0.0, 0.0, 1.0, 1.0, 2.0, 0.0, + 0.0, 2.0, 2.0, 2.0, 0.0, 0.0, 1.0, 1.0, 2.0, 0.0, 0.0, 1.0, 1.0, 2.0, 0.0, 0.0, + ]; + + assert_eq!(all_pixels, expected); + } + + /// This test is used to verify, that xgboost is updating a model on consecutive tiles. + // TODO: This test should be more extensive + #[test] + #[allow(clippy::too_many_lines)] + fn xg_train_increment() { + // DATA USED FOR TRAINING ------------- + // raster green || raster blue || raster temp + // 255, 255, 0, 0, 0 || 0, 0, 0, 255, 255 || 15, 15, 60, 5, 5 + // 255, 255, 0, 0, 0 || 0, 0, 0, 255, 255 || 15, 15, 60, 5, 5 + // 0, 0, 0, 0, 0 || 0, 0, 100, 0, 0 || 60, 60, 0, 60, 60 + // 0, 0, 0, 255, 255 || 255, 255, 0, 0, 0 || 5, 5, 60, 15, 15 + // 0, 0, 0, 255, 255 || 255, 255, 0, 0, 0 || 5, 5, 60, 15, 15 + + // Class layout looks like this: + // 0, 0, 2, 1, 1 + // 0, 0, 2, 1, 1 + // 2, 2, 3, 2, 2 + // 1, 1, 2, 0, 0 + // 1, 1, 2, 0, 0 + + let b1 = &[ + 255., 255., 0., 0., 0., 255., 255., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 255., + 255., 0., 0., 0., 255., 255., + ]; + + let b2 = &[ + 0., 0., 0., 255., 255., 0., 0., 0., 255., 255., 0., 0., 100., 0., 0., 255., 255., 0., + 0., 0., 255., 255., 0., 0., 0., + ]; + + let b3 = &[ + 15., 15., 60., 5., 5., 15., 15., 60., 5., 5., 60., 60., 0., 60., 60., 5., 5., 60., 15., + 15., 5., 5., 60., 15., 15., + ]; + + let data_arr_2d = arr2(zip_bands(b1, b2, b3).as_slice()); + + // specification of the class labels for each datapoint + let target_vec = arr2(&[ + [0., 0., 2., 1., 1.], + [0., 0., 2., 1., 1.], + [2., 2., 3., 2., 2.], + [1., 1., 2., 0., 0.], + [1., 1., 2., 0., 0.], + ]); + + // raster green || raster blue || raster temp + // 0, 0, 0, 0, 0 || 100, 100, 0, 0, 255 || 0, 0, 15, 15, 5 + // 0, 0, 0, 0, 0 || 100, 100, 100, 0, 0 || 0, 0, 0, 15, 60 + // 0, 0, 0, 0, 0 || 0, 100, 100, 100, 0 || 15, 0, 0, 0, 15 + // 0, 0, 0, 0, 0 || 0, 0, 100, 100, 100 || 15, 15, 0, 0, 0 + // 0, 0, 0, 0, 0 || 0, 0, 0, 100, 100 || 15, 15, 15, 0, 0 + + // Class layout looks like this: + // 3, 3, 0, 0, 1 + // 3, 3, 3, 0, 2 + // 0, 3, 3, 3, 0 + // 0, 0, 3, 3, 3 + // 0, 0, 0, 3, 3 + + let b1 = &[ + 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., + 0., 0., 0., + ]; + + let b2 = &[ + 100., 100., 0., 0., 255., 100., 100., 100., 0., 0., 0., 100., 100., 100., 0., 0., 0., + 100., 100., 100., 0., 0., 0., 100., 100., + ]; + + let b3 = &[ + 0., 0., 15., 15., 5., 0., 0., 0., 15., 60., 15., 0., 0., 0., 15., 15., 15., 0., 0., 0., + 15., 15., 15., 0., 0., + ]; + + let data_arr_2d2 = arr2(zip_bands(b1, b2, b3).as_slice()); + + let target_vec2 = arr2(&[ + [3., 3., 0., 0., 1.], + [3., 3., 3., 0., 2.], + [0., 3., 3., 3., 0.], + [0., 0., 3., 3., 3.], + [0., 0., 0., 3., 3.], + ]); + + // how many rounds should be trained? + let training_rounds = 2; + + // setup data/model cache + let mut booster_vec: Vec = Vec::new(); + let mut matrix_vec: Vec = Vec::new(); + + for _ in 0..training_rounds { + let mut xg_matrix = make_xg_matrix(&data_arr_2d); + + // set labels for the dmatrix + xg_matrix + .set_labels(target_vec.as_slice().unwrap()) + .unwrap(); + + let mut xg_matrix2 = make_xg_matrix(&data_arr_2d2); + + // set labels for the dmatrix + xg_matrix2 + .set_labels(target_vec2.as_slice().unwrap()) + .unwrap(); + + // start the training process + if booster_vec.is_empty() { + // in the first iteration, there is no model yet. + matrix_vec.push(xg_matrix); + + let mut initial_training_config: HashMap<&str, &str> = HashMap::new(); + + initial_training_config.insert("validate_parameters", "1"); + initial_training_config.insert("process_type", "default"); + initial_training_config.insert("tree_method", "hist"); + initial_training_config.insert("max_depth", "10"); + initial_training_config.insert("objective", "multi:softmax"); + initial_training_config.insert("num_class", "4"); + initial_training_config.insert("eta", "0.75"); + + let evals = &[(matrix_vec.get(0).unwrap(), "train")]; + let bst = Booster::train( + Some(evals), + matrix_vec.get(0).unwrap(), + initial_training_config, + None, // <- No old model yet + ) + .unwrap(); + + // store the first booster + booster_vec.push(bst); + } + // update training rounds + else { + // this is a consecutive iteration, so we need the last booster instance + // to update the model + let bst = booster_vec.pop().unwrap(); + + let mut update_training_config: HashMap<&str, &str> = HashMap::new(); + + update_training_config.insert("validate_parameters", "1"); + update_training_config.insert("process_type", "update"); + update_training_config.insert("updater", "refresh"); + update_training_config.insert("refresh_leaf", "true"); + update_training_config.insert("objective", "multi:softmax"); + update_training_config.insert("num_class", "4"); + update_training_config.insert("max_depth", "15"); + + let evals = &[(matrix_vec.get(0).unwrap(), "orig"), (&xg_matrix2, "train")]; + let bst_updated = Booster::train( + Some(evals), + &xg_matrix2, + update_training_config, + Some(bst), // <- this contains the last model which is now being updated + ) + .unwrap(); + + // store the new booster instance + booster_vec.push(bst_updated); + } + } + + // lets use the trained model now to predict something + let bst = booster_vec.pop().unwrap(); + + // test tile looks like this: + // GREEN || BLUE || TEMP + // 0, 0, 0, 0, 0 || 100, 100, 100, 100, 100 || 0, 0, 0, 0, 0 + // 0, 0, 255, 0, 0 || 100, 0, 0, 255, 100 || 0, 60, 15, 5, 0 + // 0, 255, 0, 255, 0 || 100, 0, 0, 0, 100 || 0, 15, 60, 15, 0 + // 0, 0, 255, 0, 0 || 100, 255, 0, 0, 100 || 0, 5, 15, 60, 0 + // 0, 0, 0, 0, 0 || 100, 100, 100, 100, 100 || 0, 0, 0, 0, 0 + + // with class layout like this: + // 3, 3, 3, 3, 3 + // 3, 2, 0, 1, 3 + // 3, 0, 2, 0, 3 + // 3, 1, 0, 2, 3 + // 3, 3, 3, 3, 3 + + let b1 = &[ + 0., 0., 0., 0., 0., 0., 0., 255., 0., 0., 0., 255., 0., 255., 0., 0., 0., 255., 0., 0., + 0., 0., 0., 0., 0., + ]; + + let b2 = &[ + 100., 100., 100., 100., 100., 100., 0., 0., 255., 100., 100., 0., 0., 0., 100., 100., + 255., 0., 0., 100., 100., 100., 100., 100., 100., + ]; + + let b3 = &[ + 0., 0., 0., 0., 0., 0., 60., 15., 5., 0., 0., 15., 60., 15., 0., 0., 5., 15., 60., 0., + 0., 0., 0., 0., 0., + ]; + + let test_data_arr_2d = arr2(zip_bands(b1, b2, b3).as_slice()); + let test_data = make_xg_matrix(&test_data_arr_2d); + let result = bst.predict(&test_data).unwrap(); + + // this result is not 100% desired, but most likely just a matter of a better or bigger training set + let expected = vec![ + 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 0.0, 0.0, 3.0, 3.0, 3.0, 0.0, 0.0, 0.0, 3.0, 3.0, 3.0, + 0.0, 0.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, + ]; + + assert_eq!(expected, result); + } + + /// This test verifies, that xgboost is creating meaningful result tiles from a learned model. + #[tokio::test] + async fn xg_single_tile_test() { + // 255, 255, 0, 0, 0 + // 255, 255, 0, 0, 0 + // 0, 0, 0, 0, 0 + // 0, 0, 0, 255, 255 + // 0, 0, 0, 255, 255 + let green = make_raster(vec![ + 255, 255, 0, 0, 0, 255, 255, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 255, 255, 0, 0, 0, 255, + 255, + ]); + + // 0, 0, 0, 255, 255 + // 0, 0, 0, 255, 255 + // 0, 0, 0, 0, 0 + // 255, 255, 0, 0, 0 + // 255, 255, 0, 0, 0 + let blue = make_raster(vec![ + 0, 0, 0, 255, 255, 0, 0, 0, 255, 255, 0, 0, 0, 0, 0, 255, 255, 0, 0, 0, 255, 255, 0, 0, + 0, + ]); + + // 15, 15, 60, 5, 5 + // 15, 15, 60, 5, 5 + // 60, 60, 60, 60, 60 + // 5, 5, 60, 15, 15 + // 5, 5, 60, 15, 15 + let temp = make_raster(vec![ + 15, 15, 60, 5, 5, 15, 15, 60, 5, 5, 60, 60, 60, 60, 60, 5, 5, 60, 15, 15, 5, 5, 60, 15, + 15, + ]); + + let srcs = vec![green, blue, temp]; + + let model_path = PathBuf::from(test_data!("pro/ml/xgboost/s2_10m_de_marburg/model.json")); + + let xg = XgboostOperator { + params: XgboostParams { + model_sub_path: model_path.clone(), + no_data_value: -1_000., + }, + sources: MultipleRasterSources { rasters: srcs }, + }; + + let mut exe_ctx = MockExecutionContext::new_with_tiling_spec(TilingSpecification::new( + (0., 0.).into(), + [5, 5].into(), + )); + + exe_ctx + .initialize_ml_model(model_path) + .expect("The model file should be available."); + + let op = RasterOperator::boxed(xg) + .initialize(&exe_ctx) + .await + .unwrap(); + + let processor = op.query_processor().unwrap().get_f32().unwrap(); + + let query_rect = RasterQueryRectangle { + spatial_bounds: SpatialPartition2D::new((0., 5.).into(), (5., 0.).into()).unwrap(), + time_interval: TimeInterval::new_unchecked(0, 1), + spatial_resolution: SpatialResolution::one(), + }; + + let query_ctx = MockQueryContext::test_default(); + + let result_stream = processor.query(query_rect, &query_ctx).await.unwrap(); + + let result: Vec>> = result_stream.collect().await; + let result = result.into_iter().collect::>>().unwrap(); + + let mut all_pixels = Vec::new(); + + for tile in result { + let data_of_tile = tile.into_materialized_tile().grid_array.inner_grid.data; + for pixel in &data_of_tile { + all_pixels.push(pixel.round()); + } + } + + let expected = vec![ + 0., 0., 2., 1., 1., 0., 0., 2., 1., 1., 2., 2., 2., 2., 2., 1., 1., 2., 0., 0., 1., 1., + 2., 0., 0., + ]; + + assert_eq!(all_pixels, expected); + } +} diff --git a/operators/src/pro/mod.rs b/operators/src/pro/mod.rs index ea33364da..97d480472 100644 --- a/operators/src/pro/mod.rs +++ b/operators/src/pro/mod.rs @@ -2,3 +2,4 @@ pub mod adapters; pub mod meta; +pub mod ml; diff --git a/services/Cargo.toml b/services/Cargo.toml index a0d30e8ef..2d2628f1a 100644 --- a/services/Cargo.toml +++ b/services/Cargo.toml @@ -96,6 +96,7 @@ rand = "0.8.4" tempfile = "3.1" wiremock-grpc = "0.0.3-alpha1" xml-rs = "0.8.3" +serial_test = "0.9.0" [build-dependencies] vergen = "7" diff --git a/services/src/contexts/mod.rs b/services/src/contexts/mod.rs index e597db43d..29af92691 100644 --- a/services/src/contexts/mod.rs +++ b/services/src/contexts/mod.rs @@ -1,11 +1,16 @@ use crate::error::Result; use crate::layers::storage::{LayerDb, LayerProviderDb}; use crate::tasks::{TaskContext, TaskManager}; +use crate::util::config::get_config_element; +use crate::util::path_with_base_path; use crate::{projects::ProjectDb, workflows::registry::WorkflowRegistry}; use async_trait::async_trait; use geoengine_datatypes::primitives::{RasterQueryRectangle, VectorQueryRectangle}; +use geoengine_datatypes::util::canonicalize_subpath; use rayon::ThreadPool; use std::sync::Arc; +use tokio::fs::File; +use tokio::io::AsyncWriteExt; use tokio::sync::RwLock; mod in_memory; @@ -169,6 +174,7 @@ where } } +#[async_trait::async_trait] impl ExecutionContext for ExecutionContextImpl where D: DatasetDb @@ -213,6 +219,51 @@ where ) -> Box { op } + + /// This method is meant to read a ml model from disk, specified by the config key `machinelearning.model_defs_path`. + async fn read_ml_model( + &self, + model_sub_path: std::path::PathBuf, + ) -> geoengine_operators::util::Result { + let cfg = get_config_element::() + .map_err(|_| geoengine_operators::error::Error::InvalidMachineLearningConfig)?; + + let model_base_path = cfg.model_defs_path; + + let model_path = canonicalize_subpath(&model_base_path, &model_sub_path)?; + let model = tokio::fs::read_to_string(model_path).await?; + + Ok(model) + } + + /// This method is meant to write a ml model to disk. + /// The provided path for the model has to exist. + async fn write_ml_model( + &mut self, + model_sub_path: std::path::PathBuf, + ml_model_str: String, + ) -> geoengine_operators::util::Result<()> { + let cfg = get_config_element::() + .map_err(|_| geoengine_operators::error::Error::InvalidMachineLearningConfig)?; + + let model_base_path = cfg.model_defs_path; + + // make sure, that the model sub path is not escaping the config path + let model_path = path_with_base_path(&model_base_path, &model_sub_path) + .map_err(|_| geoengine_operators::error::Error::InvalidMlModelPath)?; + + let parent_dir = model_path + .parent() + .ok_or(geoengine_operators::error::Error::CouldNotGetMlModelDirectory)?; + + tokio::fs::create_dir_all(parent_dir).await?; + + // TODO: add routine or error, if a given modelpath would overwrite an existing model + let mut file = File::create(model_path).await?; + file.write_all(ml_model_str.as_bytes()).await?; + + Ok(()) + } } // TODO: use macro(?) for delegating meta_data function to DatasetDB to avoid redundant code @@ -343,3 +394,85 @@ where } } } + +#[cfg(test)] + +mod tests { + use super::*; + use std::path::PathBuf; + + use geoengine_datatypes::{test_data, util::test::TestDefault}; + use serial_test::serial; + + use crate::{ + contexts::{Context, InMemoryContext, SimpleSession}, + util::config::set_config, + }; + + #[tokio::test] + #[serial] + async fn read_model_test() { + let cfg = get_config_element::().unwrap(); + let cfg_backup = &cfg.model_defs_path; + + set_config( + "machinelearning.model_defs_path", + test_data!("pro/ml").to_str().unwrap(), + ) + .unwrap(); + + let ctx = InMemoryContext::test_default(); + let exe_ctx = ctx.execution_context(SimpleSession::default()).unwrap(); + + let model_path = PathBuf::from("xgboost/s2_10m_de_marburg/model.json"); + let mut model = exe_ctx.read_ml_model(model_path).await.unwrap(); + + let actual: String = model.drain(0..277).collect(); + + set_config( + "machinelearning.model_defs_path", + cfg_backup.to_str().unwrap(), + ) + .unwrap(); + + let expected = "{\"learner\":{\"attributes\":{},\"feature_names\":[],\"feature_types\":[],\"gradient_booster\":{\"model\":{\"gbtree_model_param\":{\"num_parallel_tree\":\"1\",\"num_trees\":\"16\",\"size_leaf_vector\":\"0\"},\"tree_info\":[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],\"trees\":[{\"base_weights\":[5.192308E-1,9.722222E-1"; + + assert_eq!(actual, expected); + } + + #[tokio::test] + #[serial] + async fn write_model_test() { + let cfg = get_config_element::().unwrap(); + let cfg_backup = cfg.model_defs_path; + + let tmp_dir = tempfile::tempdir().unwrap(); + let tmp_path = tmp_dir.path(); + std::fs::create_dir_all(tmp_path.join("pro/ml/xgboost")).unwrap(); + + let temp_ml_path = tmp_path.join("pro/ml").to_str().unwrap().to_string(); + + set_config("machinelearning.model_defs_path", temp_ml_path).unwrap(); + let ctx = InMemoryContext::test_default(); + let mut exe_ctx = ctx.execution_context(SimpleSession::default()).unwrap(); + + let model_path = PathBuf::from("xgboost/model.json"); + + exe_ctx + .write_ml_model(model_path, String::from("model content")) + .await + .unwrap(); + + set_config( + "machinelearning.model_defs_path", + cfg_backup.to_str().unwrap(), + ) + .unwrap(); + + let actual = tokio::fs::read_to_string(tmp_path.join("pro/ml/xgboost/model.json")) + .await + .unwrap(); + + assert_eq!(actual, "model content"); + } +} diff --git a/services/src/datasets/external/netcdfcf/mod.rs b/services/src/datasets/external/netcdfcf/mod.rs index 99dc9112d..2a898226d 100644 --- a/services/src/datasets/external/netcdfcf/mod.rs +++ b/services/src/datasets/external/netcdfcf/mod.rs @@ -25,7 +25,6 @@ use crate::layers::listing::LayerCollectionProvider; use crate::projects::RasterSymbology; use crate::projects::Symbology; use crate::tasks::TaskContext; -use crate::util::canonicalize_subpath; use crate::util::user_input::Validated; use crate::workflows::workflow::Workflow; use async_trait::async_trait; @@ -39,6 +38,7 @@ use geoengine_datatypes::primitives::{ }; use geoengine_datatypes::raster::{GdalGeoTransform, RasterDataType}; use geoengine_datatypes::spatial_reference::SpatialReference; +use geoengine_datatypes::util::canonicalize_subpath; use geoengine_operators::engine::RasterOperator; use geoengine_operators::engine::TypedOperator; use geoengine_operators::source::GdalSource; diff --git a/services/src/datasets/external/netcdfcf/overviews.rs b/services/src/datasets/external/netcdfcf/overviews.rs index 0b049a7a6..22edc711b 100644 --- a/services/src/datasets/external/netcdfcf/overviews.rs +++ b/services/src/datasets/external/netcdfcf/overviews.rs @@ -3,7 +3,7 @@ use crate::{ api::model::datatypes::ResamplingMethod, datasets::{external::netcdfcf::NetCdfCfDataProvider, storage::MetaDataDefinition}, tasks::{TaskContext, TaskStatusInfo}, - util::{canonicalize_subpath, config::get_config_element, path_with_base_path}, + util::{config::get_config_element, path_with_base_path}, }; use gdal::{ cpl::CslStringList, @@ -14,6 +14,7 @@ use gdal::{ Dataset, DatasetOptions, GdalOpenFlags, }; use gdal_sys::GDALGetRasterStatistics; +use geoengine_datatypes::util::canonicalize_subpath; use geoengine_datatypes::{ error::BoxedResultExt, primitives::{DateTimeParseFormat, TimeInstance, TimeInterval}, diff --git a/services/src/handlers/session.rs b/services/src/handlers/session.rs index 7c85812f9..cf81767f0 100644 --- a/services/src/handlers/session.rs +++ b/services/src/handlers/session.rs @@ -134,6 +134,7 @@ mod tests { use actix_web_httpauth::headers::authorization::Bearer; use geoengine_datatypes::spatial_reference::SpatialReferenceOption; use geoengine_datatypes::util::test::TestDefault; + use serial_test::serial; #[tokio::test] async fn session() { @@ -202,7 +203,8 @@ mod tests { check_allowed_http_methods(anonymous_test_helper, &[Method::POST]).await; } - #[tokio::test(flavor = "multi_thread", worker_threads = 1)] + #[tokio::test] + #[serial] async fn it_disables_anonymous_access() { config::set_config( "session.fixed_session_token", diff --git a/services/src/pro/contexts/mod.rs b/services/src/pro/contexts/mod.rs index 9f17204b3..64864fcfb 100644 --- a/services/src/pro/contexts/mod.rs +++ b/services/src/pro/contexts/mod.rs @@ -8,6 +8,7 @@ use std::sync::Arc; use geoengine_datatypes::dataset::DataId; use geoengine_datatypes::primitives::{RasterQueryRectangle, VectorQueryRectangle}; use geoengine_datatypes::raster::TilingSpecification; +use geoengine_datatypes::util::canonicalize_subpath; use geoengine_operators::engine::{ CreateSpan, ExecutionContext, InitializedPlotOperator, InitializedVectorOperator, MetaData, MetaDataProvider, RasterResultDescriptor, VectorResultDescriptor, @@ -19,12 +20,15 @@ pub use in_memory::ProInMemoryContext; #[cfg(feature = "postgres")] pub use postgres::PostgresContext; use rayon::ThreadPool; +use tokio::io::AsyncWriteExt; use crate::contexts::{Context, Session}; use crate::datasets::listing::SessionMetaDataProvider; use crate::datasets::storage::DatasetDb; use crate::layers::storage::LayerProviderDb; use crate::pro::users::{OidcRequestDb, UserDb, UserSession}; +use crate::util::config::get_config_element; +use crate::util::path_with_base_path; use async_trait::async_trait; @@ -75,6 +79,7 @@ where } } +#[async_trait::async_trait] impl ExecutionContext for ExecutionContextImpl where D: DatasetDb @@ -120,6 +125,51 @@ where // as plots do not produce a stream of results, we have nothing to count for now op } + + /// This method is meant to read a ml model from disk, specified by the config key `machinelearning.model_defs_path`. + async fn read_ml_model( + &self, + model_sub_path: std::path::PathBuf, + ) -> geoengine_operators::util::Result { + let cfg = get_config_element::() + .map_err(|_| geoengine_operators::error::Error::InvalidMachineLearningConfig)?; + + let model_base_path = cfg.model_defs_path; + + let model_path = canonicalize_subpath(&model_base_path, &model_sub_path)?; + let model = tokio::fs::read_to_string(model_path).await?; + + Ok(model) + } + + /// This method is meant to write a ml model to disk. + /// The provided path for the model has to exist. + async fn write_ml_model( + &mut self, + model_sub_path: std::path::PathBuf, + ml_model_str: String, + ) -> geoengine_operators::util::Result<()> { + let cfg = get_config_element::() + .map_err(|_| geoengine_operators::error::Error::InvalidMachineLearningConfig)?; + + let model_base_path = cfg.model_defs_path; + + // make sure, that the model sub path is not escaping the config path + let model_path = path_with_base_path(&model_base_path, &model_sub_path) + .map_err(|_| geoengine_operators::error::Error::InvalidMlModelPath)?; + + let parent_dir = model_path + .parent() + .ok_or(geoengine_operators::error::Error::CouldNotGetMlModelDirectory)?; + + tokio::fs::create_dir_all(parent_dir).await?; + + // TODO: add routine or error, if a given modelpath would overwrite an existing model + let mut file = tokio::fs::File::create(model_path).await?; + file.write_all(ml_model_str.as_bytes()).await?; + + Ok(()) + } } // TODO: use macro(?) for delegating meta_data function to DatasetDB to avoid redundant code @@ -250,3 +300,84 @@ where } } } + +#[cfg(test)] +mod tests { + use super::*; + use std::path::PathBuf; + + use geoengine_datatypes::{test_data, util::test::TestDefault}; + use serial_test::serial; + + use crate::{ + contexts::Context, pro::util::tests::create_session_helper, util::config::set_config, + }; + + #[tokio::test] + #[serial] + async fn read_model_test() { + let cfg = get_config_element::().unwrap(); + let cfg_backup = cfg.model_defs_path; + + set_config( + "machinelearning.model_defs_path", + test_data!("pro/ml").to_str().unwrap(), + ) + .unwrap(); + + let ctx = ProInMemoryContext::test_default(); + let session = create_session_helper(&ctx).await; + let exe_ctx = ctx.execution_context(session).unwrap(); + + let model_path = PathBuf::from("xgboost/s2_10m_de_marburg/model.json"); + let mut model = exe_ctx.read_ml_model(model_path).await.unwrap(); + + let actual: String = model.drain(0..277).collect(); + + set_config( + "machinelearning.model_defs_path", + cfg_backup.to_str().unwrap(), + ) + .unwrap(); + + let expected = "{\"learner\":{\"attributes\":{},\"feature_names\":[],\"feature_types\":[],\"gradient_booster\":{\"model\":{\"gbtree_model_param\":{\"num_parallel_tree\":\"1\",\"num_trees\":\"16\",\"size_leaf_vector\":\"0\"},\"tree_info\":[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],\"trees\":[{\"base_weights\":[5.192308E-1,9.722222E-1"; + + assert_eq!(actual, expected); + } + #[tokio::test] + #[serial] + async fn write_model_test() { + let cfg = get_config_element::().unwrap(); + let cfg_backup = cfg.model_defs_path; + + let tmp_dir = tempfile::tempdir().unwrap(); + let tmp_path = tmp_dir.path(); + std::fs::create_dir_all(tmp_path.join("pro/ml/xgboost")).unwrap(); + + let temp_ml_path = tmp_path.join("pro/ml").to_str().unwrap().to_string(); + + set_config("machinelearning.model_defs_path", temp_ml_path).unwrap(); + + let ctx = ProInMemoryContext::test_default(); + let session = create_session_helper(&ctx).await; + let mut exe_ctx = ctx.execution_context(session).unwrap(); + + let model_path = PathBuf::from("xgboost/ml.json"); + exe_ctx + .write_ml_model(model_path, String::from("model content")) + .await + .unwrap(); + + set_config( + "machinelearning.model_defs_path", + cfg_backup.to_str().unwrap(), + ) + .unwrap(); + + let actual = tokio::fs::read_to_string(tmp_path.join("pro/ml/xgboost/ml.json")) + .await + .unwrap(); + + assert_eq!(actual, "model content"); + } +} diff --git a/services/src/pro/handlers/drone_mapping.rs b/services/src/pro/handlers/drone_mapping.rs index 2e8d3c6d2..5018aa55e 100644 --- a/services/src/pro/handlers/drone_mapping.rs +++ b/services/src/pro/handlers/drone_mapping.rs @@ -387,12 +387,14 @@ mod tests { use actix_web::{http::header, test}; use actix_web_httpauth::headers::authorization::Bearer; use geoengine_operators::engine::{MetaData, MetaDataProvider, RasterOperator}; + use serial_test::serial; use std::io::Write; use std::io::{Cursor, Read}; use std::path::PathBuf; #[allow(clippy::too_many_lines)] - #[tokio::test(flavor = "multi_thread", worker_threads = 1)] + #[tokio::test] + #[serial] async fn it_works() -> Result<()> { let mut test_data = TestDataUploads::default(); // remember created folder and remove them on drop diff --git a/services/src/pro/handlers/users.rs b/services/src/pro/handlers/users.rs index e4de1db95..1656355b3 100644 --- a/services/src/pro/handlers/users.rs +++ b/services/src/pro/handlers/users.rs @@ -436,6 +436,7 @@ mod tests { use httptest::responders::status_code; use httptest::{Expectation, Server}; use serde_json::json; + use serial_test::serial; async fn register_test_helper( ctx: C, @@ -854,7 +855,8 @@ mod tests { ); } - #[tokio::test(flavor = "multi_thread", worker_threads = 1)] + #[tokio::test] + #[serial] async fn it_disables_anonymous_access() { let ctx = ProInMemoryContext::test_default(); @@ -881,7 +883,8 @@ mod tests { .await; } - #[tokio::test(flavor = "multi_thread", worker_threads = 1)] + #[tokio::test] + #[serial] async fn it_disables_user_registration() { let ctx = ProInMemoryContext::test_default(); diff --git a/services/src/util/config.rs b/services/src/util/config.rs index 7d5f87b9d..c8917d53f 100644 --- a/services/src/util/config.rs +++ b/services/src/util/config.rs @@ -376,6 +376,15 @@ impl ConfigElement for Session { const KEY: &'static str = "session"; } +#[derive(Debug, Deserialize)] +pub struct MachineLearning { + pub model_defs_path: PathBuf, +} + +impl ConfigElement for MachineLearning { + const KEY: &'static str = "machinelearning"; +} + #[cfg(feature = "nfdi")] #[derive(Debug, Deserialize)] pub struct GFBio { diff --git a/services/src/util/mod.rs b/services/src/util/mod.rs index 76962720a..8d2057ec2 100644 --- a/services/src/util/mod.rs +++ b/services/src/util/mod.rs @@ -78,25 +78,6 @@ where } } -/// Canonicalize `base`/`sub_path` and ensure the `sub_path` doesn't escape the `base` -/// returns an error if the `sub_path` escapes the `base` -/// -/// This only works if the `Path` you are referring to actually exists. -/// -pub fn canonicalize_subpath(base: &Path, sub_path: &Path) -> crate::error::Result { - let base = base.canonicalize()?; - let path = base.join(sub_path).canonicalize()?; - - if path.starts_with(&base) { - Ok(path) - } else { - Err(crate::error::Error::SubPathMustNotEscapeBasePath { - base, - sub_path: sub_path.into(), - }) - } -} - /// Join `base` and `sub_path` and ensure the `sub_path` doesn't escape the `base` /// returns an error if the `sub_path` escapes the `base` /// @@ -169,23 +150,6 @@ mod mod_tests { ); } - #[test] - fn it_doesnt_escape_base_path() { - let tmp_dir = tempfile::tempdir().unwrap(); - let tmp_path = tmp_dir.path(); - std::fs::create_dir_all(tmp_path.join("foo/bar/foobar")).unwrap(); - std::fs::create_dir_all(tmp_path.join("foo/barfoo")).unwrap(); - - assert_eq!( - canonicalize_subpath(&tmp_path.join("foo/bar"), Path::new("foobar")) - .unwrap() - .to_string_lossy(), - tmp_path.join("foo/bar/foobar").to_string_lossy() - ); - - assert!(canonicalize_subpath(&tmp_path.join("foo/bar"), Path::new("../barfoo")).is_err()); - } - #[test] fn it_doesnt_escape_base_path_too() { let tmp_dir = tempfile::tempdir().unwrap(); diff --git a/test_data/pro/ml/xgboost/s2_10m_de_marburg/model.json b/test_data/pro/ml/xgboost/s2_10m_de_marburg/model.json new file mode 100644 index 000000000..9c787d729 --- /dev/null +++ b/test_data/pro/ml/xgboost/s2_10m_de_marburg/model.json @@ -0,0 +1 @@ +{"learner":{"attributes":{},"feature_names":[],"feature_types":[],"gradient_booster":{"model":{"gbtree_model_param":{"num_parallel_tree":"1","num_trees":"16","size_leaf_vector":"0"},"tree_info":[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],"trees":[{"base_weights":[5.192308E-1,9.722222E-1,-1.3333334E-1,4.0500003E-1,1.3333334E-1],"categories":[],"categories_nodes":[],"categories_segments":[],"categories_sizes":[],"default_left":[0,0,0,0,0],"id":0,"left_children":[1,3,-1,-1,-1],"loss_changes":[1.1782053E1,2.9888897E0,0E0,0E0,0E0],"parents":[2147483647,0,0,1,1],"right_children":[2,4,-1,-1,-1],"split_conditions":[2.55E2,2.55E2,-1.3333334E-1,4.0500003E-1,1.3333334E-1],"split_indices":[0,1,0,0,0],"split_type":[0,0,0,0,0],"sum_hessian":[2.5E1,1.7E1,8E0,9E0,8E0],"tree_param":{"num_deleted":"0","num_feature":"3","num_nodes":"5","size_leaf_vector":"0"}},{"base_weights":[3.7903848E-1,7.10463E-1,-9.777779E-2,2.9565004E-1,9.777779E-2],"categories":[],"categories_nodes":[],"categories_segments":[],"categories_sizes":[],"default_left":[0,0,0,0,0],"id":1,"left_children":[1,3,-1,-1,-1],"loss_changes":[6.306263E0,1.5825138E0,0E0,0E0,0E0],"parents":[2147483647,0,0,1,1],"right_children":[2,4,-1,-1,-1],"split_conditions":[2.55E2,2.55E2,-9.777779E-2,2.9565004E-1,9.777779E-2],"split_indices":[0,1,0,0,0],"split_type":[0,0,0,0,0],"sum_hessian":[2.5E1,1.7E1,8E0,9E0,8E0],"tree_param":{"num_deleted":"0","num_feature":"3","num_nodes":"5","size_leaf_vector":"0"}},{"base_weights":[2.7669805E-1,5.1918113E-1,-7.170371E-2,2.1582447E-1,7.170371E-2],"categories":[],"categories_nodes":[],"categories_segments":[],"categories_sizes":[],"default_left":[0,0,0,0,0],"id":2,"left_children":[1,3,-1,-1,-1],"loss_changes":[3.375418E0,8.378372E-1,0E0,0E0,0E0],"parents":[2147483647,0,0,1,1],"right_children":[2,4,-1,-1,-1],"split_conditions":[2.55E2,2.55E2,-7.170371E-2,2.1582447E-1,7.170371E-2],"split_indices":[0,1,0,0,0],"split_type":[0,0,0,0,0],"sum_hessian":[2.5E1,1.7E1,8E0,9E0,8E0],"tree_param":{"num_deleted":"0","num_feature":"3","num_nodes":"5","size_leaf_vector":"0"}},{"base_weights":[2.0198956E-1,3.7940055E-1,-5.2582722E-2,1.5755187E-1,5.258271E-2],"categories":[],"categories_nodes":[],"categories_segments":[],"categories_sizes":[],"default_left":[0,0,0,0,0],"id":3,"left_children":[1,3,-1,-1,-1],"loss_changes":[1.806706E0,4.4355345E-1,0E0,0E0,0E0],"parents":[2147483647,0,0,1,1],"right_children":[2,4,-1,-1,-1],"split_conditions":[2.55E2,2.55E2,-5.2582722E-2,1.5755187E-1,5.258271E-2],"split_indices":[0,1,0,0,0],"split_type":[0,0,0,0,0],"sum_hessian":[2.5E1,1.7E1,8E0,9E0,8E0],"tree_param":{"num_deleted":"0","num_feature":"3","num_nodes":"5","size_leaf_vector":"0"}},{"base_weights":[1.4745237E-1,2.7725452E-1,-3.856066E-2,1.1501286E-1,3.8560644E-2],"categories":[],"categories_nodes":[],"categories_segments":[],"categories_sizes":[],"default_left":[0,0,0,0,0],"id":4,"left_children":[1,3,-1,-1,-1],"loss_changes":[9.6705645E-1,2.3480415E-1,0E0,0E0,0E0],"parents":[2147483647,0,0,1,1],"right_children":[2,4,-1,-1,-1],"split_conditions":[2.55E2,2.55E2,-3.856066E-2,1.1501286E-1,3.8560644E-2],"split_indices":[0,1,0,0,0],"split_type":[0,0,0,0,0],"sum_hessian":[2.5E1,1.7E1,8E0,9E0,8E0],"tree_param":{"num_deleted":"0","num_feature":"3","num_nodes":"5","size_leaf_vector":"0"}},{"base_weights":[1.0764022E-1,2.0261002E-1,-2.8277816E-2,8.395938E-2,2.827781E-2],"categories":[],"categories_nodes":[],"categories_segments":[],"categories_sizes":[],"default_left":[0,0,0,0,0],"id":5,"left_children":[1,3,-1,-1,-1],"loss_changes":[5.176313E-1,1.24290645E-1,0E0,0E0,0E0],"parents":[2147483647,0,0,1,1],"right_children":[2,4,-1,-1,-1],"split_conditions":[2.55E2,2.55E2,-2.8277816E-2,8.395938E-2,2.827781E-2],"split_indices":[0,1,0,0,0],"split_type":[0,0,0,0,0],"sum_hessian":[2.5E1,1.7E1,8E0,9E0,8E0],"tree_param":{"num_deleted":"0","num_feature":"3","num_nodes":"5","size_leaf_vector":"0"}},{"base_weights":[7.857737E-2,1.4806242E-1,-2.0737063E-2,6.1290357E-2,2.0737061E-2],"categories":[],"categories_nodes":[],"categories_segments":[],"categories_sizes":[],"default_left":[0,0,0,0,0],"id":6,"left_children":[1,3,-1,-1,-1],"loss_changes":[2.7707273E-1,6.578764E-2,0E0,0E0,0E0],"parents":[2147483647,0,0,1,1],"right_children":[2,4,-1,-1,-1],"split_conditions":[2.55E2,2.55E2,-2.0737063E-2,6.1290357E-2,2.0737061E-2],"split_indices":[0,1,0,0,0],"split_type":[0,0,0,0,0],"sum_hessian":[2.5E1,1.7E1,8E0,9E0,8E0],"tree_param":{"num_deleted":"0","num_feature":"3","num_nodes":"5","size_leaf_vector":"0"}},{"base_weights":[5.7361472E-2,1.0820076E-1,-1.5207182E-2,4.474195E-2,1.520718E-2],"categories":[],"categories_nodes":[],"categories_segments":[],"categories_sizes":[],"default_left":[0,0,0,0,0],"id":7,"left_children":[1,3,-1,-1,-1],"loss_changes":[1.4831032E-1,3.481944E-2,0E0,0E0,0E0],"parents":[2147483647,0,0,1,1],"right_children":[2,4,-1,-1,-1],"split_conditions":[2.55E2,2.55E2,-1.5207182E-2,4.474195E-2,1.520718E-2],"split_indices":[0,1,0,0,0],"split_type":[0,0,0,0,0],"sum_hessian":[2.5E1,1.7E1,8E0,9E0,8E0],"tree_param":{"num_deleted":"0","num_feature":"3","num_nodes":"5","size_leaf_vector":"0"}},{"base_weights":[4.1873865E-2,7.907102E-2,-1.1151932E-2,3.2661613E-2,1.1151934E-2],"categories":[],"categories_nodes":[],"categories_segments":[],"categories_sizes":[],"default_left":[0,0,0,0,0],"id":8,"left_children":[1,3,-1,-1,-1],"loss_changes":[7.938771E-2,1.8427692E-2,0E0,0E0,0E0],"parents":[2147483647,0,0,1,1],"right_children":[2,4,-1,-1,-1],"split_conditions":[2.55E2,2.55E2,-1.1151932E-2,3.2661613E-2,1.1151934E-2],"split_indices":[0,1,0,0,0],"split_type":[0,0,0,0,0],"sum_hessian":[2.5E1,1.7E1,8E0,9E0,8E0],"tree_param":{"num_deleted":"0","num_feature":"3","num_nodes":"5","size_leaf_vector":"0"}},{"base_weights":[3.0567948E-2,5.7783842E-2,-8.178083E-3,2.3842992E-2,8.178092E-3],"categories":[],"categories_nodes":[],"categories_segments":[],"categories_sizes":[],"default_left":[0,0,0,0,0],"id":9,"left_children":[1,3,-1,-1,-1],"loss_changes":[4.249522E-2,9.751983E-3,0E0,0E0,0E0],"parents":[2147483647,0,0,1,1],"right_children":[2,4,-1,-1,-1],"split_conditions":[2.55E2,2.55E2,-8.178083E-3,2.3842992E-2,8.178092E-3],"split_indices":[0,1,0,0,0],"split_type":[0,0,0,0,0],"sum_hessian":[2.5E1,1.7E1,8E0,9E0,8E0],"tree_param":{"num_deleted":"0","num_feature":"3","num_nodes":"5","size_leaf_vector":"0"}},{"base_weights":[2.2314575E-2,4.22276E-2,-5.997261E-3,1.740537E-2,5.997261E-3],"categories":[],"categories_nodes":[],"categories_segments":[],"categories_sizes":[],"default_left":[0,0,0,0,0],"id":10,"left_children":[1,3,-1,-1,-1],"loss_changes":[2.2747332E-2,5.1604137E-3,0E0,0E0,0E0],"parents":[2147483647,0,0,1,1],"right_children":[2,4,-1,-1,-1],"split_conditions":[2.55E2,2.55E2,-5.997261E-3,1.740537E-2,5.997261E-3],"split_indices":[0,1,0,0,0],"split_type":[0,0,0,0,0],"sum_hessian":[2.5E1,1.7E1,8E0,9E0,8E0],"tree_param":{"num_deleted":"0","num_feature":"3","num_nodes":"5","size_leaf_vector":"0"}},{"base_weights":[1.6289638E-2,3.0859463E-2,-4.3979916E-3,1.2705914E-2,4.3979967E-3],"categories":[],"categories_nodes":[],"categories_segments":[],"categories_sizes":[],"default_left":[0,0,0,0,0],"id":11,"left_children":[1,3,-1,-1,-1],"loss_changes":[1.2176588E-2,2.730526E-3,0E0,0E0,0E0],"parents":[2147483647,0,0,1,1],"right_children":[2,4,-1,-1,-1],"split_conditions":[2.55E2,2.55E2,-4.3979916E-3,1.2705914E-2,4.3979967E-3],"split_indices":[0,1,0,0,0],"split_type":[0,0,0,0,0],"sum_hessian":[2.5E1,1.7E1,8E0,9E0,8E0],"tree_param":{"num_deleted":"0","num_feature":"3","num_nodes":"5","size_leaf_vector":"0"}},{"base_weights":[1.1891436E-2,2.2551842E-2,-3.2251936E-3,9.275315E-3,3.2251994E-3],"categories":[],"categories_nodes":[],"categories_segments":[],"categories_sizes":[],"default_left":[0,0,0,0,0],"id":12,"left_children":[1,3,-1,-1,-1],"loss_changes":[6.5181646E-3,1.4447039E-3,0E0,0E0,0E0],"parents":[2147483647,0,0,1,1],"right_children":[2,4,-1,-1,-1],"split_conditions":[2.55E2,2.55E2,-3.2251936E-3,9.275315E-3,3.2251994E-3],"split_indices":[0,1,0,0,0],"split_type":[0,0,0,0,0],"sum_hessian":[2.5E1,1.7E1,8E0,9E0,8E0],"tree_param":{"num_deleted":"0","num_feature":"3","num_nodes":"5","size_leaf_vector":"0"}},{"base_weights":[8.680745E-3,1.6480757E-2,-2.3651419E-3,6.770979E-3,2.365144E-3],"categories":[],"categories_nodes":[],"categories_segments":[],"categories_sizes":[],"default_left":[0,0,0,0,0],"id":13,"left_children":[1,3,-1,-1,-1],"loss_changes":[3.4892275E-3,7.643318E-4,0E0,0E0,0E0],"parents":[2147483647,0,0,1,1],"right_children":[2,4,-1,-1,-1],"split_conditions":[2.55E2,2.55E2,-2.3651419E-3,6.770979E-3,2.365144E-3],"split_indices":[0,1,0,0,0],"split_type":[0,0,0,0,0],"sum_hessian":[2.5E1,1.7E1,8E0,9E0,8E0],"tree_param":{"num_deleted":"0","num_feature":"3","num_nodes":"5","size_leaf_vector":"0"}},{"base_weights":[6.336939E-3,1.2044085E-2,-1.7344374E-3,4.942818E-3,1.7344317E-3],"categories":[],"categories_nodes":[],"categories_segments":[],"categories_sizes":[],"default_left":[0,0,0,0,0],"id":14,"left_children":[1,3,-1,-1,-1],"loss_changes":[1.8678304E-3,4.043507E-4,0E0,0E0,0E0],"parents":[2147483647,0,0,1,1],"right_children":[2,4,-1,-1,-1],"split_conditions":[2.55E2,2.55E2,-1.7344374E-3,4.942818E-3,1.7344317E-3],"split_indices":[0,1,0,0,0],"split_type":[0,0,0,0,0],"sum_hessian":[2.5E1,1.7E1,8E0,9E0,8E0],"tree_param":{"num_deleted":"0","num_feature":"3","num_nodes":"5","size_leaf_vector":"0"}},{"base_weights":[4.6259793E-3,8.801837E-3,-1.2719207E-3,3.6082687E-3,1.2719154E-3],"categories":[],"categories_nodes":[],"categories_segments":[],"categories_sizes":[],"default_left":[0,0,0,0,0],"id":15,"left_children":[1,3,-1,-1,-1],"loss_changes":[9.998886E-4,2.1389709E-4,0E0,0E0,0E0],"parents":[2147483647,0,0,1,1],"right_children":[2,4,-1,-1,-1],"split_conditions":[2.55E2,2.55E2,-1.2719207E-3,3.6082687E-3,1.2719154E-3],"split_indices":[0,1,0,0,0],"split_type":[0,0,0,0,0],"sum_hessian":[2.5E1,1.7E1,8E0,9E0,8E0],"tree_param":{"num_deleted":"0","num_feature":"3","num_nodes":"5","size_leaf_vector":"0"}}]},"name":"gbtree"},"learner_model_param":{"base_score":"5E-1","num_class":"0","num_feature":"3","num_target":"1"},"objective":{"name":"reg:squarederror","reg_loss_param":{"scale_pos_weight":"1"}}},"version":[1,6,2]} \ No newline at end of file From 3e4540e9dfd091be4b47ed9be19808ce912aef7b Mon Sep 17 00:00:00 2001 From: Marcus Weber Date: Fri, 2 Dec 2022 13:59:37 +0100 Subject: [PATCH 2/3] remove patch versioning --- operators/Cargo.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/operators/Cargo.toml b/operators/Cargo.toml index 503bdf9a1..fdc88113a 100644 --- a/operators/Cargo.toml +++ b/operators/Cargo.toml @@ -50,13 +50,13 @@ tracing = "0.1" tracing-opentelemetry = "0.18" typetag = "0.2" uuid = { version = "1.1", features = ["serde", "v4", "v5"] } -xgboost-rs = "0.1.2" +xgboost-rs = "0.1" [dev-dependencies] async-stream = "0.3" geo-rand = { git = "https://github.com/lelongg/geo-rand", tag = "v0.3.0" } rand = "0.8" -ndarray = "0.15.6" +ndarray = "0.15" [[bench]] From d52107f638a4116d28b5c3a59d37d1574f96cf36 Mon Sep 17 00:00:00 2001 From: Marcus Weber Date: Fri, 2 Dec 2022 14:57:17 +0100 Subject: [PATCH 3/3] fix: remove unneccessary type declaration --- datatypes/src/error.rs | 3 ++- datatypes/src/util/mod.rs | 2 +- operators/src/pro/ml/xgboost.rs | 2 +- 3 files changed, 4 insertions(+), 3 deletions(-) diff --git a/datatypes/src/error.rs b/datatypes/src/error.rs index 099839387..e4a40422d 100644 --- a/datatypes/src/error.rs +++ b/datatypes/src/error.rs @@ -7,7 +7,8 @@ use crate::{ use snafu::{prelude::*, AsErrorSource, ErrorCompat, IntoError}; use std::{any::Any, convert::Infallible, path::PathBuf, sync::Arc}; -pub type Result = std::result::Result; +use crate::util::Result; + pub trait ErrorSource: std::error::Error + Send + Sync + Any + 'static + AsErrorSource { fn boxed(self) -> Box where diff --git a/datatypes/src/util/mod.rs b/datatypes/src/util/mod.rs index a108df012..49fb25897 100644 --- a/datatypes/src/util/mod.rs +++ b/datatypes/src/util/mod.rs @@ -19,7 +19,7 @@ pub use result::Result; /// /// This only works if the `Path` you are referring to actually exists. /// -pub fn canonicalize_subpath(base: &Path, sub_path: &Path) -> crate::error::Result { +pub fn canonicalize_subpath(base: &Path, sub_path: &Path) -> Result { let base = base.canonicalize()?; let path = base.join(sub_path).canonicalize()?; diff --git a/operators/src/pro/ml/xgboost.rs b/operators/src/pro/ml/xgboost.rs index 0e984fd01..f5597463b 100644 --- a/operators/src/pro/ml/xgboost.rs +++ b/operators/src/pro/ml/xgboost.rs @@ -785,7 +785,7 @@ mod tests { let bst = booster_vec.pop().unwrap(); // test tile looks like this: - // GREEN || BLUE || TEMP + // GREEN || BLUE || TEMP // 0, 0, 0, 0, 0 || 100, 100, 100, 100, 100 || 0, 0, 0, 0, 0 // 0, 0, 255, 0, 0 || 100, 0, 0, 255, 100 || 0, 60, 15, 5, 0 // 0, 255, 0, 255, 0 || 100, 0, 0, 0, 100 || 0, 15, 60, 15, 0