Skip to content

Commit

Permalink
BUG: fix TypeErrors raised within _python_agg_general (#29425)
Browse files Browse the repository at this point in the history
  • Loading branch information
jbrockmendel authored and jreback committed Nov 6, 2019
1 parent b72f928 commit 08087d6
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 15 deletions.
15 changes: 13 additions & 2 deletions pandas/core/groupby/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -899,10 +899,21 @@ def _python_agg_general(self, func, *args, **kwargs):
output = {}
for name, obj in self._iterate_slices():
try:
result, counts = self.grouper.agg_series(obj, f)
# if this function is invalid for this dtype, we will ignore it.
func(obj[:0])
except TypeError:
continue
else:
except AssertionError:
raise
except Exception:
# Our function depends on having a non-empty argument
# See test_groupby_agg_err_catching
pass

result, counts = self.grouper.agg_series(obj, f)
if result is not None:
# TODO: only 3 test cases get None here, do something
# in those cases
output[name] = self._try_cast(result, obj, numeric_only=True)

if len(output) == 0:
Expand Down
36 changes: 23 additions & 13 deletions pandas/core/groupby/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,7 @@ class BaseGrouper:
Parameters
----------
axis : int
the axis to group
axis : Index
groupings : array of grouping
all the grouping instances to handle in this grouper
for example for grouper list to groupby, need to pass the list
Expand All @@ -78,8 +77,15 @@ class BaseGrouper:
"""

def __init__(
self, axis, groupings, sort=True, group_keys=True, mutated=False, indexer=None
self,
axis: Index,
groupings,
sort=True,
group_keys=True,
mutated=False,
indexer=None,
):
assert isinstance(axis, Index), axis
self._filter_empty_groups = self.compressed = len(groupings) != 1
self.axis = axis
self.groupings = groupings
Expand Down Expand Up @@ -623,7 +629,7 @@ def _aggregate_series_pure_python(self, obj, func):
counts = np.zeros(ngroups, dtype=int)
result = None

splitter = get_splitter(obj, group_index, ngroups, axis=self.axis)
splitter = get_splitter(obj, group_index, ngroups, axis=0)

for label, group in splitter:
res = func(group)
Expand All @@ -635,8 +641,12 @@ def _aggregate_series_pure_python(self, obj, func):
counts[label] = group.shape[0]
result[label] = res

result = lib.maybe_convert_objects(result, try_float=0)
# TODO: try_cast back to EA?
if result is not None:
# if splitter is empty, result can be None, in which case
# maybe_convert_objects would raise TypeError
result = lib.maybe_convert_objects(result, try_float=0)
# TODO: try_cast back to EA?

return result, counts


Expand Down Expand Up @@ -781,6 +791,11 @@ def groupings(self):
]

def agg_series(self, obj: Series, func):
if is_extension_array_dtype(obj.dtype):
# pre-empty SeriesBinGrouper from raising TypeError
# TODO: watch out, this can return None
return self._aggregate_series_pure_python(obj, func)

dummy = obj[:0]
grouper = libreduction.SeriesBinGrouper(obj, func, self.bins, dummy)
return grouper.get_result()
Expand Down Expand Up @@ -809,12 +824,13 @@ def _is_indexed_like(obj, axes) -> bool:


class DataSplitter:
def __init__(self, data, labels, ngroups, axis=0):
def __init__(self, data, labels, ngroups, axis: int = 0):
self.data = data
self.labels = ensure_int64(labels)
self.ngroups = ngroups

self.axis = axis
assert isinstance(axis, int), axis

@cache_readonly
def slabels(self):
Expand All @@ -837,12 +853,6 @@ def __iter__(self):
starts, ends = lib.generate_slices(self.slabels, self.ngroups)

for i, (start, end) in enumerate(zip(starts, ends)):
# Since I'm now compressing the group ids, it's now not "possible"
# to produce empty slices because such groups would not be observed
# in the data
# if start >= end:
# raise AssertionError('Start %s must be less than end %s'
# % (str(start), str(end)))
yield i, self._chop(sdata, slice(start, end))

def _get_sorted_data(self):
Expand Down
38 changes: 38 additions & 0 deletions pandas/tests/groupby/aggregate/test_other.py
Original file line number Diff line number Diff line change
Expand Up @@ -602,3 +602,41 @@ def test_agg_lambda_with_timezone():
columns=["date"],
)
tm.assert_frame_equal(result, expected)


@pytest.mark.parametrize(
"err_cls",
[
NotImplementedError,
RuntimeError,
KeyError,
IndexError,
OSError,
ValueError,
ArithmeticError,
AttributeError,
],
)
def test_groupby_agg_err_catching(err_cls):
# make sure we suppress anything other than TypeError or AssertionError
# in _python_agg_general

# Use a non-standard EA to make sure we don't go down ndarray paths
from pandas.tests.extension.decimal.array import DecimalArray, make_data, to_decimal

data = make_data()[:5]
df = pd.DataFrame(
{"id1": [0, 0, 0, 1, 1], "id2": [0, 1, 0, 1, 1], "decimals": DecimalArray(data)}
)

expected = pd.Series(to_decimal([data[0], data[3]]))

def weird_func(x):
# weird function that raise something other than TypeError or IndexError
# in _python_agg_general
if len(x) == 0:
raise err_cls
return x.iloc[0]

result = df["decimals"].groupby(df["id1"]).agg(weird_func)
tm.assert_series_equal(result, expected, check_names=False)

0 comments on commit 08087d6

Please sign in to comment.