Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/xgboost #639

Merged
merged 3 commits into from
Dec 7, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Added

- Added a new operator `XGBoost`. This operator allows to use a pretrained model in order to make predictions based on some set of raster tile data.

- https://github.com/geo-engine/geoengine/pull/639

- Added a handler (`/available`) to the API to check if the service is available.

- https://github.com/geo-engine/geoengine/pull/681
Expand Down
3 changes: 3 additions & 0 deletions Settings-default.toml
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,9 @@ provider_defs_path = "./test_data/provider_defs"
layer_defs_path = "./test_data/layer_defs"
layer_collection_defs_path = "./test_data/layer_collection_defs"

[machinelearning]
model_defs_path = "./test_data/pro/ml"

[gdal]
# TODO: find good default
# Use 0 for `ALL_CPUS` option or a number >0 for a specific number of threads.
Expand Down
4 changes: 3 additions & 1 deletion datatypes/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ async-trait = "0.1"
chrono = "0.4"
float-cmp = "0.9"
gdal = "0.14"
geo-types = "0.7"
geo = "0.23"
geojson = "0.24"
image = "0.24"
Expand All @@ -37,6 +38,7 @@ uuid = { version = "1.1", features = ["serde", "v4", "v5"] }

[dev-dependencies]
criterion = "0.4"
tempfile = "3.1"

[[bench]]
name = "multi_point_collection"
Expand All @@ -52,4 +54,4 @@ harness = false

[[bench]]
name = "masked_grid_mapping"
harness = false
harness = false
19 changes: 18 additions & 1 deletion datatypes/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@ use crate::{
spatial_reference::SpatialReference,
};
use snafu::{prelude::*, AsErrorSource, ErrorCompat, IntoError};
use std::{any::Any, convert::Infallible, sync::Arc};
use std::{any::Any, convert::Infallible, path::PathBuf, sync::Arc};

use crate::util::Result;

pub trait ErrorSource: std::error::Error + Send + Sync + Any + 'static + AsErrorSource {
fn boxed(self) -> Box<dyn ErrorSource>
Expand Down Expand Up @@ -313,6 +315,15 @@ pub enum Error {
a: SpatialReference,
b: SpatialReference,
},

Io {
source: std::io::Error,
},

SubPathMustNotEscapeBasePath {
base: PathBuf,
sub_path: PathBuf,
},
}

impl From<arrow::error::ArrowError> for Error {
Expand All @@ -338,3 +349,9 @@ impl From<gdal::errors::GdalError> for Error {
Self::Gdal { source: gdal_error }
}
}

impl From<std::io::Error> for Error {
fn from(e: std::io::Error) -> Self {
Self::Io { source: e }
}
}
42 changes: 42 additions & 0 deletions datatypes/src/util/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,48 @@ mod result;
pub mod well_known_data;

pub mod test;
use std::path::{Path, PathBuf};

pub use self::identifiers::Identifier;
pub use any::{AsAny, AsAnyArc};
pub use result::Result;

/// Canonicalize `base`/`sub_path` and ensure the `sub_path` doesn't escape the `base`
/// returns an error if the `sub_path` escapes the `base`
///
/// This only works if the `Path` you are referring to actually exists.
///
pub fn canonicalize_subpath(base: &Path, sub_path: &Path) -> Result<PathBuf> {
let base = base.canonicalize()?;
let path = base.join(sub_path).canonicalize()?;

if path.starts_with(&base) {
Ok(path)
} else {
Err(crate::error::Error::SubPathMustNotEscapeBasePath {
base,
sub_path: sub_path.into(),
})
}
}

#[cfg(test)]
mod mod_tests {
use super::*;
#[test]
fn it_doesnt_escape_base_path() {
let tmp_dir = tempfile::tempdir().unwrap();
let tmp_path = tmp_dir.path();
std::fs::create_dir_all(tmp_path.join("foo/bar/foobar")).unwrap();
std::fs::create_dir_all(tmp_path.join("foo/barfoo")).unwrap();

assert_eq!(
canonicalize_subpath(&tmp_path.join("foo/bar"), Path::new("foobar"))
.unwrap()
.to_string_lossy(),
tmp_path.join("foo/bar/foobar").to_string_lossy()
);

assert!(canonicalize_subpath(&tmp_path.join("foo/bar"), Path::new("../barfoo")).is_err());
}
}
2 changes: 2 additions & 0 deletions operators/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,13 @@ tracing = "0.1"
tracing-opentelemetry = "0.18"
typetag = "0.2"
uuid = { version = "1.1", features = ["serde", "v4", "v5"] }
xgboost-rs = "0.1"

[dev-dependencies]
async-stream = "0.3"
geo-rand = { git = "https://github.com/lelongg/geo-rand", tag = "v0.3.0" }
rand = "0.8"
ndarray = "0.15"


[[bench]]
Expand Down
35 changes: 35 additions & 0 deletions operators/src/engine/execution_context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,11 @@ use std::any::Any;
use std::collections::HashMap;
use std::fmt::Debug;
use std::marker::PhantomData;
use std::path::PathBuf;
use std::sync::Arc;

/// A context that provides certain utility access during operator initialization
#[async_trait::async_trait]
pub trait ExecutionContext: Send
+ Sync
+ MetaDataProvider<MockDatasetDataSourceLoadingInfo, VectorResultDescriptor, VectorQueryRectangle>
Expand All @@ -50,6 +52,10 @@ pub trait ExecutionContext: Send
op: Box<dyn InitializedPlotOperator>,
span: CreateSpan,
) -> Box<dyn InitializedPlotOperator>;

async fn read_ml_model(&self, path: PathBuf) -> Result<String>;

async fn write_ml_model(&mut self, path: PathBuf, ml_model_str: String) -> Result<()>;
}

#[async_trait]
Expand Down Expand Up @@ -84,6 +90,7 @@ pub struct MockExecutionContext {
pub thread_pool: Arc<ThreadPool>,
pub meta_data: HashMap<DataId, Box<dyn Any + Send + Sync>>,
pub tiling_specification: TilingSpecification,
pub ml_models: HashMap<PathBuf, String>,
}

impl TestDefault for MockExecutionContext {
Expand All @@ -92,6 +99,7 @@ impl TestDefault for MockExecutionContext {
thread_pool: create_rayon_thread_pool(0),
meta_data: HashMap::default(),
tiling_specification: TilingSpecification::test_default(),
ml_models: HashMap::default(),
}
}
}
Expand All @@ -102,6 +110,7 @@ impl MockExecutionContext {
thread_pool: create_rayon_thread_pool(0),
meta_data: HashMap::default(),
tiling_specification,
ml_models: HashMap::default(),
}
}

Expand All @@ -113,6 +122,7 @@ impl MockExecutionContext {
thread_pool: create_rayon_thread_pool(num_threads),
meta_data: HashMap::default(),
tiling_specification,
ml_models: HashMap::default(),
}
}

Expand All @@ -136,8 +146,17 @@ impl MockExecutionContext {
abort_trigger: Some(abort_trigger),
}
}

pub fn initialize_ml_model(&mut self, model_path: PathBuf) -> Result<()> {
let model = std::fs::read_to_string(&model_path)?;

self.ml_models.insert(model_path, model);

Ok(())
}
}

#[async_trait::async_trait]
impl ExecutionContext for MockExecutionContext {
fn thread_pool(&self) -> &Arc<ThreadPool> {
&self.thread_pool
Expand Down Expand Up @@ -170,6 +189,22 @@ impl ExecutionContext for MockExecutionContext {
) -> Box<dyn InitializedPlotOperator> {
op
}

async fn read_ml_model(&self, path: PathBuf) -> Result<String> {
let res = self
.ml_models
.get(&path)
.ok_or(Error::MachineLearningModelNotFound)?
.clone();

Ok(res)
}

async fn write_ml_model(&mut self, path: PathBuf, ml_model_str: String) -> Result<()> {
self.ml_models.insert(path, ml_model_str);

Ok(())
}
}

#[async_trait]
Expand Down
20 changes: 20 additions & 0 deletions operators/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,26 @@ pub enum Error {
QueryCanceled,

AbortTriggerAlreadyUsed,

SubPathMustNotEscapeBasePath {
base: PathBuf,
sub_path: PathBuf,
},

InvalidDataProviderConfig,

InvalidMachineLearningConfig,

MachineLearningModelNotFound,

InvalidMlModelPath,
CouldNotGetMlModelDirectory,

#[cfg(feature = "pro")]
#[snafu(context(false))]
XGBoost {
source: crate::pro::ml::xgboost::XGBoostModuleError,
},
}

impl From<crate::adapters::SparseTilesFillAdapterError> for Error {
Expand Down
1 change: 1 addition & 0 deletions operators/src/pro/ml/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
pub mod xgboost;
Loading