Skip to content

Commit 452c7fb

Browse files
authored
ENH: add value_counts to EA interface (#62254)
1 parent bcc60a4 commit 452c7fb

File tree

7 files changed

+29
-59
lines changed

7 files changed

+29
-59
lines changed

ci/code_checks.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ if [[ -z "$CHECK" || "$CHECK" == "docstrings" ]]; then
7373
-i "pandas.Period.freq GL08" \
7474
-i "pandas.Period.ordinal GL08" \
7575
-i "pandas.errors.IncompatibleFrequency SA01,SS06,EX01" \
76+
-i "pandas.api.extensions.ExtensionArray.value_counts EX01,RT03,SA01" \
7677
-i "pandas.core.groupby.DataFrameGroupBy.plot PR02" \
7778
-i "pandas.core.groupby.SeriesGroupBy.plot PR02" \
7879
-i "pandas.core.resample.Resampler.quantile PR01,PR07" \

pandas/core/arrays/base.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,10 @@
9999
npt,
100100
)
101101

102-
from pandas import Index
102+
from pandas import (
103+
Index,
104+
Series,
105+
)
103106

104107
_extension_array_shared_docs: dict[str, str] = {}
105108

@@ -1673,6 +1676,25 @@ def repeat(self, repeats: int | Sequence[int], axis: AxisInt | None = None) -> S
16731676
ind = np.arange(len(self)).repeat(repeats)
16741677
return self.take(ind)
16751678

1679+
def value_counts(self, dropna: bool = True) -> Series:
1680+
"""
1681+
Return a Series containing counts of unique values.
1682+
1683+
Parameters
1684+
----------
1685+
dropna : bool, default True
1686+
Don't include counts of NA values.
1687+
1688+
Returns
1689+
-------
1690+
Series
1691+
"""
1692+
from pandas.core.algorithms import value_counts_internal as value_counts
1693+
1694+
result = value_counts(self.to_numpy(copy=False), sort=False, dropna=dropna)
1695+
result.index = result.index.astype(self.dtype)
1696+
return result
1697+
16761698
# ------------------------------------------------------------------------
16771699
# Indexing methods
16781700
# ------------------------------------------------------------------------

pandas/core/arrays/interval.py

Lines changed: 0 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,6 @@
7575
isin,
7676
take,
7777
unique,
78-
value_counts_internal as value_counts,
7978
)
8079
from pandas.core.arrays import ArrowExtensionArray
8180
from pandas.core.arrays.base import (
@@ -105,7 +104,6 @@
105104

106105
from pandas import (
107106
Index,
108-
Series,
109107
)
110108

111109

@@ -1197,28 +1195,6 @@ def _validate_setitem_value(self, value):
11971195

11981196
return value_left, value_right
11991197

1200-
def value_counts(self, dropna: bool = True) -> Series:
1201-
"""
1202-
Returns a Series containing counts of each interval.
1203-
1204-
Parameters
1205-
----------
1206-
dropna : bool, default True
1207-
Don't include counts of NaN.
1208-
1209-
Returns
1210-
-------
1211-
counts : Series
1212-
1213-
See Also
1214-
--------
1215-
Series.value_counts
1216-
"""
1217-
# TODO: implement this is a non-naive way!
1218-
result = value_counts(np.asarray(self), dropna=dropna)
1219-
result.index = result.index.astype(self.dtype)
1220-
return result
1221-
12221198
# ---------------------------------------------------------------------
12231199
# Rendering Methods
12241200

pandas/core/arrays/string_.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1037,10 +1037,7 @@ def sum(
10371037
return self._wrap_reduction_result(axis, result)
10381038

10391039
def value_counts(self, dropna: bool = True) -> Series:
1040-
from pandas.core.algorithms import value_counts_internal as value_counts
1041-
1042-
result = value_counts(self._ndarray, sort=False, dropna=dropna)
1043-
result.index = result.index.astype(self.dtype)
1040+
result = super().value_counts(dropna=dropna)
10441041

10451042
if self.dtype.na_value is libmissing.NA:
10461043
result = result.astype("Int64")

pandas/tests/extension/decimal/array.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
is_scalar,
2626
)
2727
from pandas.core import arraylike
28-
from pandas.core.algorithms import value_counts_internal as value_counts
2928
from pandas.core.arraylike import OpsMixin
3029
from pandas.core.arrays import (
3130
ExtensionArray,
@@ -291,9 +290,6 @@ def convert_values(param):
291290

292291
return np.asarray(res, dtype=bool)
293292

294-
def value_counts(self, dropna: bool = True):
295-
return value_counts(self.to_numpy(), dropna=dropna)
296-
297293
# We override fillna here to simulate a 3rd party EA that has done so. This
298294
# lets us test a 3rd-party EA that has not yet updated to include a "copy"
299295
# keyword in its fillna method.

pandas/tests/extension/decimal/test_decimal.py

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -171,26 +171,6 @@ def test_fillna_limit_series(self, data_missing):
171171
):
172172
super().test_fillna_limit_series(data_missing)
173173

174-
@pytest.mark.parametrize("dropna", [True, False])
175-
def test_value_counts(self, all_data, dropna):
176-
all_data = all_data[:10]
177-
if dropna:
178-
other = np.array(all_data[~all_data.isna()])
179-
else:
180-
other = all_data
181-
182-
vcs = pd.Series(all_data).value_counts(dropna=dropna)
183-
vcs_ex = pd.Series(other).value_counts(dropna=dropna)
184-
185-
with decimal.localcontext() as ctx:
186-
# avoid raising when comparing Decimal("NAN") < Decimal(2)
187-
ctx.traps[decimal.InvalidOperation] = False
188-
189-
result = vcs.sort_index()
190-
expected = vcs_ex.sort_index()
191-
192-
tm.assert_series_equal(result, expected)
193-
194174
def test_series_repr(self, data):
195175
# Overriding this base test to explicitly test that
196176
# the custom _formatter is used

pandas/tests/extension/json/test_json.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -189,14 +189,12 @@ def test_ffill_limit_area(
189189
data_missing, limit_area, input_ilocs, expected_ilocs
190190
)
191191

192-
@unhashable
193-
def test_value_counts(self, all_data, dropna):
192+
def test_value_counts(self, all_data, dropna, request):
193+
if len(all_data) == 100 or dropna:
194+
mark = pytest.mark.xfail(reason="unhashable")
195+
request.applymarker(mark)
194196
super().test_value_counts(all_data, dropna)
195197

196-
@unhashable
197-
def test_value_counts_with_normalize(self, data):
198-
super().test_value_counts_with_normalize(data)
199-
200198
@unhashable
201199
def test_sort_values_frame(self):
202200
# TODO (EA.factorize): see if _values_for_factorize allows this.

0 commit comments

Comments
 (0)