diff --git a/pangeo_forge/recipe.py b/pangeo_forge/recipe.py index 19afe66a..a2d8cb05 100644 --- a/pangeo_forge/recipe.py +++ b/pangeo_forge/recipe.py @@ -142,6 +142,10 @@ class NetCDFtoZarrSequentialRecipe(BaseRecipe): :param delete_input_encoding: Whether to remove Xarray encoding from variables in the input dataset :param fsspec_open_kwargs: Extra options for opening the inputs with fsspec. + :param process_input: Function to call on each opened input, with signature + `(ds: xr.Dataset, filename: str) -> ds: xr.Dataset`. + :param process_chunk: Function to call on each concatenated chunk, with signature + `(ds: xr.Dataset) -> ds: xr.Dataset`. """ input_urls: Iterable[str] = field(repr=False) @@ -156,6 +160,8 @@ class NetCDFtoZarrSequentialRecipe(BaseRecipe): xarray_concat_kwargs: dict = field(default_factory=dict) delete_input_encoding: bool = True fsspec_open_kwargs: dict = field(default_factory=dict) + process_input: Optional[Callable[[xr.Dataset, str], xr.Dataset]] = None + process_chunk: Optional[Callable[[xr.Dataset], xr.Dataset]] = None def __post_init__(self): self._chunks_inputs = { @@ -255,7 +261,11 @@ def open_input(self, fname: str): for var in ds.variables: ds[var].encoding = {} + if self.process_input is not None: + ds = self.process_input(ds, str(fname)) + logger.debug(f"{ds}") + return ds def open_chunk(self, chunk_key): @@ -265,6 +275,10 @@ def open_chunk(self, chunk_key): # CONCAT DELETES ENCODING!!! # OR NO IT DOESN'T! Not in the latest version of xarray? ds = xr.concat(dsets, self.sequence_dim, **self.xarray_concat_kwargs) + + if self.process_chunk is not None: + ds = self.process_chunk(ds) + logger.debug(f"{ds}") # TODO: maybe do some chunking here? diff --git a/tests/test_recipe.py b/tests/test_recipe.py index 6ea1ae35..1cf2a51c 100644 --- a/tests/test_recipe.py +++ b/tests/test_recipe.py @@ -8,6 +8,13 @@ dummy_fnames = ["a.nc", "b.nc", "c.nc"] +def incr_date(ds, filename=""): + # add one day + t = [d + int(24 * 3600e9) for d in ds.time.values] + ds = ds.assign_coords(time=t) + return ds + + @pytest.mark.skip(reason="Removed this class for now") @pytest.mark.parametrize( "file_urls, files_per_chunk, expected_keys, expected_filenames", @@ -72,8 +79,12 @@ def test_NetCDFtoZarrSequentialRecipeHttpAuth( assert ds_target.identical(ds_expected) +@pytest.mark.parametrize( + "process_input, process_chunk", + [(None, None), (incr_date, None), (None, incr_date), (incr_date, incr_date)], +) def test_NetCDFtoZarrSequentialRecipe( - daily_xarray_dataset, netcdf_local_paths, tmp_target, tmp_cache + daily_xarray_dataset, netcdf_local_paths, tmp_target, tmp_cache, process_input, process_chunk ): # the same recipe is created as a fixture in conftest.py @@ -85,6 +96,8 @@ def test_NetCDFtoZarrSequentialRecipe( nitems_per_input=daily_xarray_dataset.attrs["items_per_file"], target=tmp_target, input_cache=tmp_cache, + process_input=process_input, + process_chunk=process_chunk, ) # this is the cannonical way to manually execute a recipe @@ -97,6 +110,18 @@ def test_NetCDFtoZarrSequentialRecipe( ds_target = xr.open_zarr(tmp_target.get_mapper(), consolidated=True).load() ds_expected = daily_xarray_dataset.compute() + + if process_input is not None: + # check that the process_input hook made some changes + assert not ds_target.identical(ds_expected) + # apply these changes to the expected dataset + ds_expected = process_input(ds_expected) + if process_chunk is not None: + # check that the process_chunk hook made some changes + assert not ds_target.identical(ds_expected) + # apply these changes to the expected dataset + ds_expected = process_chunk(ds_expected) + assert ds_target.identical(ds_expected)