Skip to content

BUG: GH10355 groupby std() no longer sqrts grouping cols #11300

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 15 additions & 2 deletions pandas/core/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -789,7 +789,12 @@ def std(self, ddof=1):
For multiple groupings, the result index will be a MultiIndex
"""
# todo, implement at cython level?
return np.sqrt(self.var(ddof=ddof))
if ddof == 1:
return self._cython_agg_general('std')
else:
self._set_selection_from_grouper()
f = lambda x: x.std(ddof=ddof)
return self._python_agg_general(f)

def var(self, ddof=1):
"""
Expand Down Expand Up @@ -1467,6 +1472,10 @@ def get_group_levels(self):
#------------------------------------------------------------
# Aggregation functions

def _cython_std(group_var, out, b, c, d):
group_var(out, b, c, d)
out **= 0.5 # needs to be applied in place

_cython_functions = {
'add': 'group_add',
'prod': 'group_prod',
Expand All @@ -1477,6 +1486,10 @@ def get_group_levels(self):
'name': 'group_median'
},
'var': 'group_var',
'std': {
'name': 'group_var',
'f': _cython_std,
},
'first': {
'name': 'group_nth',
'f': lambda func, a, b, c, d: func(a, b, c, d, 1)
Expand Down Expand Up @@ -1512,7 +1525,7 @@ def get_func(fname):

# a sub-function
f = ftype.get('f')
if f is not None:
if f is not None and afunc is not None:

def wrapper(*args, **kwargs):
return f(afunc, *args, **kwargs)
Expand Down
30 changes: 30 additions & 0 deletions pandas/tests/test_groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -5545,6 +5545,36 @@ def test_nunique_with_object(self):
expected = pd.Series([1] * 5, name='name', index=index)
tm.assert_series_equal(result, expected)

def test_std_with_as_index_false(self):
# GH 10355
df = pd.DataFrame({
'a': [1, 1, 1, 2, 2, 2, 3, 3, 3],
'b': [1, 2, 3, 4, 5, 6, 7, 8, 9],
})
sd = df.groupby('a', as_index=False).std()

expected = pd.DataFrame({
'a': [1, 2, 3],
'b': [1, 1, 1],
})
tm.assert_frame_equal(expected, sd)

def test_std_with_ddof(self):
df = pd.DataFrame({
'a': [1, 1, 1, 2, 2, 2, 3, 3, 3],
'b': [1, 2, 3, 1, 5, 6, 7, 8, 10],
})
sd = df.groupby('a', as_index=False).std(ddof=0)

expected = pd.DataFrame({
'a': [1, 2, 3],
'b': [
np.std([1, 2, 3], ddof=0),
np.std([1, 5, 6], ddof=0),
np.std([7, 8, 10], ddof=0)],
})
tm.assert_frame_equal(expected, sd)


def assert_fp_equal(a, b):
assert (np.abs(a - b) < 1e-12).all()
Expand Down