From b9fe6f0dfca9b779290fb0cb57d6c8ad7bb8d3da Mon Sep 17 00:00:00 2001 From: sundy-li <543950155@qq.com> Date: Fri, 29 Nov 2024 15:33:39 +0800 Subject: [PATCH 1/7] fix(query): keep remaining_predicates when filtering grouping sets --- src/query/ast/src/ast/format/syntax/query.rs | 9 + src/query/ast/src/ast/query.rs | 86 ++- src/query/ast/src/parser/query.rs | 29 +- src/query/ast/tests/it/parser.rs | 4 + src/query/ast/tests/it/testdata/stmt.txt | 543 ++++++++++++++++++ src/query/expression/benches/bench.rs | 26 + src/query/functions/src/scalars/arithmetic.rs | 43 +- src/query/functions/src/scalars/other.rs | 10 +- src/query/sql/src/planner/binder/aggregate.rs | 103 +++- .../rule_push_down_filter_aggregate.rs | 7 +- src/tests/sqlsmith/src/sql_gen/query.rs | 12 +- .../03_common/03_0003_select_group_by.test | 85 +++ 12 files changed, 865 insertions(+), 92 deletions(-) diff --git a/src/query/ast/src/ast/format/syntax/query.rs b/src/query/ast/src/ast/format/syntax/query.rs index deb6fafe30da..0e93bfc94acb 100644 --- a/src/query/ast/src/ast/format/syntax/query.rs +++ b/src/query/ast/src/ast/format/syntax/query.rs @@ -272,6 +272,15 @@ fn pretty_group_by(group_by: Option) -> RcDoc<'static> { ) .append(RcDoc::line()) .append(RcDoc::text(")")), + + GroupBy::Combined(sets) => RcDoc::line() + .append(RcDoc::text("GROUP BY ").append(RcDoc::line().nest(NEST_FACTOR))) + .append( + interweave_comma(sets.into_iter().map(|s| RcDoc::text(s.to_string()))) + .nest(NEST_FACTOR) + .group(), + ) + .append(RcDoc::line()), } } else { RcDoc::nil() diff --git a/src/query/ast/src/ast/query.rs b/src/query/ast/src/ast/query.rs index 76893ba29623..165b66fd6b22 100644 --- a/src/query/ast/src/ast/query.rs +++ b/src/query/ast/src/ast/query.rs @@ -189,38 +189,8 @@ impl Display for SelectStmt { // GROUP BY clause if self.group_by.is_some() { write!(f, " GROUP BY ")?; - match self.group_by.as_ref().unwrap() { - GroupBy::Normal(exprs) => { - write_comma_separated_list(f, exprs)?; - } - GroupBy::All => { - write!(f, "ALL")?; - } - GroupBy::GroupingSets(sets) => { - write!(f, "GROUPING SETS (")?; - for (i, set) in sets.iter().enumerate() { - if i > 0 { - write!(f, ", ")?; - } - write!(f, "(")?; - write_comma_separated_list(f, set)?; - write!(f, ")")?; - } - write!(f, ")")?; - } - GroupBy::Cube(exprs) => { - write!(f, "CUBE (")?; - write_comma_separated_list(f, exprs)?; - write!(f, ")")?; - } - GroupBy::Rollup(exprs) => { - write!(f, "ROLLUP (")?; - write_comma_separated_list(f, exprs)?; - write!(f, ")")?; - } - } + write!(f, "{}", self.group_by.as_ref().unwrap())?; } - // HAVING clause if let Some(having) = &self.having { write!(f, " HAVING {having}")?; @@ -254,6 +224,60 @@ pub enum GroupBy { Cube(Vec), /// GROUP BY ROLLUP ( expr [, expr]* ) Rollup(Vec), + Combined(Vec), +} + +impl GroupBy { + pub fn normal_items(&self) -> Vec { + match self { + GroupBy::Normal(exprs) => exprs.clone(), + _ => Vec::new(), + } + } +} + +impl Display for GroupBy { + fn fmt(&self, f: &mut Formatter) -> std::fmt::Result { + match self { + GroupBy::Normal(exprs) => { + write_comma_separated_list(f, exprs)?; + } + GroupBy::All => { + write!(f, "ALL")?; + } + GroupBy::GroupingSets(sets) => { + write!(f, "GROUPING SETS (")?; + for (i, set) in sets.iter().enumerate() { + if i > 0 { + write!(f, ", ")?; + } + write!(f, "(")?; + write_comma_separated_list(f, set)?; + write!(f, ")")?; + } + write!(f, ")")?; + } + GroupBy::Cube(exprs) => { + write!(f, "CUBE (")?; + write_comma_separated_list(f, exprs)?; + write!(f, ")")?; + } + GroupBy::Rollup(exprs) => { + write!(f, "ROLLUP (")?; + write_comma_separated_list(f, exprs)?; + write!(f, ")")?; + } + GroupBy::Combined(group_bys) => { + for (i, group_by) in group_bys.iter().enumerate() { + if i > 0 { + write!(f, ", ")?; + } + write!(f, "{}", group_by)?; + } + } + } + Ok(()) + } } /// A relational set expression, like `SELECT ... FROM ... {UNION|EXCEPT|INTERSECT} SELECT ... FROM ...` diff --git a/src/query/ast/src/parser/query.rs b/src/query/ast/src/parser/query.rs index cd9cb44be96b..efa318b11726 100644 --- a/src/query/ast/src/parser/query.rs +++ b/src/query/ast/src/parser/query.rs @@ -1073,10 +1073,6 @@ impl<'a, I: Iterator>> PrattParser } pub fn group_by_items(i: Input) -> IResult { - let normal = map(rule! { ^#comma_separated_list1(expr) }, |groups| { - GroupBy::Normal(groups) - }); - let all = map(rule! { ALL }, |_| GroupBy::All); let cube = map( @@ -1096,10 +1092,31 @@ pub fn group_by_items(i: Input) -> IResult { map(rule! { #expr }, |e| vec![e]), )); let group_sets = map( - rule! { GROUPING ~ SETS ~ "(" ~ ^#comma_separated_list1(group_set) ~ ")" }, + rule! { GROUPING ~ ^SETS ~ "(" ~ ^#comma_separated_list1(group_set) ~ ")" }, |(_, _, _, sets, _)| GroupBy::GroupingSets(sets), ); - rule!(#all | #group_sets | #cube | #rollup | #normal)(i) + + // New rule to handle multiple GroupBy items + let single_normal = map(rule! { #expr }, |group| GroupBy::Normal(vec![group])); + let group_by_item = alt((all, group_sets, cube, rollup, single_normal)); + map(rule! { ^#comma_separated_list1(group_by_item) }, |items| { + if items.len() > 1 { + if items.iter().all(|item| matches!(item, GroupBy::Normal(_))) { + let items = items + .into_iter() + .flat_map(|item| match item { + GroupBy::Normal(exprs) => exprs, + _ => unreachable!(), + }) + .collect(); + GroupBy::Normal(items) + } else { + GroupBy::Combined(items) + } + } else { + items.into_iter().next().unwrap() + } + })(i) } pub fn window_frame_bound(i: Input) -> IResult { diff --git a/src/query/ast/tests/it/parser.rs b/src/query/ast/tests/it/parser.rs index 9a536b2c2ee4..e82030fa112d 100644 --- a/src/query/ast/tests/it/parser.rs +++ b/src/query/ast/tests/it/parser.rs @@ -589,12 +589,16 @@ fn test_statement() { "#, r#"SHOW FILE FORMATS"#, r#"DROP FILE FORMAT my_csv"#, + r#"SELECT * FROM t GROUP BY all"#, + r#"SELECT * FROM t GROUP BY a, b, c, d"#, r#"SELECT * FROM t GROUP BY GROUPING SETS (a, b, c, d)"#, r#"SELECT * FROM t GROUP BY GROUPING SETS (a, b, (c, d))"#, r#"SELECT * FROM t GROUP BY GROUPING SETS ((a, b), (c), (d, e))"#, r#"SELECT * FROM t GROUP BY GROUPING SETS ((a, b), (), (d, e))"#, r#"SELECT * FROM t GROUP BY CUBE (a, b, c)"#, r#"SELECT * FROM t GROUP BY ROLLUP (a, b, c)"#, + r#"SELECT * FROM t GROUP BY a, ROLLUP (b, c)"#, + r#"SELECT * FROM t GROUP BY GROUPING SETS ((a, b)), a, ROLLUP (b, c)"#, r#"CREATE MASKING POLICY email_mask AS (val STRING) RETURNS STRING -> CASE WHEN current_role() IN ('ANALYST') THEN VAL ELSE '*********'END comment = 'this is a masking policy'"#, r#"CREATE OR REPLACE MASKING POLICY email_mask AS (val STRING) RETURNS STRING -> CASE WHEN current_role() IN ('ANALYST') THEN VAL ELSE '*********'END comment = 'this is a masking policy'"#, r#"DESC MASKING POLICY email_mask"#, diff --git a/src/query/ast/tests/it/testdata/stmt.txt b/src/query/ast/tests/it/testdata/stmt.txt index 206022cca94e..96ee590dac59 100644 --- a/src/query/ast/tests/it/testdata/stmt.txt +++ b/src/query/ast/tests/it/testdata/stmt.txt @@ -17358,6 +17358,227 @@ DropFileFormat { } +---------- Input ---------- +SELECT * FROM t GROUP BY all +---------- Output --------- +SELECT * FROM t GROUP BY ALL +---------- AST ------------ +Query( + Query { + span: Some( + 0..28, + ), + with: None, + body: Select( + SelectStmt { + span: Some( + 0..28, + ), + hints: None, + distinct: false, + top_n: None, + select_list: [ + StarColumns { + qualified: [ + Star( + Some( + 7..8, + ), + ), + ], + column_filter: None, + }, + ], + from: [ + Table { + span: Some( + 14..15, + ), + catalog: None, + database: None, + table: Identifier { + span: Some( + 14..15, + ), + name: "t", + quote: None, + ident_type: None, + }, + alias: None, + temporal: None, + with_options: None, + pivot: None, + unpivot: None, + sample: None, + }, + ], + selection: None, + group_by: Some( + All, + ), + having: None, + window_list: None, + qualify: None, + }, + ), + order_by: [], + limit: [], + offset: None, + ignore_result: false, + }, +) + + +---------- Input ---------- +SELECT * FROM t GROUP BY a, b, c, d +---------- Output --------- +SELECT * FROM t GROUP BY a, b, c, d +---------- AST ------------ +Query( + Query { + span: Some( + 0..35, + ), + with: None, + body: Select( + SelectStmt { + span: Some( + 0..35, + ), + hints: None, + distinct: false, + top_n: None, + select_list: [ + StarColumns { + qualified: [ + Star( + Some( + 7..8, + ), + ), + ], + column_filter: None, + }, + ], + from: [ + Table { + span: Some( + 14..15, + ), + catalog: None, + database: None, + table: Identifier { + span: Some( + 14..15, + ), + name: "t", + quote: None, + ident_type: None, + }, + alias: None, + temporal: None, + with_options: None, + pivot: None, + unpivot: None, + sample: None, + }, + ], + selection: None, + group_by: Some( + Normal( + [ + ColumnRef { + span: Some( + 25..26, + ), + column: ColumnRef { + database: None, + table: None, + column: Name( + Identifier { + span: Some( + 25..26, + ), + name: "a", + quote: None, + ident_type: None, + }, + ), + }, + }, + ColumnRef { + span: Some( + 28..29, + ), + column: ColumnRef { + database: None, + table: None, + column: Name( + Identifier { + span: Some( + 28..29, + ), + name: "b", + quote: None, + ident_type: None, + }, + ), + }, + }, + ColumnRef { + span: Some( + 31..32, + ), + column: ColumnRef { + database: None, + table: None, + column: Name( + Identifier { + span: Some( + 31..32, + ), + name: "c", + quote: None, + ident_type: None, + }, + ), + }, + }, + ColumnRef { + span: Some( + 34..35, + ), + column: ColumnRef { + database: None, + table: None, + column: Name( + Identifier { + span: Some( + 34..35, + ), + name: "d", + quote: None, + ident_type: None, + }, + ), + }, + }, + ], + ), + ), + having: None, + window_list: None, + qualify: None, + }, + ), + order_by: [], + limit: [], + offset: None, + ignore_result: false, + }, +) + + ---------- Input ---------- SELECT * FROM t GROUP BY GROUPING SETS (a, b, c, d) ---------- Output --------- @@ -18264,6 +18485,328 @@ Query( ) +---------- Input ---------- +SELECT * FROM t GROUP BY a, ROLLUP (b, c) +---------- Output --------- +SELECT * FROM t GROUP BY a, ROLLUP (b, c) +---------- AST ------------ +Query( + Query { + span: Some( + 0..41, + ), + with: None, + body: Select( + SelectStmt { + span: Some( + 0..41, + ), + hints: None, + distinct: false, + top_n: None, + select_list: [ + StarColumns { + qualified: [ + Star( + Some( + 7..8, + ), + ), + ], + column_filter: None, + }, + ], + from: [ + Table { + span: Some( + 14..15, + ), + catalog: None, + database: None, + table: Identifier { + span: Some( + 14..15, + ), + name: "t", + quote: None, + ident_type: None, + }, + alias: None, + temporal: None, + with_options: None, + pivot: None, + unpivot: None, + sample: None, + }, + ], + selection: None, + group_by: Some( + Combined( + [ + Normal( + [ + ColumnRef { + span: Some( + 25..26, + ), + column: ColumnRef { + database: None, + table: None, + column: Name( + Identifier { + span: Some( + 25..26, + ), + name: "a", + quote: None, + ident_type: None, + }, + ), + }, + }, + ], + ), + Rollup( + [ + ColumnRef { + span: Some( + 36..37, + ), + column: ColumnRef { + database: None, + table: None, + column: Name( + Identifier { + span: Some( + 36..37, + ), + name: "b", + quote: None, + ident_type: None, + }, + ), + }, + }, + ColumnRef { + span: Some( + 39..40, + ), + column: ColumnRef { + database: None, + table: None, + column: Name( + Identifier { + span: Some( + 39..40, + ), + name: "c", + quote: None, + ident_type: None, + }, + ), + }, + }, + ], + ), + ], + ), + ), + having: None, + window_list: None, + qualify: None, + }, + ), + order_by: [], + limit: [], + offset: None, + ignore_result: false, + }, +) + + +---------- Input ---------- +SELECT * FROM t GROUP BY GROUPING SETS ((a, b)), a, ROLLUP (b, c) +---------- Output --------- +SELECT * FROM t GROUP BY GROUPING SETS ((a, b)), a, ROLLUP (b, c) +---------- AST ------------ +Query( + Query { + span: Some( + 0..65, + ), + with: None, + body: Select( + SelectStmt { + span: Some( + 0..65, + ), + hints: None, + distinct: false, + top_n: None, + select_list: [ + StarColumns { + qualified: [ + Star( + Some( + 7..8, + ), + ), + ], + column_filter: None, + }, + ], + from: [ + Table { + span: Some( + 14..15, + ), + catalog: None, + database: None, + table: Identifier { + span: Some( + 14..15, + ), + name: "t", + quote: None, + ident_type: None, + }, + alias: None, + temporal: None, + with_options: None, + pivot: None, + unpivot: None, + sample: None, + }, + ], + selection: None, + group_by: Some( + Combined( + [ + GroupingSets( + [ + [ + ColumnRef { + span: Some( + 41..42, + ), + column: ColumnRef { + database: None, + table: None, + column: Name( + Identifier { + span: Some( + 41..42, + ), + name: "a", + quote: None, + ident_type: None, + }, + ), + }, + }, + ColumnRef { + span: Some( + 44..45, + ), + column: ColumnRef { + database: None, + table: None, + column: Name( + Identifier { + span: Some( + 44..45, + ), + name: "b", + quote: None, + ident_type: None, + }, + ), + }, + }, + ], + ], + ), + Normal( + [ + ColumnRef { + span: Some( + 49..50, + ), + column: ColumnRef { + database: None, + table: None, + column: Name( + Identifier { + span: Some( + 49..50, + ), + name: "a", + quote: None, + ident_type: None, + }, + ), + }, + }, + ], + ), + Rollup( + [ + ColumnRef { + span: Some( + 60..61, + ), + column: ColumnRef { + database: None, + table: None, + column: Name( + Identifier { + span: Some( + 60..61, + ), + name: "b", + quote: None, + ident_type: None, + }, + ), + }, + }, + ColumnRef { + span: Some( + 63..64, + ), + column: ColumnRef { + database: None, + table: None, + column: Name( + Identifier { + span: Some( + 63..64, + ), + name: "c", + quote: None, + ident_type: None, + }, + ), + }, + }, + ], + ), + ], + ), + ), + having: None, + window_list: None, + qualify: None, + }, + ), + order_by: [], + limit: [], + offset: None, + ignore_result: false, + }, +) + + ---------- Input ---------- CREATE MASKING POLICY email_mask AS (val STRING) RETURNS STRING -> CASE WHEN current_role() IN ('ANALYST') THEN VAL ELSE '*********'END comment = 'this is a masking policy' ---------- Output --------- diff --git a/src/query/expression/benches/bench.rs b/src/query/expression/benches/bench.rs index bfc342bc322a..a92e7d30f423 100644 --- a/src/query/expression/benches/bench.rs +++ b/src/query/expression/benches/bench.rs @@ -20,11 +20,14 @@ use criterion::Criterion; use databend_common_column::buffer::Buffer; use databend_common_expression::arrow::deserialize_column; use databend_common_expression::arrow::serialize_column; +use databend_common_expression::types::ArgType; use databend_common_expression::types::BinaryType; +use databend_common_expression::types::Int32Type; use databend_common_expression::types::StringType; use databend_common_expression::Column; use databend_common_expression::DataBlock; use databend_common_expression::FromData; +use databend_common_expression::Value; use rand::rngs::StdRng; use rand::Rng; use rand::SeedableRng; @@ -222,6 +225,29 @@ fn bench(c: &mut Criterion) { .collect::>(); }) }); + + let value1 = Value::::Column(left.clone()); + let value2 = Value::::Column(right.clone()); + + group.bench_function(format!("register_new/{length}"), |b| { + b.iter(|| { + let iter = (0..length).map(|i| { + let a = unsafe { value1.index_unchecked(i) }; + let b = unsafe { value2.index_unchecked(i) }; + a + b + }); + let _c = Int32Type::column_from_iter(iter, &[]); + }) + }); + + group.bench_function(format!("register_old/{length}"), |b| { + b.iter(|| { + let a = value1.as_column().unwrap(); + let b = value2.as_column().unwrap(); + let iter = a.iter().zip(b.iter()).map(|(a, b)| *a + b); + let _c = Int32Type::column_from_iter(iter, &[]); + }) + }); } } diff --git a/src/query/functions/src/scalars/arithmetic.rs b/src/query/functions/src/scalars/arithmetic.rs index bd42635b3e06..e3c60a6c3c84 100644 --- a/src/query/functions/src/scalars/arithmetic.rs +++ b/src/query/functions/src/scalars/arithmetic.rs @@ -21,7 +21,6 @@ use std::str::FromStr; use std::sync::Arc; use databend_common_expression::serialize::read_decimal_with_size; -use databend_common_expression::types::binary::BinaryColumnBuilder; use databend_common_expression::types::decimal::DecimalDomain; use databend_common_expression::types::decimal::DecimalType; use databend_common_expression::types::nullable::NullableColumn; @@ -38,7 +37,6 @@ use databend_common_expression::types::NullableType; use databend_common_expression::types::NumberClass; use databend_common_expression::types::NumberDataType; use databend_common_expression::types::SimpleDomain; -use databend_common_expression::types::StringColumn; use databend_common_expression::types::StringType; use databend_common_expression::types::ALL_FLOAT_TYPES; use databend_common_expression::types::ALL_INTEGER_TYPES; @@ -972,25 +970,24 @@ pub fn register_number_to_string(registry: &mut FunctionRegistry) { Value::Column(from) => { let options = NUM_TYPE::lexical_options(); const FORMAT: u128 = lexical_core::format::STANDARD; - type Native = ::Native; - let mut builder = StringColumnBuilder::with_capacity(from.len()); unsafe { + builder.row_buffer.resize( + ::Native::FORMATTED_SIZE_DECIMAL, + 0, + ); for x in from.iter() { - builder.row_buffer.resize( - ::Native::FORMATTED_SIZE_DECIMAL, - 0, - ); let len = lexical_core::write_with_options::<_, FORMAT>( Native::from(*x), - &mut builder.row_buffer, + &mut builder.row_buffer[0..], &options, ) .len(); - builder.row_buffer.truncate(len); - builder.commit_row(); + builder.data.push_value(std::str::from_utf8_unchecked( + &builder.row_buffer[0..len], + )); } } Value::Column(builder.build()) @@ -1005,29 +1002,27 @@ pub fn register_number_to_string(registry: &mut FunctionRegistry) { Value::Column(from) => { let options = NUM_TYPE::lexical_options(); const FORMAT: u128 = lexical_core::format::STANDARD; - let mut builder = - BinaryColumnBuilder::with_capacity(from.len(), from.len() + 1); - let values = &mut builder.data; - type Native = ::Native; - let mut offset: usize = 0; + let mut builder = StringColumnBuilder::with_capacity(from.len()); + unsafe { + builder.row_buffer.resize( + ::Native::FORMATTED_SIZE_DECIMAL, + 0, + ); for x in from.iter() { - values.reserve(offset + Native::FORMATTED_SIZE_DECIMAL); - values.set_len(offset + Native::FORMATTED_SIZE_DECIMAL); - let bytes = &mut values[offset..]; let len = lexical_core::write_with_options::<_, FORMAT>( Native::from(*x), - bytes, + &mut builder.row_buffer[0..], &options, ) .len(); - offset += len; - builder.offsets.push(offset as u64); + builder.data.push_value(std::str::from_utf8_unchecked( + &builder.row_buffer[0..len], + )); } - values.set_len(offset); } - let result = StringColumn::try_from(builder.build()).unwrap(); + let result = builder.build(); Value::Column(NullableColumn::new( result, Bitmap::new_constant(true, from.len()), diff --git a/src/query/functions/src/scalars/other.rs b/src/query/functions/src/scalars/other.rs index 0c4a44c38d55..ace497ed6a58 100644 --- a/src/query/functions/src/scalars/other.rs +++ b/src/query/functions/src/scalars/other.rs @@ -22,7 +22,6 @@ use databend_common_base::base::convert_number_size; use databend_common_base::base::uuid::Uuid; use databend_common_base::base::OrderedFloat; use databend_common_expression::error_to_null; -use databend_common_expression::types::binary::BinaryColumnBuilder; use databend_common_expression::types::boolean::BooleanDomain; use databend_common_expression::types::nullable::NullableColumn; use databend_common_expression::types::number::Float32Type; @@ -31,7 +30,7 @@ use databend_common_expression::types::number::Int64Type; use databend_common_expression::types::number::UInt32Type; use databend_common_expression::types::number::UInt8Type; use databend_common_expression::types::number::F64; -use databend_common_expression::types::string::StringColumn; +use databend_common_expression::types::string::StringColumnBuilder; use databend_common_expression::types::ArgType; use databend_common_expression::types::DataType; use databend_common_expression::types::DateType; @@ -229,16 +228,15 @@ pub fn register(registry: &mut FunctionRegistry) { "gen_random_uuid", |_| FunctionDomain::Full, |ctx| { - let mut builder = BinaryColumnBuilder::with_capacity(ctx.num_rows, 0); + let mut builder = StringColumnBuilder::with_capacity(ctx.num_rows); for _ in 0..ctx.num_rows { let value = Uuid::now_v7(); - write!(&mut builder.data, "{}", value).unwrap(); + write!(&mut builder.row_buffer, "{}", value).unwrap(); builder.commit_row(); } - let col = StringColumn::try_from(builder.build()).unwrap(); - Value::Column(col) + Value::Column(builder.build()) }, ); } diff --git a/src/query/sql/src/planner/binder/aggregate.rs b/src/query/sql/src/planner/binder/aggregate.rs index 306223240b7b..cb9ae84e1090 100644 --- a/src/query/sql/src/planner/binder/aggregate.rs +++ b/src/query/sql/src/planner/binder/aggregate.rs @@ -393,7 +393,9 @@ impl Binder { let original_context = bind_context.expr_context.clone(); bind_context.set_expr_context(ExprContext::GroupClaue); - match group_by { + + let group_by = Self::expand_group(group_by.clone())?; + match &group_by { GroupBy::Normal(exprs) => self.resolve_group_items( bind_context, select_list, @@ -416,25 +418,91 @@ impl Binder { GroupBy::GroupingSets(sets) => { self.resolve_grouping_sets(bind_context, select_list, sets, &available_aliases)?; } - // TODO: avoid too many clones. + _ => unreachable!(), + } + bind_context.set_expr_context(original_context); + Ok(()) + } + + pub fn expand_group(group_by: GroupBy) -> Result { + match group_by { + GroupBy::Normal(_) | GroupBy::All | GroupBy::GroupingSets(_) => Ok(group_by), + GroupBy::Cube(exprs) => { + // Expand CUBE to GroupingSets + let sets = Self::generate_cube_sets(exprs); + Ok(GroupBy::GroupingSets(sets)) + } GroupBy::Rollup(exprs) => { - // ROLLUP (a,b,c) => GROUPING SETS ((a,b,c), (a,b), (a), ()) - let mut sets = Vec::with_capacity(exprs.len() + 1); - for i in (0..=exprs.len()).rev() { - sets.push(exprs[0..i].to_vec()); + // Expand ROLLUP to GroupingSets + let sets = Self::generate_rollup_sets(exprs); + Ok(GroupBy::GroupingSets(sets)) + } + GroupBy::Combined(groups) => { + // Flatten and expand all nested GroupBy variants + let mut combined_sets = Vec::new(); + for group in groups { + match Self::expand_group(group)? { + GroupBy::Normal(exprs) => { + combined_sets = Self::cartesian_product(combined_sets, vec![exprs]); + } + GroupBy::GroupingSets(sets) => { + combined_sets = Self::cartesian_product(combined_sets, sets); + } + _other => unreachable!(), + } } - self.resolve_grouping_sets(bind_context, select_list, &sets, &available_aliases)?; + Ok(GroupBy::GroupingSets(combined_sets)) } - GroupBy::Cube(exprs) => { - // CUBE (a,b) => GROUPING SETS ((a,b),(a),(b),()) // All subsets - let sets = (0..=exprs.len()) - .flat_map(|count| exprs.clone().into_iter().combinations(count)) - .collect::>(); - self.resolve_grouping_sets(bind_context, select_list, &sets, &available_aliases)?; + } + } + + /// Generate GroupingSets from CUBE (expr1, expr2, ...) + fn generate_cube_sets(exprs: Vec) -> Vec> { + let mut result = Vec::new(); + let n = exprs.len(); + + // Iterate through all possible subsets of the given expressions + for i in 0..(1 << n) { + let mut subset = Vec::new(); + for j in 0..n { + if (i & (1 << j)) != 0 { + subset.push(exprs[j].clone()); + } } + result.push(subset); } - bind_context.set_expr_context(original_context); - Ok(()) + + result + } + + /// Generate GroupingSets from ROLLUP (expr1, expr2, ...) + fn generate_rollup_sets(exprs: Vec) -> Vec> { + let mut result = Vec::new(); + for i in (0..=exprs.len()).rev() { + result.push(exprs[..i].to_vec()); + } + result + } + + /// Perform Cartesian product of two sets of grouping sets + fn cartesian_product(set1: Vec>, set2: Vec>) -> Vec> { + if set1.is_empty() { + return set2; + } + + if set2.is_empty() { + return set1; + } + + let mut result = Vec::new(); + for s1 in set1 { + for s2 in &set2 { + let mut combined = s1.clone(); + combined.extend(s2.clone()); + result.push(combined); + } + } + result } pub fn bind_aggregate( @@ -525,6 +593,11 @@ impl Binder { set }) .collect::>(); + + // Because we are not using union all to implement grouping sets + // We will remove the duplicated grouping sets here. + // For example: SELECT brand, segment, SUM (quantity) FROM sales GROUP BY GROUPING sets(brand, segment), GROUPING sets(brand, segment); + // brand X segment will not appear twice in the result, the results are not standard but accpetable. let grouping_sets = grouping_sets.into_iter().unique().collect(); let mut dup_group_items = Vec::with_capacity(agg_info.group_items.len()); for (i, item) in agg_info.group_items.iter().enumerate() { diff --git a/src/query/sql/src/planner/optimizer/rule/rewrite/rule_push_down_filter_aggregate.rs b/src/query/sql/src/planner/optimizer/rule/rewrite/rule_push_down_filter_aggregate.rs index 599122e2a97d..06ad9643a736 100644 --- a/src/query/sql/src/planner/optimizer/rule/rewrite/rule_push_down_filter_aggregate.rs +++ b/src/query/sql/src/planner/optimizer/rule/rewrite/rule_push_down_filter_aggregate.rs @@ -89,7 +89,11 @@ impl Rule for RulePushDownFilterAggregate { if predicate_used_columns.is_subset(&aggregate_child_prop.output_columns) && predicate_used_columns.is_subset(&aggregate_group_columns) { - pushed_down_predicates.push(predicate); + pushed_down_predicates.push(predicate.clone()); + // Keep full remaining_predicates cause grouping_sets exists + if aggregate.grouping_sets.is_some() { + remaining_predicates.push(predicate); + } } else { remaining_predicates.push(predicate) } @@ -98,6 +102,7 @@ impl Rule for RulePushDownFilterAggregate { let pushed_down_filter = Filter { predicates: pushed_down_predicates, }; + let mut result = if remaining_predicates.is_empty() { SExpr::create_unary( Arc::new(aggregate.into()), diff --git a/src/tests/sqlsmith/src/sql_gen/query.rs b/src/tests/sqlsmith/src/sql_gen/query.rs index 9ef1f0b68c66..31217eb15eb0 100644 --- a/src/tests/sqlsmith/src/sql_gen/query.rs +++ b/src/tests/sqlsmith/src/sql_gen/query.rs @@ -207,9 +207,7 @@ impl<'a, R: Rng> SqlGenerator<'a, R> { nulls_first: Some(self.flip_coin()), })) } - GroupBy::Rollup(group_by) - | GroupBy::Cube(group_by) - | GroupBy::Normal(group_by) => { + GroupBy::Normal(group_by) => { orders.extend(group_by.iter().map(|expr| OrderByExpr { expr: expr.clone(), asc: Some(self.flip_coin()), @@ -351,12 +349,10 @@ impl<'a, R: Rng> SqlGenerator<'a, R> { groupby_items.push(groupby_item); } - match self.rng.gen_range(0..=4) { + match self.rng.gen_range(0..=2) { 0 => Some(GroupBy::Normal(groupby_items)), 1 => Some(GroupBy::All), 2 => Some(GroupBy::GroupingSets(vec![groupby_items])), - 3 => Some(GroupBy::Cube(groupby_items)), - 4 => Some(GroupBy::Rollup(groupby_items)), _ => unreachable!(), } } @@ -370,9 +366,7 @@ impl<'a, R: Rng> SqlGenerator<'a, R> { }; match group_by { - Some(GroupBy::Normal(group_by)) - | Some(GroupBy::Cube(group_by)) - | Some(GroupBy::Rollup(group_by)) => { + Some(GroupBy::Normal(group_by)) => { let ty = self.gen_data_type(); let agg_expr = self.gen_agg_func(&ty); targets.push(SelectTarget::AliasedExpr { diff --git a/tests/sqllogictests/suites/base/03_common/03_0003_select_group_by.test b/tests/sqllogictests/suites/base/03_common/03_0003_select_group_by.test index 5d68e422f37d..9aefcb755bc7 100644 --- a/tests/sqllogictests/suites/base/03_common/03_0003_select_group_by.test +++ b/tests/sqllogictests/suites/base/03_common/03_0003_select_group_by.test @@ -300,3 +300,88 @@ select sum(number + 3 ), number % 3 from numbers(10) group by sum(number + 3 ) statement error (?s)1065.*GROUP BY items can't contain aggregate functions or window functions select sum(number + 3 ), number % 3 from numbers(10) group by 1, 2; + + +# test grouping sets, rollup + +statement ok +CREATE OR REPLACE TABLE sales ( + brand VARCHAR NOT NULL, + segment VARCHAR NOT NULL, + quantity INT NOT NULL +); + +statement ok +INSERT INTO sales (brand, segment, quantity) +VALUES + ('ABC', 'Premium', 100), + ('ABC', 'Basic', 200), + ('XYZ', 'Premium', 100), + ('XYZ', 'Basic', 300); + +query ITTI +SELECT quantity, brand, segment, SUM (quantity) FROM sales GROUP BY brand, GROUPING SETS(segment, quantity) order by 1,2,3; +---- +100 ABC NULL 100 +100 XYZ NULL 100 +200 ABC NULL 200 +300 XYZ NULL 300 +NULL ABC Basic 200 +NULL ABC Premium 100 +NULL XYZ Basic 300 +NULL XYZ Premium 100 + +query ITTI +SELECT quantity, brand, segment, SUM (quantity) FROM sales GROUP BY brand, rollup(segment, quantity) order by 1,2,3; +---- +100 ABC Premium 100 +100 XYZ Premium 100 +200 ABC Basic 200 +300 XYZ Basic 300 +NULL ABC Basic 200 +NULL ABC Premium 100 +NULL ABC NULL 300 +NULL XYZ Basic 300 +NULL XYZ Premium 100 +NULL XYZ NULL 400 + +query ITTI +SELECT quantity, brand, segment, SUM (quantity) FROM sales GROUP BY brand, rollup(quantity) , segment order by 1,2,3; +---- +100 ABC Premium 100 +100 XYZ Premium 100 +200 ABC Basic 200 +300 XYZ Basic 300 +NULL ABC Basic 200 +NULL ABC Premium 100 +NULL XYZ Basic 300 +NULL XYZ Premium 100 + + + + +## results are deduplicated in grouping sets, the results are not standard but accpetable. +query ITTI +SELECT quantity, brand, segment, SUM (quantity) FROM sales GROUP BY brand, rollup(quantity), GROUPING sets(brand, segment, quantity) order by 1,2,3; +---- +100 ABC Premium 100 +100 ABC NULL 100 +100 XYZ Premium 100 +100 XYZ NULL 100 +200 ABC Basic 200 +200 ABC NULL 200 +300 XYZ Basic 300 +300 XYZ NULL 300 +NULL ABC Basic 200 +NULL ABC Premium 100 +NULL ABC NULL 300 +NULL XYZ Basic 300 +NULL XYZ Premium 100 +NULL XYZ NULL 400 + +## filter push down into grouping sets +query ITTI +select * from (SELECT quantity, brand, segment, SUM (quantity) FROM sales GROUP BY rollup( quantity, brand, segment) ) e where e.segment = 'Basic'; +---- +300 XYZ Basic 300 +200 ABC Basic 200 From 4b618fe2db3e0bb713cd255b6ab012dd0413c8cc Mon Sep 17 00:00:00 2001 From: sundy-li <543950155@qq.com> Date: Fri, 29 Nov 2024 15:38:37 +0800 Subject: [PATCH 2/7] update --- src/query/sql/src/planner/binder/aggregate.rs | 8 ++++++-- .../suites/base/03_common/03_0003_select_group_by.test | 5 ++++- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/src/query/sql/src/planner/binder/aggregate.rs b/src/query/sql/src/planner/binder/aggregate.rs index cb9ae84e1090..8957c9c792fe 100644 --- a/src/query/sql/src/planner/binder/aggregate.rs +++ b/src/query/sql/src/planner/binder/aggregate.rs @@ -448,7 +448,11 @@ impl Binder { GroupBy::GroupingSets(sets) => { combined_sets = Self::cartesian_product(combined_sets, sets); } - _other => unreachable!(), + other => { + return Err(ErrorCode::SyntaxException( + "COMBINED GROUP BY does not support {:?}", + )); + } } } Ok(GroupBy::GroupingSets(combined_sets)) @@ -597,7 +601,7 @@ impl Binder { // Because we are not using union all to implement grouping sets // We will remove the duplicated grouping sets here. // For example: SELECT brand, segment, SUM (quantity) FROM sales GROUP BY GROUPING sets(brand, segment), GROUPING sets(brand, segment); - // brand X segment will not appear twice in the result, the results are not standard but accpetable. + // brand X segment will not appear twice in the result, the results are not standard but acceptable. let grouping_sets = grouping_sets.into_iter().unique().collect(); let mut dup_group_items = Vec::with_capacity(agg_info.group_items.len()); for (i, item) in agg_info.group_items.iter().enumerate() { diff --git a/tests/sqllogictests/suites/base/03_common/03_0003_select_group_by.test b/tests/sqllogictests/suites/base/03_common/03_0003_select_group_by.test index 9aefcb755bc7..790254638e1b 100644 --- a/tests/sqllogictests/suites/base/03_common/03_0003_select_group_by.test +++ b/tests/sqllogictests/suites/base/03_common/03_0003_select_group_by.test @@ -360,7 +360,7 @@ NULL XYZ Premium 100 -## results are deduplicated in grouping sets, the results are not standard but accpetable. +## results are deduplicated in grouping sets, the results are not standard but acceptable. query ITTI SELECT quantity, brand, segment, SUM (quantity) FROM sales GROUP BY brand, rollup(quantity), GROUPING sets(brand, segment, quantity) order by 1,2,3; ---- @@ -385,3 +385,6 @@ select * from (SELECT quantity, brand, segment, SUM (quantity) FROM ---- 300 XYZ Basic 300 200 ABC Basic 200 + +statement error +SELECT quantity, brand, segment, SUM (quantity) FROM sales GROUP BY brand, rollup(quantity), all; From 83ddc5dba0b6468eeec699f9aaf27926aec15a47 Mon Sep 17 00:00:00 2001 From: sundy-li <543950155@qq.com> Date: Fri, 29 Nov 2024 15:42:30 +0800 Subject: [PATCH 3/7] update --- src/query/sql/src/planner/binder/aggregate.rs | 6 +++--- src/tests/sqlsmith/src/sql_gen/query.rs | 2 +- .../base/03_common/03_0003_select_group_by.test | 16 ++++++++-------- 3 files changed, 12 insertions(+), 12 deletions(-) diff --git a/src/query/sql/src/planner/binder/aggregate.rs b/src/query/sql/src/planner/binder/aggregate.rs index 8957c9c792fe..3d39811ca0be 100644 --- a/src/query/sql/src/planner/binder/aggregate.rs +++ b/src/query/sql/src/planner/binder/aggregate.rs @@ -449,9 +449,9 @@ impl Binder { combined_sets = Self::cartesian_product(combined_sets, sets); } other => { - return Err(ErrorCode::SyntaxException( - "COMBINED GROUP BY does not support {:?}", - )); + return Err(ErrorCode::SyntaxException(format!( + "COMBINED GROUP BY does not support {other:?}" + ))); } } } diff --git a/src/tests/sqlsmith/src/sql_gen/query.rs b/src/tests/sqlsmith/src/sql_gen/query.rs index 31217eb15eb0..4c7142fbdde0 100644 --- a/src/tests/sqlsmith/src/sql_gen/query.rs +++ b/src/tests/sqlsmith/src/sql_gen/query.rs @@ -403,7 +403,7 @@ impl<'a, R: Rng> SqlGenerator<'a, R> { alias: None, })); } - None => { + _ => { let select_num = self.rng.gen_range(1..=7); for _ in 0..select_num { let ty = self.gen_data_type(); diff --git a/tests/sqllogictests/suites/base/03_common/03_0003_select_group_by.test b/tests/sqllogictests/suites/base/03_common/03_0003_select_group_by.test index 790254638e1b..f9573b89f178 100644 --- a/tests/sqllogictests/suites/base/03_common/03_0003_select_group_by.test +++ b/tests/sqllogictests/suites/base/03_common/03_0003_select_group_by.test @@ -142,7 +142,7 @@ statement ok CREATE TABLE IF NOT EXISTS t_variant(id Int null, var Variant null) Engine = Fuse statement ok -INSERT INTO t_variant VALUES(1, parse_json('{"k":"v"}')), (2, parse_json('{"k":"v"}')), (3, parse_json('"abcd"')), (4, parse_json('"abcd"')), (5, parse_json('12')), (6, parse_json('12')), (7, parse_json('[1,2,3]')), (8, parse_json('[1,2,3]')) +INSERT INTO t_variant VALUES(1, parse_json('{"k":"v"}')), (2, parse_json('{"k":"v"}')), (3, parse_json('"abcd"')), (4, parse_json('"abcd"')), (5, parse_json('12')), (6, parse_json('12')), (7, parse_json('[1,2,3]')), (8, parse_json('[1,2,3]')) query IIT SELECT max(id) as n, min(id), var FROM t_variant GROUP BY var ORDER BY n ASC @@ -159,7 +159,7 @@ statement ok CREATE TABLE IF NOT EXISTS t_array(id Int null, arr Array(Int32)) Engine = Fuse statement ok -INSERT INTO t_array VALUES(1, []), (2, []), (3, [1,2,3]), (4, [1,2,3]), (5, [4,5,6]), (6, [4,5,6]) +INSERT INTO t_array VALUES(1, []), (2, []), (3, [1,2,3]), (4, [1,2,3]), (5, [4,5,6]), (6, [4,5,6]) query I select count() from numbers(10) group by 'ab' @@ -320,7 +320,7 @@ VALUES ('XYZ', 'Basic', 300); query ITTI -SELECT quantity, brand, segment, SUM (quantity) FROM sales GROUP BY brand, GROUPING SETS(segment, quantity) order by 1,2,3; +SELECT quantity, brand, segment, SUM (quantity) FROM sales GROUP BY brand, GROUPING SETS(segment, quantity) order by 1,2,3; ---- 100 ABC NULL 100 100 XYZ NULL 100 @@ -332,7 +332,7 @@ NULL XYZ Basic 300 NULL XYZ Premium 100 query ITTI -SELECT quantity, brand, segment, SUM (quantity) FROM sales GROUP BY brand, rollup(segment, quantity) order by 1,2,3; +SELECT quantity, brand, segment, SUM (quantity) FROM sales GROUP BY brand, rollup(segment, quantity) order by 1,2,3; ---- 100 ABC Premium 100 100 XYZ Premium 100 @@ -346,7 +346,7 @@ NULL XYZ Premium 100 NULL XYZ NULL 400 query ITTI -SELECT quantity, brand, segment, SUM (quantity) FROM sales GROUP BY brand, rollup(quantity) , segment order by 1,2,3; +SELECT quantity, brand, segment, SUM (quantity) FROM sales GROUP BY brand, rollup(quantity) , segment order by 1,2,3; ---- 100 ABC Premium 100 100 XYZ Premium 100 @@ -362,7 +362,7 @@ NULL XYZ Premium 100 ## results are deduplicated in grouping sets, the results are not standard but acceptable. query ITTI -SELECT quantity, brand, segment, SUM (quantity) FROM sales GROUP BY brand, rollup(quantity), GROUPING sets(brand, segment, quantity) order by 1,2,3; +SELECT quantity, brand, segment, SUM (quantity) FROM sales GROUP BY brand, rollup(quantity), GROUPING sets(brand, segment, quantity) order by 1,2,3; ---- 100 ABC Premium 100 100 ABC NULL 100 @@ -381,10 +381,10 @@ NULL XYZ NULL 400 ## filter push down into grouping sets query ITTI -select * from (SELECT quantity, brand, segment, SUM (quantity) FROM sales GROUP BY rollup( quantity, brand, segment) ) e where e.segment = 'Basic'; +select * from (SELECT quantity, brand, segment, SUM (quantity) FROM sales GROUP BY rollup( quantity, brand, segment) ) e where e.segment = 'Basic'; ---- 300 XYZ Basic 300 200 ABC Basic 200 statement error -SELECT quantity, brand, segment, SUM (quantity) FROM sales GROUP BY brand, rollup(quantity), all; +SELECT quantity, brand, segment, SUM (quantity) FROM sales GROUP BY brand, rollup(quantity), all; From 114f62312e29d3119173add403f8ba2b7fb26374 Mon Sep 17 00:00:00 2001 From: sundy-li <543950155@qq.com> Date: Fri, 29 Nov 2024 15:45:47 +0800 Subject: [PATCH 4/7] update --- src/tests/sqlsmith/src/sql_gen/query.rs | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/tests/sqlsmith/src/sql_gen/query.rs b/src/tests/sqlsmith/src/sql_gen/query.rs index 4c7142fbdde0..c931a8cbd2e0 100644 --- a/src/tests/sqlsmith/src/sql_gen/query.rs +++ b/src/tests/sqlsmith/src/sql_gen/query.rs @@ -207,7 +207,9 @@ impl<'a, R: Rng> SqlGenerator<'a, R> { nulls_first: Some(self.flip_coin()), })) } - GroupBy::Normal(group_by) => { + GroupBy::Rollup(group_by) + | GroupBy::Cube(group_by) + | GroupBy::Normal(group_by) => { orders.extend(group_by.iter().map(|expr| OrderByExpr { expr: expr.clone(), asc: Some(self.flip_coin()), @@ -234,6 +236,7 @@ impl<'a, R: Rng> SqlGenerator<'a, R> { orders.push(order_by_expr); } } + _ => unimplemented!(), } } else { for _ in 0..order_nums { From b22de8cdb9623d8b88962e4e422b35ec30071207 Mon Sep 17 00:00:00 2001 From: sundy-li <543950155@qq.com> Date: Sat, 30 Nov 2024 21:07:38 +0800 Subject: [PATCH 5/7] update --- src/query/sql/src/planner/binder/aggregate.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/query/sql/src/planner/binder/aggregate.rs b/src/query/sql/src/planner/binder/aggregate.rs index 3d39811ca0be..fe5a85696197 100644 --- a/src/query/sql/src/planner/binder/aggregate.rs +++ b/src/query/sql/src/planner/binder/aggregate.rs @@ -468,9 +468,9 @@ impl Binder { // Iterate through all possible subsets of the given expressions for i in 0..(1 << n) { let mut subset = Vec::new(); - for j in 0..n { + for (j, expr) in exprs.iter().enumerate() { if (i & (1 << j)) != 0 { - subset.push(exprs[j].clone()); + subset.push(expr.clone()); } } result.push(subset); From 8272d577c0724331f2c87cdd19804a6bd41583dd Mon Sep 17 00:00:00 2001 From: sundyli <543950155@qq.com> Date: Sun, 1 Dec 2024 13:00:37 +0800 Subject: [PATCH 6/7] Update 03_0003_select_group_by.test --- .../suites/base/03_common/03_0003_select_group_by.test | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/sqllogictests/suites/base/03_common/03_0003_select_group_by.test b/tests/sqllogictests/suites/base/03_common/03_0003_select_group_by.test index f9573b89f178..b2e84947356b 100644 --- a/tests/sqllogictests/suites/base/03_common/03_0003_select_group_by.test +++ b/tests/sqllogictests/suites/base/03_common/03_0003_select_group_by.test @@ -381,7 +381,7 @@ NULL XYZ NULL 400 ## filter push down into grouping sets query ITTI -select * from (SELECT quantity, brand, segment, SUM (quantity) FROM sales GROUP BY rollup( quantity, brand, segment) ) e where e.segment = 'Basic'; +select * from (SELECT quantity, brand, segment, SUM (quantity) FROM sales GROUP BY rollup( quantity, brand, segment) ) e where e.segment = 'Basic' order by quantity desc; ---- 300 XYZ Basic 300 200 ABC Basic 200 From 41819595225a5b980c34779f9d0650a9956df8f6 Mon Sep 17 00:00:00 2001 From: sundy-li <543950155@qq.com> Date: Sun, 1 Dec 2024 23:35:02 +0800 Subject: [PATCH 7/7] update --- src/query/sql/src/planner/binder/aggregate.rs | 18 +++--------------- 1 file changed, 3 insertions(+), 15 deletions(-) diff --git a/src/query/sql/src/planner/binder/aggregate.rs b/src/query/sql/src/planner/binder/aggregate.rs index fe5a85696197..9433c6c4ee43 100644 --- a/src/query/sql/src/planner/binder/aggregate.rs +++ b/src/query/sql/src/planner/binder/aggregate.rs @@ -462,21 +462,9 @@ impl Binder { /// Generate GroupingSets from CUBE (expr1, expr2, ...) fn generate_cube_sets(exprs: Vec) -> Vec> { - let mut result = Vec::new(); - let n = exprs.len(); - - // Iterate through all possible subsets of the given expressions - for i in 0..(1 << n) { - let mut subset = Vec::new(); - for (j, expr) in exprs.iter().enumerate() { - if (i & (1 << j)) != 0 { - subset.push(expr.clone()); - } - } - result.push(subset); - } - - result + (0..=exprs.len()) + .flat_map(|count| exprs.clone().into_iter().combinations(count)) + .collect::>() } /// Generate GroupingSets from ROLLUP (expr1, expr2, ...)