1010from functools import partial
1111from textwrap import dedent
1212import typing
13- from typing import Any , Callable , FrozenSet , Iterable , Sequence , Type , Union , cast
13+ from typing import (
14+ Any ,
15+ Callable ,
16+ FrozenSet ,
17+ Iterable ,
18+ Mapping ,
19+ Sequence ,
20+ Type ,
21+ Union ,
22+ cast ,
23+ )
1424
1525import numpy as np
1626
@@ -309,28 +319,91 @@ def _aggregate_multiple_funcs(self, arg):
309319
310320 return DataFrame (results , columns = columns )
311321
312- def _wrap_series_output (self , output , index , names = None ):
313- """ common agg/transform wrapping logic """
314- output = output [self ._selection_name ]
322+ def _wrap_series_output (
323+ self , output : Mapping [base .OutputKey , Union [Series , np .ndarray ]], index : Index ,
324+ ) -> Union [Series , DataFrame ]:
325+ """
326+ Wraps the output of a SeriesGroupBy operation into the expected result.
327+
328+ Parameters
329+ ----------
330+ output : Mapping[base.OutputKey, Union[Series, np.ndarray]]
331+ Data to wrap.
332+ index : pd.Index
333+ Index to apply to the output.
315334
316- if names is not None :
317- return DataFrame (output , index = index , columns = names )
335+ Returns
336+ -------
337+ Series or DataFrame
338+
339+ Notes
340+ -----
341+ In the vast majority of cases output and columns will only contain one
342+ element. The exception is operations that expand dimensions, like ohlc.
343+ """
344+ indexed_output = {key .position : val for key , val in output .items ()}
345+ columns = Index (key .label for key in output )
346+
347+ result : Union [Series , DataFrame ]
348+ if len (output ) > 1 :
349+ result = DataFrame (indexed_output , index = index )
350+ result .columns = columns
318351 else :
319- name = self ._selection_name
320- if name is None :
321- name = self ._selected_obj .name
322- return Series (output , index = index , name = name )
352+ result = Series (indexed_output [0 ], index = index , name = columns [0 ])
353+
354+ return result
355+
356+ def _wrap_aggregated_output (
357+ self , output : Mapping [base .OutputKey , Union [Series , np .ndarray ]]
358+ ) -> Union [Series , DataFrame ]:
359+ """
360+ Wraps the output of a SeriesGroupBy aggregation into the expected result.
323361
324- def _wrap_aggregated_output (self , output , names = None ):
362+ Parameters
363+ ----------
364+ output : Mapping[base.OutputKey, Union[Series, np.ndarray]]
365+ Data to wrap.
366+
367+ Returns
368+ -------
369+ Series or DataFrame
370+
371+ Notes
372+ -----
373+ In the vast majority of cases output will only contain one element.
374+ The exception is operations that expand dimensions, like ohlc.
375+ """
325376 result = self ._wrap_series_output (
326- output = output , index = self .grouper .result_index , names = names
377+ output = output , index = self .grouper .result_index
327378 )
328379 return self ._reindex_output (result )._convert (datetime = True )
329380
330- def _wrap_transformed_output (self , output , names = None ):
331- return self ._wrap_series_output (
332- output = output , index = self .obj .index , names = names
333- )
381+ def _wrap_transformed_output (
382+ self , output : Mapping [base .OutputKey , Union [Series , np .ndarray ]]
383+ ) -> Series :
384+ """
385+ Wraps the output of a SeriesGroupBy aggregation into the expected result.
386+
387+ Parameters
388+ ----------
389+ output : dict[base.OutputKey, Union[Series, np.ndarray]]
390+ Dict with a sole key of 0 and a value of the result values.
391+
392+ Returns
393+ -------
394+ Series
395+
396+ Notes
397+ -----
398+ output should always contain one element. It is specified as a dict
399+ for consistency with DataFrame methods and _wrap_aggregated_output.
400+ """
401+ assert len (output ) == 1
402+ result = self ._wrap_series_output (output = output , index = self .obj .index )
403+
404+ # No transformations increase the ndim of the result
405+ assert isinstance (result , Series )
406+ return result
334407
335408 def _wrap_applied_output (self , keys , values , not_indexed_same = False ):
336409 if len (keys ) == 0 :
@@ -1084,17 +1157,6 @@ def _aggregate_item_by_item(self, func, *args, **kwargs) -> DataFrame:
10841157
10851158 return DataFrame (result , columns = result_columns )
10861159
1087- def _decide_output_index (self , output , labels ):
1088- if len (output ) == len (labels ):
1089- output_keys = labels
1090- else :
1091- output_keys = sorted (output )
1092-
1093- if isinstance (labels , MultiIndex ):
1094- output_keys = MultiIndex .from_tuples (output_keys , names = labels .names )
1095-
1096- return output_keys
1097-
10981160 def _wrap_applied_output (self , keys , values , not_indexed_same = False ):
10991161 if len (keys ) == 0 :
11001162 return DataFrame (index = keys )
@@ -1561,27 +1623,62 @@ def _insert_inaxis_grouper_inplace(self, result):
15611623 if in_axis :
15621624 result .insert (0 , name , lev )
15631625
1564- def _wrap_aggregated_output (self , output , names = None ):
1565- agg_axis = 0 if self .axis == 1 else 1
1566- agg_labels = self ._obj_with_exclusions ._get_axis (agg_axis )
1626+ def _wrap_aggregated_output (
1627+ self , output : Mapping [base .OutputKey , Union [Series , np .ndarray ]]
1628+ ) -> DataFrame :
1629+ """
1630+ Wraps the output of DataFrameGroupBy aggregations into the expected result.
15671631
1568- output_keys = self ._decide_output_index (output , agg_labels )
1632+ Parameters
1633+ ----------
1634+ output : Mapping[base.OutputKey, Union[Series, np.ndarray]]
1635+ Data to wrap.
1636+
1637+ Returns
1638+ -------
1639+ DataFrame
1640+ """
1641+ indexed_output = {key .position : val for key , val in output .items ()}
1642+ columns = Index (key .label for key in output )
1643+
1644+ result = DataFrame (indexed_output )
1645+ result .columns = columns
15691646
15701647 if not self .as_index :
1571- result = DataFrame (output , columns = output_keys )
15721648 self ._insert_inaxis_grouper_inplace (result )
15731649 result = result ._consolidate ()
15741650 else :
15751651 index = self .grouper .result_index
1576- result = DataFrame ( output , index = index , columns = output_keys )
1652+ result . index = index
15771653
15781654 if self .axis == 1 :
15791655 result = result .T
15801656
15811657 return self ._reindex_output (result )._convert (datetime = True )
15821658
1583- def _wrap_transformed_output (self , output , names = None ) -> DataFrame :
1584- return DataFrame (output , index = self .obj .index )
1659+ def _wrap_transformed_output (
1660+ self , output : Mapping [base .OutputKey , Union [Series , np .ndarray ]]
1661+ ) -> DataFrame :
1662+ """
1663+ Wraps the output of DataFrameGroupBy transformations into the expected result.
1664+
1665+ Parameters
1666+ ----------
1667+ output : Mapping[base.OutputKey, Union[Series, np.ndarray]]
1668+ Data to wrap.
1669+
1670+ Returns
1671+ -------
1672+ DataFrame
1673+ """
1674+ indexed_output = {key .position : val for key , val in output .items ()}
1675+ columns = Index (key .label for key in output )
1676+
1677+ result = DataFrame (indexed_output )
1678+ result .columns = columns
1679+ result .index = self .obj .index
1680+
1681+ return result
15851682
15861683 def _wrap_agged_blocks (self , items , blocks ):
15871684 if not self .as_index :
@@ -1701,9 +1798,11 @@ def groupby_series(obj, col=None):
17011798 if isinstance (obj , Series ):
17021799 results = groupby_series (obj )
17031800 else :
1801+ # TODO: this is duplicative of how GroupBy naturally works
1802+ # Try to consolidate with normal wrapping functions
17041803 from pandas .core .reshape .concat import concat
17051804
1706- results = [groupby_series (obj [ col ], col ) for col in obj .columns ]
1805+ results = [groupby_series (content , label ) for label , content in obj .items () ]
17071806 results = concat (results , axis = 1 )
17081807 results .columns .names = obj .columns .names
17091808
@@ -1745,7 +1844,7 @@ def _normalize_keyword_aggregation(kwargs):
17451844 """
17461845 Normalize user-provided "named aggregation" kwargs.
17471846
1748- Transforms from the new ``Dict [str, NamedAgg]`` style kwargs
1847+ Transforms from the new ``Mapping [str, NamedAgg]`` style kwargs
17491848 to the old OrderedDict[str, List[scalar]]].
17501849
17511850 Parameters
@@ -1766,7 +1865,7 @@ def _normalize_keyword_aggregation(kwargs):
17661865 >>> _normalize_keyword_aggregation({'output': ('input', 'sum')})
17671866 (OrderedDict([('input', ['sum'])]), ('output',), [('input', 'sum')])
17681867 """
1769- # Normalize the aggregation functions as Dict [column, List[func]],
1868+ # Normalize the aggregation functions as Mapping [column, List[func]],
17701869 # process normally, then fixup the names.
17711870 # TODO(Py35): When we drop python 3.5, change this to
17721871 # defaultdict(list)
0 commit comments