diff --git a/query/src/sql/planner/binder/bind_context.rs b/query/src/sql/planner/binder/bind_context.rs index 2f77e58f9368..eabea1b410a6 100644 --- a/query/src/sql/planner/binder/bind_context.rs +++ b/query/src/sql/planner/binder/bind_context.rs @@ -12,6 +12,9 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::collections::HashMap; +use std::sync::Arc; + use common_ast::ast::TableAlias; use common_ast::parser::token::Token; use common_ast::DisplayError; @@ -21,10 +24,12 @@ use common_datavalues::DataSchemaRefExt; use common_datavalues::DataTypeImpl; use common_exception::ErrorCode; use common_exception::Result; +use parking_lot::RwLock; use super::AggregateInfo; use crate::sql::common::IndexType; use crate::sql::normalize_identifier; +use crate::sql::optimizer::SExpr; use crate::sql::plans::Scalar; use crate::sql::NameResolutionContext; @@ -54,7 +59,7 @@ pub enum NameResolutionResult { } /// `BindContext` stores all the free variables in a query and tracks the context of binding procedure. -#[derive(Clone, Default, Debug)] +#[derive(Clone, Debug)] pub struct BindContext { pub parent: Option>, @@ -69,20 +74,37 @@ pub struct BindContext { /// Format type of query output. pub format: Option, + + pub ctes_map: Arc>>, +} + +#[derive(Clone, Debug)] +pub struct CteInfo { + pub columns_alias: Vec, + pub s_expr: SExpr, + pub bind_context: BindContext, } impl BindContext { pub fn new() -> Self { - Self::default() + Self { + parent: None, + columns: Vec::new(), + aggregate_info: AggregateInfo::default(), + in_grouping: false, + format: None, + ctes_map: Arc::new(RwLock::new(HashMap::new())), + } } pub fn with_parent(parent: Box) -> Self { BindContext { - parent: Some(parent), + parent: Some(parent.clone()), columns: vec![], aggregate_info: Default::default(), in_grouping: false, format: None, + ctes_map: parent.ctes_map.clone(), } } @@ -90,6 +112,7 @@ impl BindContext { pub fn replace(&self) -> Self { let mut bind_context = BindContext::new(); bind_context.parent = self.parent.clone(); + bind_context.ctes_map = self.ctes_map.clone(); bind_context } @@ -267,3 +290,9 @@ impl BindContext { DataSchemaRefExt::create(fields) } } + +impl Default for BindContext { + fn default() -> Self { + BindContext::new() + } +} diff --git a/query/src/sql/planner/binder/select.rs b/query/src/sql/planner/binder/select.rs index bf5515455c2a..89cfdc939ef0 100644 --- a/query/src/sql/planner/binder/select.rs +++ b/query/src/sql/planner/binder/select.rs @@ -31,6 +31,7 @@ use common_exception::ErrorCode; use common_exception::Result; use crate::sql::binder::scalar_common::split_conjunctions; +use crate::sql::binder::CteInfo; use crate::sql::optimizer::SExpr; use crate::sql::planner::binder::scalar::ScalarBinder; use crate::sql::planner::binder::BindContext; @@ -149,6 +150,7 @@ impl<'a> Binder { let mut output_context = BindContext::new(); output_context.parent = from_context.parent; output_context.columns = from_context.columns; + output_context.ctes_map = from_context.ctes_map; Ok((s_expr, output_context)) } @@ -176,11 +178,30 @@ impl<'a> Binder { } } + #[async_recursion] pub(crate) async fn bind_query( &mut self, bind_context: &BindContext, query: &Query<'_>, ) -> Result<(SExpr, BindContext)> { + if let Some(with) = &query.with { + for cte in with.ctes.iter() { + let table_name = cte.alias.name.name.clone(); + if bind_context.ctes_map.read().contains_key(&table_name) { + return Err(ErrorCode::SemanticError(format!( + "duplicate cte {table_name}" + ))); + } + let (s_expr, cte_bind_context) = self.bind_query(bind_context, &cte.query).await?; + let cte_info = CteInfo { + columns_alias: cte.alias.columns.iter().map(|c| c.name.clone()).collect(), + s_expr, + bind_context: cte_bind_context.clone(), + }; + let mut ctes_map = bind_context.ctes_map.write(); + ctes_map.insert(table_name, cte_info); + } + } let (mut s_expr, mut bind_context) = match query.body { SetExpr::Select(_) | SetExpr::Query(_) => { self.bind_set_expr(bind_context, &query.body, &query.order_by) diff --git a/query/src/sql/planner/binder/table.rs b/query/src/sql/planner/binder/table.rs index b5174f10e228..01cf8da85d81 100644 --- a/query/src/sql/planner/binder/table.rs +++ b/query/src/sql/planner/binder/table.rs @@ -18,6 +18,7 @@ use common_ast::ast::Indirection; use common_ast::ast::SelectStmt; use common_ast::ast::SelectTarget; use common_ast::ast::Statement; +use common_ast::ast::TableAlias; use common_ast::ast::TableReference; use common_ast::ast::TimeTravelPoint; use common_ast::parser::parse_sql; @@ -35,6 +36,7 @@ use crate::sessions::TableContext; use crate::sql::binder::scalar::ScalarBinder; use crate::sql::binder::Binder; use crate::sql::binder::ColumnBinding; +use crate::sql::binder::CteInfo; use crate::sql::optimizer::SExpr; use crate::sql::planner::semantic::normalize_identifier; use crate::sql::planner::semantic::TypeChecker; @@ -94,6 +96,11 @@ impl<'a> Binder { alias, travel_point, } => { + let table_name = normalize_identifier(table, &self.name_resolution_ctx).name; + // Check and bind common table expression + if let Some(cte_info) = bind_context.ctes_map.read().get(&table_name) { + return self.bind_cte(bind_context, &table_name, alias, cte_info); + } // Get catalog name let catalog = catalog .as_ref() @@ -106,8 +113,6 @@ impl<'a> Binder { .map(|ident| normalize_identifier(ident, &self.name_resolution_ctx).name) .unwrap_or_else(|| self.ctx.get_current_database()); - let table = normalize_identifier(table, &self.name_resolution_ctx).name; - let tenant = self.ctx.get_tenant(); let navigation_point = match travel_point { @@ -121,7 +126,7 @@ impl<'a> Binder { tenant.as_str(), catalog.as_str(), database.as_str(), - table.as_str(), + table_name.as_str(), &navigation_point, ) .await?; @@ -233,6 +238,47 @@ impl<'a> Binder { } } + fn bind_cte( + &mut self, + bind_context: &BindContext, + table_name: &str, + alias: &Option, + cte_info: &CteInfo, + ) -> Result<(SExpr, BindContext)> { + let mut new_bind_context = bind_context.clone(); + new_bind_context.columns = cte_info.bind_context.columns.clone(); + let mut cols_alias = cte_info.columns_alias.clone(); + if let Some(alias) = alias { + for (idx, col_alias) in alias.columns.iter().enumerate() { + if idx < cte_info.columns_alias.len() { + cols_alias[idx] = col_alias.name.clone(); + } else { + cols_alias.push(col_alias.name.clone()); + } + } + } + let alias_table_name = alias + .as_ref() + .map(|alias| normalize_identifier(&alias.name, &self.name_resolution_ctx).name) + .unwrap_or_else(|| table_name.to_string()); + for column in new_bind_context.columns.iter_mut() { + column.database_name = None; + column.table_name = Some(alias_table_name.clone()); + } + + if cols_alias.len() > new_bind_context.columns.len() { + return Err(ErrorCode::SemanticError(format!( + "table has {} columns available but {} columns specified", + new_bind_context.columns.len(), + cols_alias.len() + ))); + } + for (index, column_name) in cols_alias.iter().enumerate() { + new_bind_context.columns[index].column_name = column_name.clone(); + } + Ok((cte_info.s_expr.clone(), new_bind_context)) + } + fn bind_base_table( &mut self, bind_context: &BindContext, diff --git a/tests/logictest/suites/base/15_query/cte.test b/tests/logictest/suites/base/15_query/cte.test new file mode 100644 index 000000000000..a20672ba0038 --- /dev/null +++ b/tests/logictest/suites/base/15_query/cte.test @@ -0,0 +1,384 @@ +statement ok +drop table if exists t1 all; + +statement ok +create table t1(a integer, b integer, c integer, d integer, e integer); + +statement ok +insert into t1(e,c,b,d,a) values(103,102,100,101,104); + +statement ok +insert into t1(a,c,d,e,b) values(107,106,108,109,105); + +statement ok +insert into t1(e,d,b,a,c) values(110,114,112,111,113); + +statement ok +insert into t1(d,c,e,a,b) values(116,119,117,115,118); + +statement query III +with t2(tt) as (select a from t1), t3 as (select * from t1), t4 as (select a from t1 where a > 105) select t2.tt, t3.a, t4.a from t2, t3, t4 where t2.tt > 107 order by t2.tt, t3.a, t4.a; + +---- +111 104 107 +111 104 111 +111 104 115 +111 107 107 +111 107 111 +111 107 115 +111 111 107 +111 111 111 +111 111 115 +111 115 107 +111 115 111 +111 115 115 +115 104 107 +115 104 111 +115 104 115 +115 107 107 +115 107 111 +115 107 115 +115 111 107 +115 111 111 +115 111 115 +115 115 107 +115 115 111 +115 115 115 + +statement query I +with t2(tt) as (select a from t1) select t2.tt from t2 where t2.tt > 105 order by t2.tt; + +---- +107 +111 +115 + +statement query I +with t2 as (select a from t1) select t2.a from t2 where t2.a > 107 order by t2.a; + +---- +111 +115 + +statement query II +with t2(tt) as (select a from t1) SELECT t1.a, t1.b FROM t1 WHERE EXISTS(SELECT * FROM t2 WHERE t2.tt=t1.a) order by t1.a, t1.b; + +---- +104 100 +107 105 +111 112 +115 118 + +statement ok +DROP TABLE IF EXISTS test1; + +statement ok +CREATE TABLE test1(i int, j int); + +statement ok +INSERT INTO test1 VALUES (1, 2), (3, 4); + +statement query I +WITH test1 AS (SELECT * FROM numbers(5)) SELECT * FROM test1; + +---- +0 +1 +2 +3 +4 + +statement query II +WITH test1 AS (SELECT i + 1, j + 1 FROM test1) SELECT * FROM test1; + +---- +2 3 +4 5 + +statement query II +WITH test1 AS (SELECT i + 1, j + 1 FROM test1) SELECT * FROM (SELECT * FROM test1); + +---- +2 3 +4 5 + +statement query III +SELECT * FROM (WITH t1 AS (SELECT to_int32(number) i FROM numbers(5)) SELECT * FROM t1) l INNER JOIN test1 r on l.i = r.i; + +---- +1 1 2 +3 3 4 + +statement ok +DROP TABLE IF EXISTS test1; + +statement query I +WITH test1 AS (SELECT number-1 as n FROM numbers(42)) +SELECT max(n+1)+1 z FROM test1; + +---- +42 + +statement query I +WITH test1 AS (SELECT number-1 as n FROM numbers(4442) order by n limit 100) SELECT max(n) FROM test1 where n=422; + +---- +0 + +statement query I +WITH test1 AS (SELECT number-1 as n FROM numbers(4442) order by n limit 100) SELECT max(n) FROM test1 where n=42; + +---- +42 + +statement ok +drop table if exists with_test; + +statement ok +create table with_test(n int64 null); + +statement ok +insert into with_test select number - 1 from numbers(10000); + +statement query I +WITH test1 AS (SELECT n FROM with_test order by n limit 100) +SELECT max(n) FROM test1 where n=422; + +---- +NULL + +statement query I +WITH test1 AS (SELECT n FROM with_test order by n limit 100) +SELECT max(n) FROM test1 where n=42; + +---- +42 + +statement query I +WITH test1 AS (SELECT n FROM with_test where n = 42 order by n limit 100) +SELECT max(n) FROM test1 where n=42; + +---- +42 + +statement query I +WITH test1 AS (SELECT n FROM with_test where n = 42 or 1=1 order by n limit 100) +SELECT max(n) FROM test1 where n=42; + +---- +42 + +statement query I +WITH test1 AS (SELECT n, null b FROM with_test where 1=1 and n = 42 order by n) +SELECT max(n) FROM test1 where n=45; + +---- +NULL + +statement query I +WITH test1 AS (SELECT n, null b, n+1 m FROM with_test where 1=0 or n = 42 order by n limit 4) +SELECT max(n) m FROM test1 where test1.m=43 having max(n)=42; + +---- +42 + +statement query I +with + test1 as (select n, null b, n+1 m from with_test where n = 42 order by n limit 4), + test2 as (select n + 1 as x, n - 1 as y from test1), + test3 as (select x * y as z from test2) +select z + 1 as q from test3; + +---- +1764 + +statement ok +drop table with_test; + +statement query I +WITH +x AS (SELECT number AS a FROM numbers(10)), +y AS (SELECT number AS a FROM numbers(5)) +SELECT * FROM x WHERE a in (SELECT a FROM y) +ORDER BY a; + +---- +0 +1 +2 +3 +4 + +statement query I +WITH +x AS (SELECT number AS a FROM numbers(10)), +y AS (SELECT number AS a FROM numbers(5)) +SELECT x.a FROM x left JOIN y ON x.a = y.a +ORDER BY a; + +---- +0 +1 +2 +3 +4 +5 +6 +7 +8 +9 + +statement query I +WITH +x AS (SELECT number AS a FROM numbers(10)), +y AS (SELECT number AS a FROM numbers(5)) +SELECT * FROM x JOIN y USING (a) +ORDER BY x.a; + +---- +0 +1 +2 +3 +4 + +statement query I +WITH +x AS (SELECT number AS a FROM numbers(10)), +y AS (SELECT number AS a FROM numbers(5)), +z AS (SELECT * FROM x WHERE a % 2), +w AS (SELECT * FROM y WHERE a > 0) +SELECT * FROM x JOIN y USING (a) WHERE x.a in (SELECT * FROM z) +ORDER BY x.a; + +---- +1 +3 + +statement query I +WITH +x AS (SELECT number AS a FROM numbers(10)), +y AS (SELECT number AS a FROM numbers(5)), +z AS (SELECT * FROM x WHERE a % 2), +w AS (SELECT * FROM y WHERE a > 0) +SELECT x.a FROM x JOIN y USING (a) WHERE x.a in (SELECT * FROM z) +HAVING x.a <= (SELECT max(a) FROM w) +ORDER BY x.a; + +---- +1 +3 + +statement ok +CREATE TABLE cte1(a Int64); + +statement ok +CREATE TABLE cte2(a Int64); + +statement ok +INSERT INTO cte1 SELECT * FROM numbers(10000); + +statement ok +INSERT INTO cte2 SELECT * FROM numbers(5000); + +statement query I +WITH +x AS (SELECT * FROM cte1), +y AS (SELECT * FROM cte2), +z AS (SELECT * FROM x WHERE a % 2 = 1), +w AS (SELECT * FROM y WHERE a > 333) +SELECT max(x.a) +FROM x JOIN y USING (a) +WHERE x.a in (SELECT * FROM z) AND x.a <= (SELECT max(a) FROM w); + +---- +4999 + +statement query I +WITH +x AS (SELECT * FROM cte1), +y AS (SELECT * FROM cte2), +z AS (SELECT * FROM x WHERE a % 3 = 1), +w AS (SELECT * FROM y WHERE a > 333 AND a < 1000) +SELECT count(x.a) +FROM x left JOIN y USING (a) +WHERE x.a in (SELECT * FROM z) AND x.a <= (SELECT max(a) FROM w); + +---- +333 + +statement query I +WITH +x AS (SELECT * FROM cte1), +y AS (SELECT * FROM cte2), +z AS (SELECT * FROM x WHERE a % 3 = 1), +w AS (SELECT * FROM y WHERE a > 333 AND a < 1000) +SELECT count(x.a) +FROM x left JOIN y USING (a) +WHERE x.a in (SELECT * FROM z); + +---- +3333 + +statement query I +WITH +x AS (SELECT a-4000 a FROM cte1 WHERE cte1.a >700), +y AS (SELECT * FROM cte2), +z AS (SELECT * FROM x WHERE a % 3 = 1), +w AS (SELECT * FROM y WHERE a > 333 AND a < 1000) +SELECT count(*) +FROM x left JOIN y USING (a) +WHERE x.a in (SELECT * FROM z); + +---- +2000 + +statement query III +WITH +x AS (SELECT a-4000 a FROM cte1 WHERE cte1.a >700), +y AS (SELECT * FROM cte2), +z AS (SELECT * FROM x WHERE a % 3 = 1), +w AS (SELECT * FROM y WHERE a > 333 AND a < 1000) +SELECT max(a), min(a), count(*) +FROM x +WHERE a in (SELECT * FROM z) AND a <100; + +---- +97 1 33 + +statement query III +WITH +x AS (SELECT a-4000 a FROM cte1 WHERE cte1.a >700), +y AS (SELECT * FROM cte2), +z AS (SELECT * FROM x WHERE a % 3 = 1), +w AS (SELECT * FROM y WHERE a > 333 AND a < 1000) +SELECT max(a), min(a), count(*) FROM x +WHERE a <100; + +---- +99 -3299 3399 + +statement query III +WITH +x AS (SELECT a-4000 a FROM cte1 t WHERE t.a >700), +y AS (SELECT x.a a FROM x left JOIN cte1 USING (a)), +z AS (SELECT * FROM x WHERE a % 3 = 1), +w AS (SELECT * FROM y WHERE a > 333 AND a < 1000) +SELECT max(a), min(a), count(*) +FROM y +WHERE a <100; + +---- +99 -3299 3399 + +statement ok +DROP TABLE cte1; + +statement ok +DROP TABLE cte2; + +statement query I +with it as ( select * from numbers(1) ) select i.number from it as i; + +---- +0