@@ -653,37 +653,37 @@ def _iterate_slices(self):
653653 def transform (self , func , * args , ** kwargs ):
654654 raise AbstractMethodError (self )
655655
656- def _cumcount_array (self , arr = None , ascending = True ):
656+ def _cumcount_array (self , ascending = True ):
657657 """
658- arr is where cumcount gets its values from
658+ Parameters
659+ ----------
660+ ascending : bool, default True
661+ If False, number in reverse, from length of group - 1 to 0.
659662
660663 Note
661664 ----
662665 this is currently implementing sort=False
663666 (though the default is sort=True) for groupby in general
664667 """
665- if arr is None :
666- arr = np .arange (self .grouper ._max_groupsize , dtype = 'int64' )
667-
668- len_index = len (self ._selected_obj .index )
669- cumcounts = np .zeros (len_index , dtype = arr .dtype )
670- if not len_index :
671- return cumcounts
668+ ids , _ , ngroups = self .grouper .group_info
669+ sorter = _get_group_index_sorter (ids , ngroups )
670+ ids , count = ids [sorter ], len (ids )
672671
673- indices , values = [], []
674- for v in self .indices .values ():
675- indices .append (v )
672+ if count == 0 :
673+ return np .empty (0 , dtype = np .int64 )
676674
677- if ascending :
678- values .append (arr [:len (v )])
679- else :
680- values .append (arr [len (v ) - 1 ::- 1 ])
675+ run = np .r_ [True , ids [:- 1 ] != ids [1 :]]
676+ rep = np .diff (np .r_ [np .nonzero (run )[0 ], count ])
677+ out = (~ run ).cumsum ()
681678
682- indices = np .concatenate (indices )
683- values = np .concatenate (values )
684- cumcounts [indices ] = values
679+ if ascending :
680+ out -= np .repeat (out [run ], rep )
681+ else :
682+ out = np .repeat (out [np .r_ [run [1 :], True ]], rep ) - out
685683
686- return cumcounts
684+ rev = np .empty (count , dtype = np .intp )
685+ rev [sorter ] = np .arange (count , dtype = np .intp )
686+ return out [rev ].astype (np .int64 , copy = False )
687687
688688 def _index_with_as_index (self , b ):
689689 """
@@ -1170,47 +1170,21 @@ def nth(self, n, dropna=None):
11701170 else :
11711171 raise TypeError ("n needs to be an int or a list/set/tuple of ints" )
11721172
1173- m = self .grouper ._max_groupsize
1174- # filter out values that are outside [-m, m)
1175- pos_nth_values = [i for i in nth_values if i >= 0 and i < m ]
1176- neg_nth_values = [i for i in nth_values if i < 0 and i >= - m ]
1177-
1173+ nth_values = np .array (nth_values , dtype = np .intp )
11781174 self ._set_selection_from_grouper ()
1179- if not dropna : # good choice
1180- if not pos_nth_values and not neg_nth_values :
1181- # no valid nth values
1182- return self ._selected_obj .loc [[]]
1183-
1184- rng = np .zeros (m , dtype = bool )
1185- for i in pos_nth_values :
1186- rng [i ] = True
1187- is_nth = self ._cumcount_array (rng )
11881175
1189- if neg_nth_values :
1190- rng = np .zeros (m , dtype = bool )
1191- for i in neg_nth_values :
1192- rng [- i - 1 ] = True
1193- is_nth |= self ._cumcount_array (rng , ascending = False )
1176+ if not dropna :
1177+ mask = np .in1d (self ._cumcount_array (), nth_values ) | \
1178+ np .in1d (self ._cumcount_array (ascending = False ) + 1 , - nth_values )
11941179
1195- result = self ._selected_obj [is_nth ]
1180+ out = self ._selected_obj [mask ]
1181+ if not self .as_index :
1182+ return out
11961183
1197- # the result index
1198- if self .as_index :
1199- ax = self .obj ._info_axis
1200- names = self .grouper .names
1201- if self .obj .ndim == 1 :
1202- # this is a pass-thru
1203- pass
1204- elif all ([x in ax for x in names ]):
1205- indicies = [self .obj [name ][is_nth ] for name in names ]
1206- result .index = MultiIndex .from_arrays (
1207- indicies ).set_names (names )
1208- elif self ._group_selection is not None :
1209- result .index = self .obj ._get_axis (self .axis )[is_nth ]
1210-
1211- result = result .sort_index ()
1184+ ids , _ , _ = self .grouper .group_info
1185+ out .index = self .grouper .result_index [ids [mask ]]
12121186
1213- return result
1187+ return out . sort_index () if self . sort else out
12141188
12151189 if isinstance (self ._selected_obj , DataFrame ) and \
12161190 dropna not in ['any' , 'all' ]:
@@ -1241,8 +1215,8 @@ def nth(self, n, dropna=None):
12411215 axis = self .axis , level = self .level ,
12421216 sort = self .sort )
12431217
1244- sizes = dropped .groupby (grouper ). size ( )
1245- result = dropped . groupby ( grouper ) .nth (n )
1218+ grb = dropped .groupby (grouper , as_index = self . as_index , sort = self . sort )
1219+ sizes , result = grb . size (), grb .nth (n )
12461220 mask = (sizes < max_len ).values
12471221
12481222 # set the results which don't meet the criteria
@@ -1380,11 +1354,8 @@ def head(self, n=5):
13801354 0 1 2
13811355 2 5 6
13821356 """
1383-
1384- obj = self ._selected_obj
1385- in_head = self ._cumcount_array () < n
1386- head = obj [in_head ]
1387- return head
1357+ mask = self ._cumcount_array () < n
1358+ return self ._selected_obj [mask ]
13881359
13891360 @Substitution (name = 'groupby' )
13901361 @Appender (_doc_template )
@@ -1409,12 +1380,8 @@ def tail(self, n=5):
14091380 0 a 1
14101381 2 b 1
14111382 """
1412-
1413- obj = self ._selected_obj
1414- rng = np .arange (0 , - self .grouper ._max_groupsize , - 1 , dtype = 'int64' )
1415- in_tail = self ._cumcount_array (rng , ascending = False ) > - n
1416- tail = obj [in_tail ]
1417- return tail
1383+ mask = self ._cumcount_array (ascending = False ) < n
1384+ return self ._selected_obj [mask ]
14181385
14191386
14201387@Appender (GroupBy .__doc__ )
0 commit comments