From 8464a96a122b054e9afd1e30490bb346e3e13f64 Mon Sep 17 00:00:00 2001 From: Aaron Zuspan Date: Tue, 26 Mar 2024 11:48:59 -0700 Subject: [PATCH 1/2] Add test to trigger pickling error Computing a dataset using "processes" or a Client scheduler currently fails due to a pickling error. This adds a corresponding test that will need to be fixed. --- tests/test_datasets.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/tests/test_datasets.py b/tests/test_datasets.py index 2be79c6..147a68a 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -1,3 +1,4 @@ +import pickle from dataclasses import dataclass import pytest @@ -36,6 +37,10 @@ def test_load_dataset(configuration: DatasetConfiguration, as_dataset: bool): assert y.shape == (configuration.n_samples, configuration.n_targets) if as_dataset: + # Some Dask schedulers require pickling, so ensure that the loaded dataset is + # pickleable during compute. We could try computing directly, but that is much + # slower. + assert pickle.dumps(X_image) assert list(X.columns) == list(X_image.data_vars) assert X_image.sizes == { "y": configuration.image_size[0], From 199536955856cfd8d4d05dcd74cd18cd3e4ffb44 Mon Sep 17 00:00:00 2001 From: Aaron Zuspan Date: Tue, 26 Mar 2024 11:52:47 -0700 Subject: [PATCH 2/2] Fix failing pickling test `PosixPath.open()` returns a `BufferedReader` that is not pickleable, breaking dask computation with `Client` or "processes" schedulers. This still uses `resources` to identify data file paths, but opens directly from the paths instead of file buffers to fix pickling. --- src/sknnr_spatial/datasets/_base.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/src/sknnr_spatial/datasets/_base.py b/src/sknnr_spatial/datasets/_base.py index 2a8f8c3..2c84f48 100644 --- a/src/sknnr_spatial/datasets/_base.py +++ b/src/sknnr_spatial/datasets/_base.py @@ -19,10 +19,13 @@ def _load_rasters_to_dataset( """Load a list of rasters from the data module as an xarray Dataset.""" das = [] for file_name, var_name in zip(file_names, var_names): - with resources.files(module_name).joinpath(file_name).open("rb") as bin: - da = rioxarray.open_rasterio(bin, chunks=chunks) + path = resources.files(module_name).joinpath(file_name) + da = ( + rioxarray.open_rasterio(path, chunks=chunks) + .to_dataset(dim="band") + .rename({1: var_name}) + ) - da = da.to_dataset(dim="band").rename({1: var_name}) das.append(da) return xr.merge(das)