Skip to content

cov() and corr() #2652

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 65 additions & 7 deletions xarray/core/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from . import computation, groupby, indexing, ops, resample, rolling, utils
from ..plot.plot import _PlotMethods
from .accessors import DatetimeAccessor
from .alignment import align, reindex_like_indexers
from .alignment import align, reindex_like_indexers, broadcast
from .common import AbstractArray, DataWithCoords
from .coordinates import (
DataArrayCoordinates, Indexes, LevelCoordinatesSource,
Expand All @@ -29,8 +29,8 @@
def _infer_coords_and_dims(shape, coords, dims):
"""All the logic for creating a new DataArray"""

if (coords is not None and not utils.is_dict_like(coords) and
len(coords) != len(shape)):
if (coords is not None and not utils.is_dict_like(coords)
and len(coords) != len(shape)):
raise ValueError('coords is not dict-like, but it has %s items, '
'which does not match the %s dimensions of the '
'data' % (len(coords), len(shape)))
Expand Down Expand Up @@ -1873,8 +1873,8 @@ def _all_compat(self, other, compat_str):
def compat(x, y):
return getattr(x.variable, compat_str)(y.variable)

return (utils.dict_equiv(self.coords, other.coords, compat=compat) and
compat(self, other))
return (utils.dict_equiv(self.coords, other.coords, compat=compat)
and compat(self, other))

def broadcast_equals(self, other):
"""Two DataArrays are broadcast equal if they are equal after
Expand Down Expand Up @@ -1921,8 +1921,8 @@ def identical(self, other):
DataArray.equal
"""
try:
return (self.name == other.name and
self._all_compat(other, 'identical'))
return (self.name == other.name
and self._all_compat(other, 'identical'))
except (TypeError, AttributeError):
return False

Expand Down Expand Up @@ -2413,6 +2413,64 @@ def differentiate(self, coord, edge_order=1, datetime_unit=None):
coord, edge_order, datetime_unit)
return self._from_temp_dataset(ds)

def cov(self, other, dim=None):
"""Compute covariance between two DataArray objects along a shared dimension.

Parameters
----------
other: DataArray
The other array with which the covariance will be computed
dim: The dimension along which the covariance will be computed

Returns
-------
covariance: DataArray
"""
# 1. Broadcast the two arrays
self, other = broadcast(self, other)

# 2. Ignore the nans
valid_values = self.notnull() & other.notnull()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It allocates larger memory than dot or tensordot.
Can we use xr.dot instead of broadcasting?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I used broadcast to ensure that the dataarrays get aligned and extra dimensions (if any) in one get inserted into the other. So, broadcast implemented here doesn't do any arithmetic computation, as such. I didn't know xr.dot could be used in such a context. Could it?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

xarray.dot does do alignment/broadcasting, but it definitely doesn't skip missing values so I'm not sure it would would work well here.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the typical case, I would expect arguments for which correlation is being computed will have the same dimensions. So I don't think xarray.dot would be much faster.

self = self.where(valid_values, drop=True)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be best to avoid drop=True if possible. Dropping elements can really slow things down when using dask arrays, because determining the elements to drop requires computing the arrays. In contrast, if we avoid drop=True we can build a lazy computation graph.

other = other.where(valid_values, drop=True)
valid_count = valid_values.sum(dim)

# 3. Compute mean and standard deviation along the given dim
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove 'and standard deviation'

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am a little worrying that users could misunderstand cov is for (auto-)covariance rather than cross-covariance, which we are implementing here.
Probably a function like xr.cov(x, y) is better than method?

I can implement it as xr.cov(x,y). However, I made the implementation to be consistent with pd.Series cov() and corr(). https://pandas.pydata.org/pandas-docs/stable/generated/pandas.Series.cov.html. So I think users might be more familiar with this implementation.
If we make it a function, then may be do it for both cov() and corr(), just to be consistent?

demeaned_self = self - self.mean(dim=dim)
demeaned_other = other - other.mean(dim=dim)

# 4. Compute covariance along the given dim
cov = (demeaned_self * demeaned_other).sum(dim=dim) / (valid_count)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this slightly simpler version would work:

self, other = broadcast(self, other)
valid_values = self.notnull() & other.notnull()
self = self.where(valid_values)
other = self.where(valid_values)
demeaned_self = self - self.mean(dim=dim)
demeaned_other = other - other.mean(dim=dim)
cov = (demeaned_self * demeaned_other).mean(dim=dim)

Or maybe we want to keep using valid_count for the ddof argument.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Testing this version. Will look into ddof.


return cov

def corr(self, other, dim=None):
"""Compute correlation between two DataArray objects along a shared dimension.

Parameters
----------
other: DataArray
The other array with which the correlation will be computed
dim: The dimension along which the correlation will be computed

Returns
-------
correlation: DataArray
"""
# 1. Broadcast the two arrays
self, other = broadcast(self, other)

# 2. Ignore the nans
valid_values = self.notnull() & other.notnull()
self = self.where(valid_values, drop=True)
other = other.where(valid_values, drop=True)

# 3. Compute correlation based on standard deviations and cov()
self_std = self.std(dim=dim)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What value do we use for ddof? Should that be a keyword argument to this method?

other_std = other.std(dim=dim)

return self.cov(other, dim=dim) / (self_std * other_std)


# priority most be higher than Variable to properly work with binary ufuncs
ops.inject_all_ops_and_reduce_methods(DataArray, priority=60)
59 changes: 49 additions & 10 deletions xarray/tests/test_dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -3305,6 +3305,45 @@ def test_rank(self):
y = DataArray([0.75, 0.25, np.nan, 0.5, 1.0], dims=('z',))
assert_equal(y.rank('z', pct=True), y)

def test_corr(self):
# self: Load demo data and trim it's size
ds = xr.tutorial.load_dataset('air_temperature')
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Loading the tutorial datasets requires network access, which we try to avoid for tests. Can you write this test using synthetic data instead?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, will do. The tutorial code can be moved into an 'example' in documentation/ user guide later.

air = ds.air[:18, ...]
# other: select missaligned data, and smooth it to dampen the correlation with self.
air_smooth = ds.air[2:20, ...].rolling(time=3, center=True).mean(dim='time') # .
# A handy function to select an example grid

def select_pts(da):
return da.sel(lat=45, lon=250)

# Test #1: Misaligned 1-D dataarrays with missing values
ts1 = select_pts(air.copy())
ts2 = select_pts(air_smooth.copy())

def pd_corr(ts1, ts2):
"""Ensure the ts are aligned and missing values ignored"""
# ts1,ts2 = xr.align(ts1,ts2)
valid_values = ts1.notnull() & ts2.notnull()

ts1 = ts1.where(valid_values, drop=True)
ts2 = ts2.where(valid_values, drop=True)

return ts1.to_series().corr(ts2.to_series())

expected = pd_corr(ts1, ts2)
actual = ts1.corr(ts2)
np.allclose(expected, actual)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can use assert_allclose from xarray.testing. I'm not sure whether that is asserting or returning a bool

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@max-sixty assert_allclose gives AssertionError, hence used np.allclose - it returns a bool.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You need to use an actual assertion here. Otherwise this isn't testing anything -- np.allclose() could fail and we wouldn't know.


# Test #2: Misaligned N-D dataarrays with missing values
actual_ND = air.corr(air_smooth, dim='time')
actual = select_pts(actual_ND)
np.allclose(expected, actual)

# Test #3: One 1-D dataarray and another N-D dataarray; misaligned and having missing values
actual_ND = air_smooth.corr(ts1, dim='time')
actual = select_pts(actual_ND)
np.allclose(actual, expected)


@pytest.fixture(params=[1])
def da(request):
Expand Down Expand Up @@ -3640,14 +3679,14 @@ def test_to_and_from_iris(self):
assert coord.var_name == original_coord.name
assert_array_equal(
coord.points, CFDatetimeCoder().encode(original_coord).values)
assert (actual.coord_dims(coord) ==
original.get_axis_num(
assert (actual.coord_dims(coord)
== original.get_axis_num(
original.coords[coord.var_name].dims))

assert (actual.coord('distance2').attributes['foo'] ==
original.coords['distance2'].attrs['foo'])
assert (actual.coord('distance').units ==
cf_units.Unit(original.coords['distance'].units))
assert (actual.coord('distance2').attributes['foo']
== original.coords['distance2'].attrs['foo'])
assert (actual.coord('distance').units
== cf_units.Unit(original.coords['distance'].units))
assert actual.attributes['baz'] == original.attrs['baz']
assert actual.standard_name == original.attrs['standard_name']

Expand Down Expand Up @@ -3705,14 +3744,14 @@ def test_to_and_from_iris_dask(self):
assert coord.var_name == original_coord.name
assert_array_equal(
coord.points, CFDatetimeCoder().encode(original_coord).values)
assert (actual.coord_dims(coord) ==
original.get_axis_num(
assert (actual.coord_dims(coord)
== original.get_axis_num(
original.coords[coord.var_name].dims))

assert (actual.coord('distance2').attributes['foo'] == original.coords[
'distance2'].attrs['foo'])
assert (actual.coord('distance').units ==
cf_units.Unit(original.coords['distance'].units))
assert (actual.coord('distance').units
== cf_units.Unit(original.coords['distance'].units))
assert actual.attributes['baz'] == original.attrs['baz']
assert actual.standard_name == original.attrs['standard_name']

Expand Down