Skip to content

Commit f1dc278

Browse files
authored
PERF: support mask in libgroupby.group_nth (#46163)
1 parent 9094a43 commit f1dc278

File tree

3 files changed

+24
-5
lines changed

3 files changed

+24
-5
lines changed

pandas/_libs/groupby.pyi

+2
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,8 @@ def group_nth(
111111
counts: np.ndarray, # int64_t[::1]
112112
values: np.ndarray, # ndarray[rank_t, ndim=2]
113113
labels: np.ndarray, # const int64_t[:]
114+
mask: npt.NDArray[np.bool_] | None,
115+
result_mask: npt.NDArray[np.bool_] | None,
114116
min_count: int = ..., # int64_t
115117
rank: int = ..., # int64_t
116118
) -> None: ...

pandas/_libs/groupby.pyx

+20-3
Original file line numberDiff line numberDiff line change
@@ -1066,6 +1066,8 @@ def group_nth(iu_64_floating_obj_t[:, ::1] out,
10661066
int64_t[::1] counts,
10671067
ndarray[iu_64_floating_obj_t, ndim=2] values,
10681068
const intp_t[::1] labels,
1069+
const uint8_t[:, :] mask,
1070+
uint8_t[:, ::1] result_mask=None,
10691071
int64_t min_count=-1,
10701072
int64_t rank=1,
10711073
) -> None:
@@ -1078,6 +1080,8 @@ def group_nth(iu_64_floating_obj_t[:, ::1] out,
10781080
ndarray[iu_64_floating_obj_t, ndim=2] resx
10791081
ndarray[int64_t, ndim=2] nobs
10801082
bint runtime_error = False
1083+
bint uses_mask = mask is not None
1084+
bint isna_entry
10811085

10821086
# TODO(cython3):
10831087
# Instead of `labels.shape[0]` use `len(labels)`
@@ -1104,7 +1108,12 @@ def group_nth(iu_64_floating_obj_t[:, ::1] out,
11041108
for j in range(K):
11051109
val = values[i, j]
11061110

1107-
if not checknull(val):
1111+
if uses_mask:
1112+
isna_entry = mask[i, j]
1113+
else:
1114+
isna_entry = checknull(val)
1115+
1116+
if not isna_entry:
11081117
# NB: use _treat_as_na here once
11091118
# conditional-nogil is available.
11101119
nobs[lab, j] += 1
@@ -1129,16 +1138,24 @@ def group_nth(iu_64_floating_obj_t[:, ::1] out,
11291138
for j in range(K):
11301139
val = values[i, j]
11311140

1132-
if not _treat_as_na(val, True):
1141+
if uses_mask:
1142+
isna_entry = mask[i, j]
1143+
else:
1144+
isna_entry = _treat_as_na(val, True)
11331145
# TODO: Sure we always want is_datetimelike=True?
1146+
1147+
if not isna_entry:
11341148
nobs[lab, j] += 1
11351149
if nobs[lab, j] == rank:
11361150
resx[lab, j] = val
11371151

11381152
for i in range(ncounts):
11391153
for j in range(K):
11401154
if nobs[i, j] < min_count:
1141-
if iu_64_floating_obj_t is int64_t:
1155+
if uses_mask:
1156+
result_mask[i, j] = True
1157+
elif iu_64_floating_obj_t is int64_t:
1158+
# TODO: only if datetimelike?
11421159
out[i, j] = NPY_NAT
11431160
elif iu_64_floating_obj_t is uint64_t:
11441161
runtime_error = True

pandas/core/groupby/ops.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ def __init__(self, kind: str, how: str):
140140

141141
# "group_any" and "group_all" are also support masks, but don't go
142142
# through WrappedCythonOp
143-
_MASKED_CYTHON_FUNCTIONS = {"cummin", "cummax", "min", "max", "last"}
143+
_MASKED_CYTHON_FUNCTIONS = {"cummin", "cummax", "min", "max", "last", "first"}
144144

145145
_cython_arity = {"ohlc": 4} # OHLC
146146

@@ -532,7 +532,7 @@ def _call_cython_op(
532532
result_mask=result_mask,
533533
is_datetimelike=is_datetimelike,
534534
)
535-
elif self.how in ["last"]:
535+
elif self.how in ["first", "last"]:
536536
func(
537537
out=result,
538538
counts=counts,

0 commit comments

Comments
 (0)