Skip to content

Commit

Permalink
Merge pull request #129 from Fedomn/conjunction-function
Browse files Browse the repository at this point in the history
feat(planner): introduce conjunction function
  • Loading branch information
mergify[bot] authored Dec 24, 2022
2 parents ca8c189 + a92c7b9 commit c647e99
Show file tree
Hide file tree
Showing 14 changed files with 350 additions and 4 deletions.
12 changes: 12 additions & 0 deletions src/execution/expression_executor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,18 @@ impl ExpressionExecutor {
let func = e.function.function;
func(&left_result, &right_result)?
}
BoundExpression::BoundConjunctionExpression(e) => {
assert!(e.children.len() >= 2);
let mut conjunction_result = Self::execute_internal(&e.children[0], input)?;
for i in 1..e.children.len() {
let func = e.function.function;
conjunction_result = func(
&conjunction_result,
&Self::execute_internal(&e.children[i], input)?,
)?;
}
conjunction_result
}
})
}
}
20 changes: 20 additions & 0 deletions src/function/conjunction/conjunction_function.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
use arrow::array::ArrayRef;
use derive_new::new;

use crate::function::FunctionError;

pub type ConjunctionFunc = fn(left: &ArrayRef, right: &ArrayRef) -> Result<ArrayRef, FunctionError>;

#[derive(new, Clone)]
pub struct ConjunctionFunction {
pub(crate) name: String,
pub(crate) function: ConjunctionFunc,
}

impl std::fmt::Debug for ConjunctionFunction {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ConjunctionFunction")
.field("name", &self.name)
.finish()
}
}
65 changes: 65 additions & 0 deletions src/function/conjunction/default_conjunction.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
use std::sync::Arc;

use arrow::array::{ArrayRef, BooleanArray};
use arrow::compute::{and_kleene, or_kleene};
use arrow::datatypes::DataType;
use sqlparser::ast::BinaryOperator;

use super::{ConjunctionFunc, ConjunctionFunction};
use crate::function::FunctionError;

pub struct DefaultConjunctionFunctions;

macro_rules! boolean_op {
($LEFT:expr, $RIGHT:expr, $OP:ident) => {{
if *$LEFT.data_type() != DataType::Boolean || *$RIGHT.data_type() != DataType::Boolean {
return Err(FunctionError::ConjunctionError(format!(
"Cannot evaluate binary expression with types {:?} and {:?}, only Boolean supported",
$LEFT.data_type(),
$RIGHT.data_type()
)));
}

let ll = $LEFT
.as_any()
.downcast_ref::<BooleanArray>()
.expect("boolean_op failed to downcast array");
let rr = $RIGHT
.as_any()
.downcast_ref::<BooleanArray>()
.expect("boolean_op failed to downcast array");
Ok(Arc::new($OP(&ll, &rr)?))
}};
}

impl DefaultConjunctionFunctions {
fn default_and_function(left: &ArrayRef, right: &ArrayRef) -> Result<ArrayRef, FunctionError> {
boolean_op!(left, right, and_kleene)
}

fn default_or_function(left: &ArrayRef, right: &ArrayRef) -> Result<ArrayRef, FunctionError> {
boolean_op!(left, right, or_kleene)
}

fn get_conjunction_function_internal(
op: &BinaryOperator,
) -> Result<(&str, ConjunctionFunc), FunctionError> {
Ok(match op {
BinaryOperator::And => ("and", Self::default_and_function),
BinaryOperator::Or => ("or", Self::default_or_function),
_ => {
return Err(FunctionError::ConjunctionError(format!(
"Unsupported conjunction operator {:?}",
op
)))
}
})
}

pub fn get_conjunction_function(
op: &BinaryOperator,
) -> Result<ConjunctionFunction, FunctionError> {
let (name, func) = Self::get_conjunction_function_internal(op)?;
Ok(ConjunctionFunction::new(name.to_string(), func))
}
}
4 changes: 4 additions & 0 deletions src/function/conjunction/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
mod conjunction_function;
mod default_conjunction;
pub use conjunction_function::*;
pub use default_conjunction::*;
2 changes: 2 additions & 0 deletions src/function/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,4 +29,6 @@ pub enum FunctionError {
CastError(String),
#[error("Comparison error: {0}")]
ComparisonError(String),
#[error("Conjunction error: {0}")]
ConjunctionError(String),
}
2 changes: 2 additions & 0 deletions src/function/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
mod cast;
mod comparison;
mod conjunction;
mod errors;
mod scalar;
mod table;
Expand All @@ -8,6 +9,7 @@ use std::sync::Arc;

pub use cast::*;
pub use comparison::*;
pub use conjunction::*;
use derive_new::new;
pub use errors::*;
pub use scalar::*;
Expand Down
1 change: 1 addition & 0 deletions src/planner_v2/binder/expression/bind_cast_expression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ impl BoundCastExpression {
alias: String,
try_cast: bool,
) -> Result<BoundExpression, BindError> {
// TODO: enhance alias to reduce outside alias assignment
let source_type = expr.return_type();
assert!(source_type != target_type);
let cast_function = DefaultCastFunctions::get_cast_function(&source_type, &target_type)?;
Expand Down
55 changes: 55 additions & 0 deletions src/planner_v2/binder/expression/bind_conjunction_expression.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
use derive_new::new;

use super::{BoundCastExpression, BoundExpression, BoundExpressionBase};
use crate::function::{ConjunctionFunction, DefaultConjunctionFunctions};
use crate::planner_v2::{BindError, ExpressionBinder};
use crate::types_v2::LogicalType;

#[derive(new, Debug, Clone)]
pub struct BoundConjunctionExpression {
pub(crate) base: BoundExpressionBase,
pub(crate) function: ConjunctionFunction,
pub(crate) children: Vec<BoundExpression>,
}

impl ExpressionBinder<'_> {
pub fn bind_conjunction_expression(
&mut self,
left: &sqlparser::ast::Expr,
op: &sqlparser::ast::BinaryOperator,
right: &sqlparser::ast::Expr,
result_names: &mut Vec<String>,
result_types: &mut Vec<LogicalType>,
) -> Result<BoundExpression, BindError> {
let function = DefaultConjunctionFunctions::get_conjunction_function(op)?;
let mut return_names = vec![];
let mut left = self.bind_expression(left, &mut return_names, &mut vec![])?;
let mut right = self.bind_expression(right, &mut return_names, &mut vec![])?;
if left.return_type() != LogicalType::Boolean {
let alias = format!("cast({} as {}", left.alias(), LogicalType::Boolean);
left = BoundCastExpression::add_cast_to_type(
left,
LogicalType::Boolean,
alias.clone(),
true,
)?;
return_names[0] = alias;
}
if right.return_type() != LogicalType::Boolean {
let alias = format!("cast({} as {}", right.alias(), LogicalType::Boolean);
right = BoundCastExpression::add_cast_to_type(
right,
LogicalType::Boolean,
alias.clone(),
true,
)?;
return_names[1] = alias;
}
result_names.push(format!("{}({},{})", op, return_names[0], return_names[1]));
result_types.push(LogicalType::Boolean);
let base = BoundExpressionBase::new("".to_string(), LogicalType::Boolean);
Ok(BoundExpression::BoundConjunctionExpression(
BoundConjunctionExpression::new(base, function, vec![left, right]),
))
}
}
6 changes: 6 additions & 0 deletions src/planner_v2/binder/expression/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
mod bind_cast_expression;
mod bind_column_ref_expression;
mod bind_comparison_expression;
mod bind_conjunction_expression;
mod bind_constant_expression;
mod bind_function_expression;
mod bind_reference_expression;
Expand All @@ -9,6 +10,7 @@ mod column_binding;
pub use bind_cast_expression::*;
pub use bind_column_ref_expression::*;
pub use bind_comparison_expression::*;
pub use bind_conjunction_expression::*;
pub use bind_constant_expression::*;
pub use bind_function_expression::*;
pub use bind_reference_expression::*;
Expand All @@ -33,6 +35,7 @@ pub enum BoundExpression {
BoundCastExpression(BoundCastExpression),
BoundFunctionExpression(BoundFunctionExpression),
BoundComparisonExpression(BoundComparisonExpression),
BoundConjunctionExpression(BoundConjunctionExpression),
}

impl BoundExpression {
Expand All @@ -44,6 +47,7 @@ impl BoundExpression {
BoundExpression::BoundCastExpression(expr) => expr.base.return_type.clone(),
BoundExpression::BoundFunctionExpression(expr) => expr.base.return_type.clone(),
BoundExpression::BoundComparisonExpression(expr) => expr.base.return_type.clone(),
BoundExpression::BoundConjunctionExpression(expr) => expr.base.return_type.clone(),
}
}

Expand All @@ -55,6 +59,7 @@ impl BoundExpression {
BoundExpression::BoundCastExpression(expr) => expr.base.alias.clone(),
BoundExpression::BoundFunctionExpression(expr) => expr.base.alias.clone(),
BoundExpression::BoundComparisonExpression(expr) => expr.base.alias.clone(),
BoundExpression::BoundConjunctionExpression(expr) => expr.base.alias.clone(),
}
}

Expand All @@ -66,6 +71,7 @@ impl BoundExpression {
BoundExpression::BoundCastExpression(expr) => expr.base.alias = alias,
BoundExpression::BoundFunctionExpression(expr) => expr.base.alias = alias,
BoundExpression::BoundComparisonExpression(expr) => expr.base.alias = alias,
BoundExpression::BoundConjunctionExpression(expr) => expr.base.alias = alias,
}
}
}
5 changes: 3 additions & 2 deletions src/planner_v2/expression_binder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,9 @@ impl ExpressionBinder<'_> {
| sqlparser::ast::BinaryOperator::NotEq => {
self.bind_comparison_expression(left, op, right, result_names, result_types)
}
sqlparser::ast::BinaryOperator::And => todo!(),
sqlparser::ast::BinaryOperator::Or => todo!(),
sqlparser::ast::BinaryOperator::And | sqlparser::ast::BinaryOperator::Or => {
self.bind_conjunction_expression(left, op, right, result_names, result_types)
}
other => Err(BindError::UnsupportedExpr(other.to_string())),
}
}
Expand Down
3 changes: 3 additions & 0 deletions src/planner_v2/expression_iterator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@ impl ExpressionIterator {
callback(&mut e.left);
callback(&mut e.right);
}
BoundExpression::BoundConjunctionExpression(e) => {
e.children.iter_mut().for_each(callback)
}
}
}
}
11 changes: 9 additions & 2 deletions src/planner_v2/logical_operator_visitor.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use super::{
BoundCastExpression, BoundColumnRefExpression, BoundComparisonExpression,
BoundConstantExpression, BoundExpression, BoundFunctionExpression, BoundReferenceExpression,
ExpressionIterator, LogicalOperator,
BoundConjunctionExpression, BoundConstantExpression, BoundExpression, BoundFunctionExpression,
BoundReferenceExpression, ExpressionIterator, LogicalOperator,
};

/// Visitor pattern on logical operators, also includes rewrite expression ability.
Expand Down Expand Up @@ -38,6 +38,7 @@ pub trait LogicalOperatorVisitor {
BoundExpression::BoundCastExpression(e) => self.visit_replace_cast(e),
BoundExpression::BoundFunctionExpression(e) => self.visit_function_expression(e),
BoundExpression::BoundComparisonExpression(e) => self.visit_comparison_expression(e),
BoundExpression::BoundConjunctionExpression(e) => self.visit_conjunction_expression(e),
};
if let Some(new_expr) = result {
*expr = new_expr;
Expand Down Expand Up @@ -71,4 +72,10 @@ pub trait LogicalOperatorVisitor {
) -> Option<BoundExpression> {
None
}
fn visit_conjunction_expression(
&self,
_: &BoundConjunctionExpression,
) -> Option<BoundExpression> {
None
}
}
9 changes: 9 additions & 0 deletions src/util/tree_render.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,15 @@ impl TreeRender {
let r = Self::bound_expression_to_string(&e.right);
format!("{} {} {}", l, e.function.name, r)
}
BoundExpression::BoundConjunctionExpression(e) => {
let args = e
.children
.iter()
.map(Self::bound_expression_to_string)
.collect::<Vec<_>>()
.join(", ");
format!("{}({}])", e.function.name, args)
}
}
}

Expand Down
Loading

0 comments on commit c647e99

Please sign in to comment.