Skip to content

Commit 4940c51

Browse files
rhshadrachjorisvandenbossche
authored andcommitted
BUG(string dtype): groupby/resampler.min/max returns float on all NA strings (pandas-dev#60985)
* BUG(string dtype): groupby/resampler.min/max returns float on all NA strings * Merge cleanup * whatsnew * Add type-ignore * Remove condition
1 parent 2cc3762 commit 4940c51

File tree

2 files changed

+100
-1
lines changed

2 files changed

+100
-1
lines changed

pandas/core/groupby/groupby.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ class providing the base-class of operations.
8585
is_numeric_dtype,
8686
is_object_dtype,
8787
is_scalar,
88+
is_string_dtype,
8889
needs_i8_conversion,
8990
pandas_dtype,
9091
)
@@ -1945,8 +1946,13 @@ def _agg_py_fallback(
19451946
# preserve the kind of exception that raised
19461947
raise type(err)(msg) from err
19471948

1948-
if ser.dtype == object:
1949+
dtype = ser.dtype
1950+
if dtype == object:
19491951
res_values = res_values.astype(object, copy=False)
1952+
elif is_string_dtype(dtype):
1953+
# mypy doesn't infer dtype is an ExtensionDtype
1954+
string_array_cls = dtype.construct_array_type() # type: ignore[union-attr]
1955+
res_values = string_array_cls._from_sequence(res_values, dtype=dtype)
19501956

19511957
# If we are DataFrameGroupBy and went through a SeriesGroupByPath
19521958
# then we need to reshape

pandas/tests/groupby/test_reductions.py

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
isna,
2121
)
2222
import pandas._testing as tm
23+
from pandas.tests.groupby import get_groupby_method_args
2324
from pandas.util import _test_decorators as td
2425

2526

@@ -710,6 +711,98 @@ def test_min_empty_string_dtype(func, string_dtype_no_object):
710711
tm.assert_frame_equal(result, expected)
711712

712713

714+
@pytest.mark.parametrize("min_count", [0, 1])
715+
@pytest.mark.parametrize("test_series", [True, False])
716+
def test_string_dtype_all_na(
717+
string_dtype_no_object, reduction_func, skipna, min_count, test_series
718+
):
719+
# https://github.com/pandas-dev/pandas/issues/60985
720+
if reduction_func == "corrwith":
721+
# corrwith is deprecated.
722+
return
723+
724+
dtype = string_dtype_no_object
725+
726+
if reduction_func in [
727+
"any",
728+
"all",
729+
"idxmin",
730+
"idxmax",
731+
"mean",
732+
"median",
733+
"std",
734+
"var",
735+
]:
736+
kwargs = {"skipna": skipna}
737+
elif reduction_func in ["kurt"]:
738+
kwargs = {"min_count": min_count}
739+
elif reduction_func in ["count", "nunique", "quantile", "sem", "size"]:
740+
kwargs = {}
741+
else:
742+
kwargs = {"skipna": skipna, "min_count": min_count}
743+
744+
expected_dtype, expected_value = dtype, pd.NA
745+
if reduction_func in ["all", "any"]:
746+
expected_dtype = "bool"
747+
# TODO: For skipna=False, bool(pd.NA) raises; should groupby?
748+
expected_value = not skipna if reduction_func == "any" else True
749+
elif reduction_func in ["count", "nunique", "size"]:
750+
# TODO: Should be more consistent - return Int64 when dtype.na_value is pd.NA?
751+
if (
752+
test_series
753+
and reduction_func == "size"
754+
and dtype.storage == "pyarrow"
755+
and dtype.na_value is pd.NA
756+
):
757+
expected_dtype = "Int64"
758+
else:
759+
expected_dtype = "int64"
760+
expected_value = 1 if reduction_func == "size" else 0
761+
elif reduction_func in ["idxmin", "idxmax"]:
762+
expected_dtype, expected_value = "float64", np.nan
763+
elif not skipna or min_count > 0:
764+
expected_value = pd.NA
765+
elif reduction_func == "sum":
766+
# https://github.com/pandas-dev/pandas/pull/60936
767+
expected_value = ""
768+
769+
df = DataFrame({"a": ["x"], "b": [pd.NA]}, dtype=dtype)
770+
obj = df["b"] if test_series else df
771+
args = get_groupby_method_args(reduction_func, obj)
772+
gb = obj.groupby(df["a"])
773+
method = getattr(gb, reduction_func)
774+
775+
if reduction_func in [
776+
"mean",
777+
"median",
778+
"kurt",
779+
"prod",
780+
"quantile",
781+
"sem",
782+
"skew",
783+
"std",
784+
"var",
785+
]:
786+
msg = f"dtype '{dtype}' does not support operation '{reduction_func}'"
787+
with pytest.raises(TypeError, match=msg):
788+
method(*args, **kwargs)
789+
return
790+
elif reduction_func in ["idxmin", "idxmax"] and not skipna:
791+
msg = f"{reduction_func} with skipna=False encountered an NA value."
792+
with pytest.raises(ValueError, match=msg):
793+
method(*args, **kwargs)
794+
return
795+
796+
result = method(*args, **kwargs)
797+
index = pd.Index(["x"], name="a", dtype=dtype)
798+
if test_series or reduction_func == "size":
799+
name = None if not test_series and reduction_func == "size" else "b"
800+
expected = Series(expected_value, index=index, dtype=expected_dtype, name=name)
801+
else:
802+
expected = DataFrame({"b": expected_value}, index=index, dtype=expected_dtype)
803+
tm.assert_equal(result, expected)
804+
805+
713806
def test_max_nan_bug():
714807
df = DataFrame(
715808
{

0 commit comments

Comments
 (0)