@@ -1066,6 +1066,8 @@ def group_nth(iu_64_floating_obj_t[:, ::1] out,
1066
1066
int64_t[::1] counts ,
1067
1067
ndarray[iu_64_floating_obj_t , ndim = 2 ] values,
1068
1068
const intp_t[::1] labels ,
1069
+ const uint8_t[:, :] mask ,
1070
+ uint8_t[:, ::1] result_mask = None ,
1069
1071
int64_t min_count = - 1 ,
1070
1072
int64_t rank = 1 ,
1071
1073
) -> None:
@@ -1078,6 +1080,8 @@ def group_nth(iu_64_floating_obj_t[:, ::1] out,
1078
1080
ndarray[iu_64_floating_obj_t , ndim = 2 ] resx
1079
1081
ndarray[int64_t , ndim = 2 ] nobs
1080
1082
bint runtime_error = False
1083
+ bint uses_mask = mask is not None
1084
+ bint isna_entry
1081
1085
1082
1086
# TODO(cython3 ):
1083
1087
# Instead of `labels.shape[0]` use `len(labels)`
@@ -1104,7 +1108,12 @@ def group_nth(iu_64_floating_obj_t[:, ::1] out,
1104
1108
for j in range (K):
1105
1109
val = values[i, j]
1106
1110
1107
- if not checknull(val):
1111
+ if uses_mask:
1112
+ isna_entry = mask[i, j]
1113
+ else :
1114
+ isna_entry = checknull(val)
1115
+
1116
+ if not isna_entry:
1108
1117
# NB: use _treat_as_na here once
1109
1118
# conditional-nogil is available.
1110
1119
nobs[lab, j] += 1
@@ -1129,16 +1138,24 @@ def group_nth(iu_64_floating_obj_t[:, ::1] out,
1129
1138
for j in range (K):
1130
1139
val = values[i, j]
1131
1140
1132
- if not _treat_as_na(val, True ):
1141
+ if uses_mask:
1142
+ isna_entry = mask[i, j]
1143
+ else :
1144
+ isna_entry = _treat_as_na(val, True )
1133
1145
# TODO: Sure we always want is_datetimelike=True?
1146
+
1147
+ if not isna_entry:
1134
1148
nobs[lab, j] += 1
1135
1149
if nobs[lab, j] == rank:
1136
1150
resx[lab, j] = val
1137
1151
1138
1152
for i in range (ncounts):
1139
1153
for j in range (K):
1140
1154
if nobs[i, j] < min_count:
1141
- if iu_64_floating_obj_t is int64_t:
1155
+ if uses_mask:
1156
+ result_mask[i, j] = True
1157
+ elif iu_64_floating_obj_t is int64_t:
1158
+ # TODO: only if datetimelike?
1142
1159
out[i, j] = NPY_NAT
1143
1160
elif iu_64_floating_obj_t is uint64_t:
1144
1161
runtime_error = True
0 commit comments