Skip to content

Commit

Permalink
Merge #639
Browse files Browse the repository at this point in the history
639: Feature/xgboost r=michaelmattig a=MarWeUMR

- [ x] I added an entry to [`CHANGELOG.md`](CHANGELOG.md) if knowledge of this change could be valuable to users.

---

This PR is intended to add functionality for using `XGBoost` with a pre-trained model to generate predictions based on some raster tiles.

There are three main parts:

- the operator itself: xgboost.rs
- the xgboost-sys crate to generate the bindings and build the library
- the bindings module where all the higher level functionality is added to the bindings

In the operator is a test method located, which uses the `marburg`dataset to predict land usage. This test takes some time to finish and should be replaced with an easier dataset just for testing purpose.


Co-authored-by: Marcus Weber <weberma@students.uni-marburg.de>
  • Loading branch information
bors[bot] and MarWeUMR authored Dec 7, 2022
2 parents 48190fb + d52107f commit 406ef23
Show file tree
Hide file tree
Showing 22 changed files with 1,338 additions and 44 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Added

- Added a new operator `XGBoost`. This operator allows to use a pretrained model in order to make predictions based on some set of raster tile data.

- https://github.com/geo-engine/geoengine/pull/639

- Added a handler (`/available`) to the API to check if the service is available.

- https://github.com/geo-engine/geoengine/pull/681
Expand Down
3 changes: 3 additions & 0 deletions Settings-default.toml
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,9 @@ provider_defs_path = "./test_data/provider_defs"
layer_defs_path = "./test_data/layer_defs"
layer_collection_defs_path = "./test_data/layer_collection_defs"

[machinelearning]
model_defs_path = "./test_data/pro/ml"

[gdal]
# TODO: find good default
# Use 0 for `ALL_CPUS` option or a number >0 for a specific number of threads.
Expand Down
4 changes: 3 additions & 1 deletion datatypes/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ async-trait = "0.1"
chrono = "0.4"
float-cmp = "0.9"
gdal = "0.14"
geo-types = "0.7"
geo = "0.23"
geojson = "0.24"
image = "0.24"
Expand All @@ -37,6 +38,7 @@ uuid = { version = "1.1", features = ["serde", "v4", "v5"] }

[dev-dependencies]
criterion = "0.4"
tempfile = "3.1"

[[bench]]
name = "multi_point_collection"
Expand All @@ -52,4 +54,4 @@ harness = false

[[bench]]
name = "masked_grid_mapping"
harness = false
harness = false
19 changes: 18 additions & 1 deletion datatypes/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@ use crate::{
spatial_reference::SpatialReference,
};
use snafu::{prelude::*, AsErrorSource, ErrorCompat, IntoError};
use std::{any::Any, convert::Infallible, sync::Arc};
use std::{any::Any, convert::Infallible, path::PathBuf, sync::Arc};

use crate::util::Result;

pub trait ErrorSource: std::error::Error + Send + Sync + Any + 'static + AsErrorSource {
fn boxed(self) -> Box<dyn ErrorSource>
Expand Down Expand Up @@ -313,6 +315,15 @@ pub enum Error {
a: SpatialReference,
b: SpatialReference,
},

Io {
source: std::io::Error,
},

SubPathMustNotEscapeBasePath {
base: PathBuf,
sub_path: PathBuf,
},
}

impl From<arrow::error::ArrowError> for Error {
Expand All @@ -338,3 +349,9 @@ impl From<gdal::errors::GdalError> for Error {
Self::Gdal { source: gdal_error }
}
}

impl From<std::io::Error> for Error {
fn from(e: std::io::Error) -> Self {
Self::Io { source: e }
}
}
42 changes: 42 additions & 0 deletions datatypes/src/util/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,48 @@ mod result;
pub mod well_known_data;

pub mod test;
use std::path::{Path, PathBuf};

pub use self::identifiers::Identifier;
pub use any::{AsAny, AsAnyArc};
pub use result::Result;

/// Canonicalize `base`/`sub_path` and ensure the `sub_path` doesn't escape the `base`
/// returns an error if the `sub_path` escapes the `base`
///
/// This only works if the `Path` you are referring to actually exists.
///
pub fn canonicalize_subpath(base: &Path, sub_path: &Path) -> Result<PathBuf> {
let base = base.canonicalize()?;
let path = base.join(sub_path).canonicalize()?;

if path.starts_with(&base) {
Ok(path)
} else {
Err(crate::error::Error::SubPathMustNotEscapeBasePath {
base,
sub_path: sub_path.into(),
})
}
}

#[cfg(test)]
mod mod_tests {
use super::*;
#[test]
fn it_doesnt_escape_base_path() {
let tmp_dir = tempfile::tempdir().unwrap();
let tmp_path = tmp_dir.path();
std::fs::create_dir_all(tmp_path.join("foo/bar/foobar")).unwrap();
std::fs::create_dir_all(tmp_path.join("foo/barfoo")).unwrap();

assert_eq!(
canonicalize_subpath(&tmp_path.join("foo/bar"), Path::new("foobar"))
.unwrap()
.to_string_lossy(),
tmp_path.join("foo/bar/foobar").to_string_lossy()
);

assert!(canonicalize_subpath(&tmp_path.join("foo/bar"), Path::new("../barfoo")).is_err());
}
}
2 changes: 2 additions & 0 deletions operators/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,13 @@ tracing = "0.1"
tracing-opentelemetry = "0.18"
typetag = "0.2"
uuid = { version = "1.1", features = ["serde", "v4", "v5"] }
xgboost-rs = "0.1"

[dev-dependencies]
async-stream = "0.3"
geo-rand = { git = "https://github.com/lelongg/geo-rand", tag = "v0.3.0" }
rand = "0.8"
ndarray = "0.15"


[[bench]]
Expand Down
35 changes: 35 additions & 0 deletions operators/src/engine/execution_context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,11 @@ use std::any::Any;
use std::collections::HashMap;
use std::fmt::Debug;
use std::marker::PhantomData;
use std::path::PathBuf;
use std::sync::Arc;

/// A context that provides certain utility access during operator initialization
#[async_trait::async_trait]
pub trait ExecutionContext: Send
+ Sync
+ MetaDataProvider<MockDatasetDataSourceLoadingInfo, VectorResultDescriptor, VectorQueryRectangle>
Expand All @@ -50,6 +52,10 @@ pub trait ExecutionContext: Send
op: Box<dyn InitializedPlotOperator>,
span: CreateSpan,
) -> Box<dyn InitializedPlotOperator>;

async fn read_ml_model(&self, path: PathBuf) -> Result<String>;

async fn write_ml_model(&mut self, path: PathBuf, ml_model_str: String) -> Result<()>;
}

#[async_trait]
Expand Down Expand Up @@ -84,6 +90,7 @@ pub struct MockExecutionContext {
pub thread_pool: Arc<ThreadPool>,
pub meta_data: HashMap<DataId, Box<dyn Any + Send + Sync>>,
pub tiling_specification: TilingSpecification,
pub ml_models: HashMap<PathBuf, String>,
}

impl TestDefault for MockExecutionContext {
Expand All @@ -92,6 +99,7 @@ impl TestDefault for MockExecutionContext {
thread_pool: create_rayon_thread_pool(0),
meta_data: HashMap::default(),
tiling_specification: TilingSpecification::test_default(),
ml_models: HashMap::default(),
}
}
}
Expand All @@ -102,6 +110,7 @@ impl MockExecutionContext {
thread_pool: create_rayon_thread_pool(0),
meta_data: HashMap::default(),
tiling_specification,
ml_models: HashMap::default(),
}
}

Expand All @@ -113,6 +122,7 @@ impl MockExecutionContext {
thread_pool: create_rayon_thread_pool(num_threads),
meta_data: HashMap::default(),
tiling_specification,
ml_models: HashMap::default(),
}
}

Expand All @@ -136,8 +146,17 @@ impl MockExecutionContext {
abort_trigger: Some(abort_trigger),
}
}

pub fn initialize_ml_model(&mut self, model_path: PathBuf) -> Result<()> {
let model = std::fs::read_to_string(&model_path)?;

self.ml_models.insert(model_path, model);

Ok(())
}
}

#[async_trait::async_trait]
impl ExecutionContext for MockExecutionContext {
fn thread_pool(&self) -> &Arc<ThreadPool> {
&self.thread_pool
Expand Down Expand Up @@ -170,6 +189,22 @@ impl ExecutionContext for MockExecutionContext {
) -> Box<dyn InitializedPlotOperator> {
op
}

async fn read_ml_model(&self, path: PathBuf) -> Result<String> {
let res = self
.ml_models
.get(&path)
.ok_or(Error::MachineLearningModelNotFound)?
.clone();

Ok(res)
}

async fn write_ml_model(&mut self, path: PathBuf, ml_model_str: String) -> Result<()> {
self.ml_models.insert(path, ml_model_str);

Ok(())
}
}

#[async_trait]
Expand Down
20 changes: 20 additions & 0 deletions operators/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,26 @@ pub enum Error {
QueryCanceled,

AbortTriggerAlreadyUsed,

SubPathMustNotEscapeBasePath {
base: PathBuf,
sub_path: PathBuf,
},

InvalidDataProviderConfig,

InvalidMachineLearningConfig,

MachineLearningModelNotFound,

InvalidMlModelPath,
CouldNotGetMlModelDirectory,

#[cfg(feature = "pro")]
#[snafu(context(false))]
XGBoost {
source: crate::pro::ml::xgboost::XGBoostModuleError,
},
}

impl From<crate::adapters::SparseTilesFillAdapterError> for Error {
Expand Down
1 change: 1 addition & 0 deletions operators/src/pro/ml/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
pub mod xgboost;
Loading

0 comments on commit 406ef23

Please sign in to comment.