Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

assert_equal and dask #3350

Closed
dcherian opened this issue Sep 27, 2019 · 3 comments · Fixed by #3584
Closed

assert_equal and dask #3350

dcherian opened this issue Sep 27, 2019 · 3 comments · Fixed by #3584

Comments

@dcherian
Copy link
Contributor

MCVE Code Sample

Example 1

import xarray as xr
import numpy as np

da = xr.DataArray(np.random.randn(10, 20), name="a")
ds = da.to_dataset()
xr.testing.assert_equal(da, da.chunk({"dim_0": 2})) # works
xr.testing.assert_equal(da.chunk(), da.chunk({"dim_0": 2})) # works

xr.testing.assert_equal(ds, ds.chunk({"dim_0": 2})) # works
xr.testing.assert_equal(ds.chunk(), ds.chunk({"dim_0": 2}))  # does not work

I get

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-1-bc8216a67408> in <module>
      8 
      9 xr.testing.assert_equal(ds, ds.chunk({"dim_0": 2})) # works
---> 10 xr.testing.assert_equal(ds.chunk(), ds.chunk({"dim_0": 2}))  # does not work

~/work/python/xarray/xarray/testing.py in assert_equal(a, b)
     56         assert a.equals(b), formatting.diff_array_repr(a, b, "equals")
     57     elif isinstance(a, Dataset):
---> 58         assert a.equals(b), formatting.diff_dataset_repr(a, b, "equals")
     59     else:
     60         raise TypeError("{} not supported by assertion comparison".format(type(a)))

~/work/python/xarray/xarray/core/dataset.py in equals(self, other)
   1322         """
   1323         try:
-> 1324             return self._all_compat(other, "equals")
   1325         except (TypeError, AttributeError):
   1326             return False

~/work/python/xarray/xarray/core/dataset.py in _all_compat(self, other, compat_str)
   1285 
   1286         return self._coord_names == other._coord_names and utils.dict_equiv(
-> 1287             self._variables, other._variables, compat=compat
   1288         )
   1289 

~/work/python/xarray/xarray/core/utils.py in dict_equiv(first, second, compat)
    335     """
    336     for k in first:
--> 337         if k not in second or not compat(first[k], second[k]):
    338             return False
    339     for k in second:

~/work/python/xarray/xarray/core/dataset.py in compat(x, y)
   1282         # require matching order for equality
   1283         def compat(x: Variable, y: Variable) -> bool:
-> 1284             return getattr(x, compat_str)(y)
   1285 
   1286         return self._coord_names == other._coord_names and utils.dict_equiv(

~/work/python/xarray/xarray/core/variable.py in equals(self, other, equiv)
   1558         try:
   1559             return self.dims == other.dims and (
-> 1560                 self._data is other._data or equiv(self.data, other.data)
   1561             )
   1562         except (TypeError, AttributeError):

~/work/python/xarray/xarray/core/duck_array_ops.py in array_equiv(arr1, arr2)
    201         warnings.filterwarnings("ignore", "In the future, 'NAT == x'")
    202         flag_array = (arr1 == arr2) | (isnull(arr1) & isnull(arr2))
--> 203         return bool(flag_array.all())
    204 
    205 

~/miniconda3/envs/dcpy/lib/python3.7/site-packages/dask/array/core.py in __bool__(self)
   1380             )
   1381         else:
-> 1382             return bool(self.compute())
   1383 
   1384     __nonzero__ = __bool__  # python 2

~/miniconda3/envs/dcpy/lib/python3.7/site-packages/dask/base.py in compute(self, **kwargs)
    173         dask.base.compute
    174         """
--> 175         (result,) = compute(self, traverse=False, **kwargs)
    176         return result
    177 

~/miniconda3/envs/dcpy/lib/python3.7/site-packages/dask/base.py in compute(*args, **kwargs)
    444     keys = [x.__dask_keys__() for x in collections]
    445     postcomputes = [x.__dask_postcompute__() for x in collections]
--> 446     results = schedule(dsk, keys, **kwargs)
    447     return repack([f(r, *a) for r, (f, a) in zip(results, postcomputes)])
    448 

~/miniconda3/envs/dcpy/lib/python3.7/site-packages/dask/threaded.py in get(dsk, result, cache, num_workers, pool, **kwargs)
     80         get_id=_thread_get_id,
     81         pack_exception=pack_exception,
---> 82         **kwargs
     83     )
     84 

~/miniconda3/envs/dcpy/lib/python3.7/site-packages/dask/local.py in get_async(apply_async, num_workers, dsk, result, cache, get_id, rerun_exceptions_locally, pack_exception, raise_exception, callbacks, dumps, loads, **kwargs)
    489                         _execute_task(task, data)  # Re-execute locally
    490                     else:
--> 491                         raise_exception(exc, tb)
    492                 res, worker_id = loads(res_info)
    493                 state["cache"][key] = res

~/miniconda3/envs/dcpy/lib/python3.7/site-packages/dask/compatibility.py in reraise(exc, tb)
    128         if exc.__traceback__ is not tb:
    129             raise exc.with_traceback(tb)
--> 130         raise exc
    131 
    132     import pickle as cPickle

~/miniconda3/envs/dcpy/lib/python3.7/site-packages/dask/local.py in execute_task(key, task_info, dumps, loads, get_id, pack_exception)
    231     try:
    232         task, data = loads(task_info)
--> 233         result = _execute_task(task, data)
    234         id = get_id()
    235         result = dumps((result, id))

~/miniconda3/envs/dcpy/lib/python3.7/site-packages/dask/core.py in _execute_task(arg, cache, dsk)
    117         func, args = arg[0], arg[1:]
    118         args2 = [_execute_task(a, cache) for a in args]
--> 119         return func(*args2)
    120     elif not ishashable(arg):
    121         return arg

~/miniconda3/envs/dcpy/lib/python3.7/site-packages/dask/optimization.py in __call__(self, *args)
   1057         if not len(args) == len(self.inkeys):
   1058             raise ValueError("Expected %d args, got %d" % (len(self.inkeys), len(args)))
-> 1059         return core.get(self.dsk, self.outkey, dict(zip(self.inkeys, args)))
   1060 
   1061     def __reduce__(self):

~/miniconda3/envs/dcpy/lib/python3.7/site-packages/dask/core.py in get(dsk, out, cache)
    147     for key in toposort(dsk):
    148         task = dsk[key]
--> 149         result = _execute_task(task, cache)
    150         cache[key] = result
    151     result = _execute_task(out, cache)

~/miniconda3/envs/dcpy/lib/python3.7/site-packages/dask/core.py in _execute_task(arg, cache, dsk)
    117         func, args = arg[0], arg[1:]
    118         args2 = [_execute_task(a, cache) for a in args]
--> 119         return func(*args2)
    120     elif not ishashable(arg):
    121         return arg

ValueError: operands could not be broadcast together with shapes (0,20) (2,20) 

Example 2

The relevant xarray line in the previous traceback is flag_array = (arr1 == arr2) | (isnull(arr1) & isnull(arr2)), so I tried

(ds.isnull() & ds.chunk({"dim_0": 1}).isnull()).compute() # works
(ds.chunk().isnull() & ds.chunk({"dim_0": 1}).isnull()).compute() # does not work?!
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-4-abdfbeda355a> in <module>
      1 (ds.isnull() & ds.chunk({"dim_0": 1}).isnull()).compute() # works
----> 2 (ds.chunk().isnull() & ds.chunk({"dim_0": 1}).isnull()).compute() # does not work?!

~/work/python/xarray/xarray/core/dataset.py in compute(self, **kwargs)
    791         """
    792         new = self.copy(deep=False)
--> 793         return new.load(**kwargs)
    794 
    795     def _persist_inplace(self, **kwargs) -> "Dataset":

~/work/python/xarray/xarray/core/dataset.py in load(self, **kwargs)
    645 
    646             for k, data in zip(lazy_data, evaluated_data):
--> 647                 self.variables[k].data = data
    648 
    649         # load everything else sequentially

~/work/python/xarray/xarray/core/variable.py in data(self, data)
    331         data = as_compatible_data(data)
    332         if data.shape != self.shape:
--> 333             raise ValueError("replacement data must match the Variable's shape")
    334         self._data = data
    335 

ValueError: replacement data must match the Variable's shape

Problem Description

I don't know what's going on here. I expect assert_equal should return True for all these examples.

Our test for isnull with dask always calls load before comparing:

def test_isnull_with_dask():
    da = construct_dataarray(2, np.float32, contains_nan=True, dask=True)
    assert isinstance(da.isnull().data, dask_array_type)
    assert_equal(da.isnull().load(), da.load().isnull())

Output of xr.show_versions()

xarray master & dask 2.3.0

# Paste the output here xr.show_versions() here INSTALLED VERSIONS ------------------ commit: 6ece6a1 python: 3.7.3 | packaged by conda-forge | (default, Jul 1 2019, 21:52:21) [GCC 7.3.0] python-bits: 64 OS: Linux OS-release: 4.15.0-64-generic machine: x86_64 processor: x86_64 byteorder: little LC_ALL: None LANG: en_US.UTF-8 LOCALE: en_US.UTF-8 libhdf5: 1.10.4 libnetcdf: 4.6.2

xarray: 0.13.0+13.g6ece6a1c
pandas: 0.25.1
numpy: 1.17.2
scipy: 1.3.1
netCDF4: 1.5.1.2
pydap: None
h5netcdf: 0.7.4
h5py: 2.9.0
Nio: None
zarr: None
cftime: 1.0.3.4
nc_time_axis: None
PseudoNetCDF: None
rasterio: None
cfgrib: 0.9.7.1
iris: None
bottleneck: 1.2.1
dask: 2.3.0
distributed: 2.3.2
matplotlib: 3.1.1
cartopy: 0.17.0
seaborn: 0.9.0
numbagg: None
setuptools: 41.2.0
pip: 19.2.3
conda: 4.7.11
pytest: 5.1.2
IPython: 7.8.0
sphinx: 2.2.0

@dcherian dcherian mentioned this issue Sep 27, 2019
4 tasks
@shoyer
Copy link
Member

shoyer commented Sep 27, 2019

Here's a slightly simpler case:

In [28]: ds = xr.Dataset({'x': (('y',), np.zeros(10))})

In [29]: (ds.chunk().isnull() & ds.chunk(5).isnull()).compute()
ValueError: operands could not be broadcast together with shapes (0,) (5,)

@shoyer shoyer added the bug label Sep 27, 2019
@shoyer
Copy link
Member

shoyer commented Sep 27, 2019

Interestingly, it looks like the difference comes down to whether we chunk DataArrays or Datasets. The former produces graphs with fixed (reproducible) keys, the later doesn't:

In [57]: dict(ds.chunk().x.data.dask)
Out[57]:
{('xarray-x-a46bb46a12a44073da484c1311d00dec',
  0): array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])}

In [58]: dict(ds.chunk().x.data.dask)
Out[58]:
{('xarray-x-a46bb46a12a44073da484c1311d00dec',
  0): array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])}

In [59]: dict(ds.x.chunk().data.dask)
Out[59]:
{('xarray-<this-array>-d75d5cc0f0ce1b56590d80702339c0f0',
  0): array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])}

In [60]: dict(ds.x.chunk().data.dask)
Out[60]:
{('xarray-<this-array>-0f78e51941cfb0e25d41ac24ef330a50',
  0): array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])}

But clearly this should work either way. The size zero dimension is a give-away that the problem has something to do with dask's _meta propagation.

@dcherian
Copy link
Contributor Author

dcherian commented Dec 1, 2019

The size zero dimension is a give-away that the problem has something to do with dask's _meta propagation.

I think the size 0 results from chunk(). With chunk(2) other weird errors come up:

TypeError: tuple indices must be integers or slices, not tuple

We were specifying a name for the chunked array in Dataset.chunk but this name was independent of chunk sizes i.e. ds.chunk() & ds.chunk(2) have the same names which ends up confusing dask (I think). #3584 fixes this by providing chunks as an input to tokenize. I also needed to add __dask_tokenize__ to ReprObject so that names were reproducible after going through a to_temp_dataset transformation

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants