Skip to content

Commit

Permalink
Bypass open_dataset in to_xarray
Browse files Browse the repository at this point in the history
  • Loading branch information
sandorkertesz committed Oct 27, 2024
1 parent 5843ad1 commit 1df7bd8
Show file tree
Hide file tree
Showing 6 changed files with 134 additions and 42 deletions.
11 changes: 10 additions & 1 deletion src/earthkit/data/readers/grib/xarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,11 @@ def to_xarray(self, engine="earthkit", xarray_open_dataset_kwargs=None, **kwargs
Typecode or data-type of the array data.
* array_module: module
The module to use for array operations. Default is numpy.
* direct_backend: bool, None
If True, the backend is used directly bypassing :py:meth:`xarray.open_dataset`
and ignoring all non-backend related kwargs. If False, the data is read via
:py:meth:`xarray.open_dataset`. Its default value (None) expands
to False unless the ``profile`` overwrites it.
When ``engine="cfgrib"`` the following engine specific kwargs are supported:
Expand Down Expand Up @@ -367,9 +372,13 @@ def to_xarray_earthkit(self, user_kwargs):
# print(f"{kwargs=}")
# print(f"{xarray_open_dataset_kwargs=}")

# separate backend_kwargs from other_kwargs
backend_kwargs = xarray_open_dataset_kwargs.pop("backend_kwargs", None)
other_kwargs = xarray_open_dataset_kwargs

from earthkit.data.utils.xarray.builder import from_earthkit

return from_earthkit(self, **xarray_open_dataset_kwargs)
return from_earthkit(self, backend_kwargs=backend_kwargs, other_kwargs=other_kwargs)

def to_xarray_cfgrib(self, user_kwargs):
xarray_open_dataset_kwargs = {}
Expand Down
126 changes: 87 additions & 39 deletions src/earthkit/data/utils/xarray/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,19 @@
import xarray
import xarray.core.indexing as indexing

from earthkit.data.utils import ensure_dict
from earthkit.data.utils import ensure_iterable

from .profile import Profile

LOG = logging.getLogger(__name__)

# These backend_kwargs are also direct xarray.open_dataset kwargs
BACKEND_AND_XR_OPEN_DS_KWARGS = ["decode_times", "decode_timedelta", "drop_variables"]

# These kwargs cannot be passed to xarray.open_dataset (not even inside backend_kwargs)
NON_XR_OPEN_DS_KWARGS = ["split_dims", "direct_backend"]


class VariableBuilder:
def __init__(self, name, var_dims, data_maker, local_attr_keys, tensor, remapping):
Expand Down Expand Up @@ -407,25 +414,31 @@ class DatasetBuilder:
def __init__(
self,
ds,
profile="mars",
**kwargs,
backend_kwargs=None,
other_kwargs=None,
):
self.profile = Profile.make(profile, **kwargs)
self.ds = ds
backend_kwargs = ensure_dict(backend_kwargs)
other_kwargs = ensure_dict(other_kwargs)

# collect backend and other kwargs that can be passed to xarray.open_dataset
self.backend_kwargs = dict(**backend_kwargs)
for k in NON_XR_OPEN_DS_KWARGS:
self.backend_kwargs.pop(k, None)
self.xr_open_dataset_kwargs = dict(**other_kwargs)

profile = backend_kwargs.pop("profile", Profile.DEFAULT_PROFILE_NAME)
self.profile = Profile.make(profile, **backend_kwargs)

if self.profile.lazy_load:
self.builder = TensorBackendDataBuilder
else:
self.builder = MemoryBackendDataBuilder

self.kwargs = kwargs
self.grids = {}

# @cached_property
# def profile(self):
# from .profile import Profile
self.split_dims = self.profile.dims.split_dims
self.direct_backend = self.profile.direct_backend

# return Profile.make(self.profile_name, **self.kwargs)
self.grids = {}

def parse(self):
assert not hasattr(self.ds, "_ek_builder")
Expand All @@ -438,7 +451,6 @@ def parse(self):

# create a new fieldlist for optimised access to unique values
ds = XArrayInputFieldList(self.ds, keys=self.profile.index_keys, remapping=remapping)

# LOG.debug(f"{ds.db=}")
LOG.debug(f"before update: {self.profile.dim_keys=}")

Expand Down Expand Up @@ -470,12 +482,16 @@ def grid(self, ds):


class SingleDatasetBuilder(DatasetBuilder):
def __init__(self, *args, **kwargs):
split_dims = kwargs.get("split_dims", None)
if split_dims:
def __init__(self, *args, from_xr=False, **kwargs):
super().__init__(*args, **kwargs)

if self.split_dims:
raise ValueError("SingleDatasetMaker does not support splitting")

super().__init__(*args, **kwargs)
if from_xr and self.direct_backend:
raise ValueError(
"SingleDatasetMaker does not support direct_backend=True when invoked from xarray"
)

def build(self):
ds_sorted = self.parse()
Expand All @@ -492,21 +508,17 @@ def build(self):


class SplitDatasetBuilder(DatasetBuilder):
def __init__(self, *args, backend_kwargs=None, **kwargs):
def __init__(self, *args, **kwargs):
"""
split_dims: str, or iterable of str, None
Dimension or list of dimensions to use for splitting the data into multiple hypercubes.
Default is None.
"""
self.split_dims = backend_kwargs.pop("split_dims", None)
self.backend_kwargs = dict(**backend_kwargs)
self.xr_open_dataset_kwargs = dict(**kwargs)
super().__init__(*args, **kwargs)

if not self.split_dims:
raise ValueError("SplitDatasetMaker requires split_dims")

super().__init__(*args, split_dims=self.split_dims, **backend_kwargs)

def build(self):
from .splitter import Splitter

Expand All @@ -523,29 +535,65 @@ def build(self):
)

s_ds._ek_builder = builder
datasets.append(
xarray.open_dataset(s_ds, backend_kwargs=self.backend_kwargs, **self.xr_open_dataset_kwargs)
)
s_ds._ek_builder = None
if self.direct_backend:
datasets.append(builder.build())
else:
s_ds._ek_builder = builder
datasets.append(
xarray.open_dataset(
s_ds, backend_kwargs=self.backend_kwargs, **self.xr_open_dataset_kwargs
)
)
s_ds._ek_builder = None

return datasets[0] if len(datasets) == 1 else datasets


def from_earthkit(ds, **kwargs):
"""Create an xarray dataset from an earthkit fieldlist."""
backend_kwargs = kwargs.get("backend_kwargs", {})
split_dims = backend_kwargs.get("split_dims", None)
def from_earthkit(ds, backend_kwargs=None, other_kwargs=None):
"""Create an xarray dataset from an earthkit fieldlist.
assert kwargs["engine"] == "earthkit"
Parameters
----------
ds: FieldList
The input fieldlist.
backend_kwargs: dict, optional
Backend kwargs that can be passed to
:py:meth:`xarray.open_dataset` as "backend_kwargs".
other_kwargs: dict, optional
Additional kwargs passed to :py:meth:`xarray.open_dataset`. Cannot contain
any of the keys in ``backend_kwargs``.
"""
backend_kwargs = ensure_dict(backend_kwargs)
other_kwargs = ensure_dict(other_kwargs)

# certain kwargs are both backend_kwargs and other_kwargs. We copy them to
# backend_kwargs_full
backend_kwargs_full = dict(**backend_kwargs)
for k in BACKEND_AND_XR_OPEN_DS_KWARGS:
if k in other_kwargs:
backend_kwargs_full[k] = other_kwargs[k]

# to create the profile we need all the possible backend_kwargs (bar profile)
profile = backend_kwargs_full.pop("profile", Profile.DEFAULT_PROFILE_NAME)
profile = Profile.make(profile, **backend_kwargs_full)

# the backend builder is directly called bypassing xarray.open_dataset
if profile.direct_backend:
backend_kwargs_full["profile"] = profile
if not profile.dims.split_dims:
return SingleDatasetBuilder(ds, backend_kwargs=backend_kwargs_full).build()
else:
return SplitDatasetBuilder(ds, backend_kwargs=backend_kwargs_full).build()
# xarray.open_dataset is called
else:
assert other_kwargs["engine"] == "earthkit"
backend_kwargs["profile"] = profile

if not split_dims:
backend_kwargs.pop("split_dims", None)
# certain kwargs are not allowed in xarray.open_dataset
for k in NON_XR_OPEN_DS_KWARGS:
backend_kwargs.pop(k, None)

if len(kwargs) == 1:
assert kwargs["engine"] == "earthkit"
return SingleDatasetBuilder(ds, **backend_kwargs).build()
if not profile.dims.split_dims:
return xarray.open_dataset(ds, backend_kwargs=backend_kwargs, **other_kwargs)
else:
return xarray.open_dataset(ds, **kwargs)
else:
backend_kwargs = kwargs.pop("backend_kwargs", {})
return SplitDatasetBuilder(ds, backend_kwargs=backend_kwargs, **kwargs).build()
return SplitDatasetBuilder(ds, backend_kwargs=backend_kwargs, other_kwargs=other_kwargs).build()
1 change: 1 addition & 0 deletions src/earthkit/data/utils/xarray/defaults.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ array_module: numpy
# other
lazy_load: true
release_source: false
direct_backend: false
strict: false
errors: raise

Expand Down
2 changes: 1 addition & 1 deletion src/earthkit/data/utils/xarray/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,7 @@ def open_dataset(
else:
from .builder import SingleDatasetBuilder

return SingleDatasetBuilder(fieldlist, **_kwargs).build()
return SingleDatasetBuilder(fieldlist, from_xr=True, backend_kwargs=_kwargs).build()

@classmethod
def guess_can_open(cls, filename_or_obj):
Expand Down
6 changes: 6 additions & 0 deletions src/earthkit/data/utils/xarray/profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ def defaults(self):

class Profile:
USER_ONLY_OPTIONS = ["remapping", "patches"]
DEFAULT_PROFILE_NAME = "mars"

def __init__(
self,
Expand Down Expand Up @@ -147,6 +148,7 @@ def __init__(
self.decode_timedelta = kwargs.pop("decode_timedelta")
self.lazy_load = kwargs.pop("lazy_load")
self.release_source = kwargs.pop("release_source")
self.direct_backend = kwargs.pop("direct_backend")
self.strict = kwargs.pop("strict")
self.errors = kwargs.pop("errors")

Expand Down Expand Up @@ -174,6 +176,10 @@ def __init__(
@staticmethod
def make(name_or_def, *args, **kwargs):
# print("name_or_def", name_or_def)

if isinstance(name_or_def, Profile):
return name_or_def

if name_or_def is None:
name_or_def = {}

Expand Down
30 changes: 29 additions & 1 deletion tests/xr_engine/test_xr_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,8 @@ def test_xr_engine_detailed_check(api):
@pytest.mark.parametrize("stream", [False, True])
@pytest.mark.parametrize("lazy_load", [False, True])
@pytest.mark.parametrize("release_source", [False, True])
def test_xr_engine_detailed_flatten_check(stream, lazy_load, release_source):
@pytest.mark.parametrize("direct_backend", [False, True])
def test_xr_engine_detailed_flatten_check(stream, lazy_load, release_source, direct_backend):
filename = "test-data/xr_engine/level/pl.grib"
ds_ek, ds_ek_ref = load_grib_data(filename, "url", stream=stream)

Expand All @@ -257,6 +258,7 @@ def test_xr_engine_detailed_flatten_check(stream, lazy_load, release_source):
"add_valid_time_coord": False,
"lazy_load": lazy_load,
"release_source": release_source,
"direct_backend": direct_backend,
}
}
}
Expand Down Expand Up @@ -415,6 +417,32 @@ def test_xr_engine_detailed_flatten_check(stream, lazy_load, release_source):
assert np.allclose(r.values, vals_ref)


@pytest.mark.cache
@pytest.mark.parametrize(
"kwargs",
[
{"split_dims": ["step"]},
{"split_dims": None},
{"direct_backend": None},
{"direct_backend": True},
{"direct_backend": False},
],
)
def test_xr_engine_invalid_kwargs(kwargs):
ds_ek = from_source("url", earthkit_remote_test_data_file("test-data", "xr_engine", "level", "pl.grib"))

import xarray as xr

with pytest.raises(TypeError):
xr.open_dataset(
ds_ek.path,
engine="earthkit",
time_dim_mode="raw",
**kwargs,
)


@pytest.mark.cache
def test_xr_engine_dtype():
ds_ek = from_source("url", earthkit_remote_test_data_file("test-data/xr_engine/level/pl.grib"))

Expand Down

0 comments on commit 1df7bd8

Please sign in to comment.