Skip to content

Commit e41c02c

Browse files
Garamdajayzhan211
andauthored
Support WITHIN GROUP syntax to standardize certain existing aggregate functions (apache#13511)
* Add within group variable to aggregate function and arguments * Support within group and disable null handling for ordered set aggregate functions (apache#13511) * Refactored function to match updated signature * Modify proto to support within group clause * Modify physical planner and accumulator to support ordered set aggregate function * Support session management for ordered set aggregate functions * Align code, tests, and examples with changes to aggregate function logic * Ensure compatibility with new `within_group` and `order_by` handling. * Adjust tests and examples to align with the new logic. * Fix typo in existing comments * Enhance test * Add test cases for changed signature * Update signature in docs * Fix bug : handle missing within_group when applying children tree node * Change the signature of approx_percentile_cont for consistency * Add missing within_group for expr display * Handle edge case when over and within group clause are used together * Apply clippy advice: avoids too many arguments * Add new test cases using descending order * Apply cargo fmt * Revert unintended submodule changes * Apply prettier guidance * Apply doc guidance by update_function_doc.sh * Rollback WITHIN GROUP and related logic after converting it into expr * Make it not to handle redundant logic * Rollback ordered set aggregate functions from session to save same info in udf itself * Convert within group to order by when converting sql to expr * Add function to determine it is ordered-set aggregate function * Rollback within group from proto * Utilize within group as order by in functions-aggregate * Apply clippy * Convert order by to within group * Apply cargo fmt * Remove plain line breaks * Remove duplicated column arg in schema name * Refactor boolean functions to just return primitive type * Make within group necessary in the signature of existing ordered set aggr funcs * Apply cargo fmt * Support a single ordering expression in the signature * Apply cargo fmt * Add dataframe function test cases to verify descending ordering * Apply cargo fmt * Apply code reviews * Uses order by consistently after done with sql * Remove redundant comment * Serve more clear error msg * Handle error cases in the same code block * Update error msg in test as corresponding code changed * fix --------- Co-authored-by: Jay Zhan <jayzhan211@gmail.com>
1 parent 230f31b commit e41c02c

File tree

12 files changed

+342
-123
lines changed

12 files changed

+342
-123
lines changed

datafusion/core/benches/aggregate_query_sql.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,7 @@ fn criterion_benchmark(c: &mut Criterion) {
158158
query(
159159
ctx.clone(),
160160
&rt,
161-
"SELECT utf8, approx_percentile_cont(u64_wide, 0.5, 2500) \
161+
"SELECT utf8, approx_percentile_cont(0.5, 2500) WITHIN GROUP (ORDER BY u64_wide) \
162162
FROM t GROUP BY utf8",
163163
)
164164
})
@@ -169,7 +169,7 @@ fn criterion_benchmark(c: &mut Criterion) {
169169
query(
170170
ctx.clone(),
171171
&rt,
172-
"SELECT utf8, approx_percentile_cont(f32, 0.5, 2500) \
172+
"SELECT utf8, approx_percentile_cont(0.5, 2500) WITHIN GROUP (ORDER BY f32) \
173173
FROM t GROUP BY utf8",
174174
)
175175
})

datafusion/core/tests/dataframe/dataframe_functions.rs

Lines changed: 69 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -384,19 +384,34 @@ async fn test_fn_approx_median() -> Result<()> {
384384

385385
#[tokio::test]
386386
async fn test_fn_approx_percentile_cont() -> Result<()> {
387-
let expr = approx_percentile_cont(col("b"), lit(0.5), None);
387+
let expr = approx_percentile_cont(col("b").sort(true, false), lit(0.5), None);
388388

389389
let df = create_test_table().await?;
390390
let batches = df.aggregate(vec![], vec![expr]).unwrap().collect().await?;
391391

392392
assert_snapshot!(
393393
batches_to_string(&batches),
394394
@r"
395-
+---------------------------------------------+
396-
| approx_percentile_cont(test.b,Float64(0.5)) |
397-
+---------------------------------------------+
398-
| 10 |
399-
+---------------------------------------------+
395+
+---------------------------------------------------------------------------+
396+
| approx_percentile_cont(Float64(0.5)) WITHIN GROUP [test.b ASC NULLS LAST] |
397+
+---------------------------------------------------------------------------+
398+
| 10 |
399+
+---------------------------------------------------------------------------+
400+
");
401+
402+
let expr = approx_percentile_cont(col("b").sort(false, false), lit(0.1), None);
403+
404+
let df = create_test_table().await?;
405+
let batches = df.aggregate(vec![], vec![expr]).unwrap().collect().await?;
406+
407+
assert_snapshot!(
408+
batches_to_string(&batches),
409+
@r"
410+
+----------------------------------------------------------------------------+
411+
| approx_percentile_cont(Float64(0.1)) WITHIN GROUP [test.b DESC NULLS LAST] |
412+
+----------------------------------------------------------------------------+
413+
| 100 |
414+
+----------------------------------------------------------------------------+
400415
");
401416

402417
// the arg2 parameter is a complex expr, but it can be evaluated to the literal value
@@ -405,35 +420,71 @@ async fn test_fn_approx_percentile_cont() -> Result<()> {
405420
None::<&str>,
406421
"arg_2".to_string(),
407422
));
408-
let expr = approx_percentile_cont(col("b"), alias_expr, None);
423+
let expr = approx_percentile_cont(col("b").sort(true, false), alias_expr, None);
409424
let df = create_test_table().await?;
410425
let batches = df.aggregate(vec![], vec![expr]).unwrap().collect().await?;
411426

412427
assert_snapshot!(
413428
batches_to_string(&batches),
414429
@r"
415-
+--------------------------------------+
416-
| approx_percentile_cont(test.b,arg_2) |
417-
+--------------------------------------+
418-
| 10 |
419-
+--------------------------------------+
430+
+--------------------------------------------------------------------+
431+
| approx_percentile_cont(arg_2) WITHIN GROUP [test.b ASC NULLS LAST] |
432+
+--------------------------------------------------------------------+
433+
| 10 |
434+
+--------------------------------------------------------------------+
435+
"
436+
);
437+
438+
let alias_expr = Expr::Alias(Alias::new(
439+
cast(lit(0.1), DataType::Float32),
440+
None::<&str>,
441+
"arg_2".to_string(),
442+
));
443+
let expr = approx_percentile_cont(col("b").sort(false, false), alias_expr, None);
444+
let df = create_test_table().await?;
445+
let batches = df.aggregate(vec![], vec![expr]).unwrap().collect().await?;
446+
447+
assert_snapshot!(
448+
batches_to_string(&batches),
449+
@r"
450+
+---------------------------------------------------------------------+
451+
| approx_percentile_cont(arg_2) WITHIN GROUP [test.b DESC NULLS LAST] |
452+
+---------------------------------------------------------------------+
453+
| 100 |
454+
+---------------------------------------------------------------------+
420455
"
421456
);
422457

423458
// with number of centroids set
424-
let expr = approx_percentile_cont(col("b"), lit(0.5), Some(lit(2)));
459+
let expr = approx_percentile_cont(col("b").sort(true, false), lit(0.5), Some(lit(2)));
460+
461+
let df = create_test_table().await?;
462+
let batches = df.aggregate(vec![], vec![expr]).unwrap().collect().await?;
463+
464+
assert_snapshot!(
465+
batches_to_string(&batches),
466+
@r"
467+
+------------------------------------------------------------------------------------+
468+
| approx_percentile_cont(Float64(0.5),Int32(2)) WITHIN GROUP [test.b ASC NULLS LAST] |
469+
+------------------------------------------------------------------------------------+
470+
| 30 |
471+
+------------------------------------------------------------------------------------+
472+
");
473+
474+
let expr =
475+
approx_percentile_cont(col("b").sort(false, false), lit(0.1), Some(lit(2)));
425476

426477
let df = create_test_table().await?;
427478
let batches = df.aggregate(vec![], vec![expr]).unwrap().collect().await?;
428479

429480
assert_snapshot!(
430481
batches_to_string(&batches),
431482
@r"
432-
+------------------------------------------------------+
433-
| approx_percentile_cont(test.b,Float64(0.5),Int32(2)) |
434-
+------------------------------------------------------+
435-
| 30 |
436-
+------------------------------------------------------+
483+
+-------------------------------------------------------------------------------------+
484+
| approx_percentile_cont(Float64(0.1),Int32(2)) WITHIN GROUP [test.b DESC NULLS LAST] |
485+
+-------------------------------------------------------------------------------------+
486+
| 69 |
487+
+-------------------------------------------------------------------------------------+
437488
");
438489

439490
Ok(())

datafusion/expr/src/udaf.rs

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -315,6 +315,16 @@ impl AggregateUDF {
315315
self.inner.default_value(data_type)
316316
}
317317

318+
/// See [`AggregateUDFImpl::supports_null_handling_clause`] for more details.
319+
pub fn supports_null_handling_clause(&self) -> bool {
320+
self.inner.supports_null_handling_clause()
321+
}
322+
323+
/// See [`AggregateUDFImpl::is_ordered_set_aggregate`] for more details.
324+
pub fn is_ordered_set_aggregate(&self) -> bool {
325+
self.inner.is_ordered_set_aggregate()
326+
}
327+
318328
/// Returns the documentation for this Aggregate UDF.
319329
///
320330
/// Documentation can be accessed programmatically as well as
@@ -432,6 +442,14 @@ pub trait AggregateUDFImpl: Debug + Send + Sync {
432442
null_treatment,
433443
} = params;
434444

445+
// exclude the first function argument(= column) in ordered set aggregate function,
446+
// because it is duplicated with the WITHIN GROUP clause in schema name.
447+
let args = if self.is_ordered_set_aggregate() {
448+
&args[1..]
449+
} else {
450+
&args[..]
451+
};
452+
435453
let mut schema_name = String::new();
436454

437455
schema_name.write_fmt(format_args!(
@@ -450,8 +468,14 @@ pub trait AggregateUDFImpl: Debug + Send + Sync {
450468
};
451469

452470
if let Some(order_by) = order_by {
471+
let clause = match self.is_ordered_set_aggregate() {
472+
true => "WITHIN GROUP",
473+
false => "ORDER BY",
474+
};
475+
453476
schema_name.write_fmt(format_args!(
454-
" ORDER BY [{}]",
477+
" {} [{}]",
478+
clause,
455479
schema_name_from_sorts(order_by)?
456480
))?;
457481
};
@@ -891,6 +915,18 @@ pub trait AggregateUDFImpl: Debug + Send + Sync {
891915
ScalarValue::try_from(data_type)
892916
}
893917

918+
/// If this function supports `[IGNORE NULLS | RESPECT NULLS]` clause, return true
919+
/// If the function does not, return false
920+
fn supports_null_handling_clause(&self) -> bool {
921+
true
922+
}
923+
924+
/// If this function is ordered-set aggregate function, return true
925+
/// If the function is not, return false
926+
fn is_ordered_set_aggregate(&self) -> bool {
927+
false
928+
}
929+
894930
/// Returns the documentation for this Aggregate UDF.
895931
///
896932
/// Documentation can be accessed programmatically as well as

datafusion/functions-aggregate/src/approx_median.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ make_udaf_expr_and_func!(
4545
/// APPROX_MEDIAN aggregate expression
4646
#[user_doc(
4747
doc_section(label = "Approximate Functions"),
48-
description = "Returns the approximate median (50th percentile) of input values. It is an alias of `approx_percentile_cont(x, 0.5)`.",
48+
description = "Returns the approximate median (50th percentile) of input values. It is an alias of `approx_percentile_cont(0.5) WITHIN GROUP (ORDER BY x)`.",
4949
syntax_example = "approx_median(expression)",
5050
sql_example = r#"```sql
5151
> SELECT approx_median(column_name) FROM table_name;

datafusion/functions-aggregate/src/approx_percentile_cont.rs

Lines changed: 43 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ use datafusion_common::{
3434
downcast_value, internal_err, not_impl_datafusion_err, not_impl_err, plan_err,
3535
Result, ScalarValue,
3636
};
37+
use datafusion_expr::expr::{AggregateFunction, Sort};
3738
use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs};
3839
use datafusion_expr::type_coercion::aggregates::{INTEGERS, NUMERICS};
3940
use datafusion_expr::utils::format_state_name;
@@ -51,29 +52,39 @@ create_func!(ApproxPercentileCont, approx_percentile_cont_udaf);
5152

5253
/// Computes the approximate percentile continuous of a set of numbers
5354
pub fn approx_percentile_cont(
54-
expression: Expr,
55+
order_by: Sort,
5556
percentile: Expr,
5657
centroids: Option<Expr>,
5758
) -> Expr {
59+
let expr = order_by.expr.clone();
60+
5861
let args = if let Some(centroids) = centroids {
59-
vec![expression, percentile, centroids]
62+
vec![expr, percentile, centroids]
6063
} else {
61-
vec![expression, percentile]
64+
vec![expr, percentile]
6265
};
63-
approx_percentile_cont_udaf().call(args)
66+
67+
Expr::AggregateFunction(AggregateFunction::new_udf(
68+
approx_percentile_cont_udaf(),
69+
args,
70+
false,
71+
None,
72+
Some(vec![order_by]),
73+
None,
74+
))
6475
}
6576

6677
#[user_doc(
6778
doc_section(label = "Approximate Functions"),
6879
description = "Returns the approximate percentile of input values using the t-digest algorithm.",
69-
syntax_example = "approx_percentile_cont(expression, percentile, centroids)",
80+
syntax_example = "approx_percentile_cont(percentile, centroids) WITHIN GROUP (ORDER BY expression)",
7081
sql_example = r#"```sql
71-
> SELECT approx_percentile_cont(column_name, 0.75, 100) FROM table_name;
72-
+-------------------------------------------------+
73-
| approx_percentile_cont(column_name, 0.75, 100) |
74-
+-------------------------------------------------+
75-
| 65.0 |
76-
+-------------------------------------------------+
82+
> SELECT approx_percentile_cont(0.75, 100) WITHIN GROUP (ORDER BY column_name) FROM table_name;
83+
+-----------------------------------------------------------------------+
84+
| approx_percentile_cont(0.75, 100) WITHIN GROUP (ORDER BY column_name) |
85+
+-----------------------------------------------------------------------+
86+
| 65.0 |
87+
+-----------------------------------------------------------------------+
7788
```"#,
7889
standard_argument(name = "expression",),
7990
argument(
@@ -130,6 +141,19 @@ impl ApproxPercentileCont {
130141
args: AccumulatorArgs,
131142
) -> Result<ApproxPercentileAccumulator> {
132143
let percentile = validate_input_percentile_expr(&args.exprs[1])?;
144+
145+
let is_descending = args
146+
.ordering_req
147+
.first()
148+
.map(|sort_expr| sort_expr.options.descending)
149+
.unwrap_or(false);
150+
151+
let percentile = if is_descending {
152+
1.0 - percentile
153+
} else {
154+
percentile
155+
};
156+
133157
let tdigest_max_size = if args.exprs.len() == 3 {
134158
Some(validate_input_max_size_expr(&args.exprs[2])?)
135159
} else {
@@ -292,6 +316,14 @@ impl AggregateUDFImpl for ApproxPercentileCont {
292316
Ok(arg_types[0].clone())
293317
}
294318

319+
fn supports_null_handling_clause(&self) -> bool {
320+
false
321+
}
322+
323+
fn is_ordered_set_aggregate(&self) -> bool {
324+
true
325+
}
326+
295327
fn documentation(&self) -> Option<&Documentation> {
296328
self.doc()
297329
}

datafusion/functions-aggregate/src/approx_percentile_cont_with_weight.rs

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -52,14 +52,14 @@ make_udaf_expr_and_func!(
5252
#[user_doc(
5353
doc_section(label = "Approximate Functions"),
5454
description = "Returns the weighted approximate percentile of input values using the t-digest algorithm.",
55-
syntax_example = "approx_percentile_cont_with_weight(expression, weight, percentile)",
55+
syntax_example = "approx_percentile_cont_with_weight(weight, percentile) WITHIN GROUP (ORDER BY expression)",
5656
sql_example = r#"```sql
57-
> SELECT approx_percentile_cont_with_weight(column_name, weight_column, 0.90) FROM table_name;
58-
+----------------------------------------------------------------------+
59-
| approx_percentile_cont_with_weight(column_name, weight_column, 0.90) |
60-
+----------------------------------------------------------------------+
61-
| 78.5 |
62-
+----------------------------------------------------------------------+
57+
> SELECT approx_percentile_cont_with_weight(weight_column, 0.90) WITHIN GROUP (ORDER BY column_name) FROM table_name;
58+
+---------------------------------------------------------------------------------------------+
59+
| approx_percentile_cont_with_weight(weight_column, 0.90) WITHIN GROUP (ORDER BY column_name) |
60+
+---------------------------------------------------------------------------------------------+
61+
| 78.5 |
62+
+---------------------------------------------------------------------------------------------+
6363
```"#,
6464
standard_argument(name = "expression", prefix = "The"),
6565
argument(
@@ -178,6 +178,14 @@ impl AggregateUDFImpl for ApproxPercentileContWithWeight {
178178
self.approx_percentile_cont.state_fields(args)
179179
}
180180

181+
fn supports_null_handling_clause(&self) -> bool {
182+
false
183+
}
184+
185+
fn is_ordered_set_aggregate(&self) -> bool {
186+
true
187+
}
188+
181189
fn documentation(&self) -> Option<&Documentation> {
182190
self.doc()
183191
}

datafusion/proto/tests/cases/roundtrip_logical_plan.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -973,8 +973,8 @@ async fn roundtrip_expr_api() -> Result<()> {
973973
stddev_pop(lit(2.2)),
974974
approx_distinct(lit(2)),
975975
approx_median(lit(2)),
976-
approx_percentile_cont(lit(2), lit(0.5), None),
977-
approx_percentile_cont(lit(2), lit(0.5), Some(lit(50))),
976+
approx_percentile_cont(lit(2).sort(true, false), lit(0.5), None),
977+
approx_percentile_cont(lit(2).sort(true, false), lit(0.5), Some(lit(50))),
978978
approx_percentile_cont_with_weight(lit(2), lit(1), lit(0.5)),
979979
grouping(lit(1)),
980980
bit_and(lit(2)),

datafusion/proto/tests/cases/roundtrip_physical_plan.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -504,7 +504,7 @@ fn rountrip_aggregate_with_approx_pencentile_cont() -> Result<()> {
504504
vec![col("b", &schema)?, lit(0.5)],
505505
)
506506
.schema(Arc::clone(&schema))
507-
.alias("APPROX_PERCENTILE_CONT(b, 0.5)")
507+
.alias("APPROX_PERCENTILE_CONT(0.5) WITHIN GROUP (ORDER BY b)")
508508
.build()
509509
.map(Arc::new)?];
510510

0 commit comments

Comments
 (0)