Skip to content

Commit

Permalink
BUG: preserve categorical & sparse types when grouping / pivot (#27071)
Browse files Browse the repository at this point in the history
  • Loading branch information
jreback authored Jun 27, 2019
1 parent de0867f commit ce86c21
Show file tree
Hide file tree
Showing 15 changed files with 205 additions and 71 deletions.
31 changes: 30 additions & 1 deletion doc/source/whatsnew/v0.25.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,35 @@ of ``object`` dtype. :attr:`Series.str` will now infer the dtype data *within* t
s
s.str.startswith(b'a')
.. _whatsnew_0250.api_breaking.groupby_categorical:

Categorical dtypes are preserved during groupby
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Previously, columns that were categorical, but not the groupby key(s) would be converted to ``object`` dtype during groupby operations. Pandas now will preserve these dtypes. (:issue:`18502`)

.. ipython:: python
df = pd.DataFrame(
{'payload': [-1, -2, -1, -2],
'col': pd.Categorical(["foo", "bar", "bar", "qux"], ordered=True)})
df
df.dtypes
*Previous Behavior*:

.. code-block:: python
In [5]: df.groupby('payload').first().col.dtype
Out[5]: dtype('O')
*New Behavior*:

.. ipython:: python
df.groupby('payload').first().col.dtype
.. _whatsnew_0250.api_breaking.incompatible_index_unions:

Incompatible Index type unions
Expand Down Expand Up @@ -809,7 +838,7 @@ ExtensionArray

- Bug in :func:`factorize` when passing an ``ExtensionArray`` with a custom ``na_sentinel`` (:issue:`25696`).
- :meth:`Series.count` miscounts NA values in ExtensionArrays (:issue:`26835`)
- Keyword argument ``deep`` has been removed from :method:`ExtensionArray.copy` (:issue:`27083`)
- Keyword argument ``deep`` has been removed from :meth:`ExtensionArray.copy` (:issue:`27083`)

Other
^^^^^
Expand Down
11 changes: 9 additions & 2 deletions pandas/core/groupby/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,12 +158,19 @@ def _cython_agg_blocks(self, how, alt=None, numeric_only=True,

obj = self.obj[data.items[locs]]
s = groupby(obj, self.grouper)
result = s.aggregate(lambda x: alt(x, axis=self.axis))
try:
result = s.aggregate(lambda x: alt(x, axis=self.axis))
except TypeError:
# we may have an exception in trying to aggregate
# continue and exclude the block
pass

finally:

dtype = block.values.dtype

# see if we can cast the block back to the original dtype
result = block._try_coerce_and_cast_result(result)
result = block._try_coerce_and_cast_result(result, dtype=dtype)
newb = block.make_block(result)

new_items.append(locs)
Expand Down
42 changes: 32 additions & 10 deletions pandas/core/groupby/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -786,6 +786,8 @@ def _try_cast(self, result, obj, numeric_only=False):
elif is_extension_array_dtype(dtype):
# The function can return something of any type, so check
# if the type is compatible with the calling EA.

# return the same type (Series) as our caller
try:
result = obj._values._from_sequence(result, dtype=dtype)
except Exception:
Expand Down Expand Up @@ -1157,7 +1159,8 @@ def mean(self, *args, **kwargs):
"""
nv.validate_groupby_func('mean', args, kwargs, ['numeric_only'])
try:
return self._cython_agg_general('mean', **kwargs)
return self._cython_agg_general(
'mean', alt=lambda x, axis: Series(x).mean(**kwargs), **kwargs)
except GroupByError:
raise
except Exception: # pragma: no cover
Expand All @@ -1179,7 +1182,11 @@ def median(self, **kwargs):
Median of values within each group.
"""
try:
return self._cython_agg_general('median', **kwargs)
return self._cython_agg_general(
'median',
alt=lambda x,
axis: Series(x).median(axis=axis, **kwargs),
**kwargs)
except GroupByError:
raise
except Exception: # pragma: no cover
Expand Down Expand Up @@ -1235,7 +1242,10 @@ def var(self, ddof=1, *args, **kwargs):
nv.validate_groupby_func('var', args, kwargs)
if ddof == 1:
try:
return self._cython_agg_general('var', **kwargs)
return self._cython_agg_general(
'var',
alt=lambda x, axis: Series(x).var(ddof=ddof, **kwargs),
**kwargs)
except Exception:
f = lambda x: x.var(ddof=ddof, **kwargs)
with _group_selection_context(self):
Expand Down Expand Up @@ -1263,7 +1273,6 @@ def sem(self, ddof=1):
Series or DataFrame
Standard error of the mean of values within each group.
"""

return self.std(ddof=ddof) / np.sqrt(self.count())

@Substitution(name='groupby')
Expand All @@ -1290,7 +1299,7 @@ def _add_numeric_operations(cls):
"""

def groupby_function(name, alias, npfunc,
numeric_only=True, _convert=False,
numeric_only=True,
min_count=-1):

_local_template = """
Expand All @@ -1312,17 +1321,30 @@ def f(self, **kwargs):
kwargs['min_count'] = min_count

self._set_group_selection()

# try a cython aggregation if we can
try:
return self._cython_agg_general(
alias, alt=npfunc, **kwargs)
except AssertionError as e:
raise SpecificationError(str(e))
except Exception:
result = self.aggregate(
lambda x: npfunc(x, axis=self.axis))
if _convert:
result = result._convert(datetime=True)
return result
pass

# apply a non-cython aggregation
result = self.aggregate(
lambda x: npfunc(x, axis=self.axis))

# coerce the resulting columns if we can
if isinstance(result, DataFrame):
for col in result.columns:
result[col] = self._try_cast(
result[col], self.obj[col])
else:
result = self._try_cast(
result, self.obj)

return result

set_function_name(f, name, cls)

Expand Down
6 changes: 3 additions & 3 deletions pandas/core/groupby/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from pandas.core.dtypes.common import (
ensure_float64, ensure_int64, ensure_int_or_float, ensure_object,
ensure_platform_int, is_bool_dtype, is_categorical_dtype, is_complex_dtype,
is_datetime64_any_dtype, is_integer_dtype, is_numeric_dtype,
is_datetime64_any_dtype, is_integer_dtype, is_numeric_dtype, is_sparse,
is_timedelta64_dtype, needs_i8_conversion)
from pandas.core.dtypes.missing import _maybe_fill, isna

Expand Down Expand Up @@ -451,9 +451,9 @@ def _cython_operation(self, kind, values, how, axis, min_count=-1,

# categoricals are only 1d, so we
# are not setup for dim transforming
if is_categorical_dtype(values):
if is_categorical_dtype(values) or is_sparse(values):
raise NotImplementedError(
"categoricals are not support in cython ops ATM")
"{} are not support in cython ops".format(values.dtype))
elif is_datetime64_any_dtype(values):
if how in ['add', 'prod', 'cumsum', 'cumprod']:
raise NotImplementedError(
Expand Down
24 changes: 23 additions & 1 deletion pandas/core/internals/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -594,7 +594,8 @@ def _astype(self, dtype, copy=False, errors='raise', values=None,
values = self.get_values(dtype=dtype)

# _astype_nansafe works fine with 1-d only
values = astype_nansafe(values.ravel(), dtype, copy=True)
values = astype_nansafe(
values.ravel(), dtype, copy=True, **kwargs)

# TODO(extension)
# should we make this attribute?
Expand Down Expand Up @@ -1746,6 +1747,27 @@ def _slice(self, slicer):

return self.values[slicer]

def _try_cast_result(self, result, dtype=None):
"""
if we have an operation that operates on for example floats
we want to try to cast back to our EA here if possible
result could be a 2-D numpy array, e.g. the result of
a numeric operation; but it must be shape (1, X) because
we by-definition operate on the ExtensionBlocks one-by-one
result could also be an EA Array itself, in which case it
is already a 1-D array
"""
try:

result = self._holder._from_sequence(
result.ravel(), dtype=dtype)
except Exception:
pass

return result

def formatting_values(self):
# Deprecating the ability to override _formatting_values.
# Do the warning here, it's only user in pandas, since we
Expand Down
5 changes: 4 additions & 1 deletion pandas/core/internals/construction.py
Original file line number Diff line number Diff line change
Expand Up @@ -687,7 +687,10 @@ def sanitize_array(data, index, dtype=None, copy=False,
data = np.array(data, dtype=dtype, copy=False)
subarr = np.array(data, dtype=object, copy=copy)

if is_object_dtype(subarr.dtype) and dtype != 'object':
if (not (is_extension_array_dtype(subarr.dtype) or
is_extension_array_dtype(dtype)) and
is_object_dtype(subarr.dtype) and
not is_object_dtype(dtype)):
inferred = lib.infer_dtype(subarr, skipna=False)
if inferred == 'period':
try:
Expand Down
9 changes: 5 additions & 4 deletions pandas/core/nanops.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,11 +72,12 @@ def _f(*args, **kwargs):

class bottleneck_switch:

def __init__(self, **kwargs):
def __init__(self, name=None, **kwargs):
self.name = name
self.kwargs = kwargs

def __call__(self, alt):
bn_name = alt.__name__
bn_name = self.name or alt.__name__

try:
bn_func = getattr(bn, bn_name)
Expand Down Expand Up @@ -804,7 +805,8 @@ def nansem(values, axis=None, skipna=True, ddof=1, mask=None):


def _nanminmax(meth, fill_value_typ):
@bottleneck_switch()

@bottleneck_switch(name='nan' + meth)
def reduction(values, axis=None, skipna=True, mask=None):

values, mask, dtype, dtype_max, fill_value = _get_values(
Expand All @@ -824,7 +826,6 @@ def reduction(values, axis=None, skipna=True, mask=None):
result = _wrap_results(result, dtype, fill_value)
return _maybe_null_out(result, axis, mask, values.shape)

reduction.__name__ = 'nan' + meth
return reduction


Expand Down
12 changes: 12 additions & 0 deletions pandas/tests/extension/base/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,18 @@ def test_groupby_extension_apply(
df.groupby("A").apply(groupby_apply_op)
df.groupby("A").B.apply(groupby_apply_op)

def test_groupby_apply_identity(self, data_for_grouping):
df = pd.DataFrame({"A": [1, 1, 2, 2, 3, 3, 1, 4],
"B": data_for_grouping})
result = df.groupby('A').B.apply(lambda x: x.array)
expected = pd.Series([df.B.iloc[[0, 1, 6]].array,
df.B.iloc[[2, 3]].array,
df.B.iloc[[4, 5]].array,
df.B.iloc[[7]].array],
index=pd.Index([1, 2, 3, 4], name='A'),
name='B')
self.assert_series_equal(result, expected)

def test_in_numeric_groupby(self, data_for_grouping):
df = pd.DataFrame({"A": [1, 1, 2, 2, 3, 3, 1, 4],
"B": data_for_grouping,
Expand Down
6 changes: 5 additions & 1 deletion pandas/tests/extension/decimal/test_decimal.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,11 @@ class TestCasting(BaseDecimal, base.BaseCastingTests):


class TestGroupby(BaseDecimal, base.BaseGroupbyTests):
pass

@pytest.mark.xfail(
reason="needs to correctly define __eq__ to handle nans, xref #27081.")
def test_groupby_apply_identity(self, data_for_grouping):
super().test_groupby_apply_identity(data_for_grouping)


class TestSetitem(BaseDecimal, base.BaseSetitemTests):
Expand Down
21 changes: 21 additions & 0 deletions pandas/tests/groupby/test_categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -697,6 +697,27 @@ def test_preserve_categorical_dtype():
tm.assert_frame_equal(result2, expected)


@pytest.mark.parametrize(
'func, values',
[('first', ['second', 'first']),
('last', ['fourth', 'third']),
('min', ['fourth', 'first']),
('max', ['second', 'third'])])
def test_preserve_on_ordered_ops(func, values):
# gh-18502
# preserve the categoricals on ops
c = pd.Categorical(['first', 'second', 'third', 'fourth'], ordered=True)
df = pd.DataFrame(
{'payload': [-1, -2, -1, -2],
'col': c})
g = df.groupby('payload')
result = getattr(g, func)()
expected = pd.DataFrame(
{'payload': [-2, -1],
'col': pd.Series(values, dtype=c.dtype)}).set_index('payload')
tm.assert_frame_equal(result, expected)


def test_categorical_no_compress():
data = Series(np.random.randn(9))

Expand Down
Loading

0 comments on commit ce86c21

Please sign in to comment.