Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BUG: nullable dtypes not preserved in Series.replace #44940

Merged
merged 9 commits into from
Dec 19, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions doc/source/whatsnew/v1.4.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -529,7 +529,7 @@ Other Deprecations
- Deprecated silent dropping of columns that raised a ``TypeError`` in :class:`Series.transform` and :class:`DataFrame.transform` when used with a dictionary (:issue:`43740`)
- Deprecated silent dropping of columns that raised a ``TypeError``, ``DataError``, and some cases of ``ValueError`` in :meth:`Series.aggregate`, :meth:`DataFrame.aggregate`, :meth:`Series.groupby.aggregate`, and :meth:`DataFrame.groupby.aggregate` when used with a list (:issue:`43740`)
- Deprecated casting behavior when setting timezone-aware value(s) into a timezone-aware :class:`Series` or :class:`DataFrame` column when the timezones do not match. Previously this cast to object dtype. In a future version, the values being inserted will be converted to the series or column's existing timezone (:issue:`37605`)
- Deprecated casting behavior when passing an item with mismatched-timezone to :meth:`DatetimeIndex.insert`, :meth:`DatetimeIndex.putmask`, :meth:`DatetimeIndex.where` :meth:`DatetimeIndex.fillna`, :meth:`Series.mask`, :meth:`Series.where`, :meth:`Series.fillna`, :meth:`Series.shift`, :meth:`Series.replace`, :meth:`Series.reindex` (and :class:`DataFrame` column analogues). In the past this has cast to object dtype. In a future version, these will cast the passed item to the index or series's timezone (:issue:`37605`)
- Deprecated casting behavior when passing an item with mismatched-timezone to :meth:`DatetimeIndex.insert`, :meth:`DatetimeIndex.putmask`, :meth:`DatetimeIndex.where` :meth:`DatetimeIndex.fillna`, :meth:`Series.mask`, :meth:`Series.where`, :meth:`Series.fillna`, :meth:`Series.shift`, :meth:`Series.replace`, :meth:`Series.reindex` (and :class:`DataFrame` column analogues). In the past this has cast to object dtype. In a future version, these will cast the passed item to the index or series's timezone (:issue:`37605`,:issue:`44940`)
- Deprecated the 'errors' keyword argument in :meth:`Series.where`, :meth:`DataFrame.where`, :meth:`Series.mask`, and meth:`DataFrame.mask`; in a future version the argument will be removed (:issue:`44294`)
- Deprecated the ``prefix`` keyword argument in :func:`read_csv` and :func:`read_table`, in a future version the argument will be removed (:issue:`43396`)
- Deprecated :meth:`PeriodIndex.astype` to ``datetime64[ns]`` or ``DatetimeTZDtype``, use ``obj.to_timestamp(how).tz_localize(dtype.tz)`` instead (:issue:`44398`)
Expand Down Expand Up @@ -837,7 +837,7 @@ ExtensionArray
- Bug in :func:`array` incorrectly raising when passed a ``ndarray`` with ``float16`` dtype (:issue:`44715`)
- Bug in calling ``np.sqrt`` on :class:`BooleanArray` returning a malformed :class:`FloatingArray` (:issue:`44715`)
- Bug in :meth:`Series.where` with ``ExtensionDtype`` when ``other`` is a NA scalar incompatible with the series dtype (e.g. ``NaT`` with a numeric dtype) incorrectly casting to a compatible NA value (:issue:`44697`)
-
- Fixed bug in :meth:`Series.replace` with ``FloatDtype``, ``string[python]``, or ``string[pyarrow]`` dtype not being preserved when possible (:issue:`33484`)
jreback marked this conversation as resolved.
Show resolved Hide resolved

Styler
^^^^^^
Expand Down
3 changes: 2 additions & 1 deletion pandas/core/array_algos/replace.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,8 @@ def _check_comparison_types(
f"Cannot compare types {repr(type_names[0])} and {repr(type_names[1])}"
)

if not regex:
if not regex or not should_use_regex(regex, b):
# TODO: should use missing.mask_missing?
op = lambda x: operator.eq(x, b)
else:
op = np.vectorize(
Expand Down
47 changes: 17 additions & 30 deletions pandas/core/internals/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -640,6 +640,8 @@ def replace(
to_replace,
value,
inplace: bool = False,
# mask may be pre-computed if we're called from replace_list
mask: npt.NDArray[np.bool_] | None = None,
) -> list[Block]:
"""
replace the to_replace value with value, possible to create new
Expand All @@ -665,7 +667,8 @@ def replace(
# replace_list instead of replace.
return [self] if inplace else [self.copy()]

mask = missing.mask_missing(values, to_replace)
if mask is None:
mask = missing.mask_missing(values, to_replace)
if not mask.any():
# Note: we get here with test_replace_extension_other incorrectly
# bc _can_hold_element is incorrect.
Expand All @@ -683,6 +686,7 @@ def replace(
to_replace=to_replace,
value=value,
inplace=True,
mask=mask,
)

else:
Expand Down Expand Up @@ -746,16 +750,6 @@ def replace_list(
"""
values = self.values

# TODO: dont special-case Categorical
if (
isinstance(values, Categorical)
and len(algos.unique(dest_list)) == 1
and not regex
):
# We likely got here by tiling value inside NDFrame.replace,
# so un-tile here
return self.replace(src_list, dest_list[0], inplace)

# Exclude anything that we know we won't contain
pairs = [
(x, y) for x, y in zip(src_list, dest_list) if self._can_hold_element(x)
Expand Down Expand Up @@ -844,25 +838,18 @@ def _replace_coerce(
-------
List[Block]
"""
if mask.any():
if not regex:
nb = self.coerce_to_target_dtype(value)
if nb is self and not inplace:
nb = nb.copy()
putmask_inplace(nb.values, mask, value)
return [nb]
else:
regex = should_use_regex(regex, to_replace)
if regex:
return self._replace_regex(
to_replace,
value,
inplace=inplace,
convert=False,
mask=mask,
)
return self.replace(to_replace, value, inplace=inplace)
return [self]
if should_use_regex(regex, to_replace):
return self._replace_regex(
to_replace,
value,
inplace=inplace,
convert=False,
mask=mask,
)
else:
return self.replace(
to_replace=to_replace, value=value, inplace=inplace, mask=mask
)

# ---------------------------------------------------------------------

Expand Down
12 changes: 2 additions & 10 deletions pandas/tests/arrays/categorical/test_replace.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import numpy as np
import pytest

import pandas as pd
Expand All @@ -20,18 +19,15 @@
([1, 2], 4, [4, 4, 3], False),
((1, 2, 4), 5, [5, 5, 3], False),
((5, 6), 2, [1, 2, 3], False),
# many-to-many, handled outside of Categorical and results in separate dtype
# except for cases with only 1 unique entry in `value`
([1], [2], [2, 2, 3], True),
([1, 4], [5, 2], [5, 2, 3], True),
([1], [2], [2, 2, 3], False),
([1, 4], [5, 2], [5, 2, 3], False),
# check_categorical sorts categories, which crashes on mixed dtypes
(3, "4", [1, 2, "4"], False),
([1, 2, "3"], "5", ["5", "5", 3], True),
],
)
def test_replace_categorical_series(to_replace, value, expected, flip_categories):
# GH 31720
stays_categorical = not isinstance(value, list) or len(pd.unique(value)) == 1

ser = pd.Series([1, 2, 3], dtype="category")
result = ser.replace(to_replace, value)
Expand All @@ -41,10 +37,6 @@ def test_replace_categorical_series(to_replace, value, expected, flip_categories
if flip_categories:
expected = expected.cat.set_categories(expected.cat.categories[::-1])

if not stays_categorical:
# the replace call loses categorical dtype
expected = pd.Series(np.asarray(expected))

tm.assert_series_equal(expected, result, check_category_order=False)
tm.assert_series_equal(expected, ser, check_category_order=False)

Expand Down
13 changes: 9 additions & 4 deletions pandas/tests/frame/methods/test_replace.py
Original file line number Diff line number Diff line change
Expand Up @@ -624,6 +624,14 @@ def test_replace_mixed3(self):
expected.iloc[1, 1] = m[1]
tm.assert_frame_equal(result, expected)

def test_replace_nullable_int_with_string_doesnt_cast(self):
# GH#25438 don't cast df['a'] to float64
df = DataFrame({"a": [1, 2, 3, np.nan], "b": ["some", "strings", "here", "he"]})
df["a"] = df["a"].astype("Int64")

res = df.replace("", np.nan)
tm.assert_series_equal(res["a"], df["a"])

@pytest.mark.parametrize("dtype", ["boolean", "Int64", "Float64"])
def test_replace_with_nullable_column(self, dtype):
# GH-44499
Expand Down Expand Up @@ -1382,15 +1390,12 @@ def test_replace_value_category_type(self):

tm.assert_frame_equal(result, expected)

@pytest.mark.xfail(
reason="category dtype gets changed to object type after replace, see #35268",
raises=AssertionError,
)
def test_replace_dict_category_type(self):
"""
Test to ensure category dtypes are maintained
after replace with dict values
"""
# GH#35268, GH#44940

# create input dataframe
input_dict = {"col1": ["a"], "col2": ["obj1"], "col3": ["cat1"]}
Expand Down
17 changes: 16 additions & 1 deletion pandas/tests/indexing/test_coercion.py
Original file line number Diff line number Diff line change
Expand Up @@ -1177,6 +1177,7 @@ def test_replace_series_datetime_tz(self, how, to_key, from_key, replacer):
assert obj.dtype == from_key

result = obj.replace(replacer)

exp = pd.Series(self.rep[to_key], index=index, name="yyy")
assert exp.dtype == to_key

Expand All @@ -1197,7 +1198,21 @@ def test_replace_series_datetime_datetime(self, how, to_key, from_key, replacer)
obj = pd.Series(self.rep[from_key], index=index, name="yyy")
assert obj.dtype == from_key

result = obj.replace(replacer)
warn = None
rep_ser = pd.Series(replacer)
if (
isinstance(obj.dtype, pd.DatetimeTZDtype)
and isinstance(rep_ser.dtype, pd.DatetimeTZDtype)
and obj.dtype != rep_ser.dtype
):
# mismatched tz DatetimeArray behavior will change to cast
# for setitem-like methods with mismatched tzs GH#44940
warn = FutureWarning

msg = "explicitly cast to object"
with tm.assert_produces_warning(warn, match=msg):
result = obj.replace(replacer)

exp = pd.Series(self.rep[to_key], index=index, name="yyy")
assert exp.dtype == to_key

Expand Down
128 changes: 109 additions & 19 deletions pandas/tests/series/methods/test_replace.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import pandas as pd
import pandas._testing as tm
from pandas.core.arrays import IntervalArray


class TestSeriesReplace:
Expand Down Expand Up @@ -148,20 +149,21 @@ def test_replace_with_single_list(self):
tm.assert_series_equal(s, ser)

def test_replace_mixed_types(self):
s = pd.Series(np.arange(5), dtype="int64")
ser = pd.Series(np.arange(5), dtype="int64")

def check_replace(to_rep, val, expected):
sc = s.copy()
r = s.replace(to_rep, val)
sc = ser.copy()
result = ser.replace(to_rep, val)
return_value = sc.replace(to_rep, val, inplace=True)
assert return_value is None
tm.assert_series_equal(expected, r)
tm.assert_series_equal(expected, result)
tm.assert_series_equal(expected, sc)

# MUST upcast to float
e = pd.Series([0.0, 1.0, 2.0, 3.0, 4.0])
# 3.0 can still be held in our int64 series, so we do not upcast GH#44940
tr, v = [3], [3.0]
check_replace(tr, v, e)
check_replace(tr, v, ser)
# Note this matches what we get with the scalars 3 and 3.0
check_replace(tr[0], v[0], ser)

# MUST upcast to float
e = pd.Series([0, 1, 2, 3.5, 4])
Expand Down Expand Up @@ -257,10 +259,10 @@ def test_replace2(self):
assert (ser[20:30] == -1).all()

def test_replace_with_dictlike_and_string_dtype(self, nullable_string_dtype):
# GH 32621
s = pd.Series(["one", "two", np.nan], dtype=nullable_string_dtype)
expected = pd.Series(["1", "2", np.nan])
result = s.replace({"one": "1", "two": "2"})
# GH 32621, GH#44940
ser = pd.Series(["one", "two", np.nan], dtype=nullable_string_dtype)
expected = pd.Series(["1", "2", np.nan], dtype=nullable_string_dtype)
result = ser.replace({"one": "1", "two": "2"})
tm.assert_series_equal(expected, result)

def test_replace_with_empty_dictlike(self):
Expand Down Expand Up @@ -305,17 +307,18 @@ def test_replace_mixed_types_with_string(self):
"categorical, numeric",
[
(pd.Categorical(["A"], categories=["A", "B"]), [1]),
(pd.Categorical(("A",), categories=["A", "B"]), [1]),
(pd.Categorical(("A", "B"), categories=["A", "B"]), [1, 2]),
(pd.Categorical(["A", "B"], categories=["A", "B"]), [1, 2]),
],
)
def test_replace_categorical(self, categorical, numeric):
# GH 24971
# Do not check if dtypes are equal due to a known issue that
# Categorical.replace sometimes coerces to object (GH 23305)
s = pd.Series(categorical)
result = s.replace({"A": 1, "B": 2})
expected = pd.Series(numeric)
# GH 24971, GH#23305
ser = pd.Series(categorical)
result = ser.replace({"A": 1, "B": 2})
expected = pd.Series(numeric).astype("category")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

umm why is this a resulting categorical?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

bc that's what someone implemented for the Categorical.replace behavior. im on the fence about it, but for now i think we need to be consistent with it

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm ok let's open an issue about this, i dont' think we should raise (or coerce to object) rather than return a new categorical (but maybe others disagree)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you clarify? based on your earlier comment i expected the opposite opinion, so i think the "dont" may be an unintentional double-negative.

if 2 not in expected.cat.categories:
# i.e. categories should be [1, 2] even if there are no "B"s present
# GH#44940
expected = expected.cat.add_categories(2)
tm.assert_series_equal(expected, result)

def test_replace_categorical_single(self):
Expand Down Expand Up @@ -514,3 +517,90 @@ def test_pandas_replace_na(self):
result = ser.replace(regex_mapping, regex=True)
exp = pd.Series(["CC", "CC", "CC-REPL", "DD", "CC", "", pd.NA], dtype="string")
tm.assert_series_equal(result, exp)

@pytest.mark.parametrize(
"dtype, input_data, to_replace, expected_data",
[
("bool", [True, False], {True: False}, [False, False]),
("int64", [1, 2], {1: 10, 2: 20}, [10, 20]),
("Int64", [1, 2], {1: 10, 2: 20}, [10, 20]),
("float64", [1.1, 2.2], {1.1: 10.1, 2.2: 20.5}, [10.1, 20.5]),
("Float64", [1.1, 2.2], {1.1: 10.1, 2.2: 20.5}, [10.1, 20.5]),
("string", ["one", "two"], {"one": "1", "two": "2"}, ["1", "2"]),
(
pd.IntervalDtype("int64"),
IntervalArray([pd.Interval(1, 2), pd.Interval(2, 3)]),
{pd.Interval(1, 2): pd.Interval(10, 20)},
IntervalArray([pd.Interval(10, 20), pd.Interval(2, 3)]),
),
(
pd.IntervalDtype("float64"),
IntervalArray([pd.Interval(1.0, 2.7), pd.Interval(2.8, 3.1)]),
{pd.Interval(1.0, 2.7): pd.Interval(10.6, 20.8)},
IntervalArray([pd.Interval(10.6, 20.8), pd.Interval(2.8, 3.1)]),
),
(
pd.PeriodDtype("M"),
[pd.Period("2020-05", freq="M")],
{pd.Period("2020-05", freq="M"): pd.Period("2020-06", freq="M")},
[pd.Period("2020-06", freq="M")],
),
],
)
def test_replace_dtype(self, dtype, input_data, to_replace, expected_data):
# GH#33484
ser = pd.Series(input_data, dtype=dtype)
result = ser.replace(to_replace)
expected = pd.Series(expected_data, dtype=dtype)
tm.assert_series_equal(result, expected)

def test_replace_string_dtype(self):
# GH#40732, GH#44940
ser = pd.Series(["one", "two", np.nan], dtype="string")
res = ser.replace({"one": "1", "two": "2"})
expected = pd.Series(["1", "2", np.nan], dtype="string")
tm.assert_series_equal(res, expected)

# GH#31644
ser2 = pd.Series(["A", np.nan], dtype="string")
res2 = ser2.replace("A", "B")
expected2 = pd.Series(["B", np.nan], dtype="string")
tm.assert_series_equal(res2, expected2)

ser3 = pd.Series(["A", "B"], dtype="string")
res3 = ser3.replace("A", pd.NA)
expected3 = pd.Series([pd.NA, "B"], dtype="string")
tm.assert_series_equal(res3, expected3)

def test_replace_string_dtype_list_to_replace(self):
# GH#41215, GH#44940
ser = pd.Series(["abc", "def"], dtype="string")
res = ser.replace(["abc", "any other string"], "xyz")
expected = pd.Series(["xyz", "def"], dtype="string")
tm.assert_series_equal(res, expected)

def test_replace_string_dtype_regex(self):
# GH#31644
ser = pd.Series(["A", "B"], dtype="string")
res = ser.replace(r".", "C", regex=True)
expected = pd.Series(["C", "C"], dtype="string")
tm.assert_series_equal(res, expected)

def test_replace_nullable_numeric(self):
# GH#40732, GH#44940

floats = pd.Series([1.0, 2.0, 3.999, 4.4], dtype=pd.Float64Dtype())
assert floats.replace({1.0: 9}).dtype == floats.dtype
assert floats.replace(1.0, 9).dtype == floats.dtype
assert floats.replace({1.0: 9.0}).dtype == floats.dtype
assert floats.replace(1.0, 9.0).dtype == floats.dtype

res = floats.replace(to_replace=[1.0, 2.0], value=[9.0, 10.0])
assert res.dtype == floats.dtype

ints = pd.Series([1, 2, 3, 4], dtype=pd.Int64Dtype())
assert ints.replace({1: 9}).dtype == ints.dtype
assert ints.replace(1, 9).dtype == ints.dtype
assert ints.replace({1: 9.0}).dtype == ints.dtype
assert ints.replace(1, 9.0).dtype == ints.dtype
# FIXME: ints.replace({1: 9.5}) raises bc of incorrect _can_hold_element