diff --git a/pangeo_forge_recipes/transforms.py b/pangeo_forge_recipes/transforms.py index d1aab73f..c171fcb5 100644 --- a/pangeo_forge_recipes/transforms.py +++ b/pangeo_forge_recipes/transforms.py @@ -2,11 +2,10 @@ import logging from dataclasses import dataclass, field -from typing import Any, Dict, Optional, Tuple +from typing import Any, Optional, Tuple import apache_beam as beam import xarray as xr -import zarr from .patterns import Index from .storage import CacheFSSpecTarget, OpenFileType, get_opener @@ -48,33 +47,34 @@ def expand(self, pcoll): class OpenWithXarray(beam.PTransform): xarray_open_kwargs: Optional[dict] = field(default_factory=dict) + load: bool = False def _open_with_xarray(self, element: Tuple[Index, Any]) -> Tuple[Index, xr.Dataset]: key, open_file = element - with open_file as fp: - with xr.open_dataset(fp, **self.xarray_open_kwargs) as ds: - return key, ds + logger.debug(f"_open_with_xarray({key}, {open_file})") + # workaround fsspec inconsistencies + if hasattr(open_file, "open"): + open_file = open_file.open() + ds = xr.open_dataset(open_file, cache=False, lock=False, **self.xarray_open_kwargs) + if self.load: + ds.load() + return key, ds def expand(self, pcoll): return pcoll | "Open with Xarray" >> beam.Map(self._open_with_xarray) -@beam.typehints.with_input_types(Tuple[Index, xr.Dataset]) -@beam.typehints.with_output_types(Tuple[Index, Dict]) -@dataclass -class GetXarraySchema(beam.PTransform): - def expand(self, pcoll): - pass - - -@beam.typehints.with_input_types(Dict) -@beam.typehints.with_output_types(zarr.Group) -@dataclass -class CreateZarrFromSchema(beam.PTransform): - def expand(self, pcoll): - pass - - -# all_datasets = beam.Create(file_pattern) | OpenWithFSSpec() | OpenWithXarray() -# target_zarr = all_datasets | GetXarraySchema() | CreateZarrFromSchema() -# output = all_datasets | WriteZarrChunks(target=beam.pvalue.AsSingleton(target_zarr)) +# @beam.typehints.with_input_types(Tuple[Index, xr.Dataset]) +# @beam.typehints.with_output_types(Tuple[Index, Dict]) +# @dataclass +# class GetXarraySchema(beam.PTransform): +# def expand(self, pcoll): +# pass +# +# +# @beam.typehints.with_input_types(Dict) +# @beam.typehints.with_output_types(zarr.Group) +# @dataclass +# class CreateZarrFromSchema(beam.PTransform): +# def expand(self, pcoll): +# pass diff --git a/tests/test_beam.py b/tests/test_beam.py index edd0242e..17337aeb 100644 --- a/tests/test_beam.py +++ b/tests/test_beam.py @@ -16,12 +16,13 @@ lazy_fixture("netcdf_local_file_pattern_sequential"), lazy_fixture("netcdf_http_file_pattern_sequential_1d"), ], + ids=["local", "http"], ) def pattern(request): return request.param -@pytest.fixture(params=[True, False]) +@pytest.fixture(params=[True, False], ids=["with_cache", "no_cache"]) def cache(tmp_cache, request): if request.param: return tmp_cache @@ -71,11 +72,17 @@ def _is_readable(actual): assert cache.exists(fname) -def is_xr_dataset(): +def is_xr_dataset(in_memory=False): def _is_xr_dataset(actual): for _, ds in actual: if not isinstance(ds, xr.Dataset): raise BeamAssertException(f"Object {ds} has type {type(ds)}, expected xr.Dataset.") + offending_vars = [ + vname for vname in ds.data_vars if ds[vname].variable._in_memory != in_memory + ] + if offending_vars: + msg = "were NOT in memory" if in_memory else "were in memory" + raise BeamAssertException(f"The following vars {msg}: {offending_vars}") return _is_xr_dataset @@ -86,8 +93,22 @@ def pcoll_xarray_datasets(pcoll_opened_files): return open_files | OpenWithXarray() -def test_OpenWithXarray(pcoll_xarray_datasets): +@pytest.mark.parametrize("load", [False, True]) +def test_OpenWithXarray(pcoll_opened_files, load): + input, pattern, cache = pcoll_opened_files with TestPipeline() as p: - output = p | pcoll_xarray_datasets + output = p | input | OpenWithXarray(load=load) + assert_that(output, is_xr_dataset(in_memory=load)) - assert_that(output, is_xr_dataset(), label="is xr.Dataset") + +def test_OpenWithXarray_downstream_load(pcoll_opened_files): + input, pattern, cache = pcoll_opened_files + + def manually_load(item): + key, ds = item + return key, ds.load() + + with TestPipeline() as p: + output = p | input | OpenWithXarray(load=False) + loaded_dsets = output | beam.Map(manually_load) + assert_that(loaded_dsets, is_xr_dataset(in_memory=True))