Skip to content
forked from pydata/xarray

Commit

Permalink
Implement __dask_tokenize__
Browse files Browse the repository at this point in the history
  • Loading branch information
dcherian committed Oct 24, 2019
1 parent 652dd3c commit 4ab6a66
Show file tree
Hide file tree
Showing 6 changed files with 81 additions and 0 deletions.
2 changes: 2 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ Internal Changes

- Use Python 3.6 idioms throughout the codebase. (:pull:3419)
By `Maximilian Roos <https://github.com/max-sixty>`_
- Implement :py:func:`__dask_tokenize__` for xarray objects.
By `Deepak Cherian <https://github.com/dcherian>`_

.. _whats-new.0.14.0:

Expand Down
3 changes: 3 additions & 0 deletions xarray/core/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -754,6 +754,9 @@ def reset_coords(
dataset[self.name] = self.variable
return dataset

def __dask_tokenize__(self):
return (DataArray, self._variable, self._coords, self._name)

def __dask_graph__(self):
return self._to_temp_dataset().__dask_graph__()

Expand Down
3 changes: 3 additions & 0 deletions xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -648,6 +648,9 @@ def load(self, **kwargs) -> "Dataset":

return self

def __dask_tokenize__(self):
return (Dataset, self._variables, self._coord_names, self._attrs)

def __dask_graph__(self):
graphs = {k: v.__dask_graph__() for k, v in self.variables.items()}
graphs = {k: v for k, v in graphs.items() if v is not None}
Expand Down
6 changes: 6 additions & 0 deletions xarray/core/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,6 +389,9 @@ def compute(self, **kwargs):
new = self.copy(deep=False)
return new.load(**kwargs)

def __dask_tokenize__(self):
return Variable, self._dims, self.data, self._attrs

def __dask_graph__(self):
if isinstance(self._data, dask_array_type):
return self._data.__dask_graph__()
Expand Down Expand Up @@ -1961,6 +1964,9 @@ def __init__(self, dims, data, attrs=None, encoding=None, fastpath=False):
if not isinstance(self._data, PandasIndexAdapter):
self._data = PandasIndexAdapter(self._data)

def __dask_tokenize__(self):
return (IndexVariable, self._dims, self._data.array, self._attrs)

def load(self):
# data is already loaded into memory for IndexVariable
return self
Expand Down
55 changes: 55 additions & 0 deletions xarray/tests/test_dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
assert_identical,
raises_regex,
)
from .test_backends import create_tmp_file

dask = pytest.importorskip("dask")
da = pytest.importorskip("dask.array")
Expand Down Expand Up @@ -1135,3 +1136,57 @@ def test_make_meta(map_ds):
for variable in map_ds.data_vars:
assert variable in meta.data_vars
assert meta.data_vars[variable].shape == (0,) * meta.data_vars[variable].ndim


@pytest.mark.parametrize("obj", [make_da(), make_ds()])
@pytest.mark.parametrize(
"transform",
[
lambda x: x.reset_coords(),
lambda x: x.reset_coords(drop=True),
lambda x: x.isel(x=1),
lambda x: x.attrs.update(new_attrs=1),
lambda x: x.assign_coords(cxy=1),
lambda x: x.rename({"x": "xnew"}),
lambda x: x.rename({"cxy": "cxynew"}),
],
)
def test_normalize_token_not_identical(obj, transform):
with raise_if_dask_computes():
assert not dask.base.tokenize(obj) == dask.base.tokenize(transform(obj))
assert not dask.base.tokenize(obj.compute()) == dask.base.tokenize(
transform(obj.compute())
)


@pytest.mark.parametrize("transform", [lambda x: x, lambda x: x.compute()])
def test_normalize_differently_when_data_changes(transform):
obj = transform(make_ds())
new = obj.copy(deep=True)
new["a"] *= 2
with raise_if_dask_computes():
assert not dask.base.tokenize(obj) == dask.base.tokenize(new)

obj = transform(make_da())
new = obj.copy(deep=True)
new *= 2
with raise_if_dask_computes():
assert not dask.base.tokenize(obj) == dask.base.tokenize(new)


@pytest.mark.parametrize(
"transform", [lambda x: x, lambda x: x.copy(), lambda x: x.copy(deep=True)]
)
@pytest.mark.parametrize(
"obj", [make_da(), make_ds(), make_da().indexes["x"], make_ds().variables["a"]]
)
def test_normalize_token_identical(obj, transform):
with raise_if_dask_computes():
assert dask.base.tokenize(obj) == dask.base.tokenize(transform(obj))


def test_normalize_token_netcdf_backend(map_ds):
with create_tmp_file() as tmp_file:
map_ds.to_netcdf(tmp_file)
read = xr.open_dataset(tmp_file)
assert not dask.base.tokenize(map_ds) == dask.base.tokenize(read)
12 changes: 12 additions & 0 deletions xarray/tests/test_sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
)

sparse = pytest.importorskip("sparse")
dask = pytest.importorskip("dask")


def assert_sparse_equal(a, b):
Expand Down Expand Up @@ -849,3 +850,14 @@ def test_chunk():
dsc = ds.chunk(2)
assert dsc.chunks == {"dim_0": (2, 2)}
assert_identical(dsc, ds)


def test_normalize_token():
s = sparse.COO.from_numpy(np.array([0, 0, 1, 2]))
a = DataArray(s)
dask.base.tokenize(a)
assert isinstance(a.data, sparse.COO)

ac = a.chunk(2)
dask.base.tokenize(ac)
assert isinstance(ac.data._meta, sparse.COO)

0 comments on commit 4ab6a66

Please sign in to comment.