Skip to content
This repository has been archived by the owner on Aug 29, 2023. It is now read-only.

Commit

Permalink
Merge pull request #753 from CCI-Tools/jg-746-pearsonr
Browse files Browse the repository at this point in the history
Make sure pearson_correlation_scalar input is validated
  • Loading branch information
forman authored Sep 21, 2018
2 parents 49aaf3f + a044555 commit 6540e61
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 1 deletion.
2 changes: 1 addition & 1 deletion cate/ops/correlation.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def pearson_correlation_scalar(ds_x: DatasetLike.TYPE,
array_y = ds_y[var_y]
array_x = ds_x[var_x]

if ((len(array_x.dims) != len(array_y.dims)) and
if ((len(array_x.dims) != len(array_y.dims)) or
(len(array_x.dims) != 1)):
raise ValidationError('To calculate simple correlation, both provided'
' datasets should be simple 1d timeseries. To'
Expand Down
31 changes: 31 additions & 0 deletions test/ops/test_correlation.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,37 @@ def test_error(self):
pearson_correlation_scalar(ds1, ds2, 'first', 'first')
self.assertIn('dimension should not be less', str(err.exception))

def test_3D(self):
"""
Test nominal run
"""
# Test general functionality 3D dataset variables
ds1 = xr.Dataset({
'first': (['time', 'lat', 'lon'], np.array([np.ones([4, 8]),
np.ones([4, 8]) * 2,
np.ones([4, 8]) * 3])),
'second': (['time', 'lat', 'lon'], np.array([np.ones([4, 8]) * 2,
np.ones([4, 8]) * 3,
np.ones([4, 8])])),
'lat': np.linspace(-67.5, 67.5, 4),
'lon': np.linspace(-157.5, 157.5, 8),
'time': np.array([1, 2, 3])}).chunk(chunks={'lat': 2, 'lon': 4})

ds2 = xr.Dataset({
'second': (['time', 'lat', 'lon'], np.array([np.ones([4, 8]),
np.ones([4, 8]) * 2,
np.ones([4, 8]) * 3])),
'first': (['time', 'lat', 'lon'], np.array([np.ones([4, 8]) * 2,
np.ones([4, 8]) * 3,
np.ones([4, 8])])),
'lat': np.linspace(-67.5, 67.5, 4),
'lon': np.linspace(-157.5, 157.5, 8),
'time': np.array([1, 2, 3])}).chunk(chunks={'lat': 2, 'lon': 4})

with self.assertRaises(ValueError) as err:
pearson_correlation_scalar(ds1, ds2, 'first', 'first')
self.assertIn('should be simple 1d', str(err.exception))


class TestPearson(TestCase):
def test_nominal(self):
Expand Down

0 comments on commit 6540e61

Please sign in to comment.