@@ -166,6 +166,7 @@ impl AggregateUDFImpl for FirstValue {
166166 }
167167
168168 fn groups_accumulator_supported ( & self , args : AccumulatorArgs ) -> bool {
169+ // TODO: extract to function
169170 use DataType :: * ;
170171 matches ! (
171172 args. return_type,
@@ -193,6 +194,7 @@ impl AggregateUDFImpl for FirstValue {
193194 & self ,
194195 args : AccumulatorArgs ,
195196 ) -> Result < Box < dyn GroupsAccumulator > > {
197+ // TODO: extract to function
196198 fn create_accumulator < T > (
197199 args : AccumulatorArgs ,
198200 ) -> Result < Box < dyn GroupsAccumulator > >
@@ -210,6 +212,7 @@ impl AggregateUDFImpl for FirstValue {
210212 args. ignore_nulls ,
211213 args. return_type ,
212214 & ordering_dtypes,
215+ true ,
213216 ) ?) )
214217 }
215218
@@ -258,10 +261,12 @@ impl AggregateUDFImpl for FirstValue {
258261 create_accumulator :: < Time64NanosecondType > ( args)
259262 }
260263
261- _ => internal_err ! (
262- "GroupsAccumulator not supported for first({})" ,
263- args. return_type
264- ) ,
264+ _ => {
265+ internal_err ! (
266+ "GroupsAccumulator not supported for first_value({})" ,
267+ args. return_type
268+ )
269+ }
265270 }
266271 }
267272
@@ -291,6 +296,7 @@ impl AggregateUDFImpl for FirstValue {
291296 }
292297}
293298
299+ // TODO: rename to PrimitiveGroupsAccumulator
294300struct FirstPrimitiveGroupsAccumulator < T >
295301where
296302 T : ArrowPrimitiveType + Send ,
@@ -316,12 +322,16 @@ where
316322 // buffer for `get_filtered_min_of_each_group`
317323 // filter_min_of_each_group_buf.0[group_idx] -> idx_in_val
318324 // only valid if filter_min_of_each_group_buf.1[group_idx] == true
325+ // TODO: rename to extreme_of_each_group_buf
319326 min_of_each_group_buf : ( Vec < usize > , BooleanBufferBuilder ) ,
320327
321328 // =========== option ============
322329
323330 // Stores the applicable ordering requirement.
324331 ordering_req : LexOrdering ,
332+ // true: take first element in an aggregation group according to the requested ordering.
333+ // false: take last element in an aggregation group according to the requested ordering.
334+ pick_first_in_group : bool ,
325335 // derived from `ordering_req`.
326336 sort_options : Vec < SortOptions > ,
327337 // Stores whether incoming data already satisfies the ordering requirement.
@@ -342,6 +352,7 @@ where
342352 ignore_nulls : bool ,
343353 data_type : & DataType ,
344354 ordering_dtypes : & [ DataType ] ,
355+ pick_first_in_group : bool ,
345356 ) -> Result < Self > {
346357 let requirement_satisfied = ordering_req. is_empty ( ) ;
347358
@@ -365,6 +376,7 @@ where
365376 is_sets : BooleanBufferBuilder :: new ( 0 ) ,
366377 size_of_orderings : 0 ,
367378 min_of_each_group_buf : ( Vec :: new ( ) , BooleanBufferBuilder :: new ( 0 ) ) ,
379+ pick_first_in_group,
368380 } )
369381 }
370382
@@ -391,8 +403,13 @@ where
391403
392404 assert ! ( new_ordering_values. len( ) == self . ordering_req. len( ) ) ;
393405 let current_ordering = & self . orderings [ group_idx] ;
394- compare_rows ( current_ordering, new_ordering_values, & self . sort_options )
395- . map ( |x| x. is_gt ( ) )
406+ compare_rows ( current_ordering, new_ordering_values, & self . sort_options ) . map ( |x| {
407+ if self . pick_first_in_group {
408+ x. is_gt ( )
409+ } else {
410+ x. is_lt ( )
411+ }
412+ } )
396413 }
397414
398415 fn take_orderings ( & mut self , emit_to : EmitTo ) -> Vec < Vec < ScalarValue > > {
@@ -501,10 +518,10 @@ where
501518 . map ( ScalarValue :: size_of_vec)
502519 . sum :: < usize > ( )
503520 }
504-
505521 /// Returns a vector of tuples `(group_idx, idx_in_val)` representing the index of the
506522 /// minimum value in `orderings` for each group, using lexicographical comparison.
507523 /// Values are filtered using `opt_filter` and `is_set_arr` if provided.
524+ /// TODO: rename to get_filtered_extreme_of_each_group
508525 fn get_filtered_min_of_each_group (
509526 & mut self ,
510527 orderings : & [ ArrayRef ] ,
@@ -556,15 +573,19 @@ where
556573 }
557574
558575 let is_valid = self . min_of_each_group_buf . 1 . get_bit ( group_idx) ;
559- if is_valid
560- && comparator
561- . compare ( self . min_of_each_group_buf . 0 [ group_idx] , idx_in_val)
562- . is_gt ( )
563- {
564- self . min_of_each_group_buf . 0 [ group_idx] = idx_in_val;
565- } else if !is_valid {
576+
577+ if !is_valid {
566578 self . min_of_each_group_buf . 1 . set_bit ( group_idx, true ) ;
567579 self . min_of_each_group_buf . 0 [ group_idx] = idx_in_val;
580+ } else {
581+ let ordering = comparator
582+ . compare ( self . min_of_each_group_buf . 0 [ group_idx] , idx_in_val) ;
583+
584+ if ( ordering. is_gt ( ) && self . pick_first_in_group )
585+ || ( ordering. is_lt ( ) && !self . pick_first_in_group )
586+ {
587+ self . min_of_each_group_buf . 0 [ group_idx] = idx_in_val;
588+ }
568589 }
569590 }
570591
@@ -1052,6 +1073,109 @@ impl AggregateUDFImpl for LastValue {
10521073 fn documentation ( & self ) -> Option < & Documentation > {
10531074 self . doc ( )
10541075 }
1076+
1077+ fn groups_accumulator_supported ( & self , args : AccumulatorArgs ) -> bool {
1078+ use DataType :: * ;
1079+ matches ! (
1080+ args. return_type,
1081+ Int8 | Int16
1082+ | Int32
1083+ | Int64
1084+ | UInt8
1085+ | UInt16
1086+ | UInt32
1087+ | UInt64
1088+ | Float16
1089+ | Float32
1090+ | Float64
1091+ | Decimal128 ( _, _)
1092+ | Decimal256 ( _, _)
1093+ | Date32
1094+ | Date64
1095+ | Time32 ( _)
1096+ | Time64 ( _)
1097+ | Timestamp ( _, _)
1098+ )
1099+ }
1100+
1101+ fn create_groups_accumulator (
1102+ & self ,
1103+ args : AccumulatorArgs ,
1104+ ) -> Result < Box < dyn GroupsAccumulator > > {
1105+ fn create_accumulator < T > (
1106+ args : AccumulatorArgs ,
1107+ ) -> Result < Box < dyn GroupsAccumulator > >
1108+ where
1109+ T : ArrowPrimitiveType + Send ,
1110+ {
1111+ let ordering_dtypes = args
1112+ . ordering_req
1113+ . iter ( )
1114+ . map ( |e| e. expr . data_type ( args. schema ) )
1115+ . collect :: < Result < Vec < _ > > > ( ) ?;
1116+
1117+ Ok ( Box :: new ( FirstPrimitiveGroupsAccumulator :: < T > :: try_new (
1118+ args. ordering_req . clone ( ) ,
1119+ args. ignore_nulls ,
1120+ args. return_type ,
1121+ & ordering_dtypes,
1122+ false ,
1123+ ) ?) )
1124+ }
1125+
1126+ match args. return_type {
1127+ DataType :: Int8 => create_accumulator :: < Int8Type > ( args) ,
1128+ DataType :: Int16 => create_accumulator :: < Int16Type > ( args) ,
1129+ DataType :: Int32 => create_accumulator :: < Int32Type > ( args) ,
1130+ DataType :: Int64 => create_accumulator :: < Int64Type > ( args) ,
1131+ DataType :: UInt8 => create_accumulator :: < UInt8Type > ( args) ,
1132+ DataType :: UInt16 => create_accumulator :: < UInt16Type > ( args) ,
1133+ DataType :: UInt32 => create_accumulator :: < UInt32Type > ( args) ,
1134+ DataType :: UInt64 => create_accumulator :: < UInt64Type > ( args) ,
1135+ DataType :: Float16 => create_accumulator :: < Float16Type > ( args) ,
1136+ DataType :: Float32 => create_accumulator :: < Float32Type > ( args) ,
1137+ DataType :: Float64 => create_accumulator :: < Float64Type > ( args) ,
1138+
1139+ DataType :: Decimal128 ( _, _) => create_accumulator :: < Decimal128Type > ( args) ,
1140+ DataType :: Decimal256 ( _, _) => create_accumulator :: < Decimal256Type > ( args) ,
1141+
1142+ DataType :: Timestamp ( TimeUnit :: Second , _) => {
1143+ create_accumulator :: < TimestampSecondType > ( args)
1144+ }
1145+ DataType :: Timestamp ( TimeUnit :: Millisecond , _) => {
1146+ create_accumulator :: < TimestampMillisecondType > ( args)
1147+ }
1148+ DataType :: Timestamp ( TimeUnit :: Microsecond , _) => {
1149+ create_accumulator :: < TimestampMicrosecondType > ( args)
1150+ }
1151+ DataType :: Timestamp ( TimeUnit :: Nanosecond , _) => {
1152+ create_accumulator :: < TimestampNanosecondType > ( args)
1153+ }
1154+
1155+ DataType :: Date32 => create_accumulator :: < Date32Type > ( args) ,
1156+ DataType :: Date64 => create_accumulator :: < Date64Type > ( args) ,
1157+ DataType :: Time32 ( TimeUnit :: Second ) => {
1158+ create_accumulator :: < Time32SecondType > ( args)
1159+ }
1160+ DataType :: Time32 ( TimeUnit :: Millisecond ) => {
1161+ create_accumulator :: < Time32MillisecondType > ( args)
1162+ }
1163+
1164+ DataType :: Time64 ( TimeUnit :: Microsecond ) => {
1165+ create_accumulator :: < Time64MicrosecondType > ( args)
1166+ }
1167+ DataType :: Time64 ( TimeUnit :: Nanosecond ) => {
1168+ create_accumulator :: < Time64NanosecondType > ( args)
1169+ }
1170+
1171+ _ => {
1172+ internal_err ! (
1173+ "GroupsAccumulator not supported for last_value({})" ,
1174+ args. return_type
1175+ )
1176+ }
1177+ }
1178+ }
10551179}
10561180
10571181#[ derive( Debug ) ]
@@ -1411,6 +1535,7 @@ mod tests {
14111535 true ,
14121536 & DataType :: Int64 ,
14131537 & [ DataType :: Int64 ] ,
1538+ true ,
14141539 ) ?;
14151540
14161541 let mut val_with_orderings = {
@@ -1485,7 +1610,7 @@ mod tests {
14851610 }
14861611
14871612 #[ test]
1488- fn test_frist_group_acc_size_of_ordering ( ) -> Result < ( ) > {
1613+ fn test_group_acc_size_of_ordering ( ) -> Result < ( ) > {
14891614 let schema = Arc :: new ( Schema :: new ( vec ! [
14901615 Field :: new( "a" , DataType :: Int64 , true ) ,
14911616 Field :: new( "b" , DataType :: Int64 , true ) ,
@@ -1504,6 +1629,7 @@ mod tests {
15041629 true ,
15051630 & DataType :: Int64 ,
15061631 & [ DataType :: Int64 ] ,
1632+ true ,
15071633 ) ?;
15081634
15091635 let val_with_orderings = {
@@ -1563,4 +1689,79 @@ mod tests {
15631689
15641690 Ok ( ( ) )
15651691 }
1692+
1693+ #[ test]
1694+ fn test_last_group_acc ( ) -> Result < ( ) > {
1695+ let schema = Arc :: new ( Schema :: new ( vec ! [
1696+ Field :: new( "a" , DataType :: Int64 , true ) ,
1697+ Field :: new( "b" , DataType :: Int64 , true ) ,
1698+ Field :: new( "c" , DataType :: Int64 , true ) ,
1699+ Field :: new( "d" , DataType :: Int32 , true ) ,
1700+ Field :: new( "e" , DataType :: Boolean , true ) ,
1701+ ] ) ) ;
1702+
1703+ let sort_key = LexOrdering :: new ( vec ! [ PhysicalSortExpr {
1704+ expr: col( "c" , & schema) . unwrap( ) ,
1705+ options: SortOptions :: default ( ) ,
1706+ } ] ) ;
1707+
1708+ let mut group_acc = FirstPrimitiveGroupsAccumulator :: < Int64Type > :: try_new (
1709+ sort_key,
1710+ true ,
1711+ & DataType :: Int64 ,
1712+ & [ DataType :: Int64 ] ,
1713+ false ,
1714+ ) ?;
1715+
1716+ let mut val_with_orderings = {
1717+ let mut val_with_orderings = Vec :: < ArrayRef > :: new ( ) ;
1718+
1719+ let vals = Arc :: new ( Int64Array :: from ( vec ! [ Some ( 1 ) , None , Some ( 3 ) , Some ( -6 ) ] ) ) ;
1720+ let orderings = Arc :: new ( Int64Array :: from ( vec ! [ 1 , -9 , 3 , -6 ] ) ) ;
1721+
1722+ val_with_orderings. push ( vals) ;
1723+ val_with_orderings. push ( orderings) ;
1724+
1725+ val_with_orderings
1726+ } ;
1727+
1728+ group_acc. update_batch (
1729+ & val_with_orderings,
1730+ & [ 0 , 1 , 2 , 1 ] ,
1731+ Some ( & BooleanArray :: from ( vec ! [ true , true , false , true ] ) ) ,
1732+ 3 ,
1733+ ) ?;
1734+
1735+ let state = group_acc. state ( EmitTo :: All ) ?;
1736+
1737+ let expected_state: Vec < Arc < dyn Array > > = vec ! [
1738+ Arc :: new( Int64Array :: from( vec![ Some ( 1 ) , Some ( -6 ) , None ] ) ) ,
1739+ Arc :: new( Int64Array :: from( vec![ Some ( 1 ) , Some ( -6 ) , None ] ) ) ,
1740+ Arc :: new( BooleanArray :: from( vec![ true , true , false ] ) ) ,
1741+ ] ;
1742+ assert_eq ! ( state, expected_state) ;
1743+
1744+ group_acc. merge_batch (
1745+ & state,
1746+ & [ 0 , 1 , 2 ] ,
1747+ Some ( & BooleanArray :: from ( vec ! [ true , false , false ] ) ) ,
1748+ 3 ,
1749+ ) ?;
1750+
1751+ val_with_orderings. clear ( ) ;
1752+ val_with_orderings. push ( Arc :: new ( Int64Array :: from ( vec ! [ 66 , 6 ] ) ) ) ;
1753+ val_with_orderings. push ( Arc :: new ( Int64Array :: from ( vec ! [ 66 , 6 ] ) ) ) ;
1754+
1755+ group_acc. update_batch ( & val_with_orderings, & [ 1 , 2 ] , None , 4 ) ?;
1756+
1757+ let binding = group_acc. evaluate ( EmitTo :: All ) ?;
1758+ let eval_result = binding. as_any ( ) . downcast_ref :: < Int64Array > ( ) . unwrap ( ) ;
1759+
1760+ let expect: PrimitiveArray < Int64Type > =
1761+ Int64Array :: from ( vec ! [ Some ( 1 ) , Some ( 66 ) , Some ( 6 ) , None ] ) ;
1762+
1763+ assert_eq ! ( eval_result, & expect) ;
1764+
1765+ Ok ( ( ) )
1766+ }
15661767}
0 commit comments