Skip to content

Commit

Permalink
BUG: groupby.rank for dt64tz, period dtypes (#38187)
Browse files Browse the repository at this point in the history
  • Loading branch information
jbrockmendel authored Dec 2, 2020
1 parent 4749fd6 commit 23cbc47
Show file tree
Hide file tree
Showing 4 changed files with 79 additions and 6 deletions.
1 change: 1 addition & 0 deletions doc/source/whatsnew/v1.2.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -768,6 +768,7 @@ Groupby/resample/rolling
- Bug in :meth:`DataFrame.groupby` dropped ``nan`` groups from result with ``dropna=False`` when grouping over a single column (:issue:`35646`, :issue:`35542`)
- Bug in :meth:`.DataFrameGroupBy.head`, :meth:`.DataFrameGroupBy.tail`, :meth:`SeriesGroupBy.head`, and :meth:`SeriesGroupBy.tail` would raise when used with ``axis=1`` (:issue:`9772`)
- Bug in :meth:`.DataFrameGroupBy.transform` would raise when used with ``axis=1`` and a transformation kernel (e.g. "shift") (:issue:`36308`)
- Bug in :meth:`DataFrameGroupBy.rank` with ``datetime64tz`` or period dtype incorrectly casting results to those dtypes instead of returning ``float64`` dtype (:issue:`38187`)

Reshaping
^^^^^^^^^
Expand Down
2 changes: 1 addition & 1 deletion pandas/core/groupby/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -995,7 +995,7 @@ def _cython_transform(

try:
result, _ = self.grouper._cython_operation(
"transform", obj.values, how, axis, **kwargs
"transform", obj._values, how, axis, **kwargs
)
except NotImplementedError:
continue
Expand Down
5 changes: 4 additions & 1 deletion pandas/core/groupby/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,8 +491,11 @@ def _ea_wrap_cython_operation(
res_values, names = self._cython_operation(
kind, values, how, axis, min_count, **kwargs
)
if how in ["rank"]:
# preserve float64 dtype
return res_values, names

res_values = res_values.astype("i8", copy=False)
# FIXME: this is wrong for rank, but not tested.
result = type(orig_values)._simple_new(res_values, dtype=orig_values.dtype)
return result, names

Expand Down
77 changes: 73 additions & 4 deletions pandas/tests/groupby/test_rank.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,15 +42,40 @@ def test_rank_apply():
@pytest.mark.parametrize(
"vals",
[
[2, 2, 8, 2, 6],
np.array([2, 2, 8, 2, 6], dtype=dtype)
for dtype in ["i8", "i4", "i2", "i1", "u8", "u4", "u2", "u1", "f8", "f4", "f2"]
]
+ [
[
pd.Timestamp("2018-01-02"),
pd.Timestamp("2018-01-02"),
pd.Timestamp("2018-01-08"),
pd.Timestamp("2018-01-02"),
pd.Timestamp("2018-01-06"),
],
[
pd.Timestamp("2018-01-02", tz="US/Pacific"),
pd.Timestamp("2018-01-02", tz="US/Pacific"),
pd.Timestamp("2018-01-08", tz="US/Pacific"),
pd.Timestamp("2018-01-02", tz="US/Pacific"),
pd.Timestamp("2018-01-06", tz="US/Pacific"),
],
[
pd.Timestamp("2018-01-02") - pd.Timestamp(0),
pd.Timestamp("2018-01-02") - pd.Timestamp(0),
pd.Timestamp("2018-01-08") - pd.Timestamp(0),
pd.Timestamp("2018-01-02") - pd.Timestamp(0),
pd.Timestamp("2018-01-06") - pd.Timestamp(0),
],
[
pd.Timestamp("2018-01-02").to_period("D"),
pd.Timestamp("2018-01-02").to_period("D"),
pd.Timestamp("2018-01-08").to_period("D"),
pd.Timestamp("2018-01-02").to_period("D"),
pd.Timestamp("2018-01-06").to_period("D"),
],
],
ids=lambda x: type(x[0]),
)
@pytest.mark.parametrize(
"ties_method,ascending,pct,exp",
Expand Down Expand Up @@ -79,7 +104,12 @@ def test_rank_apply():
)
def test_rank_args(grps, vals, ties_method, ascending, pct, exp):
key = np.repeat(grps, len(vals))
vals = vals * len(grps)

orig_vals = vals
vals = list(vals) * len(grps)
if isinstance(orig_vals, np.ndarray):
vals = np.array(vals, dtype=orig_vals.dtype)

df = DataFrame({"key": key, "val": vals})
result = df.groupby("key").rank(method=ties_method, ascending=ascending, pct=pct)

Expand Down Expand Up @@ -142,7 +172,10 @@ def test_infs_n_nans(grps, vals, ties_method, ascending, na_option, exp):
@pytest.mark.parametrize(
"vals",
[
[2, 2, np.nan, 8, 2, 6, np.nan, np.nan],
np.array([2, 2, np.nan, 8, 2, 6, np.nan, np.nan], dtype=dtype)
for dtype in ["f8", "f4", "f2"]
]
+ [
[
pd.Timestamp("2018-01-02"),
pd.Timestamp("2018-01-02"),
Expand All @@ -153,7 +186,38 @@ def test_infs_n_nans(grps, vals, ties_method, ascending, na_option, exp):
np.nan,
np.nan,
],
[
pd.Timestamp("2018-01-02", tz="US/Pacific"),
pd.Timestamp("2018-01-02", tz="US/Pacific"),
np.nan,
pd.Timestamp("2018-01-08", tz="US/Pacific"),
pd.Timestamp("2018-01-02", tz="US/Pacific"),
pd.Timestamp("2018-01-06", tz="US/Pacific"),
np.nan,
np.nan,
],
[
pd.Timestamp("2018-01-02") - pd.Timestamp(0),
pd.Timestamp("2018-01-02") - pd.Timestamp(0),
np.nan,
pd.Timestamp("2018-01-08") - pd.Timestamp(0),
pd.Timestamp("2018-01-02") - pd.Timestamp(0),
pd.Timestamp("2018-01-06") - pd.Timestamp(0),
np.nan,
np.nan,
],
[
pd.Timestamp("2018-01-02").to_period("D"),
pd.Timestamp("2018-01-02").to_period("D"),
np.nan,
pd.Timestamp("2018-01-08").to_period("D"),
pd.Timestamp("2018-01-02").to_period("D"),
pd.Timestamp("2018-01-06").to_period("D"),
np.nan,
np.nan,
],
],
ids=lambda x: type(x[0]),
)
@pytest.mark.parametrize(
"ties_method,ascending,na_option,pct,exp",
Expand Down Expand Up @@ -346,7 +410,12 @@ def test_infs_n_nans(grps, vals, ties_method, ascending, na_option, exp):
)
def test_rank_args_missing(grps, vals, ties_method, ascending, na_option, pct, exp):
key = np.repeat(grps, len(vals))
vals = vals * len(grps)

orig_vals = vals
vals = list(vals) * len(grps)
if isinstance(orig_vals, np.ndarray):
vals = np.array(vals, dtype=orig_vals.dtype)

df = DataFrame({"key": key, "val": vals})
result = df.groupby("key").rank(
method=ties_method, ascending=ascending, na_option=na_option, pct=pct
Expand Down

0 comments on commit 23cbc47

Please sign in to comment.