diff --git a/pandas/core/groupby/groupby.py b/pandas/core/groupby/groupby.py index 3f84fa0f0670e..b93a6a0ff9b11 100644 --- a/pandas/core/groupby/groupby.py +++ b/pandas/core/groupby/groupby.py @@ -1174,6 +1174,26 @@ def var(self, ddof=1, *args, **kwargs): with _group_selection_context(self): return self._python_agg_general(f) + @Substitution(name='groupby') + @Appender(_doc_template) + def mad(self, skipna=True): + if not skipna: + raise NotImplementedError("'skipna=False' not yet implemented") + + if self.axis != 0: + return self.apply(lambda x: x.mad(axis=self.axis)) + + # Wrap in a try..except to catch a TypeError with bool data + # Ideally this would be implemented in `mean` instead of here + try: + demeaned = np.abs(self.shift(0) - self.transform('mean')) + result = demeaned.groupby(self.grouper.labels).mean() + result.index = self.grouper.result_index + except TypeError: + raise DataError('No numeric types to aggregate') + + return result + @Substitution(name='groupby') @Appender(_doc_template) def sem(self, ddof=1): diff --git a/pandas/tests/groupby/test_function.py b/pandas/tests/groupby/test_function.py index f8a0f1688c64e..69b87b0689f23 100644 --- a/pandas/tests/groupby/test_function.py +++ b/pandas/tests/groupby/test_function.py @@ -4,6 +4,7 @@ import pandas as pd from pandas import (DataFrame, Index, compat, isna, Series, MultiIndex, Timestamp, date_range) +from pandas.core.base import DataError from pandas.errors import UnsupportedFunctionCall from pandas.util import testing as tm import pandas.core.nanops as nanops @@ -231,17 +232,6 @@ def test_non_cython_api(): g = df.groupby('A') gni = df.groupby('A', as_index=False) - # mad - expected = DataFrame([[0], [np.nan]], columns=['B'], index=[1, 3]) - expected.index.name = 'A' - result = g.mad() - tm.assert_frame_equal(result, expected) - - expected = DataFrame([[0., 0.], [0, np.nan]], columns=['A', 'B'], - index=[0, 1]) - result = gni.mad() - tm.assert_frame_equal(result, expected) - # describe expected_index = pd.Index([1, 3], name='A') expected_col = pd.MultiIndex(levels=[['B'], @@ -481,6 +471,55 @@ def test_max_nan_bug(): assert not r['File'].isna().any() +@pytest.mark.parametrize("klass", [Series, DataFrame]) +@pytest.mark.parametrize("test_mi", [True, False]) +@pytest.mark.parametrize("dtype", ['int', 'float']) +def test_groupby_mad(klass, test_mi, dtype): + vals = np.array(range(10)).astype(dtype) + df = DataFrame({'key': ['a'] * 5 + ['b'] * 5, 'val': vals}) + + idx = pd.Index(['a', 'b'], name='key') + exp = klass([1.2, 1.2], index=idx) + grping = ['key'] + + if test_mi: + df = df.append(df) # Double the size of the frame + df['newcol'] = ['foo'] * 10 + ['bar'] * 10 + grping.append('newcol') + + mi = pd.MultiIndex.from_product((exp.index.values, + ['bar', 'foo']), + names=['key', 'newcol']) + exp = exp.append(exp) + exp.index = mi + + if klass is Series: + exp.name = 'val' + result = df.groupby(grping)['val'].mad() + tm.assert_series_equal(result, exp) + else: + exp = exp.rename(columns={0: 'val'}) + result = df.groupby(grping).mad() + tm.assert_frame_equal(result, exp) + + +@pytest.mark.parametrize("vals", [ + ['foo'] * 10, [True] * 10]) +def test_groupby_mad_raises(vals): + df = DataFrame({'key': ['a'] * 5 + ['b'] * 5, 'val': vals}) + + with tm.assert_raises_regex(DataError, + "No numeric types to aggregate"): + df.groupby('key').mad() + + +def test_groupby_mad_skipna(): + df = DataFrame({'key': ['a'] * 5 + ['b'] * 5, 'val': range(10)}) + with tm.assert_raises_regex( + NotImplementedError, "'skipna=False' not yet implemented"): + df.groupby('key').mad(skipna=False) + + def test_nlargest(): a = Series([1, 3, 5, 7, 2, 9, 0, 4, 6, 10]) b = Series(list('a' * 5 + 'b' * 5)) diff --git a/pandas/tests/groupby/test_groupby.py b/pandas/tests/groupby/test_groupby.py index 9affd0241d028..ca3ed354809a1 100644 --- a/pandas/tests/groupby/test_groupby.py +++ b/pandas/tests/groupby/test_groupby.py @@ -561,6 +561,7 @@ def test_groupby_as_index_agg(df): with tm.assert_produces_warning(FutureWarning, check_stacklevel=False): result3 = grouped['C'].agg({'Q': np.sum}) + assert_frame_equal(result3, expected3) # multi-key diff --git a/pandas/tests/groupby/test_whitelist.py b/pandas/tests/groupby/test_whitelist.py index 3afc278f9bc93..f336dbb4a5f96 100644 --- a/pandas/tests/groupby/test_whitelist.py +++ b/pandas/tests/groupby/test_whitelist.py @@ -11,7 +11,7 @@ AGG_FUNCTIONS = ['sum', 'prod', 'min', 'max', 'median', 'mean', 'skew', 'mad', 'std', 'var', 'sem'] -AGG_FUNCTIONS_WITH_SKIPNA = ['skew', 'mad'] +AGG_FUNCTIONS_WITH_SKIPNA = ['skew'] df_whitelist = [ 'last',