diff --git a/datafusion/jit/Cargo.toml b/datafusion/jit/Cargo.toml index 54c010812eb5..e539e2b1f30b 100644 --- a/datafusion/jit/Cargo.toml +++ b/datafusion/jit/Cargo.toml @@ -36,6 +36,7 @@ path = "src/lib.rs" jit = [] [dependencies] +arrow = { version = "11" } cranelift = "0.82.0" cranelift-jit = "0.82.0" cranelift-module = "0.82.0" diff --git a/datafusion/jit/src/ast.rs b/datafusion/jit/src/ast.rs index 8b9139a32e40..fd10a909e783 100644 --- a/datafusion/jit/src/ast.rs +++ b/datafusion/jit/src/ast.rs @@ -16,7 +16,7 @@ // under the License. use cranelift::codegen::ir; -use datafusion_common::{DataFusionError, ScalarValue}; +use datafusion_common::{DFSchemaRef, DataFusionError, ScalarValue}; use std::fmt::{Display, Formatter}; #[derive(Clone, Debug)] @@ -138,11 +138,13 @@ pub enum Literal { Typed(TypedLit), } -impl TryFrom for Expr { +impl TryFrom<(datafusion_expr::Expr, DFSchemaRef)> for Expr { type Error = DataFusionError; // Try to JIT compile the Expr for faster evaluation - fn try_from(value: datafusion_expr::Expr) -> Result { + fn try_from( + (value, schema): (datafusion_expr::Expr, DFSchemaRef), + ) -> Result { match &value { datafusion_expr::Expr::BinaryExpr { left, op, right } => { let op = match op { @@ -164,10 +166,30 @@ impl TryFrom for Expr { } }; Ok(Expr::Binary(op( - Box::new((*left.clone()).try_into()?), - Box::new((*right.clone()).try_into()?), + Box::new((*left.clone(), schema.clone()).try_into()?), + Box::new((*right.clone(), schema).try_into()?), ))) } + datafusion_expr::Expr::Column(col) => { + let field = schema.field_from_column(col)?; + let ty = field.data_type(); + + let jit_type = match ty { + arrow::datatypes::DataType::Int64 => I64, + arrow::datatypes::DataType::Float32 => F32, + arrow::datatypes::DataType::Float64 => F64, + arrow::datatypes::DataType::Boolean => BOOL, + + _ => { + return Err(DataFusionError::NotImplemented(format!( + "Compiling Expression with type {} not yet supported in JIT mode", + ty + ))) + } + }; + + Ok(Expr::Identifier(field.qualified_name(), jit_type)) + } datafusion_expr::Expr::Literal(s) => { let lit = match s { ScalarValue::Boolean(Some(b)) => TypedLit::Bool(*b), diff --git a/datafusion/jit/src/lib.rs b/datafusion/jit/src/lib.rs index c1db48b45bae..dff27da317e4 100644 --- a/datafusion/jit/src/lib.rs +++ b/datafusion/jit/src/lib.rs @@ -23,11 +23,15 @@ pub mod jit; #[cfg(test)] mod tests { + use std::collections::HashMap; + use std::sync::Arc; + use crate::api::{Assembler, GeneratedFunction}; use crate::ast::{BinaryExpr, Expr, Literal, TypedLit, I64}; use crate::jit::JIT; - use datafusion_common::Result; - use datafusion_expr::lit; + use arrow::datatypes::DataType; + use datafusion_common::{DFField, DFSchema, Result}; + use datafusion_expr::{col, lit}; #[test] fn iterative_fib() -> Result<()> { @@ -89,7 +93,8 @@ mod tests { #[test] fn from_datafusion_expression() -> Result<()> { let df_expr = lit(1.0f32) + lit(2.0f32); - let jit_expr: crate::ast::Expr = df_expr.try_into()?; + let schema = Arc::new(DFSchema::empty()); + let jit_expr: crate::ast::Expr = (df_expr, schema).try_into()?; assert_eq!( jit_expr, @@ -102,6 +107,26 @@ mod tests { Ok(()) } + #[test] + fn from_datafusion_expression_schema() -> Result<()> { + let df_expr = col("a") + lit(1i64); + let schema = Arc::new(DFSchema::new_with_metadata( + vec![DFField::new(Some("table1"), "a", DataType::Int64, false)], + HashMap::new(), + )?); + let jit_expr: crate::ast::Expr = (df_expr, schema).try_into()?; + + assert_eq!( + jit_expr, + Expr::Binary(BinaryExpr::Add( + Box::new(Expr::Identifier("table1.a".to_string(), I64)), + Box::new(Expr::Literal(Literal::Typed(TypedLit::Int(1)))) + )), + ); + + Ok(()) + } + unsafe fn run_code( jit: &mut JIT, code: GeneratedFunction,