Skip to content

Commit

Permalink
refactor isnull
Browse files Browse the repository at this point in the history
  • Loading branch information
jcmgray committed Sep 8, 2016
1 parent ae90fa1 commit 7e9005a
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 32 deletions.
39 changes: 16 additions & 23 deletions xarray/core/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,20 @@ def _fail_on_dask_array_input(values, msg=None, func_name=None):

around = _dask_or_eager_func('around')
isclose = _dask_or_eager_func('isclose')
isnull = _dask_or_eager_func('isnull', pd)
notnull = _dask_or_eager_func('notnull', pd)
_isnull = _dask_or_eager_func('isnull', pd)


def isnull(data):
# GH837, GH861
# isnull fcn from pandas will throw TypeError when run on numpy structured
# array therefore for dims that are np structured arrays we assume all
# data is present
try:
return _isnull(data)
except TypeError:
return np.zeros(data.shape, dtype=bool)


transpose = _dask_or_eager_func('transpose')
where = _dask_or_eager_func('where', n_array_args=3)
Expand Down Expand Up @@ -125,17 +137,7 @@ def array_equiv(arr1, arr2):
return False

flag_array = (arr1 == arr2)

# GH837, GH861
# isnull fcn from pandas will throw TypeError when run on numpy structured array
# therefore for dims that are np structured arrays we skip testing for nan

try:

flag_array |= (isnull(arr1) & isnull(arr2))

except TypeError:
pass
flag_array |= (isnull(arr1) & isnull(arr2))

return bool(flag_array.all())

Expand All @@ -149,17 +151,8 @@ def array_notnull_equiv(arr1, arr2):
return False

flag_array = (arr1 == arr2)

# GH837, GH861
# isnull fcn from pandas will throw TypeError when run on numpy structured
# array therefore for dims that are np structured arrays we skip testing
# for nan

try:
flag_array |= (isnull(arr1) | isnull(arr2))

except TypeError:
pass
flag_array |= isnull(arr1)
flag_array |= isnull(arr2)

return bool(flag_array.all())

Expand Down
35 changes: 33 additions & 2 deletions xarray/test/test_ops.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from pytest import mark
import numpy as np
from numpy import array, nan
from xarray.core import ops
from xarray.core.ops import (
first, last, count, mean
first, last, count, mean, array_notnull_equiv,
)

from . import TestCase
Expand Down Expand Up @@ -74,3 +74,34 @@ def test_count(self):

def test_all_nan_arrays(self):
assert np.isnan(mean([np.nan, np.nan]))


class TestArrayNotNullEquiv():
@mark.parametrize("arr1, arr2", [
(np.array([1, 2, 3]), np.array([1, 2, 3])),
(np.array([1, 2, np.nan]), np.array([1, np.nan, 3])),
(np.array([np.nan, 2, np.nan]), np.array([1, np.nan, np.nan])),
])
def test_equal(self, arr1, arr2):
assert array_notnull_equiv(arr1, arr2)

def test_some_not_equal(self):
a = np.array([1, 2, 4])
b = np.array([1, np.nan, 3])
assert not array_notnull_equiv(a, b)

def test_wrong_shape(self):
a = np.array([[1, np.nan, np.nan, 4]])
b = np.array([[1, 2], [np.nan, 4]])
assert not array_notnull_equiv(a, b)

@mark.parametrize("val1, val2, val3, null", [
(1, 2, 3, None),
(1., 2., 3., np.nan),
(1., 2., 3., None),
('foo', 'bar', 'baz', None),
])
def test_types(self, val1, val2, val3, null):
arr1 = np.array([val1, null, val3, null])
arr2 = np.array([val1, val2, null, null])
assert array_notnull_equiv(arr1, arr2)
11 changes: 4 additions & 7 deletions xarray/test/test_variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,10 +288,13 @@ def test_equals_all_dtypes(self):
v2 = v.copy()
self.assertTrue(v.equals(v2))
self.assertTrue(v.identical(v2))
self.assertTrue(v.notnull_equals(v2))
self.assertTrue(v[0].equals(v2[0]))
self.assertTrue(v[0].identical(v2[0]))
self.assertTrue(v[0].notnull_equals(v2[0]))
self.assertTrue(v[:2].equals(v2[:2]))
self.assertTrue(v[:2].identical(v2[:2]))
self.assertTrue(v[:2].notnull_equals(v2[:2]))

def test_eq_all_dtypes(self):
# ensure that we don't choke on comparisons for which numpy returns
Expand Down Expand Up @@ -570,15 +573,9 @@ def test_notnull_equals(self):

self.assertFalse(v1.notnull_equals(None))

v3 = Variable(('x'), [np.nan, 1, 3, np.nan])
v3 = Variable(('y'), [np.nan, 2, 3, np.nan])
self.assertFalse(v3.notnull_equals(v1))

v4 = Variable(('y'), [np.nan, 2, 3, np.nan])
self.assertFalse(v4.notnull_equals(v1))

v5 = Variable(('x', 'y'), [[1, 2], [np.nan, np.nan]])
self.assertFalse(v1.notnull_equals(v5))

d = np.array([1, 2, np.nan, np.nan])
self.assertFalse(v1.notnull_equals(d))
self.assertFalse(v2.notnull_equals(d))
Expand Down

0 comments on commit 7e9005a

Please sign in to comment.