Skip to content

Commit

Permalink
simplify logic + no_data_val for xgboost
Browse files Browse the repository at this point in the history
  • Loading branch information
MarWeUMR committed Oct 26, 2022
1 parent 5927879 commit 76fe4df
Showing 1 changed file with 21 additions and 25 deletions.
46 changes: 21 additions & 25 deletions operators/src/pro/ml/xgboost.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@ use geoengine_datatypes::primitives::{
SpatialResolution,
};
use geoengine_datatypes::raster::{
BaseTile, Grid2D, GridOrEmpty, GridShape, GridShapeAccess, Pixel, RasterDataType, RasterTile2D,
BaseTile, Grid2D, GridOrEmpty, GridShape, GridShapeAccess, GridSize, Pixel, RasterDataType,
RasterTile2D,
};
use rayon::prelude::ParallelIterator;
use rayon::slice::ParallelSlice;
Expand All @@ -35,6 +36,7 @@ use TypedRasterQueryProcessor::F32 as QueryProcessorOut;
#[serde(rename_all = "camelCase")]
pub struct XgboostParams {
model_file_path: PathBuf,
no_data_value: Option<f32>,
}

pub type XgboostOperator = Operator<XgboostParams, MultipleRasterSources>;
Expand All @@ -44,6 +46,7 @@ pub struct InitializedXgboostOperator {
result_descriptor: RasterResultDescriptor,
sources: Vec<Box<dyn InitializedRasterOperator>>,
model_file_path: PathBuf,
no_data_value: Option<f32>,
}

type PixelOut = f32;
Expand Down Expand Up @@ -125,6 +128,7 @@ impl RasterOperator for XgboostOperator {
result_descriptor: out_desc,
sources: init_rasters,
model_file_path: self.params.model_file_path,
no_data_value: self.params.no_data_value,
};

Ok(initialized_operator.boxed())
Expand All @@ -137,8 +141,6 @@ impl InitializedRasterOperator for InitializedXgboostOperator {
}

fn query_processor(&self) -> Result<TypedRasterQueryProcessor> {
// given a set of raster bands, if they have different types,
// where should they be converted?
let vec_of_rqps: Vec<Box<dyn RasterQueryProcessor<RasterType = f32>>> = self
.sources
.iter()
Expand All @@ -148,6 +150,7 @@ impl InitializedRasterOperator for InitializedXgboostOperator {
Ok(QueryProcessorOut(Box::new(XgboostProcessor::new(
vec_of_rqps,
self.model_file_path.clone(),
self.no_data_value
))))
}
}
Expand All @@ -158,17 +161,19 @@ where
{
sources: Vec<Q>,
model_file: PathBuf,
no_data_value: Option<f32>,
}

impl<Q, P> XgboostProcessor<Q, P>
where
Q: RasterQueryProcessor<RasterType = P>,
P: Pixel,
{
pub fn new(sources: Vec<Q>, model_file_path: PathBuf) -> Self {
pub fn new(sources: Vec<Q>, model_file_path: PathBuf, no_data_value: Option<f32>) -> Self {
Self {
sources,
model_file: model_file_path,
no_data_value
}
}

Expand All @@ -180,16 +185,16 @@ where
let tile = bands_of_tile.pop().ok_or(error::Error::EmptyInput)??;

// gather the data
let n_rows = tile.grid_shape_array()[0];
let n_cols = tile.grid_shape_array()[1];
let grid_shape = tile.grid_shape();
let n_bands = bands_of_tile.len() as i32;
let n_bands = bands_of_tile.len() as i32 + 1; // +1 because of .pop()
let props = tile.properties.clone(); // = &tile.properties;
let time = tile.time;
let tile_position = tile.tile_position;
let global_geo_transform = tile.global_geo_transform;

let model = self.model_file.clone();
let ndv = self.no_data_value.or(Some(f32::NAN));

let predicted_grid: geoengine_datatypes::raster::Grid<GridShape<[usize; 2]>, f32> =
crate::util::spawn_blocking(move || {
// put the tile back into place
Expand Down Expand Up @@ -217,24 +222,16 @@ where
})
.collect();

let mut pixels: Vec<_> = Vec::new();
for row in 0..(n_rows * n_cols) {
let mut row_data: Vec<f32> = Vec::new();
for col in 0..n_bands {
let pxl = rasters[col as usize][row as usize].to_owned();
row_data.push(pxl);
}
pixels.extend_from_slice(&row_data);
}

let pixels: Vec<_> = rasters.clone().into_iter().flatten().collect();
let model_content = std::fs::read_to_string(model).unwrap();

process_tile(
&pixels.into(),
&pixels,
&pool,
model_content.as_bytes(),
grid_shape,
n_bands as usize,
ndv.unwrap(),
)
})
.await
Expand All @@ -259,11 +256,11 @@ fn process_tile(
model_file: &[u8],
grid_shape: GridShape<[usize; 2]>,
n_bands: usize,
nan_val: f32,
) -> geoengine_datatypes::raster::Grid<GridShape<[usize; 2]>, f32> {
pool.install(|| {
// to get one row of data means taking (n_pixels * n_bands) elements
let n_parallel_pixels = grid_shape.shape_array[0] * grid_shape.shape_array[1];
let chunk_size = n_bands * n_parallel_pixels;
let chunk_size = grid_shape.number_of_elements() * n_bands;

let res: Vec<_> = bands_of_tile
.par_chunks(chunk_size)
Expand All @@ -273,10 +270,10 @@ fn process_tile(
elem,
mem::size_of::<f32>() * n_bands,
mem::size_of::<f32>(),
n_parallel_pixels,
grid_shape.number_of_elements(),
n_bands,
-1,
0.0,
nan_val,
) {
Ok(matrix) => matrix,
Err(err) => panic!(
Expand All @@ -297,7 +294,7 @@ fn process_tile(
// measure time for prediction
match bst.predict_from_dmat(
&xg_matrix,
&[n_parallel_pixels as u64, n_bands as u64],
&[grid_shape.number_of_elements() as u64, n_bands as u64],
&mut out_dim,
) {
Ok(res) => res,
Expand Down Expand Up @@ -599,12 +596,11 @@ mod tests {

let model_path = project_path.join("test_data/s2_10m_de_marburg/model.json");



// xg-operator takes the input data for further processing
let xg = XgboostOperator {
params: XgboostParams {
model_file_path: model_path,
no_data_value: Some(-1000.),
},
sources: MultipleRasterSources { rasters: srcs },
};
Expand Down

0 comments on commit 76fe4df

Please sign in to comment.