-
-
Notifications
You must be signed in to change notification settings - Fork 1.2k
cov() and corr() #2652
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
cov() and corr() #2652
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -9,7 +9,7 @@ | |
from . import computation, groupby, indexing, ops, resample, rolling, utils | ||
from ..plot.plot import _PlotMethods | ||
from .accessors import DatetimeAccessor | ||
from .alignment import align, reindex_like_indexers | ||
from .alignment import align, reindex_like_indexers, broadcast | ||
from .common import AbstractArray, DataWithCoords | ||
from .coordinates import ( | ||
DataArrayCoordinates, Indexes, LevelCoordinatesSource, | ||
|
@@ -29,8 +29,8 @@ | |
def _infer_coords_and_dims(shape, coords, dims): | ||
"""All the logic for creating a new DataArray""" | ||
|
||
if (coords is not None and not utils.is_dict_like(coords) and | ||
len(coords) != len(shape)): | ||
if (coords is not None and not utils.is_dict_like(coords) | ||
and len(coords) != len(shape)): | ||
raise ValueError('coords is not dict-like, but it has %s items, ' | ||
'which does not match the %s dimensions of the ' | ||
'data' % (len(coords), len(shape))) | ||
|
@@ -1873,8 +1873,8 @@ def _all_compat(self, other, compat_str): | |
def compat(x, y): | ||
return getattr(x.variable, compat_str)(y.variable) | ||
|
||
return (utils.dict_equiv(self.coords, other.coords, compat=compat) and | ||
compat(self, other)) | ||
return (utils.dict_equiv(self.coords, other.coords, compat=compat) | ||
and compat(self, other)) | ||
|
||
def broadcast_equals(self, other): | ||
"""Two DataArrays are broadcast equal if they are equal after | ||
|
@@ -1921,8 +1921,8 @@ def identical(self, other): | |
DataArray.equal | ||
""" | ||
try: | ||
return (self.name == other.name and | ||
self._all_compat(other, 'identical')) | ||
return (self.name == other.name | ||
and self._all_compat(other, 'identical')) | ||
except (TypeError, AttributeError): | ||
return False | ||
|
||
|
@@ -2413,6 +2413,64 @@ def differentiate(self, coord, edge_order=1, datetime_unit=None): | |
coord, edge_order, datetime_unit) | ||
return self._from_temp_dataset(ds) | ||
|
||
def cov(self, other, dim=None): | ||
"""Compute covariance between two DataArray objects along a shared dimension. | ||
|
||
Parameters | ||
---------- | ||
other: DataArray | ||
The other array with which the covariance will be computed | ||
dim: The dimension along which the covariance will be computed | ||
|
||
Returns | ||
------- | ||
covariance: DataArray | ||
""" | ||
# 1. Broadcast the two arrays | ||
self, other = broadcast(self, other) | ||
|
||
# 2. Ignore the nans | ||
valid_values = self.notnull() & other.notnull() | ||
self = self.where(valid_values, drop=True) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It would be best to avoid |
||
other = other.where(valid_values, drop=True) | ||
valid_count = valid_values.sum(dim) | ||
|
||
# 3. Compute mean and standard deviation along the given dim | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. remove 'and standard deviation' There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
I can implement it as xr.cov(x,y). However, I made the implementation to be consistent with pd.Series cov() and corr(). https://pandas.pydata.org/pandas-docs/stable/generated/pandas.Series.cov.html. So I think users might be more familiar with this implementation. |
||
demeaned_self = self - self.mean(dim=dim) | ||
demeaned_other = other - other.mean(dim=dim) | ||
|
||
# 4. Compute covariance along the given dim | ||
cov = (demeaned_self * demeaned_other).sum(dim=dim) / (valid_count) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this slightly simpler version would work: self, other = broadcast(self, other)
valid_values = self.notnull() & other.notnull()
self = self.where(valid_values)
other = self.where(valid_values)
demeaned_self = self - self.mean(dim=dim)
demeaned_other = other - other.mean(dim=dim)
cov = (demeaned_self * demeaned_other).mean(dim=dim) Or maybe we want to keep using There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Testing this version. Will look into |
||
|
||
return cov | ||
|
||
def corr(self, other, dim=None): | ||
"""Compute correlation between two DataArray objects along a shared dimension. | ||
|
||
Parameters | ||
---------- | ||
other: DataArray | ||
The other array with which the correlation will be computed | ||
dim: The dimension along which the correlation will be computed | ||
|
||
Returns | ||
------- | ||
correlation: DataArray | ||
""" | ||
# 1. Broadcast the two arrays | ||
self, other = broadcast(self, other) | ||
|
||
# 2. Ignore the nans | ||
valid_values = self.notnull() & other.notnull() | ||
self = self.where(valid_values, drop=True) | ||
other = other.where(valid_values, drop=True) | ||
|
||
# 3. Compute correlation based on standard deviations and cov() | ||
self_std = self.std(dim=dim) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What value do we use for |
||
other_std = other.std(dim=dim) | ||
|
||
return self.cov(other, dim=dim) / (self_std * other_std) | ||
|
||
|
||
# priority most be higher than Variable to properly work with binary ufuncs | ||
ops.inject_all_ops_and_reduce_methods(DataArray, priority=60) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3305,6 +3305,45 @@ def test_rank(self): | |
y = DataArray([0.75, 0.25, np.nan, 0.5, 1.0], dims=('z',)) | ||
assert_equal(y.rank('z', pct=True), y) | ||
|
||
def test_corr(self): | ||
# self: Load demo data and trim it's size | ||
ds = xr.tutorial.load_dataset('air_temperature') | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Loading the tutorial datasets requires network access, which we try to avoid for tests. Can you write this test using synthetic data instead? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, will do. The tutorial code can be moved into an 'example' in documentation/ user guide later. |
||
air = ds.air[:18, ...] | ||
# other: select missaligned data, and smooth it to dampen the correlation with self. | ||
air_smooth = ds.air[2:20, ...].rolling(time=3, center=True).mean(dim='time') # . | ||
# A handy function to select an example grid | ||
|
||
def select_pts(da): | ||
return da.sel(lat=45, lon=250) | ||
|
||
# Test #1: Misaligned 1-D dataarrays with missing values | ||
ts1 = select_pts(air.copy()) | ||
ts2 = select_pts(air_smooth.copy()) | ||
|
||
def pd_corr(ts1, ts2): | ||
"""Ensure the ts are aligned and missing values ignored""" | ||
# ts1,ts2 = xr.align(ts1,ts2) | ||
valid_values = ts1.notnull() & ts2.notnull() | ||
|
||
ts1 = ts1.where(valid_values, drop=True) | ||
ts2 = ts2.where(valid_values, drop=True) | ||
|
||
return ts1.to_series().corr(ts2.to_series()) | ||
|
||
expected = pd_corr(ts1, ts2) | ||
actual = ts1.corr(ts2) | ||
np.allclose(expected, actual) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You can use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @max-sixty There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You need to use an actual assertion here. Otherwise this isn't testing anything -- |
||
|
||
# Test #2: Misaligned N-D dataarrays with missing values | ||
actual_ND = air.corr(air_smooth, dim='time') | ||
actual = select_pts(actual_ND) | ||
np.allclose(expected, actual) | ||
|
||
# Test #3: One 1-D dataarray and another N-D dataarray; misaligned and having missing values | ||
actual_ND = air_smooth.corr(ts1, dim='time') | ||
actual = select_pts(actual_ND) | ||
np.allclose(actual, expected) | ||
|
||
|
||
@pytest.fixture(params=[1]) | ||
def da(request): | ||
|
@@ -3640,14 +3679,14 @@ def test_to_and_from_iris(self): | |
assert coord.var_name == original_coord.name | ||
assert_array_equal( | ||
coord.points, CFDatetimeCoder().encode(original_coord).values) | ||
assert (actual.coord_dims(coord) == | ||
original.get_axis_num( | ||
assert (actual.coord_dims(coord) | ||
== original.get_axis_num( | ||
original.coords[coord.var_name].dims)) | ||
|
||
assert (actual.coord('distance2').attributes['foo'] == | ||
original.coords['distance2'].attrs['foo']) | ||
assert (actual.coord('distance').units == | ||
cf_units.Unit(original.coords['distance'].units)) | ||
assert (actual.coord('distance2').attributes['foo'] | ||
== original.coords['distance2'].attrs['foo']) | ||
assert (actual.coord('distance').units | ||
== cf_units.Unit(original.coords['distance'].units)) | ||
assert actual.attributes['baz'] == original.attrs['baz'] | ||
assert actual.standard_name == original.attrs['standard_name'] | ||
|
||
|
@@ -3705,14 +3744,14 @@ def test_to_and_from_iris_dask(self): | |
assert coord.var_name == original_coord.name | ||
assert_array_equal( | ||
coord.points, CFDatetimeCoder().encode(original_coord).values) | ||
assert (actual.coord_dims(coord) == | ||
original.get_axis_num( | ||
assert (actual.coord_dims(coord) | ||
== original.get_axis_num( | ||
original.coords[coord.var_name].dims)) | ||
|
||
assert (actual.coord('distance2').attributes['foo'] == original.coords[ | ||
'distance2'].attrs['foo']) | ||
assert (actual.coord('distance').units == | ||
cf_units.Unit(original.coords['distance'].units)) | ||
assert (actual.coord('distance').units | ||
== cf_units.Unit(original.coords['distance'].units)) | ||
assert actual.attributes['baz'] == original.attrs['baz'] | ||
assert actual.standard_name == original.attrs['standard_name'] | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It allocates larger memory than
dot
ortensordot
.Can we use
xr.dot
instead of broadcasting?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I used
broadcast
to ensure that the dataarrays get aligned and extra dimensions (if any) in one get inserted into the other. So,broadcast
implemented here doesn't do any arithmetic computation, as such. I didn't knowxr.dot
could be used in such a context. Could it?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
xarray.dot
does do alignment/broadcasting, but it definitely doesn't skip missing values so I'm not sure it would would work well here.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the typical case, I would expect arguments for which correlation is being computed will have the same dimensions. So I don't think
xarray.dot
would be much faster.