Skip to content

Commit

Permalink
sql: add a SQL IR and factor out optimizations.
Browse files Browse the repository at this point in the history
  • Loading branch information
jacksonrnewhouse committed Apr 26, 2023
1 parent d121ced commit 91d0f26
Show file tree
Hide file tree
Showing 11 changed files with 2,041 additions and 1,156 deletions.
2 changes: 1 addition & 1 deletion arroyo-api/src/optimizations.rs
Original file line number Diff line number Diff line change
Expand Up @@ -339,7 +339,7 @@ impl FusedExpressionOperatorBuilder {
}
Some(Record) => {
self.body.push(quote!(
let record:#out_type = #expression;));
let record:#out_type = #expression?;));
self.current_return_type = Some(OptionalRecord);
}
Some(OptionalRecord) => {
Expand Down
2 changes: 1 addition & 1 deletion arroyo-datastream/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ impl Debug for WindowType {
}
}

#[derive(Clone, Encode, Decode, Serialize, Deserialize, PartialEq, Eq)]
#[derive(Debug, Clone, Encode, Decode, Serialize, Deserialize, PartialEq, Eq)]
pub enum WatermarkType {
Periodic {
period: Duration,
Expand Down
1 change: 0 additions & 1 deletion arroyo-sql-testing/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
#[cfg(test)]
mod tests {
use arroyo_sql_macro::single_test_codegen;
use chrono;

// Casts
single_test_codegen!(
Expand Down
111 changes: 83 additions & 28 deletions arroyo-sql/src/expressions.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use std::fmt::Debug;

use crate::{
operators::TwoPhaseAggregation,
pipeline::SortDirection,
types::{StructDef, StructField, TypeDef},
};
Expand Down Expand Up @@ -28,7 +29,7 @@ pub trait ExpressionGenerator: Debug {
}
}

#[derive(Debug)]
#[derive(Debug, Clone)]
pub enum Expression {
Column(ColumnExpression),
UnaryBoolean(UnaryBooleanExpression),
Expand Down Expand Up @@ -92,7 +93,7 @@ impl ExpressionGenerator for Expression {
}

impl Expression {
pub(crate) fn has_max_value(&self, field: &StructField) -> Option<i64> {
pub(crate) fn has_max_value(&self, field: &StructField) -> Option<u64> {
match self {
Expression::BinaryComparison(BinaryComparisonExpression { left, op, right }) => {
if let BinaryComparison::And = op {
Expand All @@ -116,13 +117,15 @@ impl Expression {
) => {
if field == column_field {
match (op, literal) {
(BinaryComparison::Lt, ScalarValue::Int64(Some(max))) => {
(BinaryComparison::Lt, ScalarValue::UInt64(Some(max))) => {
Some(*max - 1)
}
(BinaryComparison::LtEq, ScalarValue::Int64(Some(max))) => {
(BinaryComparison::LtEq, ScalarValue::UInt64(Some(max))) => {
Some(*max)
}
(BinaryComparison::Eq, ScalarValue::UInt64(Some(max))) => {
Some(*max)
}
(BinaryComparison::Eq, ScalarValue::Int64(Some(max))) => Some(*max),
_ => None,
}
} else {
Expand All @@ -135,13 +138,15 @@ impl Expression {
) => {
if field == column_field {
match (op, literal) {
(BinaryComparison::Gt, ScalarValue::Int64(Some(max))) => {
(BinaryComparison::Gt, ScalarValue::UInt64(Some(max))) => {
Some(*max + 1)
}
(BinaryComparison::GtEq, ScalarValue::Int64(Some(max))) => {
(BinaryComparison::GtEq, ScalarValue::UInt64(Some(max))) => {
Some(*max)
}
(BinaryComparison::Eq, ScalarValue::UInt64(Some(max))) => {
Some(*max)
}
(BinaryComparison::Eq, ScalarValue::Int64(Some(max))) => Some(*max),
_ => None,
}
} else {
Expand Down Expand Up @@ -415,7 +420,7 @@ impl Column {
}
}

#[derive(Debug)]
#[derive(Debug, Clone)]
pub struct ColumnExpression {
column_field: StructField,
}
Expand Down Expand Up @@ -460,7 +465,7 @@ pub enum UnaryOperator {
Negative,
}

#[derive(Debug)]
#[derive(Debug, Clone)]
pub struct UnaryBooleanExpression {
operator: UnaryOperator,
input: Box<Expression>,
Expand Down Expand Up @@ -512,7 +517,7 @@ impl UnaryBooleanExpression {
}
}

#[derive(Debug)]
#[derive(Debug, Clone)]
pub struct LiteralExpression {
literal: ScalarValue,
}
Expand All @@ -533,7 +538,7 @@ impl LiteralExpression {
}
}

#[derive(Debug)]
#[derive(Debug, Clone)]
pub enum BinaryComparison {
Eq,
NotEq,
Expand Down Expand Up @@ -568,7 +573,7 @@ impl TryFrom<datafusion_expr::Operator> for BinaryComparison {
}
}

#[derive(Debug)]
#[derive(Debug, Clone)]
pub struct BinaryComparisonExpression {
pub left: Box<Expression>,
pub op: BinaryComparison,
Expand Down Expand Up @@ -633,7 +638,7 @@ impl ExpressionGenerator for BinaryComparisonExpression {
}
}

#[derive(Debug)]
#[derive(Debug, Clone)]
pub enum BinaryMathOperator {
Plus,
Minus,
Expand Down Expand Up @@ -670,7 +675,7 @@ impl TryFrom<datafusion_expr::Operator> for BinaryMathOperator {
}
}

#[derive(Debug)]
#[derive(Debug, Clone)]
pub struct BinaryMathExpression {
left: Box<Expression>,
op: BinaryMathOperator,
Expand Down Expand Up @@ -718,7 +723,7 @@ impl ExpressionGenerator for BinaryMathExpression {
}
}

#[derive(Debug)]
#[derive(Debug, Clone)]
pub struct StructFieldExpression {
struct_expression: Box<Expression>,
struct_field: StructField,
Expand Down Expand Up @@ -815,10 +820,28 @@ impl Aggregator {
}
}

#[derive(Debug)]
#[derive(Debug, Clone)]
pub struct AggregationExpression {
producing_expression: Box<Expression>,
aggregator: Aggregator,
pub producing_expression: Box<Expression>,
pub aggregator: Aggregator,
}

impl TryFrom<AggregationExpression> for TwoPhaseAggregation {
type Error = anyhow::Error;

fn try_from(aggregation_expression: AggregationExpression) -> Result<Self> {
if aggregation_expression.allows_two_phase() {
Ok(TwoPhaseAggregation {
incoming_expression: *aggregation_expression.producing_expression,
aggregator: aggregation_expression.aggregator,
})
} else {
bail!(
"{:?} does not support two phase aggregation",
aggregation_expression.aggregator
);
}
}
}

impl AggregationExpression {
Expand All @@ -833,6 +856,40 @@ impl AggregationExpression {
aggregator,
}))
}

pub(crate) fn allows_two_phase(&self) -> bool {
match self.aggregator {
Aggregator::Count
| Aggregator::Sum
| Aggregator::Min
| Aggregator::Avg
| Aggregator::Max => true,
Aggregator::CountDistinct => false,
}
}

pub fn try_from_expression(expr: &Expr, input_struct: &StructDef) -> Result<Self> {
match expr {
Expr::AggregateFunction(datafusion_expr::expr::AggregateFunction {
fun,
args,
distinct,
filter: None,
}) => {
if args.len() != 1 {
bail!("unexpected arg length");
}
let producing_expression =
Box::new(to_expression_generator(&args[0], input_struct)?);
let aggregator = Aggregator::from_datafusion(fun.clone(), *distinct)?;
Ok(AggregationExpression {
producing_expression,
aggregator,
})
}
_ => bail!("expected aggregate function, not {}", expr),
}
}
}

impl ExpressionGenerator for AggregationExpression {
Expand Down Expand Up @@ -903,7 +960,7 @@ impl ExpressionGenerator for AggregationExpression {
}
}

#[derive(Debug)]
#[derive(Debug, Clone)]
pub struct CastExpression {
input: Box<Expression>,
data_type: DataType,
Expand Down Expand Up @@ -937,10 +994,8 @@ impl CastExpression {
{
true
// handle date to string casts.
} else if Self::is_date(input_data_type) || Self::is_string(output_data_type) {
true
} else {
false
Self::is_date(input_data_type) || Self::is_string(output_data_type)
}
}

Expand Down Expand Up @@ -1084,7 +1139,7 @@ impl TryFrom<BuiltinScalarFunction> for NumericFunction {
}
}

#[derive(Debug)]
#[derive(Debug, Clone)]
pub struct NumericExpression {
function: NumericFunction,
input: Box<Expression>,
Expand Down Expand Up @@ -1113,7 +1168,7 @@ impl ExpressionGenerator for NumericExpression {
}
}

#[derive(Debug)]
#[derive(Debug, Clone)]
pub struct SortExpression {
value: Expression,
direction: SortDirection,
Expand Down Expand Up @@ -1173,7 +1228,7 @@ impl SortExpression {
}
}

#[derive(Debug)]
#[derive(Debug, Clone)]
pub enum StringFunction {
Ascii(Box<Expression>),
BitLength(Box<Expression>),
Expand Down Expand Up @@ -1211,7 +1266,7 @@ pub enum StringFunction {
Rtrim(Box<Expression>, Option<Box<Expression>>),
}

#[derive(Debug)]
#[derive(Debug, Clone)]
pub enum HashFunction {
MD5,
SHA224,
Expand Down Expand Up @@ -1247,7 +1302,7 @@ impl TryFrom<BuiltinScalarFunction> for HashFunction {
}
}

#[derive(Debug)]
#[derive(Debug, Clone)]
pub struct HashExpression {
function: HashFunction,
input: Box<Expression>,
Expand Down
6 changes: 4 additions & 2 deletions arroyo-sql/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@ use datafusion::physical_plan::functions::make_scalar_function;

mod expressions;
mod operators;
mod optimizations;
mod pipeline;
mod plan_graph;
pub mod schemas;
pub mod types;

Expand Down Expand Up @@ -435,10 +437,10 @@ pub fn get_test_expression(
let statement = &ast[0];
let sql_to_rel = SqlToRel::new(&schema_provider);
let plan = sql_to_rel.sql_statement_to_plan(statement.clone()).unwrap();
let mut optimizer_config = OptimizerContext::default();
let optimizer_config = OptimizerContext::default();
let optimizer = Optimizer::new();
let plan = optimizer
.optimize(&plan, &mut optimizer_config, |_plan, _rule| {})
.optimize(&plan, &optimizer_config, |_plan, _rule| {})
.unwrap();
let LogicalPlan::Projection(projection) = plan else {panic!("expect projection")};
let generating_expression = to_expression_generator(&projection.expr[0], &struct_def).unwrap();
Expand Down
Loading

0 comments on commit 91d0f26

Please sign in to comment.