diff --git a/doc/source/whatsnew/v1.5.0.rst b/doc/source/whatsnew/v1.5.0.rst index 063e0c7512f4d..db1982b3b77c8 100644 --- a/doc/source/whatsnew/v1.5.0.rst +++ b/doc/source/whatsnew/v1.5.0.rst @@ -266,6 +266,7 @@ Indexing - Bug in :meth:`DataFrame.iloc` where indexing a single row on a :class:`DataFrame` with a single ExtensionDtype column gave a copy instead of a view on the underlying data (:issue:`45241`) - Bug in setting a NA value (``None`` or ``np.nan``) into a :class:`Series` with int-based :class:`IntervalDtype` incorrectly casting to object dtype instead of a float-based :class:`IntervalDtype` (:issue:`45568`) - Bug in :meth:`Series.__setitem__` with a non-integer :class:`Index` when using an integer key to set a value that cannot be set inplace where a ``ValueError`` was raised insead of casting to a common dtype (:issue:`45070`) +- Bug in :meth:`Series.__setitem__` when setting incompatible values into a ``PeriodDtype`` or ``IntervalDtype`` :class:`Series` raising when indexing with a boolean mask but coercing when indexing with otherwise-equivalent indexers; these now consistently coerce, along with :meth:`Series.mask` and :meth:`Series.where` (:issue:`45768`) - Bug in :meth:`Series.loc.__setitem__` and :meth:`Series.loc.__getitem__` not raising when using multiple keys without using a :class:`MultiIndex` (:issue:`13831`) - Bug when setting a value too large for a :class:`Series` dtype failing to coerce to a common type (:issue:`26049`, :issue:`32878`) - Bug in :meth:`loc.__setitem__` treating ``range`` keys as positional instead of label-based (:issue:`45479`) diff --git a/pandas/core/internals/blocks.py b/pandas/core/internals/blocks.py index 298a59d53ea05..0bbc5fa866771 100644 --- a/pandas/core/internals/blocks.py +++ b/pandas/core/internals/blocks.py @@ -1380,6 +1380,8 @@ def where(self, other, cond) -> list[Block]: cond = extract_bool_array(cond) + orig_other = other + orig_cond = cond other = self._maybe_squeeze_arg(other) cond = self._maybe_squeeze_arg(cond) @@ -1399,21 +1401,15 @@ def where(self, other, cond) -> list[Block]: if is_interval_dtype(self.dtype): # TestSetitemFloatIntervalWithIntIntervalValues - blk = self.coerce_to_target_dtype(other) - if blk.dtype == _dtype_obj: - # For now at least only support casting e.g. - # Interval[int64]->Interval[float64] - raise - return blk.where(other, cond) + blk = self.coerce_to_target_dtype(orig_other) + nbs = blk.where(orig_other, orig_cond) + return self._maybe_downcast(nbs, "infer") elif isinstance(self, NDArrayBackedExtensionBlock): # NB: not (yet) the same as # isinstance(values, NDArrayBackedExtensionArray) - if isinstance(self.dtype, PeriodDtype): - # TODO: don't special-case - raise - blk = self.coerce_to_target_dtype(other) - nbs = blk.where(other, cond) + blk = self.coerce_to_target_dtype(orig_other) + nbs = blk.where(orig_other, orig_cond) return self._maybe_downcast(nbs, "infer") else: @@ -1430,6 +1426,8 @@ def putmask(self, mask, new) -> list[Block]: values = self.values + orig_new = new + orig_mask = mask new = self._maybe_squeeze_arg(new) mask = self._maybe_squeeze_arg(mask) @@ -1442,21 +1440,14 @@ def putmask(self, mask, new) -> list[Block]: if is_interval_dtype(self.dtype): # Discussion about what we want to support in the general # case GH#39584 - blk = self.coerce_to_target_dtype(new) - if blk.dtype == _dtype_obj: - # For now at least, only support casting e.g. - # Interval[int64]->Interval[float64], - raise - return blk.putmask(mask, new) + blk = self.coerce_to_target_dtype(orig_new) + return blk.putmask(orig_mask, orig_new) elif isinstance(self, NDArrayBackedExtensionBlock): # NB: not (yet) the same as # isinstance(values, NDArrayBackedExtensionArray) - if isinstance(self.dtype, PeriodDtype): - # TODO: don't special-case - raise - blk = self.coerce_to_target_dtype(new) - return blk.putmask(mask, new) + blk = self.coerce_to_target_dtype(orig_new) + return blk.putmask(orig_mask, orig_new) else: raise diff --git a/pandas/tests/arrays/interval/test_interval.py b/pandas/tests/arrays/interval/test_interval.py index 57aae96f38ac7..5ef69106f0278 100644 --- a/pandas/tests/arrays/interval/test_interval.py +++ b/pandas/tests/arrays/interval/test_interval.py @@ -76,10 +76,16 @@ def test_set_closed(self, closed, new_closed): ], ) def test_where_raises(self, other): + # GH#45768 The IntervalArray methods raises; the Series method coerces ser = pd.Series(IntervalArray.from_breaks([1, 2, 3, 4], closed="left")) + mask = np.array([True, False, True]) match = "'value.closed' is 'right', expected 'left'." with pytest.raises(ValueError, match=match): - ser.where([True, False, True], other=other) + ser.array._where(mask, other) + + res = ser.where(mask, other=other) + expected = ser.astype(object).where(mask, other) + tm.assert_series_equal(res, expected) def test_shift(self): # https://github.com/pandas-dev/pandas/issues/31495, GH#22428, GH#31502 diff --git a/pandas/tests/arrays/test_period.py b/pandas/tests/arrays/test_period.py index 2592a0263c585..de0e766e4a2aa 100644 --- a/pandas/tests/arrays/test_period.py +++ b/pandas/tests/arrays/test_period.py @@ -124,10 +124,16 @@ def test_sub_period(): [pd.Period("2000", freq="H"), period_array(["2000", "2001", "2000"], freq="H")], ) def test_where_different_freq_raises(other): + # GH#45768 The PeriodArray method raises, the Series method coerces ser = pd.Series(period_array(["2000", "2001", "2002"], freq="D")) cond = np.array([True, False, True]) + with pytest.raises(IncompatibleFrequency, match="freq"): - ser.where(cond, other) + ser.array._where(cond, other) + + res = ser.where(cond, other) + expected = ser.astype(object).where(cond, other) + tm.assert_series_equal(res, expected) # ---------------------------------------------------------------------------- diff --git a/pandas/tests/frame/indexing/test_where.py b/pandas/tests/frame/indexing/test_where.py index bd43328e0c16c..d2fa187106e1b 100644 --- a/pandas/tests/frame/indexing/test_where.py +++ b/pandas/tests/frame/indexing/test_where.py @@ -706,6 +706,20 @@ def test_where_interval_noop(self): res = ser.where(ser.notna()) tm.assert_series_equal(res, ser) + def test_where_interval_fullop_downcast(self, frame_or_series): + # GH#45768 + obj = frame_or_series([pd.Interval(0, 0)] * 2) + other = frame_or_series([1.0, 2.0]) + res = obj.where(~obj.notna(), other) + + # since all entries are being changed, we will downcast result + # from object to ints (not floats) + tm.assert_equal(res, other.astype(np.int64)) + + # unlike where, Block.putmask does not downcast + obj.mask(obj.notna(), other, inplace=True) + tm.assert_equal(obj, other.astype(object)) + @pytest.mark.parametrize( "dtype", [ @@ -736,6 +750,16 @@ def test_where_datetimelike_noop(self, dtype): res4 = df.mask(mask2, "foo") tm.assert_frame_equal(res4, df) + # opposite case where we are replacing *all* values -> we downcast + # from object dtype # GH#45768 + res5 = df.where(mask2, 4) + expected = DataFrame(4, index=df.index, columns=df.columns) + tm.assert_frame_equal(res5, expected) + + # unlike where, Block.putmask does not downcast + df.mask(~mask2, 4, inplace=True) + tm.assert_frame_equal(df, expected.astype(object)) + def test_where_try_cast_deprecated(frame_or_series): obj = DataFrame(np.random.randn(4, 3)) @@ -894,14 +918,29 @@ def test_where_period_invalid_na(frame_or_series, as_cat, request): else: msg = "value should be a 'Period'" - with pytest.raises(TypeError, match=msg): - obj.where(mask, tdnat) + if as_cat: + with pytest.raises(TypeError, match=msg): + obj.where(mask, tdnat) - with pytest.raises(TypeError, match=msg): - obj.mask(mask, tdnat) + with pytest.raises(TypeError, match=msg): + obj.mask(mask, tdnat) + + with pytest.raises(TypeError, match=msg): + obj.mask(mask, tdnat, inplace=True) + + else: + # With PeriodDtype, ser[i] = tdnat coerces instead of raising, + # so for consistency, ser[mask] = tdnat must as well + expected = obj.astype(object).where(mask, tdnat) + result = obj.where(mask, tdnat) + tm.assert_equal(result, expected) + + expected = obj.astype(object).mask(mask, tdnat) + result = obj.mask(mask, tdnat) + tm.assert_equal(result, expected) - with pytest.raises(TypeError, match=msg): obj.mask(mask, tdnat, inplace=True) + tm.assert_equal(obj, expected) def test_where_nullable_invalid_na(frame_or_series, any_numeric_ea_dtype): diff --git a/pandas/tests/series/indexing/test_setitem.py b/pandas/tests/series/indexing/test_setitem.py index b3b4af165d297..dd871796b36b6 100644 --- a/pandas/tests/series/indexing/test_setitem.py +++ b/pandas/tests/series/indexing/test_setitem.py @@ -18,6 +18,7 @@ IntervalIndex, MultiIndex, NaT, + Period, Series, Timedelta, Timestamp, @@ -1275,6 +1276,22 @@ def obj(self): return Series(timedelta_range("1 day", periods=4)) +@pytest.mark.parametrize( + "val", ["foo", Period("2016", freq="Y"), Interval(1, 2, closed="both")] +) +@pytest.mark.parametrize("exp_dtype", [object]) +class TestPeriodIntervalCoercion(CoercionTest): + # GH#45768 + @pytest.fixture( + params=[ + period_range("2016-01-01", periods=3, freq="D"), + interval_range(1, 5), + ] + ) + def obj(self, request): + return Series(request.param) + + def test_20643(): # closed by GH#45121 orig = Series([0, 1, 2], index=["a", "b", "c"])