Skip to content

Commit

Permalink
Merge pull request #7056 from xudong963/cte3
Browse files Browse the repository at this point in the history
feat(planner): support common table expression(CTE)
  • Loading branch information
mergify[bot] authored Aug 11, 2022
2 parents 4af8546 + 14b4e15 commit b92dd3a
Show file tree
Hide file tree
Showing 4 changed files with 486 additions and 6 deletions.
35 changes: 32 additions & 3 deletions query/src/sql/planner/binder/bind_context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@
// See the License for the specific language governing permissions and
// limitations under the License.

use std::collections::HashMap;
use std::sync::Arc;

use common_ast::ast::TableAlias;
use common_ast::parser::token::Token;
use common_ast::DisplayError;
Expand All @@ -21,10 +24,12 @@ use common_datavalues::DataSchemaRefExt;
use common_datavalues::DataTypeImpl;
use common_exception::ErrorCode;
use common_exception::Result;
use parking_lot::RwLock;

use super::AggregateInfo;
use crate::sql::common::IndexType;
use crate::sql::normalize_identifier;
use crate::sql::optimizer::SExpr;
use crate::sql::plans::Scalar;
use crate::sql::NameResolutionContext;

Expand Down Expand Up @@ -54,7 +59,7 @@ pub enum NameResolutionResult {
}

/// `BindContext` stores all the free variables in a query and tracks the context of binding procedure.
#[derive(Clone, Default, Debug)]
#[derive(Clone, Debug)]
pub struct BindContext {
pub parent: Option<Box<BindContext>>,

Expand All @@ -69,27 +74,45 @@ pub struct BindContext {

/// Format type of query output.
pub format: Option<String>,

pub ctes_map: Arc<RwLock<HashMap<String, CteInfo>>>,
}

#[derive(Clone, Debug)]
pub struct CteInfo {
pub columns_alias: Vec<String>,
pub s_expr: SExpr,
pub bind_context: BindContext,
}

impl BindContext {
pub fn new() -> Self {
Self::default()
Self {
parent: None,
columns: Vec::new(),
aggregate_info: AggregateInfo::default(),
in_grouping: false,
format: None,
ctes_map: Arc::new(RwLock::new(HashMap::new())),
}
}

pub fn with_parent(parent: Box<BindContext>) -> Self {
BindContext {
parent: Some(parent),
parent: Some(parent.clone()),
columns: vec![],
aggregate_info: Default::default(),
in_grouping: false,
format: None,
ctes_map: parent.ctes_map.clone(),
}
}

/// Create a new BindContext with self's parent as its parent
pub fn replace(&self) -> Self {
let mut bind_context = BindContext::new();
bind_context.parent = self.parent.clone();
bind_context.ctes_map = self.ctes_map.clone();
bind_context
}

Expand Down Expand Up @@ -267,3 +290,9 @@ impl BindContext {
DataSchemaRefExt::create(fields)
}
}

impl Default for BindContext {
fn default() -> Self {
BindContext::new()
}
}
21 changes: 21 additions & 0 deletions query/src/sql/planner/binder/select.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ use common_exception::ErrorCode;
use common_exception::Result;

use crate::sql::binder::scalar_common::split_conjunctions;
use crate::sql::binder::CteInfo;
use crate::sql::optimizer::SExpr;
use crate::sql::planner::binder::scalar::ScalarBinder;
use crate::sql::planner::binder::BindContext;
Expand Down Expand Up @@ -149,6 +150,7 @@ impl<'a> Binder {
let mut output_context = BindContext::new();
output_context.parent = from_context.parent;
output_context.columns = from_context.columns;
output_context.ctes_map = from_context.ctes_map;

Ok((s_expr, output_context))
}
Expand Down Expand Up @@ -176,11 +178,30 @@ impl<'a> Binder {
}
}

#[async_recursion]
pub(crate) async fn bind_query(
&mut self,
bind_context: &BindContext,
query: &Query<'_>,
) -> Result<(SExpr, BindContext)> {
if let Some(with) = &query.with {
for cte in with.ctes.iter() {
let table_name = cte.alias.name.name.clone();
if bind_context.ctes_map.read().contains_key(&table_name) {
return Err(ErrorCode::SemanticError(format!(
"duplicate cte {table_name}"
)));
}
let (s_expr, cte_bind_context) = self.bind_query(bind_context, &cte.query).await?;
let cte_info = CteInfo {
columns_alias: cte.alias.columns.iter().map(|c| c.name.clone()).collect(),
s_expr,
bind_context: cte_bind_context.clone(),
};
let mut ctes_map = bind_context.ctes_map.write();
ctes_map.insert(table_name, cte_info);
}
}
let (mut s_expr, mut bind_context) = match query.body {
SetExpr::Select(_) | SetExpr::Query(_) => {
self.bind_set_expr(bind_context, &query.body, &query.order_by)
Expand Down
52 changes: 49 additions & 3 deletions query/src/sql/planner/binder/table.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ use common_ast::ast::Indirection;
use common_ast::ast::SelectStmt;
use common_ast::ast::SelectTarget;
use common_ast::ast::Statement;
use common_ast::ast::TableAlias;
use common_ast::ast::TableReference;
use common_ast::ast::TimeTravelPoint;
use common_ast::parser::parse_sql;
Expand All @@ -35,6 +36,7 @@ use crate::sessions::TableContext;
use crate::sql::binder::scalar::ScalarBinder;
use crate::sql::binder::Binder;
use crate::sql::binder::ColumnBinding;
use crate::sql::binder::CteInfo;
use crate::sql::optimizer::SExpr;
use crate::sql::planner::semantic::normalize_identifier;
use crate::sql::planner::semantic::TypeChecker;
Expand Down Expand Up @@ -94,6 +96,11 @@ impl<'a> Binder {
alias,
travel_point,
} => {
let table_name = normalize_identifier(table, &self.name_resolution_ctx).name;
// Check and bind common table expression
if let Some(cte_info) = bind_context.ctes_map.read().get(&table_name) {
return self.bind_cte(bind_context, &table_name, alias, cte_info);
}
// Get catalog name
let catalog = catalog
.as_ref()
Expand All @@ -106,8 +113,6 @@ impl<'a> Binder {
.map(|ident| normalize_identifier(ident, &self.name_resolution_ctx).name)
.unwrap_or_else(|| self.ctx.get_current_database());

let table = normalize_identifier(table, &self.name_resolution_ctx).name;

let tenant = self.ctx.get_tenant();

let navigation_point = match travel_point {
Expand All @@ -121,7 +126,7 @@ impl<'a> Binder {
tenant.as_str(),
catalog.as_str(),
database.as_str(),
table.as_str(),
table_name.as_str(),
&navigation_point,
)
.await?;
Expand Down Expand Up @@ -233,6 +238,47 @@ impl<'a> Binder {
}
}

fn bind_cte(
&mut self,
bind_context: &BindContext,
table_name: &str,
alias: &Option<TableAlias>,
cte_info: &CteInfo,
) -> Result<(SExpr, BindContext)> {
let mut new_bind_context = bind_context.clone();
new_bind_context.columns = cte_info.bind_context.columns.clone();
let mut cols_alias = cte_info.columns_alias.clone();
if let Some(alias) = alias {
for (idx, col_alias) in alias.columns.iter().enumerate() {
if idx < cte_info.columns_alias.len() {
cols_alias[idx] = col_alias.name.clone();
} else {
cols_alias.push(col_alias.name.clone());
}
}
}
let alias_table_name = alias
.as_ref()
.map(|alias| normalize_identifier(&alias.name, &self.name_resolution_ctx).name)
.unwrap_or_else(|| table_name.to_string());
for column in new_bind_context.columns.iter_mut() {
column.database_name = None;
column.table_name = Some(alias_table_name.clone());
}

if cols_alias.len() > new_bind_context.columns.len() {
return Err(ErrorCode::SemanticError(format!(
"table has {} columns available but {} columns specified",
new_bind_context.columns.len(),
cols_alias.len()
)));
}
for (index, column_name) in cols_alias.iter().enumerate() {
new_bind_context.columns[index].column_name = column_name.clone();
}
Ok((cte_info.s_expr.clone(), new_bind_context))
}

fn bind_base_table(
&mut self,
bind_context: &BindContext,
Expand Down
Loading

0 comments on commit b92dd3a

Please sign in to comment.