Skip to content

Commit

Permalink
tokenize() should ignore difference between None and {} attrs (pydata…
Browse files Browse the repository at this point in the history
  • Loading branch information
crusaderky authored Mar 1, 2024
1 parent a241845 commit 604bb6d
Show file tree
Hide file tree
Showing 7 changed files with 38 additions and 28 deletions.
2 changes: 1 addition & 1 deletion xarray/core/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -1070,7 +1070,7 @@ def reset_coords(
dataset[self.name] = self.variable
return dataset

def __dask_tokenize__(self):
def __dask_tokenize__(self) -> object:
from dask.base import normalize_token

return normalize_token((type(self), self._variable, self._coords, self._name))
Expand Down
8 changes: 4 additions & 4 deletions xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -694,7 +694,7 @@ def __init__(
data_vars, coords
)

self._attrs = dict(attrs) if attrs is not None else None
self._attrs = dict(attrs) if attrs else None
self._close = None
self._encoding = None
self._variables = variables
Expand Down Expand Up @@ -739,7 +739,7 @@ def attrs(self) -> dict[Any, Any]:

@attrs.setter
def attrs(self, value: Mapping[Any, Any]) -> None:
self._attrs = dict(value)
self._attrs = dict(value) if value else None

@property
def encoding(self) -> dict[Any, Any]:
Expand Down Expand Up @@ -856,11 +856,11 @@ def load(self, **kwargs) -> Self:

return self

def __dask_tokenize__(self):
def __dask_tokenize__(self) -> object:
from dask.base import normalize_token

return normalize_token(
(type(self), self._variables, self._coord_names, self._attrs)
(type(self), self._variables, self._coord_names, self._attrs or None)
)

def __dask_graph__(self):
Expand Down
6 changes: 4 additions & 2 deletions xarray/core/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -2592,11 +2592,13 @@ def __init__(self, dims, data, attrs=None, encoding=None, fastpath=False):
if not isinstance(self._data, PandasIndexingAdapter):
self._data = PandasIndexingAdapter(self._data)

def __dask_tokenize__(self):
def __dask_tokenize__(self) -> object:
from dask.base import normalize_token

# Don't waste time converting pd.Index to np.ndarray
return normalize_token((type(self), self._dims, self._data.array, self._attrs))
return normalize_token(
(type(self), self._dims, self._data.array, self._attrs or None)
)

def load(self):
# data is already loaded into memory for IndexVariable
Expand Down
7 changes: 3 additions & 4 deletions xarray/namedarray/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -511,7 +511,7 @@ def attrs(self) -> dict[Any, Any]:

@attrs.setter
def attrs(self, value: Mapping[Any, Any]) -> None:
self._attrs = dict(value)
self._attrs = dict(value) if value else None

def _check_shape(self, new_data: duckarray[Any, _DType_co]) -> None:
if new_data.shape != self.shape:
Expand Down Expand Up @@ -570,13 +570,12 @@ def real(
return real(self)
return self._new(data=self._data.real)

def __dask_tokenize__(self) -> Hashable:
def __dask_tokenize__(self) -> object:
# Use v.data, instead of v._data, in order to cope with the wrappers
# around NetCDF and the like
from dask.base import normalize_token

s, d, a, attrs = type(self), self._dims, self.data, self.attrs
return normalize_token((s, d, a, attrs)) # type: ignore[no-any-return]
return normalize_token((type(self), self._dims, self.data, self._attrs or None))

def __dask_graph__(self) -> Graph | None:
if is_duck_dask_array(self._data):
Expand Down
4 changes: 2 additions & 2 deletions xarray/namedarray/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ def __eq__(self, other: ReprObject | Any) -> bool:
def __hash__(self) -> int:
return hash((type(self), self._value))

def __dask_tokenize__(self) -> Hashable:
def __dask_tokenize__(self) -> object:
from dask.base import normalize_token

return normalize_token((type(self), self._value)) # type: ignore[no-any-return]
return normalize_token((type(self), self._value))
35 changes: 24 additions & 11 deletions xarray/tests/test_dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,17 +299,6 @@ def test_persist(self):
self.assertLazyAndAllClose(u + 1, v)
self.assertLazyAndAllClose(u + 1, v2)

def test_tokenize_empty_attrs(self) -> None:
# Issue #6970
assert self.eager_var._attrs is None
expected = dask.base.tokenize(self.eager_var)
assert self.eager_var.attrs == self.eager_var._attrs == {}
assert (
expected
== dask.base.tokenize(self.eager_var)
== dask.base.tokenize(self.lazy_var.compute())
)

@requires_pint
def test_tokenize_duck_dask_array(self):
import pint
Expand Down Expand Up @@ -1573,6 +1562,30 @@ def test_token_identical(obj, transform):
)


@pytest.mark.parametrize(
"obj",
[
make_ds(), # Dataset
make_ds().variables["c2"], # Variable
make_ds().variables["x"], # IndexVariable
],
)
def test_tokenize_empty_attrs(obj):
"""Issues #6970 and #8788"""
obj.attrs = {}
assert obj._attrs is None
a = dask.base.tokenize(obj)

assert obj.attrs == {}
assert obj._attrs == {} # attrs getter changed None to dict
b = dask.base.tokenize(obj)
assert a == b

obj2 = obj.copy()
c = dask.base.tokenize(obj2)
assert a == c


def test_recursive_token():
"""Test that tokenization is invoked recursively, and doesn't just rely on the
output of str()
Expand Down
4 changes: 0 additions & 4 deletions xarray/tests/test_sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -878,10 +878,6 @@ def test_dask_token():
import dask

s = sparse.COO.from_numpy(np.array([0, 0, 1, 2]))

# https://github.com/pydata/sparse/issues/300
s.__dask_tokenize__ = lambda: dask.base.normalize_token(s.__dict__)

a = DataArray(s)
t1 = dask.base.tokenize(a)
t2 = dask.base.tokenize(a)
Expand Down

0 comments on commit 604bb6d

Please sign in to comment.