3232 Shape ,
3333 npt ,
3434)
35+ from pandas .compat import (
36+ IS64 ,
37+ is_platform_windows ,
38+ )
3539from pandas .errors import AbstractMethodError
3640from pandas .util ._decorators import doc
3741from pandas .util ._validators import validate_fillna_kwargs
@@ -1081,21 +1085,31 @@ def _quantile(
10811085 # ------------------------------------------------------------------
10821086 # Reductions
10831087
1084- def _reduce (self , name : str , * , skipna : bool = True , ** kwargs ):
1088+ def _reduce (
1089+ self , name : str , * , skipna : bool = True , keepdims : bool = False , ** kwargs
1090+ ):
10851091 if name in {"any" , "all" , "min" , "max" , "sum" , "prod" , "mean" , "var" , "std" }:
1086- return getattr (self , name )(skipna = skipna , ** kwargs )
1087-
1088- data = self ._data
1089- mask = self ._mask
1090-
1091- # median, skew, kurt, sem
1092- op = getattr (nanops , f"nan{ name } " )
1093- result = op (data , axis = 0 , skipna = skipna , mask = mask , ** kwargs )
1092+ result = getattr (self , name )(skipna = skipna , ** kwargs )
1093+ else :
1094+ # median, skew, kurt, sem
1095+ data = self ._data
1096+ mask = self ._mask
1097+ op = getattr (nanops , f"nan{ name } " )
1098+ axis = kwargs .pop ("axis" , None )
1099+ result = op (data , axis = axis , skipna = skipna , mask = mask , ** kwargs )
1100+
1101+ if keepdims :
1102+ if isna (result ):
1103+ return self ._wrap_na_result (name = name , axis = 0 , mask_size = (1 ,))
1104+ else :
1105+ result = result .reshape (1 )
1106+ mask = np .zeros (1 , dtype = bool )
1107+ return self ._maybe_mask_result (result , mask )
10941108
1095- if np . isnan (result ):
1109+ if isna (result ):
10961110 return libmissing .NA
1097-
1098- return result
1111+ else :
1112+ return result
10991113
11001114 def _wrap_reduction_result (self , name : str , result , * , skipna , axis ):
11011115 if isinstance (result , np .ndarray ):
@@ -1108,6 +1122,32 @@ def _wrap_reduction_result(self, name: str, result, *, skipna, axis):
11081122 return self ._maybe_mask_result (result , mask )
11091123 return result
11101124
1125+ def _wrap_na_result (self , * , name , axis , mask_size ):
1126+ mask = np .ones (mask_size , dtype = bool )
1127+
1128+ float_dtyp = "float32" if self .dtype == "Float32" else "float64"
1129+ if name in ["mean" , "median" , "var" , "std" , "skew" ]:
1130+ np_dtype = float_dtyp
1131+ elif name in ["min" , "max" ] or self .dtype .itemsize == 8 :
1132+ np_dtype = self .dtype .numpy_dtype .name
1133+ else :
1134+ is_windows_or_32bit = is_platform_windows () or not IS64
1135+ int_dtyp = "int32" if is_windows_or_32bit else "int64"
1136+ uint_dtyp = "uint32" if is_windows_or_32bit else "uint64"
1137+ np_dtype = {"b" : int_dtyp , "i" : int_dtyp , "u" : uint_dtyp , "f" : float_dtyp }[
1138+ self .dtype .kind
1139+ ]
1140+
1141+ value = np .array ([1 ], dtype = np_dtype )
1142+ return self ._maybe_mask_result (value , mask = mask )
1143+
1144+ def _wrap_min_count_reduction_result (
1145+ self , name : str , result , * , skipna , min_count , axis
1146+ ):
1147+ if min_count == 0 and isinstance (result , np .ndarray ):
1148+ return self ._maybe_mask_result (result , np .zeros (result .shape , dtype = bool ))
1149+ return self ._wrap_reduction_result (name , result , skipna = skipna , axis = axis )
1150+
11111151 def sum (
11121152 self ,
11131153 * ,
@@ -1125,7 +1165,9 @@ def sum(
11251165 min_count = min_count ,
11261166 axis = axis ,
11271167 )
1128- return self ._wrap_reduction_result ("sum" , result , skipna = skipna , axis = axis )
1168+ return self ._wrap_min_count_reduction_result (
1169+ "sum" , result , skipna = skipna , min_count = min_count , axis = axis
1170+ )
11291171
11301172 def prod (
11311173 self ,
@@ -1136,14 +1178,17 @@ def prod(
11361178 ** kwargs ,
11371179 ):
11381180 nv .validate_prod ((), kwargs )
1181+
11391182 result = masked_reductions .prod (
11401183 self ._data ,
11411184 self ._mask ,
11421185 skipna = skipna ,
11431186 min_count = min_count ,
11441187 axis = axis ,
11451188 )
1146- return self ._wrap_reduction_result ("prod" , result , skipna = skipna , axis = axis )
1189+ return self ._wrap_min_count_reduction_result (
1190+ "prod" , result , skipna = skipna , min_count = min_count , axis = axis
1191+ )
11471192
11481193 def mean (self , * , skipna : bool = True , axis : AxisInt | None = 0 , ** kwargs ):
11491194 nv .validate_mean ((), kwargs )
@@ -1183,23 +1228,25 @@ def std(
11831228
11841229 def min (self , * , skipna : bool = True , axis : AxisInt | None = 0 , ** kwargs ):
11851230 nv .validate_min ((), kwargs )
1186- return masked_reductions .min (
1231+ result = masked_reductions .min (
11871232 self ._data ,
11881233 self ._mask ,
11891234 skipna = skipna ,
11901235 axis = axis ,
11911236 )
1237+ return self ._wrap_reduction_result ("min" , result , skipna = skipna , axis = axis )
11921238
11931239 def max (self , * , skipna : bool = True , axis : AxisInt | None = 0 , ** kwargs ):
11941240 nv .validate_max ((), kwargs )
1195- return masked_reductions .max (
1241+ result = masked_reductions .max (
11961242 self ._data ,
11971243 self ._mask ,
11981244 skipna = skipna ,
11991245 axis = axis ,
12001246 )
1247+ return self ._wrap_reduction_result ("max" , result , skipna = skipna , axis = axis )
12011248
1202- def any (self , * , skipna : bool = True , ** kwargs ):
1249+ def any (self , * , skipna : bool = True , axis : AxisInt | None = 0 , ** kwargs ):
12031250 """
12041251 Return whether any element is truthy.
12051252
@@ -1218,6 +1265,7 @@ def any(self, *, skipna: bool = True, **kwargs):
12181265 If `skipna` is False, the result will still be True if there is
12191266 at least one element that is truthy, otherwise NA will be returned
12201267 if there are NA's present.
1268+ axis : int, optional, default 0
12211269 **kwargs : any, default None
12221270 Additional keywords have no effect but might be accepted for
12231271 compatibility with NumPy.
@@ -1261,7 +1309,6 @@ def any(self, *, skipna: bool = True, **kwargs):
12611309 >>> pd.array([0, 0, pd.NA]).any(skipna=False)
12621310 <NA>
12631311 """
1264- kwargs .pop ("axis" , None )
12651312 nv .validate_any ((), kwargs )
12661313
12671314 values = self ._data .copy ()
@@ -1280,7 +1327,7 @@ def any(self, *, skipna: bool = True, **kwargs):
12801327 else :
12811328 return self .dtype .na_value
12821329
1283- def all (self , * , skipna : bool = True , ** kwargs ):
1330+ def all (self , * , skipna : bool = True , axis : AxisInt | None = 0 , ** kwargs ):
12841331 """
12851332 Return whether all elements are truthy.
12861333
@@ -1299,6 +1346,7 @@ def all(self, *, skipna: bool = True, **kwargs):
12991346 If `skipna` is False, the result will still be False if there is
13001347 at least one element that is falsey, otherwise NA will be returned
13011348 if there are NA's present.
1349+ axis : int, optional, default 0
13021350 **kwargs : any, default None
13031351 Additional keywords have no effect but might be accepted for
13041352 compatibility with NumPy.
@@ -1342,7 +1390,6 @@ def all(self, *, skipna: bool = True, **kwargs):
13421390 >>> pd.array([1, 0, pd.NA]).all(skipna=False)
13431391 False
13441392 """
1345- kwargs .pop ("axis" , None )
13461393 nv .validate_all ((), kwargs )
13471394
13481395 values = self ._data .copy ()
@@ -1352,7 +1399,7 @@ def all(self, *, skipna: bool = True, **kwargs):
13521399 # bool, int, float, complex, str, bytes,
13531400 # _NestedSequence[Union[bool, int, float, complex, str, bytes]]]"
13541401 np .putmask (values , self ._mask , self ._truthy_value ) # type: ignore[arg-type]
1355- result = values .all ()
1402+ result = values .all (axis = axis )
13561403
13571404 if skipna :
13581405 return result
0 commit comments