diff --git a/pandas/core/arrays/base.py b/pandas/core/arrays/base.py index fa565aa802faf..55a72585acbe5 100644 --- a/pandas/core/arrays/base.py +++ b/pandas/core/arrays/base.py @@ -261,7 +261,7 @@ def fillna(self, value=None, method=None, limit=None): ------- filled : ExtensionArray with NA/NaN filled """ - from pandas.api.types import is_scalar + from pandas.api.types import is_array_like from pandas.util._validators import validate_fillna_kwargs from pandas.core.missing import pad_1d, backfill_1d @@ -269,7 +269,7 @@ def fillna(self, value=None, method=None, limit=None): mask = self.isna() - if not is_scalar(value): + if is_array_like(value): if len(value) != len(self): raise ValueError("Length of 'value' does not match. Got ({}) " " expected {}".format(len(value), len(self))) diff --git a/pandas/tests/extension/base/base.py b/pandas/tests/extension/base/base.py index d29587e635ebd..beb7948f2c14b 100644 --- a/pandas/tests/extension/base/base.py +++ b/pandas/tests/extension/base/base.py @@ -4,3 +4,6 @@ class BaseExtensionTests(object): assert_series_equal = staticmethod(tm.assert_series_equal) assert_frame_equal = staticmethod(tm.assert_frame_equal) + assert_extension_array_equal = staticmethod( + tm.assert_extension_array_equal + ) diff --git a/pandas/tests/extension/base/missing.py b/pandas/tests/extension/base/missing.py index bf404ac01bf2b..d3360eb199a89 100644 --- a/pandas/tests/extension/base/missing.py +++ b/pandas/tests/extension/base/missing.py @@ -47,6 +47,12 @@ def test_dropna_frame(self, data_missing): expected = df.iloc[:0] self.assert_frame_equal(result, expected) + def test_fillna_scalar(self, data_missing): + valid = data_missing[1] + result = data_missing.fillna(valid) + expected = data_missing.fillna(valid) + self.assert_extension_array_equal(result, expected) + def test_fillna_limit_pad(self, data_missing): arr = data_missing.take([1, 0, 0, 0, 1]) result = pd.Series(arr).fillna(method='ffill', limit=2) diff --git a/pandas/util/testing.py b/pandas/util/testing.py index a223e4d8fd23e..a1e9dcff38ec7 100644 --- a/pandas/util/testing.py +++ b/pandas/util/testing.py @@ -20,6 +20,7 @@ import numpy as np import pandas as pd +from pandas.core.arrays import ExtensionArray from pandas.core.dtypes.missing import array_equivalent from pandas.core.dtypes.common import ( is_datetimelike_v_numeric, @@ -1083,6 +1084,32 @@ def _raise(left, right, err_msg): return True +def assert_extension_array_equal(left, right): + """Check that left and right ExtensionArrays are equal. + + Parameters + ---------- + left, right : ExtensionArray + The two arrays to compare + + Notes + ----- + Missing values are checked separately from valid values. + A mask of missing values is computed for each and checked to match. + The remaining all-valid values are cast to object dtype and checked. + """ + assert isinstance(left, ExtensionArray) + assert left.dtype == right.dtype + left_na = left.isna() + right_na = right.isna() + assert_numpy_array_equal(left_na, right_na) + + left_valid = left[~left_na].astype(object) + right_valid = right[~right_na].astype(object) + + assert_numpy_array_equal(left_valid, right_valid) + + # This could be refactored to use the NDFrame.equals method def assert_series_equal(left, right, check_dtype=True, check_index_type='equiv',