From 3daa60d4d9bfa1ab7419a1e494ed2f28be450588 Mon Sep 17 00:00:00 2001 From: Fedomn Date: Wed, 21 Dec 2022 18:29:56 +0800 Subject: [PATCH 1/4] feat(planner): support logical filter Signed-off-by: Fedomn --- .../expression/bind_column_ref_expression.rs | 8 ++- .../expression_binder/column_alias_binder.rs | 32 ++++++++++++ .../binder/expression_binder/mod.rs | 4 ++ .../binder/expression_binder/where_binder.rs | 49 +++++++++++++++++ src/planner_v2/binder/mod.rs | 2 + .../binder/query_node/bind_select_node.rs | 52 +++++++++++++++++-- tests/slt/filter.slt | 14 +++++ 7 files changed, 156 insertions(+), 5 deletions(-) create mode 100644 src/planner_v2/binder/expression_binder/column_alias_binder.rs create mode 100644 src/planner_v2/binder/expression_binder/mod.rs create mode 100644 src/planner_v2/binder/expression_binder/where_binder.rs diff --git a/src/planner_v2/binder/expression/bind_column_ref_expression.rs b/src/planner_v2/binder/expression/bind_column_ref_expression.rs index b4c79c1..a792b11 100644 --- a/src/planner_v2/binder/expression/bind_column_ref_expression.rs +++ b/src/planner_v2/binder/expression/bind_column_ref_expression.rs @@ -1,8 +1,9 @@ use derive_new::new; use itertools::Itertools; +use log::debug; use super::{BoundExpression, BoundExpressionBase, ColumnBinding}; -use crate::planner_v2::{BindError, ExpressionBinder}; +use crate::planner_v2::{BindError, ExpressionBinder, LOGGING_TARGET}; use crate::types_v2::LogicalType; /// A BoundColumnRef expression represents a ColumnRef expression that was bound to an actual table @@ -64,7 +65,10 @@ impl ExpressionBinder<'_> { result_types.push(bound_col_ref.base.return_type.clone()); Ok(BoundExpression::BoundColumnRefExpression(bound_col_ref)) } else { - println!("current binder context: {:#?}", self.binder.bind_context); + debug!( + target: LOGGING_TARGET, + "Planner binder context: {:#?}", self.binder.bind_context + ); Err(BindError::Internal(format!( "column not found: {}", column_name diff --git a/src/planner_v2/binder/expression_binder/column_alias_binder.rs b/src/planner_v2/binder/expression_binder/column_alias_binder.rs new file mode 100644 index 0000000..a8d8d8d --- /dev/null +++ b/src/planner_v2/binder/expression_binder/column_alias_binder.rs @@ -0,0 +1,32 @@ +use std::collections::HashMap; + +use derive_new::new; +use expression_binder::ExpressionBinder; + +use crate::planner_v2::{expression_binder, BindError, BoundExpression}; + +/// A helper binder for WhereBinder and HavingBinder which support alias as a columnref. +#[derive(new)] +pub struct ColumnAliasBinder<'a> { + pub(crate) original_select_items: &'a [sqlparser::ast::Expr], + pub(crate) alias_map: &'a HashMap, +} + +impl<'a> ColumnAliasBinder<'a> { + pub fn bind_alias( + &self, + expression_binder: &mut ExpressionBinder, + expr: &sqlparser::ast::Expr, + ) -> Result { + if let sqlparser::ast::Expr::Identifier(ident) = expr { + let alias = ident.to_string(); + if let Some(alias_entry) = self.alias_map.get(&alias) { + let expr = self.original_select_items[*alias_entry].clone(); + let bound_expr = + expression_binder.bind_expression(&expr, &mut vec![], &mut vec![])?; + return Ok(bound_expr); + } + } + Err(BindError::Internal(format!("column not found: {}", expr))) + } +} diff --git a/src/planner_v2/binder/expression_binder/mod.rs b/src/planner_v2/binder/expression_binder/mod.rs new file mode 100644 index 0000000..1c143d7 --- /dev/null +++ b/src/planner_v2/binder/expression_binder/mod.rs @@ -0,0 +1,4 @@ +mod column_alias_binder; +mod where_binder; +pub use column_alias_binder::*; +pub use where_binder::*; diff --git a/src/planner_v2/binder/expression_binder/where_binder.rs b/src/planner_v2/binder/expression_binder/where_binder.rs new file mode 100644 index 0000000..8b720c8 --- /dev/null +++ b/src/planner_v2/binder/expression_binder/where_binder.rs @@ -0,0 +1,49 @@ +use derive_new::new; + +use super::ColumnAliasBinder; +use crate::planner_v2::{BindError, BoundExpression, ExpressionBinder}; +use crate::types_v2::LogicalType; + +/// The WHERE binder is responsible for binding an expression within the WHERE clause of a SQL +/// statement +#[derive(new)] +pub struct WhereBinder<'a> { + internal_binder: ExpressionBinder<'a>, + column_alias_binder: ColumnAliasBinder<'a>, +} + +impl<'a> WhereBinder<'a> { + pub fn bind_expression( + &mut self, + expr: &sqlparser::ast::Expr, + result_names: &mut Vec, + result_types: &mut Vec, + ) -> Result { + match expr { + sqlparser::ast::Expr::Identifier(..) | sqlparser::ast::Expr::CompoundIdentifier(..) => { + self.bind_column_ref_expr(expr, result_names, result_types) + } + other => self + .internal_binder + .bind_expression(other, result_names, result_types), + } + } + + fn bind_column_ref_expr( + &mut self, + expr: &sqlparser::ast::Expr, + result_names: &mut Vec, + result_types: &mut Vec, + ) -> Result { + // bind column ref expr first + let bind_res = self + .internal_binder + .bind_expression(expr, result_names, result_types); + if bind_res.is_ok() { + return bind_res; + } + // try to bind as alias + self.column_alias_binder + .bind_alias(&mut self.internal_binder, expr) + } +} diff --git a/src/planner_v2/binder/mod.rs b/src/planner_v2/binder/mod.rs index 2722cf3..e741775 100644 --- a/src/planner_v2/binder/mod.rs +++ b/src/planner_v2/binder/mod.rs @@ -2,6 +2,7 @@ mod bind_context; mod binding; mod errors; mod expression; +mod expression_binder; mod query_node; mod sqlparser_util; mod statement; @@ -13,6 +14,7 @@ pub use bind_context::*; pub use binding::*; pub use errors::*; pub use expression::*; +pub use expression_binder::*; pub use query_node::*; pub use sqlparser_util::*; pub use statement::*; diff --git a/src/planner_v2/binder/query_node/bind_select_node.rs b/src/planner_v2/binder/query_node/bind_select_node.rs index 0aeae63..f7a5e1b 100644 --- a/src/planner_v2/binder/query_node/bind_select_node.rs +++ b/src/planner_v2/binder/query_node/bind_select_node.rs @@ -1,9 +1,11 @@ +use std::collections::HashMap; + use derive_new::new; use sqlparser::ast::{Ident, Query}; use crate::planner_v2::{ - BindError, Binder, BoundExpression, BoundTableRef, ExpressionBinder, SqlparserResolver, - VALUES_LIST_ALIAS, + BindError, Binder, BoundExpression, BoundTableRef, ColumnAliasBinder, ExpressionBinder, + SqlparserResolver, WhereBinder, VALUES_LIST_ALIAS, }; use crate::types_v2::LogicalType; @@ -17,6 +19,13 @@ pub struct BoundSelectNode { pub(crate) select_list: Vec, /// The FROM clause pub(crate) from_table: BoundTableRef, + /// The WHERE clause + #[allow(dead_code)] + pub(crate) where_clause: Option, + /// The original unparsed expressions. This is exported after binding, because the binding + /// might change the expressions (e.g. when a * clause is present) + #[allow(dead_code)] + pub(crate) original_select_items: Option>, /// Index used by the LogicalProjection #[new(default)] pub(crate) projection_index: usize, @@ -57,7 +66,7 @@ impl Binder { .try_collect::>()?; let bound_table_ref = BoundTableRef::BoundExpressionListRef(bound_expression_list_ref); - let node = BoundSelectNode::new(names, types, select_list, bound_table_ref); + let node = BoundSelectNode::new(names, types, select_list, bound_table_ref, None, None); Ok(node) } @@ -65,6 +74,7 @@ impl Binder { &mut self, select: &sqlparser::ast::Select, ) -> Result { + // first bind the FROM table statement let from_table = self.bind_table_ref(select.from.as_slice())?; let mut result_names = vec![]; @@ -75,6 +85,40 @@ impl Binder { return Err(BindError::Internal("empty select list".to_string())); } + // create a mapping of (alias -> index) and a mapping of (Expression -> index) for the + // SELECT list + let mut original_select_items = vec![]; + let mut alias_map = HashMap::new(); + for (idx, item) in new_select_list.iter().enumerate() { + match item { + sqlparser::ast::SelectItem::UnnamedExpr(expr) => { + original_select_items.push(expr.clone()); + } + sqlparser::ast::SelectItem::ExprWithAlias { expr, alias } => { + alias_map.insert(alias.to_string(), idx); + original_select_items.push(expr.clone()); + } + sqlparser::ast::SelectItem::Wildcard(..) + | sqlparser::ast::SelectItem::QualifiedWildcard(..) => { + return Err(BindError::Internal( + "wildcard should be expanded before".to_string(), + )) + } + } + } + + // first visit the WHERE clause + // the WHERE clause happens before the GROUP BY, PROJECTION or HAVING clauses + let where_clause = if let Some(where_expr) = &select.selection { + let column_alias_binder = ColumnAliasBinder::new(&original_select_items, &alias_map); + let mut where_binder = + WhereBinder::new(ExpressionBinder::new(self), column_alias_binder); + let bound_expr = where_binder.bind_expression(where_expr, &mut vec![], &mut vec![])?; + Some(bound_expr) + } else { + None + }; + let select_list = new_select_list .iter() .map(|item| self.bind_select_item(item, &mut result_names, &mut result_types)) @@ -85,6 +129,8 @@ impl Binder { result_types, select_list, from_table, + where_clause, + Some(original_select_items), )) } diff --git a/tests/slt/filter.slt b/tests/slt/filter.slt index 897331e..58122b6 100644 --- a/tests/slt/filter.slt +++ b/tests/slt/filter.slt @@ -17,3 +17,17 @@ select id, first_name from employee where id > 3 or id = 1 ---- 1 Bill 4 Von + + +onlyif sqlrs_v2 +statement ok +create table t1(v1 int, v2 int, v3 int); +insert into t1(v3, v2, v1) values (0, 4, 1), (1, 5, 2); + + +# onlyif sqlrs_v2 +# query III +# select v1, v2 from t1 where v1 >= 1; +# ---- +# 1 4 +# 2 5 From 990a2760c03d7b406da926dec062217a7a87969f Mon Sep 17 00:00:00 2001 From: Fedomn Date: Sun, 25 Dec 2022 20:39:25 +0800 Subject: [PATCH 2/4] feat(planner): enhance logical filter Signed-off-by: Fedomn --- src/common/cast.rs | 18 +++ src/common/mod.rs | 2 + src/execution/expression_executor.rs | 4 +- src/execution/physical_plan/mod.rs | 5 + .../physical_plan/physical_expression_scan.rs | 11 +- .../physical_plan/physical_filter.rs | 33 +++++ src/execution/physical_plan_generator.rs | 1 + src/execution/volcano_executor/dummy_scan.rs | 17 +-- .../volcano_executor/expression_scan.rs | 15 +- src/execution/volcano_executor/filter.rs | 32 +++++ src/execution/volcano_executor/mod.rs | 11 +- .../conjunction/conjunction_function.rs | 8 ++ .../conjunction/default_conjunction.rs | 12 +- .../expression/bind_conjunction_expression.rs | 25 ++++ .../binder/query_node/bind_select_node.rs | 1 + .../binder/query_node/plan_select_node.rs | 14 +- src/planner_v2/binder/sqlparser_util.rs | 21 ++- .../binder/statement/bind_explain_table.rs | 4 +- .../tableref/plan_expression_list_ref.rs | 10 +- src/planner_v2/operator/logical_dummy_scan.rs | 2 +- src/planner_v2/operator/logical_filter.rs | 135 ++++++++++++++++++ src/planner_v2/operator/mod.rs | 11 ++ src/storage_v2/local_storage.rs | 20 +++ src/util/tree_render.rs | 11 ++ tests/slt/filter.slt | 16 ++- tests/slt/pragma.slt | 12 +- tests/slt/table_function.slt | 4 +- 27 files changed, 414 insertions(+), 41 deletions(-) create mode 100644 src/common/cast.rs create mode 100644 src/execution/physical_plan/physical_filter.rs create mode 100644 src/execution/volcano_executor/filter.rs create mode 100644 src/planner_v2/operator/logical_filter.rs diff --git a/src/common/cast.rs b/src/common/cast.rs new file mode 100644 index 0000000..2501fa9 --- /dev/null +++ b/src/common/cast.rs @@ -0,0 +1,18 @@ +use arrow::array::{Array, BooleanArray}; + +use crate::function::FunctionError; + +/// Downcast an Arrow Array to a concrete type +macro_rules! downcast_value { + ($Value:expr, $Type:ident) => {{ + use std::any::type_name; + $Value.as_any().downcast_ref::<$Type>().ok_or_else(|| { + FunctionError::CastError(format!("could not cast value to {}", type_name::<$Type>())) + })? + }}; +} + +/// Downcast ArrayRef to BooleanArray +pub fn as_boolean_array(array: &dyn Array) -> Result<&BooleanArray, FunctionError> { + Ok(downcast_value!(array, BooleanArray)) +} diff --git a/src/common/mod.rs b/src/common/mod.rs index b122613..25d822a 100644 --- a/src/common/mod.rs +++ b/src/common/mod.rs @@ -1,3 +1,5 @@ +mod cast; mod create_info; +pub use cast::*; pub use create_info::*; diff --git a/src/execution/expression_executor.rs b/src/execution/expression_executor.rs index 03f8187..a9f9df4 100644 --- a/src/execution/expression_executor.rs +++ b/src/execution/expression_executor.rs @@ -26,7 +26,9 @@ impl ExpressionExecutor { ) -> Result { Ok(match expr { BoundExpression::BoundColumnRefExpression(_) => todo!(), - BoundExpression::BoundConstantExpression(e) => e.value.to_array(), + BoundExpression::BoundConstantExpression(e) => { + e.value.to_array_of_size(input.num_rows()) + } BoundExpression::BoundReferenceExpression(e) => input.column(e.index).clone(), BoundExpression::BoundCastExpression(e) => { let child_result = Self::execute_internal(&e.child, input)?; diff --git a/src/execution/physical_plan/mod.rs b/src/execution/physical_plan/mod.rs index cfd207d..408d075 100644 --- a/src/execution/physical_plan/mod.rs +++ b/src/execution/physical_plan/mod.rs @@ -3,6 +3,7 @@ mod physical_create_table; mod physical_dummy_scan; mod physical_explain; mod physical_expression_scan; +mod physical_filter; mod physical_insert; mod physical_projection; mod physical_table_scan; @@ -13,6 +14,7 @@ pub use physical_create_table::*; pub use physical_dummy_scan::*; pub use physical_explain::*; pub use physical_expression_scan::*; +pub use physical_filter::*; pub use physical_insert::*; pub use physical_projection::*; pub use physical_table_scan::*; @@ -22,6 +24,7 @@ use crate::types_v2::LogicalType; #[derive(new, Default, Clone)] pub struct PhysicalOperatorBase { pub(crate) children: Vec, + #[allow(dead_code)] /// The types returned by this physical operator pub(crate) types: Vec, } @@ -35,6 +38,7 @@ pub enum PhysicalOperator { PhysicalTableScan(PhysicalTableScan), PhysicalProjection(PhysicalProjection), PhysicalColumnDataScan(PhysicalColumnDataScan), + PhysicalFilter(PhysicalFilter), } impl PhysicalOperator { @@ -47,6 +51,7 @@ impl PhysicalOperator { PhysicalOperator::PhysicalProjection(op) => &op.base.children, PhysicalOperator::PhysicalDummyScan(op) => &op.base.children, PhysicalOperator::PhysicalColumnDataScan(op) => &op.base.children, + PhysicalOperator::PhysicalFilter(op) => &op.base.children, } } } diff --git a/src/execution/physical_plan/physical_expression_scan.rs b/src/execution/physical_plan/physical_expression_scan.rs index 906f635..672755a 100644 --- a/src/execution/physical_plan/physical_expression_scan.rs +++ b/src/execution/physical_plan/physical_expression_scan.rs @@ -8,7 +8,6 @@ use crate::types_v2::LogicalType; /// The PhysicalExpressionScan scans a set of expressions #[derive(new, Clone)] pub struct PhysicalExpressionScan { - #[new(default)] pub(crate) base: PhysicalOperatorBase, /// The types of the expressions pub(crate) expr_types: Vec, @@ -21,7 +20,17 @@ impl PhysicalPlanGenerator { &self, op: LogicalExpressionGet, ) -> PhysicalOperator { + assert!(op.base.children.len() == 1); + let new_children = op + .base + .children + .into_iter() + .map(|p| self.create_plan_internal(p)) + .collect::>(); + let types = op.base.types; + let base = PhysicalOperatorBase::new(new_children, types); PhysicalOperator::PhysicalExpressionScan(PhysicalExpressionScan::new( + base, op.expr_types, op.expressions, )) diff --git a/src/execution/physical_plan/physical_filter.rs b/src/execution/physical_plan/physical_filter.rs new file mode 100644 index 0000000..7860959 --- /dev/null +++ b/src/execution/physical_plan/physical_filter.rs @@ -0,0 +1,33 @@ +use super::{PhysicalOperator, PhysicalOperatorBase}; +use crate::execution::PhysicalPlanGenerator; +use crate::planner_v2::{BoundConjunctionExpression, BoundExpression, LogicalFilter}; + +#[derive(Clone)] +pub struct PhysicalFilter { + pub(crate) base: PhysicalOperatorBase, + pub(crate) expression: BoundExpression, +} + +impl PhysicalFilter { + pub fn new(base: PhysicalOperatorBase, expressions: Vec) -> Self { + let expression = + BoundConjunctionExpression::try_build_and_conjunction_expression(expressions); + Self { base, expression } + } +} + +impl PhysicalPlanGenerator { + pub(crate) fn create_physical_filter(&self, op: LogicalFilter) -> PhysicalOperator { + assert!(op.base.children.len() == 1); + // TODO: refactor this part to common method + let new_children = op + .base + .children + .into_iter() + .map(|p| self.create_plan_internal(p)) + .collect::>(); + let types = op.base.types; + let base = PhysicalOperatorBase::new(new_children, types); + PhysicalOperator::PhysicalFilter(PhysicalFilter::new(base, op.base.expressioins)) + } +} diff --git a/src/execution/physical_plan_generator.rs b/src/execution/physical_plan_generator.rs index 3d1e41c..262ce9b 100644 --- a/src/execution/physical_plan_generator.rs +++ b/src/execution/physical_plan_generator.rs @@ -42,6 +42,7 @@ impl PhysicalPlanGenerator { LogicalOperator::LogicalProjection(op) => self.create_physical_projection(op), LogicalOperator::LogicalDummyScan(op) => self.create_physical_dummy_scan(op), LogicalOperator::LogicalExplain(op) => self.create_physical_explain(op), + LogicalOperator::LogicalFilter(op) => self.create_physical_filter(op), } } } diff --git a/src/execution/volcano_executor/dummy_scan.rs b/src/execution/volcano_executor/dummy_scan.rs index d6b9b63..003a6a0 100644 --- a/src/execution/volcano_executor/dummy_scan.rs +++ b/src/execution/volcano_executor/dummy_scan.rs @@ -1,30 +1,25 @@ use std::collections::HashMap; use std::sync::Arc; -use arrow::datatypes::{Field, Schema, SchemaRef}; +use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use arrow::record_batch::RecordBatch; use derive_new::new; use futures_async_stream::try_stream; use crate::execution::{ExecutionContext, ExecutorError, PhysicalDummyScan}; +use crate::types_v2::ScalarValue; #[derive(new)] pub struct DummyScan { - pub(crate) plan: PhysicalDummyScan, + pub(crate) _plan: PhysicalDummyScan, } impl DummyScan { #[try_stream(boxed, ok = RecordBatch, error = ExecutorError)] pub async fn execute(self, _context: Arc) { - let mut fields = vec![]; - for (idx, ty) in self.plan.base.types.iter().enumerate() { - fields.push(Field::new( - format!("col{}", idx).as_str(), - ty.clone().into(), - true, - )); - } + let fields = vec![Field::new("dummy", DataType::Boolean, true)]; let schema = SchemaRef::new(Schema::new_with_metadata(fields, HashMap::new())); - yield RecordBatch::new_empty(schema.clone()); + let array = ScalarValue::Boolean(Some(true)).to_array(); + yield RecordBatch::try_new(schema.clone(), vec![array])?; } } diff --git a/src/execution/volcano_executor/expression_scan.rs b/src/execution/volcano_executor/expression_scan.rs index 9dfeb95..ccef17e 100644 --- a/src/execution/volcano_executor/expression_scan.rs +++ b/src/execution/volcano_executor/expression_scan.rs @@ -7,12 +7,13 @@ use derive_new::new; use futures_async_stream::try_stream; use crate::execution::{ - ExecutionContext, ExecutorError, ExpressionExecutor, PhysicalExpressionScan, + BoxedExecutor, ExecutionContext, ExecutorError, ExpressionExecutor, PhysicalExpressionScan, }; #[derive(new)] pub struct ExpressionScan { pub(crate) plan: PhysicalExpressionScan, + pub(crate) child: BoxedExecutor, } impl ExpressionScan { @@ -27,10 +28,14 @@ impl ExpressionScan { )); } let schema = SchemaRef::new(Schema::new_with_metadata(fields, HashMap::new())); - let input = RecordBatch::new_empty(schema.clone()); - for exprs in self.plan.expressions.iter() { - let columns = ExpressionExecutor::execute(exprs, &input)?; - yield RecordBatch::try_new(schema.clone(), columns)?; + + #[for_await] + for batch in self.child { + let input = batch?; + for exprs in self.plan.expressions.iter() { + let columns = ExpressionExecutor::execute(exprs, &input)?; + yield RecordBatch::try_new(schema.clone(), columns)?; + } } } } diff --git a/src/execution/volcano_executor/filter.rs b/src/execution/volcano_executor/filter.rs new file mode 100644 index 0000000..1829f2a --- /dev/null +++ b/src/execution/volcano_executor/filter.rs @@ -0,0 +1,32 @@ +use std::sync::Arc; + +use arrow::compute::filter_record_batch; +use arrow::record_batch::RecordBatch; +use derive_new::new; +use futures_async_stream::try_stream; + +use crate::common::as_boolean_array; +use crate::execution::{ + BoxedExecutor, ExecutionContext, ExecutorError, ExpressionExecutor, PhysicalFilter, +}; + +#[derive(new)] +pub struct Filter { + pub(crate) plan: PhysicalFilter, + pub(crate) child: BoxedExecutor, +} + +impl Filter { + #[try_stream(boxed, ok = RecordBatch, error = ExecutorError)] + pub async fn execute(self, _context: Arc) { + let exprs = vec![self.plan.expression]; + + #[for_await] + for batch in self.child { + let batch = batch?; + let eval_mask = ExpressionExecutor::execute(&exprs, &batch)?; + let predicate = as_boolean_array(&eval_mask[0])?; + yield filter_record_batch(&batch, predicate)?; + } + } +} diff --git a/src/execution/volcano_executor/mod.rs b/src/execution/volcano_executor/mod.rs index 621bfd4..9ee0806 100644 --- a/src/execution/volcano_executor/mod.rs +++ b/src/execution/volcano_executor/mod.rs @@ -2,6 +2,7 @@ mod column_data_scan; mod create_table; mod dummy_scan; mod expression_scan; +mod filter; mod insert; mod projection; mod table_scan; @@ -12,6 +13,7 @@ pub use column_data_scan::*; pub use create_table::*; pub use dummy_scan::*; pub use expression_scan::*; +pub use filter::*; use futures::stream::BoxStream; use futures::TryStreamExt; pub use insert::*; @@ -34,7 +36,9 @@ impl VolcanoExecutor { match plan { PhysicalOperator::PhysicalCreateTable(op) => CreateTable::new(op).execute(context), PhysicalOperator::PhysicalExpressionScan(op) => { - ExpressionScan::new(op).execute(context) + let child = op.base.children.first().unwrap().clone(); + let child_executor = self.build(child, context.clone()); + ExpressionScan::new(op, child_executor).execute(context) } PhysicalOperator::PhysicalInsert(op) => { let child = op.base.children.first().unwrap().clone(); @@ -51,6 +55,11 @@ impl VolcanoExecutor { PhysicalOperator::PhysicalColumnDataScan(op) => { ColumnDataScan::new(op).execute(context) } + PhysicalOperator::PhysicalFilter(op) => { + let child = op.base.children.first().unwrap().clone(); + let child_executor = self.build(child, context.clone()); + Filter::new(op, child_executor).execute(context) + } } } diff --git a/src/function/conjunction/conjunction_function.rs b/src/function/conjunction/conjunction_function.rs index e2c7436..36b462e 100644 --- a/src/function/conjunction/conjunction_function.rs +++ b/src/function/conjunction/conjunction_function.rs @@ -1,14 +1,22 @@ use arrow::array::ArrayRef; use derive_new::new; +use strum_macros::AsRefStr; use crate::function::FunctionError; pub type ConjunctionFunc = fn(left: &ArrayRef, right: &ArrayRef) -> Result; +#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord, AsRefStr)] +pub enum ConjunctionType { + And, + Or, +} + #[derive(new, Clone)] pub struct ConjunctionFunction { pub(crate) name: String, pub(crate) function: ConjunctionFunc, + pub(crate) ty: ConjunctionType, } impl std::fmt::Debug for ConjunctionFunction { diff --git a/src/function/conjunction/default_conjunction.rs b/src/function/conjunction/default_conjunction.rs index 93ad755..4cf1303 100644 --- a/src/function/conjunction/default_conjunction.rs +++ b/src/function/conjunction/default_conjunction.rs @@ -5,7 +5,7 @@ use arrow::compute::{and_kleene, or_kleene}; use arrow::datatypes::DataType; use sqlparser::ast::BinaryOperator; -use super::{ConjunctionFunc, ConjunctionFunction}; +use super::{ConjunctionFunc, ConjunctionFunction, ConjunctionType}; use crate::function::FunctionError; pub struct DefaultConjunctionFunctions; @@ -43,10 +43,10 @@ impl DefaultConjunctionFunctions { fn get_conjunction_function_internal( op: &BinaryOperator, - ) -> Result<(&str, ConjunctionFunc), FunctionError> { + ) -> Result<(ConjunctionType, ConjunctionFunc), FunctionError> { Ok(match op { - BinaryOperator::And => ("and", Self::default_and_function), - BinaryOperator::Or => ("or", Self::default_or_function), + BinaryOperator::And => (ConjunctionType::And, Self::default_and_function), + BinaryOperator::Or => (ConjunctionType::Or, Self::default_or_function), _ => { return Err(FunctionError::ConjunctionError(format!( "Unsupported conjunction operator {:?}", @@ -59,7 +59,7 @@ impl DefaultConjunctionFunctions { pub fn get_conjunction_function( op: &BinaryOperator, ) -> Result { - let (name, func) = Self::get_conjunction_function_internal(op)?; - Ok(ConjunctionFunction::new(name.to_string(), func)) + let (ty, func) = Self::get_conjunction_function_internal(op)?; + Ok(ConjunctionFunction::new(ty.as_ref().to_string(), func, ty)) } } diff --git a/src/planner_v2/binder/expression/bind_conjunction_expression.rs b/src/planner_v2/binder/expression/bind_conjunction_expression.rs index 6342d0d..0f0a6a1 100644 --- a/src/planner_v2/binder/expression/bind_conjunction_expression.rs +++ b/src/planner_v2/binder/expression/bind_conjunction_expression.rs @@ -12,6 +12,31 @@ pub struct BoundConjunctionExpression { pub(crate) children: Vec, } +impl BoundConjunctionExpression { + /// If expressions count larger than 1, build a and conjunction expression, otherwise return the + /// first expression + pub fn try_build_and_conjunction_expression( + expressions: Vec, + ) -> BoundExpression { + assert!(!expressions.is_empty()); + // conjuct expression with and make only one expression + if expressions.len() > 1 { + let base = BoundExpressionBase::new("".to_string(), LogicalType::Boolean); + let and_func = DefaultConjunctionFunctions::get_conjunction_function( + &sqlparser::ast::BinaryOperator::And, + ) + .unwrap(); + BoundExpression::BoundConjunctionExpression(BoundConjunctionExpression::new( + base, + and_func, + expressions, + )) + } else { + expressions[0].clone() + } + } +} + impl ExpressionBinder<'_> { pub fn bind_conjunction_expression( &mut self, diff --git a/src/planner_v2/binder/query_node/bind_select_node.rs b/src/planner_v2/binder/query_node/bind_select_node.rs index f7a5e1b..71aeca3 100644 --- a/src/planner_v2/binder/query_node/bind_select_node.rs +++ b/src/planner_v2/binder/query_node/bind_select_node.rs @@ -113,6 +113,7 @@ impl Binder { let column_alias_binder = ColumnAliasBinder::new(&original_select_items, &alias_map); let mut where_binder = WhereBinder::new(ExpressionBinder::new(self), column_alias_binder); + // FIXME: where_binder not work with ExpressionBinder let bound_expr = where_binder.bind_expression(where_expr, &mut vec![], &mut vec![])?; Some(bound_expr) } else { diff --git a/src/planner_v2/binder/query_node/plan_select_node.rs b/src/planner_v2/binder/query_node/plan_select_node.rs index 3c50f76..31e1d70 100644 --- a/src/planner_v2/binder/query_node/plan_select_node.rs +++ b/src/planner_v2/binder/query_node/plan_select_node.rs @@ -3,8 +3,8 @@ use crate::planner_v2::BoundTableRef::{ BoundBaseTableRef, BoundDummyTableRef, BoundExpressionListRef, BoundTableFunction, }; use crate::planner_v2::{ - BindError, Binder, BoundCastExpression, BoundStatement, LogicalOperator, LogicalOperatorBase, - LogicalProjection, + BindError, Binder, BoundCastExpression, BoundStatement, LogicalFilter, LogicalOperator, + LogicalOperatorBase, LogicalProjection, }; use crate::types_v2::LogicalType; @@ -13,7 +13,7 @@ impl Binder { &mut self, node: BoundSelectNode, ) -> Result { - let root = match node.from_table { + let mut root = match node.from_table { BoundExpressionListRef(bound_ref) => { self.create_plan_for_expression_list_ref(bound_ref)? } @@ -22,6 +22,14 @@ impl Binder { BoundTableFunction(bound_func) => self.create_plan_for_table_function(*bound_func)?, }; + if let Some(where_clause) = node.where_clause { + root = LogicalOperator::LogicalFilter(LogicalFilter::new(LogicalOperatorBase::new( + vec![root], + vec![where_clause], + vec![], + ))); + } + let root = LogicalOperator::LogicalProjection(LogicalProjection::new( LogicalOperatorBase::new(vec![root], node.select_list, node.types.clone()), node.projection_index, diff --git a/src/planner_v2/binder/sqlparser_util.rs b/src/planner_v2/binder/sqlparser_util.rs index b2fe757..4e0b556 100644 --- a/src/planner_v2/binder/sqlparser_util.rs +++ b/src/planner_v2/binder/sqlparser_util.rs @@ -1,6 +1,6 @@ use sqlparser::ast::{ - ColumnDef, Expr, Ident, ObjectName, Query, Select, SelectItem, SetExpr, TableFactor, - TableWithJoins, WildcardAdditionalOptions, + BinaryOperator, ColumnDef, Expr, Ident, ObjectName, Query, Select, SelectItem, SetExpr, + TableFactor, TableWithJoins, Value, WildcardAdditionalOptions, }; use super::BindError; @@ -35,6 +35,7 @@ impl SqlparserResolver { pub struct SqlparserSelectBuilder { projection: Vec, from: Vec, + selection: Option, } impl SqlparserSelectBuilder { @@ -91,15 +92,25 @@ impl SqlparserSelectBuilder { self } - pub fn build(self) -> sqlparser::ast::Select { - sqlparser::ast::Select { + pub fn selection_col_eq_string(mut self, col_name: &str, eq_str: &str) -> Self { + let selection = Expr::BinaryOp { + left: Box::new(Expr::Identifier(Ident::new(col_name))), + op: BinaryOperator::Eq, + right: Box::new(Expr::Value(Value::SingleQuotedString(eq_str.to_string()))), + }; + self.selection = Some(selection); + self + } + + pub fn build(self) -> Select { + Select { distinct: false, top: None, projection: self.projection, into: None, from: self.from, lateral_views: vec![], - selection: None, + selection: self.selection, group_by: vec![], cluster_by: vec![], distribute_by: vec![], diff --git a/src/planner_v2/binder/statement/bind_explain_table.rs b/src/planner_v2/binder/statement/bind_explain_table.rs index 4a19f80..024967c 100644 --- a/src/planner_v2/binder/statement/bind_explain_table.rs +++ b/src/planner_v2/binder/statement/bind_explain_table.rs @@ -18,11 +18,11 @@ impl Binder { "Only support describe table statement".to_string(), )); } - let (_, _table_name) = SqlparserResolver::object_name_to_schema_table(table_name)?; - // FIXME: support filter table_name + let (_, table_name) = SqlparserResolver::object_name_to_schema_table(table_name)?; let select = SqlparserSelectBuilder::default() .projection_wildcard() .from_table_function("sqlrs_columns") + .selection_col_eq_string("table_name", table_name.as_str()) .build(); let query = SqlparserQueryBuilder::new_from_select(select).build(); let node = self.bind_select_node(&query)?; diff --git a/src/planner_v2/binder/tableref/plan_expression_list_ref.rs b/src/planner_v2/binder/tableref/plan_expression_list_ref.rs index d3cb934..57012f9 100644 --- a/src/planner_v2/binder/tableref/plan_expression_list_ref.rs +++ b/src/planner_v2/binder/tableref/plan_expression_list_ref.rs @@ -1,6 +1,6 @@ use super::BoundExpressionListRef; use crate::planner_v2::{ - BindError, Binder, LogicalExpressionGet, LogicalOperator, LogicalOperatorBase, + BindError, Binder, LogicalDummyScan, LogicalExpressionGet, LogicalOperator, LogicalOperatorBase, }; impl Binder { @@ -9,7 +9,13 @@ impl Binder { bound_ref: BoundExpressionListRef, ) -> Result { let table_idx = bound_ref.bind_index; - let base = LogicalOperatorBase::default(); + let base = LogicalOperatorBase::new( + vec![LogicalOperator::LogicalDummyScan(LogicalDummyScan::new( + self.generate_table_index(), + ))], + vec![], + vec![], + ); let plan = LogicalExpressionGet::new(base, table_idx, bound_ref.types, bound_ref.values); Ok(LogicalOperator::LogicalExpressionGet(plan)) } diff --git a/src/planner_v2/operator/logical_dummy_scan.rs b/src/planner_v2/operator/logical_dummy_scan.rs index 1c293b1..bb76bf5 100644 --- a/src/planner_v2/operator/logical_dummy_scan.rs +++ b/src/planner_v2/operator/logical_dummy_scan.rs @@ -2,7 +2,7 @@ use derive_new::new; use super::LogicalOperatorBase; -/// LogicalDummyScan represents a dummy scan returning nothing. +/// LogicalDummyScan represents a dummy scan returning a single row. #[derive(new, Debug, Clone)] pub struct LogicalDummyScan { #[new(default)] diff --git a/src/planner_v2/operator/logical_filter.rs b/src/planner_v2/operator/logical_filter.rs new file mode 100644 index 0000000..6842cd8 --- /dev/null +++ b/src/planner_v2/operator/logical_filter.rs @@ -0,0 +1,135 @@ +use super::LogicalOperatorBase; +use crate::function::ConjunctionType; +use crate::planner_v2::BoundExpression; + +#[derive(Debug, Clone)] +pub struct LogicalFilter { + pub(crate) base: LogicalOperatorBase, +} + +impl LogicalFilter { + fn split_predicates_internal(expr: BoundExpression) -> Vec { + match expr { + BoundExpression::BoundConjunctionExpression(e) => { + if e.function.ty == ConjunctionType::And { + let mut res = vec![]; + for child in e.children.into_iter() { + res.extend(Self::split_predicates_internal(child)); + } + res + } else { + vec![BoundExpression::BoundConjunctionExpression(e)] + } + } + _ => vec![expr], + } + } + + // Split the predicates separated by AND statements + // These are the predicates that are safe to push down because all of them MUST be true + fn split_predicates(mut self) -> Self { + let mut new_expressions = vec![]; + for expr in self.base.expressioins.into_iter() { + let split_res = Self::split_predicates_internal(expr); + new_expressions.extend(split_res); + } + self.base.expressioins = new_expressions; + self + } + + pub fn new(base: LogicalOperatorBase) -> Self { + let op = Self { base }; + op.split_predicates() + } +} + +#[cfg(test)] +mod tests { + use sqlparser::ast::BinaryOperator; + + use super::*; + use crate::function::{DefaultComparisonFunctions, DefaultConjunctionFunctions}; + use crate::planner_v2::{ + BindError, BoundColumnRefExpression, BoundComparisonExpression, BoundConjunctionExpression, + BoundConstantExpression, BoundExpression, BoundExpressionBase, ColumnBinding, + }; + use crate::types_v2::{LogicalType, ScalarValue}; + + fn build_col_expr(name: String) -> BoundExpression { + BoundExpression::BoundColumnRefExpression(BoundColumnRefExpression::new( + BoundExpressionBase::new(name, LogicalType::Integer), + ColumnBinding::new(1, 1), + 0, + )) + } + + fn build_eq_expr( + left: BoundExpression, + right: BoundExpression, + ) -> Result { + let eq_func = DefaultComparisonFunctions::get_comparison_function( + &BinaryOperator::Eq, + &LogicalType::Integer, + )?; + let base = BoundExpressionBase::new("".to_string(), LogicalType::Boolean); + Ok(BoundExpression::BoundComparisonExpression( + BoundComparisonExpression::new(base, Box::new(left), Box::new(right), eq_func), + )) + } + + fn build_and_expr( + left: BoundExpression, + right: BoundExpression, + ) -> Result { + let base = BoundExpressionBase::new("".to_string(), LogicalType::Boolean); + let and_func = DefaultConjunctionFunctions::get_conjunction_function(&BinaryOperator::And)?; + Ok(BoundExpression::BoundConjunctionExpression( + BoundConjunctionExpression::new(base, and_func, vec![left, right]), + )) + } + + fn build_or_expr( + left: BoundExpression, + right: BoundExpression, + ) -> Result { + let base = BoundExpressionBase::new("".to_string(), LogicalType::Boolean); + let and_func = DefaultConjunctionFunctions::get_conjunction_function(&BinaryOperator::Or)?; + Ok(BoundExpression::BoundConjunctionExpression( + BoundConjunctionExpression::new(base, and_func, vec![left, right]), + )) + } + + #[test] + fn test_logical_filter_split_predicates() -> Result<(), BindError> { + let v1 = BoundExpression::BoundConstantExpression(BoundConstantExpression::new( + BoundExpressionBase::new("".to_string(), LogicalType::Integer), + ScalarValue::Int32(Some(1)), + )); + let col1 = build_col_expr("col1".to_string()); + let col2 = build_col_expr("col2".to_string()); + let col3 = build_col_expr("col3".to_string()); + let col4 = build_col_expr("col4".to_string()); + let expr1 = build_eq_expr(col1, v1.clone())?; + let expr2 = build_eq_expr(col2, v1.clone())?; + let expr3 = build_eq_expr(col3, v1.clone())?; + let expr4 = build_eq_expr(col4, v1)?; + + // And(And(Col1=1, Col2=1), And(Col3=1, Col4=1)) + let and_expr1 = build_and_expr(expr1.clone(), expr2.clone())?; + let and_expr2 = build_and_expr(expr3.clone(), expr4.clone())?; + let case1 = build_and_expr(and_expr1, and_expr2)?; + let base = LogicalOperatorBase::new(vec![], vec![case1], vec![]); + let op = LogicalFilter::new(base); + assert_eq!(op.base.expressioins.len(), 4); + + // And(And(Col1=1, Col2=1), Or(Col3=1, Col4=1)) + let and_expr1 = build_and_expr(expr1, expr2)?; + let and_expr2 = build_or_expr(expr3, expr4)?; + let case2 = build_and_expr(and_expr1, and_expr2)?; + let base = LogicalOperatorBase::new(vec![], vec![case2], vec![]); + let op = LogicalFilter::new(base); + assert_eq!(op.base.expressioins.len(), 3); + + Ok(()) + } +} diff --git a/src/planner_v2/operator/mod.rs b/src/planner_v2/operator/mod.rs index 1c1b062..eba85c5 100644 --- a/src/planner_v2/operator/mod.rs +++ b/src/planner_v2/operator/mod.rs @@ -4,6 +4,7 @@ mod logical_create_table; mod logical_dummy_scan; mod logical_explain; mod logical_expression_get; +mod logical_filter; mod logical_get; mod logical_insert; mod logical_projection; @@ -12,6 +13,7 @@ pub use logical_create_table::*; pub use logical_dummy_scan::*; pub use logical_explain::*; pub use logical_expression_get::*; +pub use logical_filter::*; pub use logical_get::*; pub use logical_insert::*; pub use logical_projection::*; @@ -36,6 +38,7 @@ pub enum LogicalOperator { LogicalGet(LogicalGet), LogicalProjection(LogicalProjection), LogicalExplain(LogicalExplain), + LogicalFilter(LogicalFilter), } impl LogicalOperator { @@ -48,6 +51,7 @@ impl LogicalOperator { LogicalOperator::LogicalProjection(op) => &mut op.base.children, LogicalOperator::LogicalDummyScan(op) => &mut op.base.children, LogicalOperator::LogicalExplain(op) => &mut op.base.children, + LogicalOperator::LogicalFilter(op) => &mut op.base.children, } } @@ -60,6 +64,7 @@ impl LogicalOperator { LogicalOperator::LogicalProjection(op) => &op.base.children, LogicalOperator::LogicalDummyScan(op) => &op.base.children, LogicalOperator::LogicalExplain(op) => &op.base.children, + LogicalOperator::LogicalFilter(op) => &op.base.children, } } @@ -72,6 +77,7 @@ impl LogicalOperator { LogicalOperator::LogicalProjection(op) => &mut op.base.expressioins, LogicalOperator::LogicalDummyScan(op) => &mut op.base.expressioins, LogicalOperator::LogicalExplain(op) => &mut op.base.expressioins, + LogicalOperator::LogicalFilter(op) => &mut op.base.expressioins, } } @@ -84,6 +90,7 @@ impl LogicalOperator { LogicalOperator::LogicalProjection(op) => &op.base.types, LogicalOperator::LogicalDummyScan(op) => &op.base.types, LogicalOperator::LogicalExplain(op) => &op.base.types, + LogicalOperator::LogicalFilter(op) => &op.base.types, } } @@ -105,6 +112,7 @@ impl LogicalOperator { LogicalOperator::LogicalExplain(_) => { vec![ColumnBinding::new(0, 0), ColumnBinding::new(0, 1)] } + LogicalOperator::LogicalFilter(op) => op.base.children[0].get_column_bindings(), } } @@ -134,6 +142,9 @@ impl LogicalOperator { LogicalOperator::LogicalExplain(op) => { op.base.types = vec![LogicalType::Varchar, LogicalType::Varchar]; } + LogicalOperator::LogicalFilter(op) => { + op.base.types = op.base.children[0].types().to_vec(); + } } } diff --git a/src/storage_v2/local_storage.rs b/src/storage_v2/local_storage.rs index bad5ab6..670e51c 100644 --- a/src/storage_v2/local_storage.rs +++ b/src/storage_v2/local_storage.rs @@ -1,6 +1,7 @@ use std::collections::HashMap; use std::sync::Arc; +use arrow::compute; use arrow::record_batch::RecordBatch; use derive_new::new; @@ -95,6 +96,25 @@ impl LocalTableStorage { } fn append(&mut self, batch: RecordBatch) { + if let Some(last_batch) = self.data.last_mut() { + let remaing_count = 1024 - last_batch.num_rows(); + if remaing_count > 0 { + // need to merge current batch into last unfull batch + let appended_batch = batch.slice(0, remaing_count.min(batch.num_rows())); + *last_batch = compute::concat_batches( + &last_batch.schema(), + &[last_batch.clone(), appended_batch], + ) + .unwrap(); + + if batch.num_rows() > remaing_count { + // need a new batch + let new_batch = batch.slice(remaing_count, batch.num_rows()); + self.data.push(new_batch); + } + return; + } + } self.data.push(batch); } diff --git a/src/util/tree_render.rs b/src/util/tree_render.rs index e604315..2dfd787 100644 --- a/src/util/tree_render.rs +++ b/src/util/tree_render.rs @@ -113,6 +113,16 @@ impl TreeRender { format!("LogicalProjection: {}", exprs) } LogicalOperator::LogicalExplain(_) => "LogicalExplain".to_string(), + LogicalOperator::LogicalFilter(op) => { + let exprs = op + .base + .expressioins + .iter() + .map(Self::bound_expression_to_string) + .collect::>() + .join(", "); + format!("LogicalFilter: {}", exprs) + } } } @@ -143,6 +153,7 @@ impl TreeRender { PhysicalOperator::PhysicalTableScan(_) => "PhysicalTableScan".to_string(), PhysicalOperator::PhysicalProjection(_) => "PhysicalProjection".to_string(), PhysicalOperator::PhysicalColumnDataScan(_) => "PhysicalColumnDataScan".to_string(), + PhysicalOperator::PhysicalFilter(_) => "PhysicalFilter".to_string(), } } diff --git a/tests/slt/filter.slt b/tests/slt/filter.slt index 58122b6..a067dea 100644 --- a/tests/slt/filter.slt +++ b/tests/slt/filter.slt @@ -25,9 +25,23 @@ create table t1(v1 int, v2 int, v3 int); insert into t1(v3, v2, v1) values (0, 4, 1), (1, 5, 2); +onlyif sqlrs_v2 +query III +select v1, v2 from t1 where v1 >= 2 and v1 > v3; +---- +2 5 + +# filter alias +# onlyif sqlrs_v2 +# query III +# select v1+1 as a from t1 where a >= 1; +# ---- +# 1 4 +# 2 5 + # onlyif sqlrs_v2 # query III -# select v1, v2 from t1 where v1 >= 1; +# select v1+1 as a from t1 where a == a; # ---- # 1 4 # 2 5 diff --git a/tests/slt/pragma.slt b/tests/slt/pragma.slt index f046148..cc066cc 100644 --- a/tests/slt/pragma.slt +++ b/tests/slt/pragma.slt @@ -1,9 +1,19 @@ onlyif sqlrs_v2 statement ok create table t1(v1 int, v2 int, v3 int); +create table t2(v1 varchar, v2 boolean, v3 bigint); onlyif sqlrs_v2 -query II +query II rowsort show tables ---- main t1 +main t2 + + +onlyif sqlrs_v2 +query II +describe t1 +---- +t1 [v1, v2, v3] [Integer, Integer, Integer] + diff --git a/tests/slt/table_function.slt b/tests/slt/table_function.slt index 4211f1e..4ecce85 100644 --- a/tests/slt/table_function.slt +++ b/tests/slt/table_function.slt @@ -1,15 +1,17 @@ onlyif sqlrs_v2 statement ok create table t1(v1 int, v2 int, v3 int); +create table t2(v1 varchar, v2 boolean, v3 bigint); onlyif sqlrs_v2 query III select schema_name, schema_oid, table_name from sqlrs_tables(); ---- main 1 t1 +main 1 t2 onlyif sqlrs_v2 query III -select * from sqlrs_columns(); +select * from sqlrs_columns() where table_name = 't1'; ---- t1 [v1, v2, v3] [Integer, Integer, Integer] From 9bb6d73e6b1196c96a4ba758a99698fb478d7270 Mon Sep 17 00:00:00 2001 From: Fedomn Date: Mon, 26 Dec 2022 15:34:23 +0800 Subject: [PATCH 3/4] feat(planner): merge where binder logic into expression binder Signed-off-by: Fedomn --- .../expression/bind_column_ref_expression.rs | 65 +++++++++++++------ .../expression_binder/column_alias_binder.rs | 29 +-------- .../binder/expression_binder/mod.rs | 2 - .../binder/expression_binder/where_binder.rs | 49 -------------- .../binder/query_node/bind_select_node.rs | 16 ++--- src/planner_v2/binder/sqlparser_util.rs | 18 +++++ src/planner_v2/expression_binder.rs | 8 ++- tests/slt/filter.slt | 26 ++++---- tests/slt/table_function.slt | 2 +- 9 files changed, 96 insertions(+), 119 deletions(-) delete mode 100644 src/planner_v2/binder/expression_binder/where_binder.rs diff --git a/src/planner_v2/binder/expression/bind_column_ref_expression.rs b/src/planner_v2/binder/expression/bind_column_ref_expression.rs index a792b11..e84122a 100644 --- a/src/planner_v2/binder/expression/bind_column_ref_expression.rs +++ b/src/planner_v2/binder/expression/bind_column_ref_expression.rs @@ -1,9 +1,8 @@ use derive_new::new; -use itertools::Itertools; use log::debug; use super::{BoundExpression, BoundExpressionBase, ColumnBinding}; -use crate::planner_v2::{BindError, ExpressionBinder, LOGGING_TARGET}; +use crate::planner_v2::{BindError, ExpressionBinder, SqlparserResolver, LOGGING_TARGET}; use crate::types_v2::LogicalType; /// A BoundColumnRef expression represents a ColumnRef expression that was bound to an actual table @@ -24,34 +23,28 @@ impl ExpressionBinder<'_> { /// qualify column name with existing table name fn qualify_column_name( &self, - table_name: Option<&String>, - column_name: &String, + table_name: Option, + column_name: String, ) -> Result<(String, String), BindError> { if let Some(table_name) = table_name { - Ok((table_name.to_string(), column_name.to_string())) + Ok((table_name, column_name)) } else { - let table_name = self.binder.bind_context.get_matching_binding(column_name)?; - Ok((table_name, column_name.to_string())) + let table_name = self + .binder + .bind_context + .get_matching_binding(&column_name)?; + Ok((table_name, column_name)) } } - pub fn bind_column_ref_expr( + fn bind_column_ref_expr_internal( &mut self, idents: &[sqlparser::ast::Ident], result_names: &mut Vec, result_types: &mut Vec, ) -> Result { - let idents = idents - .iter() - .map(|ident| ident.value.to_lowercase()) - .collect_vec(); - - let (_schema_name, table_name, column_name) = match idents.as_slice() { - [column] => (None, None, column), - [table, column] => (None, Some(table), column), - [schema, table, column] => (Some(schema), Some(table), column), - _ => return Err(BindError::UnsupportedExpr(format!("{:?}", idents))), - }; + let (_schema_name, table_name, column_name) = + SqlparserResolver::resolve_expr_idents(idents)?; let (table_name, column_name) = self.qualify_column_name(table_name, column_name)?; @@ -75,4 +68,38 @@ impl ExpressionBinder<'_> { ))) } } + + fn bind_column_ref_expr_as_alias( + &mut self, + idents: &[sqlparser::ast::Ident], + result_names: &mut Vec, + result_types: &mut Vec, + ) -> Result { + let (_, _, column_name) = SqlparserResolver::resolve_expr_idents(idents)?; + if let Some(column_alias_data) = &self.column_alias_data { + if let Some(alias_entry) = column_alias_data.alias_map.get(&column_name) { + let expr = column_alias_data.original_select_items[*alias_entry].clone(); + return self.bind_expression(&expr, result_names, result_types); + } + } + Err(BindError::Internal(format!( + "column not found: {}", + column_name + ))) + } + + pub fn bind_column_ref_expr( + &mut self, + idents: &[sqlparser::ast::Ident], + result_names: &mut Vec, + result_types: &mut Vec, + ) -> Result { + // bind table column ref expr first + let bind_res = self.bind_column_ref_expr_internal(idents, result_names, result_types); + if bind_res.is_ok() { + return bind_res; + } + // try to bind as alias + self.bind_column_ref_expr_as_alias(idents, result_names, result_types) + } } diff --git a/src/planner_v2/binder/expression_binder/column_alias_binder.rs b/src/planner_v2/binder/expression_binder/column_alias_binder.rs index a8d8d8d..38a3ac2 100644 --- a/src/planner_v2/binder/expression_binder/column_alias_binder.rs +++ b/src/planner_v2/binder/expression_binder/column_alias_binder.rs @@ -1,32 +1,9 @@ use std::collections::HashMap; use derive_new::new; -use expression_binder::ExpressionBinder; -use crate::planner_v2::{expression_binder, BindError, BoundExpression}; - -/// A helper binder for WhereBinder and HavingBinder which support alias as a columnref. #[derive(new)] -pub struct ColumnAliasBinder<'a> { - pub(crate) original_select_items: &'a [sqlparser::ast::Expr], - pub(crate) alias_map: &'a HashMap, -} - -impl<'a> ColumnAliasBinder<'a> { - pub fn bind_alias( - &self, - expression_binder: &mut ExpressionBinder, - expr: &sqlparser::ast::Expr, - ) -> Result { - if let sqlparser::ast::Expr::Identifier(ident) = expr { - let alias = ident.to_string(); - if let Some(alias_entry) = self.alias_map.get(&alias) { - let expr = self.original_select_items[*alias_entry].clone(); - let bound_expr = - expression_binder.bind_expression(&expr, &mut vec![], &mut vec![])?; - return Ok(bound_expr); - } - } - Err(BindError::Internal(format!("column not found: {}", expr))) - } +pub struct ColumnAliasData { + pub(crate) original_select_items: Vec, + pub(crate) alias_map: HashMap, } diff --git a/src/planner_v2/binder/expression_binder/mod.rs b/src/planner_v2/binder/expression_binder/mod.rs index 1c143d7..ddd71a3 100644 --- a/src/planner_v2/binder/expression_binder/mod.rs +++ b/src/planner_v2/binder/expression_binder/mod.rs @@ -1,4 +1,2 @@ mod column_alias_binder; -mod where_binder; pub use column_alias_binder::*; -pub use where_binder::*; diff --git a/src/planner_v2/binder/expression_binder/where_binder.rs b/src/planner_v2/binder/expression_binder/where_binder.rs deleted file mode 100644 index 8b720c8..0000000 --- a/src/planner_v2/binder/expression_binder/where_binder.rs +++ /dev/null @@ -1,49 +0,0 @@ -use derive_new::new; - -use super::ColumnAliasBinder; -use crate::planner_v2::{BindError, BoundExpression, ExpressionBinder}; -use crate::types_v2::LogicalType; - -/// The WHERE binder is responsible for binding an expression within the WHERE clause of a SQL -/// statement -#[derive(new)] -pub struct WhereBinder<'a> { - internal_binder: ExpressionBinder<'a>, - column_alias_binder: ColumnAliasBinder<'a>, -} - -impl<'a> WhereBinder<'a> { - pub fn bind_expression( - &mut self, - expr: &sqlparser::ast::Expr, - result_names: &mut Vec, - result_types: &mut Vec, - ) -> Result { - match expr { - sqlparser::ast::Expr::Identifier(..) | sqlparser::ast::Expr::CompoundIdentifier(..) => { - self.bind_column_ref_expr(expr, result_names, result_types) - } - other => self - .internal_binder - .bind_expression(other, result_names, result_types), - } - } - - fn bind_column_ref_expr( - &mut self, - expr: &sqlparser::ast::Expr, - result_names: &mut Vec, - result_types: &mut Vec, - ) -> Result { - // bind column ref expr first - let bind_res = self - .internal_binder - .bind_expression(expr, result_names, result_types); - if bind_res.is_ok() { - return bind_res; - } - // try to bind as alias - self.column_alias_binder - .bind_alias(&mut self.internal_binder, expr) - } -} diff --git a/src/planner_v2/binder/query_node/bind_select_node.rs b/src/planner_v2/binder/query_node/bind_select_node.rs index 71aeca3..6e050f5 100644 --- a/src/planner_v2/binder/query_node/bind_select_node.rs +++ b/src/planner_v2/binder/query_node/bind_select_node.rs @@ -4,8 +4,8 @@ use derive_new::new; use sqlparser::ast::{Ident, Query}; use crate::planner_v2::{ - BindError, Binder, BoundExpression, BoundTableRef, ColumnAliasBinder, ExpressionBinder, - SqlparserResolver, WhereBinder, VALUES_LIST_ALIAS, + BindError, Binder, BoundExpression, BoundTableRef, ColumnAliasData, ExpressionBinder, + SqlparserResolver, VALUES_LIST_ALIAS, }; use crate::types_v2::LogicalType; @@ -110,12 +110,12 @@ impl Binder { // first visit the WHERE clause // the WHERE clause happens before the GROUP BY, PROJECTION or HAVING clauses let where_clause = if let Some(where_expr) = &select.selection { - let column_alias_binder = ColumnAliasBinder::new(&original_select_items, &alias_map); - let mut where_binder = - WhereBinder::new(ExpressionBinder::new(self), column_alias_binder); - // FIXME: where_binder not work with ExpressionBinder - let bound_expr = where_binder.bind_expression(where_expr, &mut vec![], &mut vec![])?; - Some(bound_expr) + let mut expr_binder = ExpressionBinder::new(self); + expr_binder.set_column_alias_data(ColumnAliasData::new( + original_select_items.clone(), + alias_map.clone(), + )); + Some(expr_binder.bind_expression(where_expr, &mut vec![], &mut vec![])?) } else { None }; diff --git a/src/planner_v2/binder/sqlparser_util.rs b/src/planner_v2/binder/sqlparser_util.rs index 4e0b556..7142e2d 100644 --- a/src/planner_v2/binder/sqlparser_util.rs +++ b/src/planner_v2/binder/sqlparser_util.rs @@ -1,3 +1,4 @@ +use itertools::Itertools; use sqlparser::ast::{ BinaryOperator, ColumnDef, Expr, Ident, ObjectName, Query, Select, SelectItem, SetExpr, TableFactor, TableWithJoins, Value, WildcardAdditionalOptions, @@ -29,6 +30,23 @@ impl SqlparserResolver { let ty = column_def.data_type.clone().try_into()?; Ok(ColumnDefinition::new(name, ty)) } + + pub fn resolve_expr_idents( + idents: &[sqlparser::ast::Ident], + ) -> Result<(Option, Option, String), BindError> { + let idents = idents + .iter() + .map(|ident| ident.value.to_lowercase()) + .collect_vec(); + + let (schema_name, table_name, column_name) = match idents.as_slice() { + [column] => (None, None, column.clone()), + [table, column] => (None, Some(table.clone()), column.clone()), + [schema, table, column] => (Some(schema.clone()), Some(table.clone()), column.clone()), + _ => return Err(BindError::UnsupportedExpr(format!("{:?}", idents))), + }; + Ok((schema_name, table_name, column_name)) + } } #[derive(Default)] diff --git a/src/planner_v2/expression_binder.rs b/src/planner_v2/expression_binder.rs index 7c963ce..8aae59f 100644 --- a/src/planner_v2/expression_binder.rs +++ b/src/planner_v2/expression_binder.rs @@ -2,15 +2,21 @@ use std::slice; use derive_new::new; -use super::{BindError, Binder, BoundExpression}; +use super::{BindError, Binder, BoundExpression, ColumnAliasData}; use crate::types_v2::LogicalType; #[derive(new)] pub struct ExpressionBinder<'a> { pub(crate) binder: &'a mut Binder, + #[new(default)] + pub(crate) column_alias_data: Option, } impl ExpressionBinder<'_> { + pub fn set_column_alias_data(&mut self, column_alias_data: ColumnAliasData) { + self.column_alias_data = Some(column_alias_data); + } + pub fn bind_expression( &mut self, expr: &sqlparser::ast::Expr, diff --git a/tests/slt/filter.slt b/tests/slt/filter.slt index a067dea..b3c5602 100644 --- a/tests/slt/filter.slt +++ b/tests/slt/filter.slt @@ -32,16 +32,16 @@ select v1, v2 from t1 where v1 >= 2 and v1 > v3; 2 5 # filter alias -# onlyif sqlrs_v2 -# query III -# select v1+1 as a from t1 where a >= 1; -# ---- -# 1 4 -# 2 5 - -# onlyif sqlrs_v2 -# query III -# select v1+1 as a from t1 where a == a; -# ---- -# 1 4 -# 2 5 +onlyif sqlrs_v2 +query III +select v1+1 as a from t1 where a >= 2; +---- +2 +3 + +onlyif sqlrs_v2 +query III +select v1+1 as a from t1 where a = a; +---- +2 +3 diff --git a/tests/slt/table_function.slt b/tests/slt/table_function.slt index 4ecce85..8dea32f 100644 --- a/tests/slt/table_function.slt +++ b/tests/slt/table_function.slt @@ -4,7 +4,7 @@ create table t1(v1 int, v2 int, v3 int); create table t2(v1 varchar, v2 boolean, v3 bigint); onlyif sqlrs_v2 -query III +query III rowsort select schema_name, schema_oid, table_name from sqlrs_tables(); ---- main 1 t1 From e9ca60c6c8233166430e9d28d24042b647d6d995 Mon Sep 17 00:00:00 2001 From: Fedomn Date: Mon, 26 Dec 2022 15:59:45 +0800 Subject: [PATCH 4/4] feat(planner): refactor Physical operator base construction Signed-off-by: Fedomn --- src/execution/physical_plan/mod.rs | 7 +++--- .../physical_plan/physical_dummy_scan.rs | 2 +- .../physical_plan/physical_explain.rs | 4 ++-- .../physical_plan/physical_expression_scan.rs | 9 +------- .../physical_plan/physical_filter.rs | 22 ++++++------------- .../physical_plan/physical_insert.rs | 8 +------ .../physical_plan/physical_projection.rs | 14 +++--------- .../physical_plan/physical_table_scan.rs | 2 +- src/execution/physical_plan_generator.rs | 16 ++++++++++++-- src/execution/volcano_executor/filter.rs | 2 +- src/execution/volcano_executor/projection.rs | 2 +- 11 files changed, 35 insertions(+), 53 deletions(-) diff --git a/src/execution/physical_plan/mod.rs b/src/execution/physical_plan/mod.rs index 408d075..bb34ca5 100644 --- a/src/execution/physical_plan/mod.rs +++ b/src/execution/physical_plan/mod.rs @@ -19,14 +19,13 @@ pub use physical_insert::*; pub use physical_projection::*; pub use physical_table_scan::*; -use crate::types_v2::LogicalType; +use crate::planner_v2::BoundExpression; #[derive(new, Default, Clone)] pub struct PhysicalOperatorBase { pub(crate) children: Vec, - #[allow(dead_code)] - /// The types returned by this physical operator - pub(crate) types: Vec, + // The set of expressions contained within the operator, if any + pub(crate) expressioins: Vec, } #[derive(Clone)] diff --git a/src/execution/physical_plan/physical_dummy_scan.rs b/src/execution/physical_plan/physical_dummy_scan.rs index 91fab77..6357c5e 100644 --- a/src/execution/physical_plan/physical_dummy_scan.rs +++ b/src/execution/physical_plan/physical_dummy_scan.rs @@ -11,7 +11,7 @@ pub struct PhysicalDummyScan { impl PhysicalPlanGenerator { pub(crate) fn create_physical_dummy_scan(&self, op: LogicalDummyScan) -> PhysicalOperator { - let base = PhysicalOperatorBase::new(vec![], op.base.types); + let base = self.create_physical_operator_base(op.base); PhysicalOperator::PhysicalDummyScan(PhysicalDummyScan::new(base)) } } diff --git a/src/execution/physical_plan/physical_explain.rs b/src/execution/physical_plan/physical_explain.rs index 797b48f..9fae55c 100644 --- a/src/execution/physical_plan/physical_explain.rs +++ b/src/execution/physical_plan/physical_explain.rs @@ -3,7 +3,7 @@ use std::sync::Arc; use arrow::array::StringArray; use arrow::record_batch::RecordBatch; -use super::{PhysicalColumnDataScan, PhysicalOperator, PhysicalOperatorBase}; +use super::{PhysicalColumnDataScan, PhysicalOperator}; use crate::execution::{PhysicalPlanGenerator, SchemaUtil}; use crate::planner_v2::LogicalExplain; use crate::util::tree_render::TreeRender; @@ -19,7 +19,7 @@ impl PhysicalPlanGenerator { // physical plan explain string let physical_plan_string = TreeRender::physical_plan_tree(&physical_child); - let base = PhysicalOperatorBase::new(vec![], types.clone()); + let base = self.create_physical_operator_base(op.base); let schema = SchemaUtil::new_schema_ref(&["type".to_string(), "plan".to_string()], &types); let types_column = Arc::new(StringArray::from(vec![ diff --git a/src/execution/physical_plan/physical_expression_scan.rs b/src/execution/physical_plan/physical_expression_scan.rs index 672755a..2e909a3 100644 --- a/src/execution/physical_plan/physical_expression_scan.rs +++ b/src/execution/physical_plan/physical_expression_scan.rs @@ -21,14 +21,7 @@ impl PhysicalPlanGenerator { op: LogicalExpressionGet, ) -> PhysicalOperator { assert!(op.base.children.len() == 1); - let new_children = op - .base - .children - .into_iter() - .map(|p| self.create_plan_internal(p)) - .collect::>(); - let types = op.base.types; - let base = PhysicalOperatorBase::new(new_children, types); + let base = self.create_physical_operator_base(op.base); PhysicalOperator::PhysicalExpressionScan(PhysicalExpressionScan::new( base, op.expr_types, diff --git a/src/execution/physical_plan/physical_filter.rs b/src/execution/physical_plan/physical_filter.rs index 7860959..69bea7d 100644 --- a/src/execution/physical_plan/physical_filter.rs +++ b/src/execution/physical_plan/physical_filter.rs @@ -1,33 +1,25 @@ use super::{PhysicalOperator, PhysicalOperatorBase}; use crate::execution::PhysicalPlanGenerator; -use crate::planner_v2::{BoundConjunctionExpression, BoundExpression, LogicalFilter}; +use crate::planner_v2::{BoundConjunctionExpression, LogicalFilter}; #[derive(Clone)] pub struct PhysicalFilter { pub(crate) base: PhysicalOperatorBase, - pub(crate) expression: BoundExpression, } impl PhysicalFilter { - pub fn new(base: PhysicalOperatorBase, expressions: Vec) -> Self { + pub fn new(mut base: PhysicalOperatorBase) -> Self { let expression = - BoundConjunctionExpression::try_build_and_conjunction_expression(expressions); - Self { base, expression } + BoundConjunctionExpression::try_build_and_conjunction_expression(base.expressioins); + base.expressioins = vec![expression]; + Self { base } } } impl PhysicalPlanGenerator { pub(crate) fn create_physical_filter(&self, op: LogicalFilter) -> PhysicalOperator { assert!(op.base.children.len() == 1); - // TODO: refactor this part to common method - let new_children = op - .base - .children - .into_iter() - .map(|p| self.create_plan_internal(p)) - .collect::>(); - let types = op.base.types; - let base = PhysicalOperatorBase::new(new_children, types); - PhysicalOperator::PhysicalFilter(PhysicalFilter::new(base, op.base.expressioins)) + let base = self.create_physical_operator_base(op.base); + PhysicalOperator::PhysicalFilter(PhysicalFilter::new(base)) } } diff --git a/src/execution/physical_plan/physical_insert.rs b/src/execution/physical_plan/physical_insert.rs index 5cd130c..8007daf 100644 --- a/src/execution/physical_plan/physical_insert.rs +++ b/src/execution/physical_plan/physical_insert.rs @@ -29,13 +29,7 @@ impl PhysicalInsert { impl PhysicalPlanGenerator { pub(crate) fn create_physical_insert(&self, op: LogicalInsert) -> PhysicalOperator { - let new_children = op - .base - .children - .into_iter() - .map(|op| self.create_plan_internal(op)) - .collect::>(); - let base = PhysicalOperatorBase::new(new_children, op.base.types); + let base = self.create_physical_operator_base(op.base); PhysicalOperator::PhysicalInsert(PhysicalInsert::new( base, op.column_index_list, diff --git a/src/execution/physical_plan/physical_projection.rs b/src/execution/physical_plan/physical_projection.rs index 133e802..14301d5 100644 --- a/src/execution/physical_plan/physical_projection.rs +++ b/src/execution/physical_plan/physical_projection.rs @@ -2,24 +2,16 @@ use derive_new::new; use super::{PhysicalOperator, PhysicalOperatorBase}; use crate::execution::PhysicalPlanGenerator; -use crate::planner_v2::{BoundExpression, LogicalProjection}; +use crate::planner_v2::LogicalProjection; #[derive(new, Clone)] pub struct PhysicalProjection { pub(crate) base: PhysicalOperatorBase, - pub(crate) select_list: Vec, } impl PhysicalPlanGenerator { pub(crate) fn create_physical_projection(&self, op: LogicalProjection) -> PhysicalOperator { - let new_children = op - .base - .children - .into_iter() - .map(|p| self.create_plan_internal(p)) - .collect::>(); - let types = op.base.types; - let base = PhysicalOperatorBase::new(new_children, types); - PhysicalOperator::PhysicalProjection(PhysicalProjection::new(base, op.base.expressioins)) + let base = self.create_physical_operator_base(op.base); + PhysicalOperator::PhysicalProjection(PhysicalProjection::new(base)) } } diff --git a/src/execution/physical_plan/physical_table_scan.rs b/src/execution/physical_plan/physical_table_scan.rs index 4416ce6..3dd8399 100644 --- a/src/execution/physical_plan/physical_table_scan.rs +++ b/src/execution/physical_plan/physical_table_scan.rs @@ -19,7 +19,7 @@ pub struct PhysicalTableScan { impl PhysicalPlanGenerator { pub(crate) fn create_physical_table_scan(&self, op: LogicalGet) -> PhysicalOperator { - let base = PhysicalOperatorBase::new(vec![], op.base.types); + let base = self.create_physical_operator_base(op.base); let plan = PhysicalTableScan::new(base, op.function, op.bind_data, op.returned_types, op.names); PhysicalOperator::PhysicalTableScan(plan) diff --git a/src/execution/physical_plan_generator.rs b/src/execution/physical_plan_generator.rs index 262ce9b..76bf1c5 100644 --- a/src/execution/physical_plan_generator.rs +++ b/src/execution/physical_plan_generator.rs @@ -3,10 +3,10 @@ use std::sync::Arc; use derive_new::new; use log::debug; -use super::{ColumnBindingResolver, PhysicalOperator}; +use super::{ColumnBindingResolver, PhysicalOperator, PhysicalOperatorBase}; use crate::execution::LOGGING_TARGET; use crate::main_entry::ClientContext; -use crate::planner_v2::{LogicalOperator, LogicalOperatorVisitor}; +use crate::planner_v2::{LogicalOperator, LogicalOperatorBase, LogicalOperatorVisitor}; use crate::util::tree_render::TreeRender; #[derive(new)] @@ -45,4 +45,16 @@ impl PhysicalPlanGenerator { LogicalOperator::LogicalFilter(op) => self.create_physical_filter(op), } } + + pub(crate) fn create_physical_operator_base( + &self, + base: LogicalOperatorBase, + ) -> PhysicalOperatorBase { + let children = base + .children + .iter() + .map(|op| self.create_plan_internal(op.clone())) + .collect::>(); + PhysicalOperatorBase::new(children, base.expressioins) + } } diff --git a/src/execution/volcano_executor/filter.rs b/src/execution/volcano_executor/filter.rs index 1829f2a..e216bf8 100644 --- a/src/execution/volcano_executor/filter.rs +++ b/src/execution/volcano_executor/filter.rs @@ -19,7 +19,7 @@ pub struct Filter { impl Filter { #[try_stream(boxed, ok = RecordBatch, error = ExecutorError)] pub async fn execute(self, _context: Arc) { - let exprs = vec![self.plan.expression]; + let exprs = self.plan.base.expressioins; #[for_await] for batch in self.child { diff --git a/src/execution/volcano_executor/projection.rs b/src/execution/volcano_executor/projection.rs index ac32b8e..8a86a4b 100644 --- a/src/execution/volcano_executor/projection.rs +++ b/src/execution/volcano_executor/projection.rs @@ -18,7 +18,7 @@ pub struct Projection { impl Projection { #[try_stream(boxed, ok = RecordBatch, error = ExecutorError)] pub async fn execute(self, _context: Arc) { - let exprs = self.plan.select_list; + let exprs = self.plan.base.expressioins; let schema = SchemaUtil::new_schema_ref_from_exprs(&exprs); #[for_await]