Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

xr.dot requires equal indexes (join="exact") #3694

Closed
mathause opened this issue Jan 14, 2020 · 5 comments · Fixed by #3699
Closed

xr.dot requires equal indexes (join="exact") #3694

mathause opened this issue Jan 14, 2020 · 5 comments · Fixed by #3699

Comments

@mathause
Copy link
Collaborator

MCVE Code Sample

import xarray as xr
import numpy as np

d1 = xr.DataArray(np.arange(4), dims=["a"], coords=dict(a=[0, 1, 2, 3]))
d2 = xr.DataArray(np.arange(4), dims=["a"], coords=dict(a=[0, 1, 2, 3]))

# note: different coords
d3 = xr.DataArray(np.arange(4), dims=["a"], coords=dict(a=[1, 2, 3, 4]))

(d1 * d2).sum() # -> array(14)
xr.dot(d1, d2) # -> array(14)

(d2 * d3).sum() # -> array(8)
xr.dot(d2, d3) # -> ValueError

Expected Output

<xarray.DataArray ()>
array(8)

Problem Description

The last statement results in an

ValueError: indexes along dimension 'a' are not equal

because xr.apply_ufunc defaults to join='exact'. However, I think this should work -
but maybe there is a good reason for this to fail?

This is a problem for #2922 (weighted operations) - I think it is fine for the weights and data to not align.

Fixing this may be as easy as specifying join='inner' in

result = apply_ufunc(
func,
*arrays,
input_core_dims=input_core_dims,
output_core_dims=output_core_dims,
dask="allowed",
)

@fujiisoup

Output of xr.show_versions()

INSTALLED VERSIONS

commit: 5afc6f3
python: 3.7.3 | packaged by conda-forge | (default, Jul 1 2019, 21:52:21)
[GCC 7.3.0]
python-bits: 64
OS: Linux
OS-release: 4.12.14-lp151.28.36-default
machine: x86_64
processor: x86_64
byteorder: little
LC_ALL: None
LANG: en_GB.UTF-8
LOCALE: en_US.UTF-8
libhdf5: 1.10.5
libnetcdf: 4.6.2

xarray: 0.14.0+164.g5afc6f32.dirty
pandas: 0.25.2
numpy: 1.17.3
scipy: 1.3.1
netCDF4: 1.5.1.2
pydap: installed
h5netcdf: 0.7.4
h5py: 2.10.0
Nio: 1.5.5
zarr: 2.3.2
cftime: 1.0.4.2
nc_time_axis: 1.2.0
PseudoNetCDF: installed
rasterio: 1.1.0
cfgrib: 0.9.7.2
iris: 2.2.0
bottleneck: 1.2.1
dask: 2.6.0
distributed: 2.6.0
matplotlib: 3.1.1
cartopy: 0.17.0
seaborn: 0.9.0
numbagg: installed
setuptools: 41.6.0.post20191029
pip: 19.3.1
conda: None
pytest: 5.2.2
IPython: 7.9.0
sphinx: None

@dcherian dcherian changed the title xr.dot requires equal dimensions xr.dot requires equal indexes (join="exact") Jan 14, 2020
@mathause mathause mentioned this issue Jan 14, 2020
3 tasks
@fujiisoup
Copy link
Member

I have no strong opinion, but if most of the arithmetic in xarray uses join='inner', then it would be nicer to do so here too.

@shoyer
Copy link
Member

shoyer commented Jan 15, 2020

You could dig through the original PRs to be sure, but I think we mostly picked join='exact' out of an abundance of caution. In principle I think it would be reasonable to change it, there is a pretty good case that (d1 * d2).sum() and d1 @ d2 should be consistent.

@mathause
Copy link
Collaborator Author

I started with a PR using join = OPTIONS["arithmetic_join"] but then I realized that dot does not support skipna. Thus join="left" join="right", and join="outer" returns NaN for nonequal coords, which I think defeats the purpose. I can

  1. Use OPTIONS["arithmetic_join"] anyway
  2. Only support join="inner" (i.e. hard-code it in the call to apply_ufunc)
  3. try to implement skipna for dot (in a separate PR)
  4. Other?

(3) is certainly the largest change but may be as easy as da.fillna(0.). Thoughts?

import numpy as np
import xarray as xr

d1 = xr.DataArray([2, 3, 5, np.NaN])
d2 = xr.DataArray([2, 3, 5, 7])

xr.dot(d1, d2) # -> NaN

xr.dot(d1.fillna(0.), d2) # -> 38
(d1 * d2).sum() # -> 38

I use this at:

# need to mask invalid DATA as `dot` does not implement skipna
if skipna or (skipna is None and da.dtype.kind in "cfO"):
da = da.fillna(0.0)
# `dot` does not broadcast arrays, so this avoids creating a large
# DataArray (if `weights` has additional dimensions)
# TODO: maybe add fasttrack (`(da * weights).sum(dims=dim, skipna=skipna)`)
return dot(da, self.weights, dims=dim)

@mathause
Copy link
Collaborator Author

mathause commented Jan 16, 2020

On second thought, even with implemented skipna the result will not differ between join="inner", join="left" join="right", and join="outer"; i.e.:

with xr.set_options(arithmetic_join="outer"):
    print((d1 * d2).sum()) # -> 38

So I guess (2) is fine.

@mathause
Copy link
Collaborator Author

Sorry for the noise, I do live thinking...


However, if I do (2) there is a different behavior for (d1 * d2).sum() and d1 @ d2 in one case:

with xr.set_options(arithmetic_join="exact"):
    xr.dot(d2, d3) # -> array(8)
    (d2 * d3).sum() # -> ValueError

So?

join = OPTIONS["arithmetic_join"]
if join not in ["exact", "inner"]:
    join = "inner"

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants