diff --git a/xarray/core/duck_array_ops.py b/xarray/core/duck_array_ops.py index 7c0b730814e..9432ccf1904 100644 --- a/xarray/core/duck_array_ops.py +++ b/xarray/core/duck_array_ops.py @@ -174,8 +174,10 @@ def as_shared_dtype(scalars_or_arrays): return [x.astype(out_type, copy=False) for x in arrays] -def allclose_or_equiv(arr1, arr2, rtol=1e-5, atol=1e-8): - """Like np.allclose, but also allows values to be NaN in both arrays +def lazy_array_equiv(arr1, arr2): + """Like array_equal, but doesn't actually compare values. + Returns True or False when equality can be determined without computing. + Returns None when equality cannot determined (e.g. one or both of arr1, arr2 are numpy arrays) """ arr1 = asarray(arr1) arr2 = asarray(arr2) @@ -189,26 +191,19 @@ def allclose_or_equiv(arr1, arr2, rtol=1e-5, atol=1e-8): # GH3068 if arr1.name == arr2.name: return True - return bool(isclose(arr1, arr2, rtol=rtol, atol=atol, equal_nan=True).all()) + return None -def lazy_array_equiv(arr1, arr2): - """Like array_equal, but doesn't actually compare values. - Returns True or False when equality can be determined without computing. - Returns None when equality cannot determined (e.g. one or both of arr1, arr2 are numpy arrays) +def allclose_or_equiv(arr1, arr2, rtol=1e-5, atol=1e-8): + """Like np.allclose, but also allows values to be NaN in both arrays """ arr1 = asarray(arr1) arr2 = asarray(arr2) - if arr1.shape != arr2.shape: - return False - if ( - dask_array - and isinstance(arr1, dask_array.Array) - and isinstance(arr2, dask_array.Array) - ): - # GH3068 - if arr1.name == arr2.name: - return True + lazy_equiv = lazy_array_equiv(arr1, arr2) + if lazy_equiv is None: + return bool(isclose(arr1, arr2, rtol=rtol, atol=atol, equal_nan=True).all()) + else: + return lazy_equiv def array_equiv(arr1, arr2): @@ -216,20 +211,14 @@ def array_equiv(arr1, arr2): """ arr1 = asarray(arr1) arr2 = asarray(arr2) - if arr1.shape != arr2.shape: - return False - if ( - dask_array - and isinstance(arr1, dask_array.Array) - and isinstance(arr2, dask_array.Array) - ): - # GH3068 - if arr1.name == arr2.name: - return True - with warnings.catch_warnings(): - warnings.filterwarnings("ignore", "In the future, 'NAT == x'") - flag_array = (arr1 == arr2) | (isnull(arr1) & isnull(arr2)) - return bool(flag_array.all()) + lazy_equiv = lazy_array_equiv(arr1, arr2) + if lazy_equiv is None: + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", "In the future, 'NAT == x'") + flag_array = (arr1 == arr2) | (isnull(arr1) & isnull(arr2)) + return bool(flag_array.all()) + else: + return lazy_equiv def array_notnull_equiv(arr1, arr2): @@ -238,20 +227,14 @@ def array_notnull_equiv(arr1, arr2): """ arr1 = asarray(arr1) arr2 = asarray(arr2) - if arr1.shape != arr2.shape: - return False - if ( - dask_array - and isinstance(arr1, dask_array.Array) - and isinstance(arr2, dask_array.Array) - ): - # GH3068 - if arr1.name == arr2.name: - return True - with warnings.catch_warnings(): - warnings.filterwarnings("ignore", "In the future, 'NAT == x'") - flag_array = (arr1 == arr2) | isnull(arr1) | isnull(arr2) - return bool(flag_array.all()) + lazy_equiv = lazy_array_equiv(arr1, arr2) + if lazy_equiv is None: + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", "In the future, 'NAT == x'") + flag_array = (arr1 == arr2) | isnull(arr1) | isnull(arr2) + return bool(flag_array.all()) + else: + return lazy_equiv def count(data, axis=None):