Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Try to actually load the dataset #373

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 24 additions & 24 deletions pangeo_forge_recipes/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,10 @@

import logging
from dataclasses import dataclass, field
from typing import Any, Dict, Optional, Tuple
from typing import Any, Optional, Tuple

import apache_beam as beam
import xarray as xr
import zarr

from .patterns import Index
from .storage import CacheFSSpecTarget, OpenFileType, get_opener
Expand Down Expand Up @@ -48,33 +47,34 @@ def expand(self, pcoll):
class OpenWithXarray(beam.PTransform):

xarray_open_kwargs: Optional[dict] = field(default_factory=dict)
load: bool = False

def _open_with_xarray(self, element: Tuple[Index, Any]) -> Tuple[Index, xr.Dataset]:
key, open_file = element
with open_file as fp:
with xr.open_dataset(fp, **self.xarray_open_kwargs) as ds:
return key, ds
logger.debug(f"_open_with_xarray({key}, {open_file})")
# workaround fsspec inconsistencies
if hasattr(open_file, "open"):
open_file = open_file.open()
ds = xr.open_dataset(open_file, cache=False, lock=False, **self.xarray_open_kwargs)
if self.load:
ds.load()
return key, ds

def expand(self, pcoll):
return pcoll | "Open with Xarray" >> beam.Map(self._open_with_xarray)


@beam.typehints.with_input_types(Tuple[Index, xr.Dataset])
@beam.typehints.with_output_types(Tuple[Index, Dict])
@dataclass
class GetXarraySchema(beam.PTransform):
def expand(self, pcoll):
pass


@beam.typehints.with_input_types(Dict)
@beam.typehints.with_output_types(zarr.Group)
@dataclass
class CreateZarrFromSchema(beam.PTransform):
def expand(self, pcoll):
pass


# all_datasets = beam.Create(file_pattern) | OpenWithFSSpec() | OpenWithXarray()
# target_zarr = all_datasets | GetXarraySchema() | CreateZarrFromSchema()
# output = all_datasets | WriteZarrChunks(target=beam.pvalue.AsSingleton(target_zarr))
# @beam.typehints.with_input_types(Tuple[Index, xr.Dataset])
# @beam.typehints.with_output_types(Tuple[Index, Dict])
# @dataclass
# class GetXarraySchema(beam.PTransform):
# def expand(self, pcoll):
# pass
#
#
# @beam.typehints.with_input_types(Dict)
# @beam.typehints.with_output_types(zarr.Group)
# @dataclass
# class CreateZarrFromSchema(beam.PTransform):
# def expand(self, pcoll):
# pass
31 changes: 26 additions & 5 deletions tests/test_beam.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,13 @@
lazy_fixture("netcdf_local_file_pattern_sequential"),
lazy_fixture("netcdf_http_file_pattern_sequential_1d"),
],
ids=["local", "http"],
)
def pattern(request):
return request.param


@pytest.fixture(params=[True, False])
@pytest.fixture(params=[True, False], ids=["with_cache", "no_cache"])
def cache(tmp_cache, request):
if request.param:
return tmp_cache
Expand Down Expand Up @@ -71,11 +72,17 @@ def _is_readable(actual):
assert cache.exists(fname)


def is_xr_dataset():
def is_xr_dataset(in_memory=False):
def _is_xr_dataset(actual):
for _, ds in actual:
if not isinstance(ds, xr.Dataset):
raise BeamAssertException(f"Object {ds} has type {type(ds)}, expected xr.Dataset.")
offending_vars = [
vname for vname in ds.data_vars if ds[vname].variable._in_memory != in_memory
]
if offending_vars:
msg = "were NOT in memory" if in_memory else "were in memory"
raise BeamAssertException(f"The following vars {msg}: {offending_vars}")

return _is_xr_dataset

Expand All @@ -86,8 +93,22 @@ def pcoll_xarray_datasets(pcoll_opened_files):
return open_files | OpenWithXarray()


def test_OpenWithXarray(pcoll_xarray_datasets):
@pytest.mark.parametrize("load", [False, True])
def test_OpenWithXarray(pcoll_opened_files, load):
input, pattern, cache = pcoll_opened_files
with TestPipeline() as p:
output = p | pcoll_xarray_datasets
output = p | input | OpenWithXarray(load=load)
assert_that(output, is_xr_dataset(in_memory=load))

assert_that(output, is_xr_dataset(), label="is xr.Dataset")

def test_OpenWithXarray_downstream_load(pcoll_opened_files):
input, pattern, cache = pcoll_opened_files

def manually_load(item):
key, ds = item
return key, ds.load()

with TestPipeline() as p:
output = p | input | OpenWithXarray(load=False)
loaded_dsets = output | beam.Map(manually_load)
assert_that(loaded_dsets, is_xr_dataset(in_memory=True))