Skip to content

Commit

Permalink
Get expr planners when creating new planner (apache#11485)
Browse files Browse the repository at this point in the history
* get expr planners when creating new planner

Signed-off-by: jayzhan211 <jayzhan211@gmail.com>

* get expr planner when creating planner

Signed-off-by: jayzhan211 <jayzhan211@gmail.com>

* no planners in sqltorel

Signed-off-by: jayzhan211 <jayzhan211@gmail.com>

* Add docs about SessionContextProvider

* Use Slice rather than Vec to access expr planners

* add test

Signed-off-by: jayzhan211 <jayzhan211@gmail.com>

* clippy

Signed-off-by: jayzhan211 <jayzhan211@gmail.com>

---------

Signed-off-by: jayzhan211 <jayzhan211@gmail.com>
Co-authored-by: Andrew Lamb <andrew@nerdnetworks.org>
  • Loading branch information
jayzhan211 and alamb authored Jul 17, 2024
1 parent d67b0fb commit de0765a
Show file tree
Hide file tree
Showing 6 changed files with 68 additions and 35 deletions.
70 changes: 54 additions & 16 deletions datafusion/core/src/execution/session_state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -516,7 +516,7 @@ impl SessionState {
}
}

let query = self.build_sql_query_planner(&provider);
let query = SqlToRel::new_with_options(&provider, self.get_parser_options());
query.statement_to_plan(statement)
}

Expand Down Expand Up @@ -569,7 +569,7 @@ impl SessionState {
tables: HashMap::new(),
};

let query = self.build_sql_query_planner(&provider);
let query = SqlToRel::new_with_options(&provider, self.get_parser_options());
query.sql_to_expr(sql_expr, df_schema, &mut PlannerContext::new())
}

Expand Down Expand Up @@ -854,20 +854,6 @@ impl SessionState {
let udtf = self.table_functions.remove(name);
Ok(udtf.map(|x| x.function().clone()))
}

fn build_sql_query_planner<'a, S>(&self, provider: &'a S) -> SqlToRel<'a, S>
where
S: ContextProvider,
{
let mut query = SqlToRel::new_with_options(provider, self.get_parser_options());

// custom planners are registered first, so they're run first and take precedence over built-in planners
for planner in self.expr_planners.iter() {
query = query.with_user_defined_planner(planner.clone());
}

query
}
}

/// A builder to be used for building [`SessionState`]'s. Defaults will
Expand Down Expand Up @@ -1597,12 +1583,20 @@ impl SessionStateDefaults {
}
}

/// Adapter that implements the [`ContextProvider`] trait for a [`SessionState`]
///
/// This is used so the SQL planner can access the state of the session without
/// having a direct dependency on the [`SessionState`] struct (and core crate)
struct SessionContextProvider<'a> {
state: &'a SessionState,
tables: HashMap<String, Arc<dyn TableSource>>,
}

impl<'a> ContextProvider for SessionContextProvider<'a> {
fn get_expr_planners(&self) -> &[Arc<dyn ExprPlanner>] {
&self.state.expr_planners
}

fn get_table_source(
&self,
name: TableReference,
Expand Down Expand Up @@ -1898,3 +1892,47 @@ impl<'a> SimplifyInfo for SessionSimplifyProvider<'a> {
expr.get_type(self.df_schema)
}
}

#[cfg(test)]
mod tests {
use std::collections::HashMap;

use arrow_schema::{DataType, Field, Schema};
use datafusion_common::DFSchema;
use datafusion_common::Result;
use datafusion_expr::Expr;
use datafusion_sql::planner::{PlannerContext, SqlToRel};

use crate::execution::context::SessionState;

use super::{SessionContextProvider, SessionStateBuilder};

#[test]
fn test_session_state_with_default_features() {
// test array planners with and without builtin planners
fn sql_to_expr(state: &SessionState) -> Result<Expr> {
let provider = SessionContextProvider {
state,
tables: HashMap::new(),
};

let sql = "[1,2,3]";
let schema = Schema::new(vec![Field::new("a", DataType::Int32, true)]);
let df_schema = DFSchema::try_from(schema)?;
let dialect = state.config.options().sql_parser.dialect.as_str();
let sql_expr = state.sql_to_expr(sql, dialect)?;

let query = SqlToRel::new_with_options(&provider, state.get_parser_options());
query.sql_to_expr(sql_expr, &df_schema, &mut PlannerContext::new())
}

let state = SessionStateBuilder::new().with_default_features().build();

assert!(sql_to_expr(&state).is_ok());

// if no builtin planners exist, you should register your own, otherwise returns error
let state = SessionStateBuilder::new().build();

assert!(sql_to_expr(&state).is_err())
}
}
5 changes: 5 additions & 0 deletions datafusion/expr/src/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,11 @@ pub trait ContextProvider {
not_impl_err!("Recursive CTE is not implemented")
}

/// Getter for expr planners
fn get_expr_planners(&self) -> &[Arc<dyn ExprPlanner>] {
&[]
}

/// Getter for a UDF description
fn get_function_meta(&self, name: &str) -> Option<Arc<ScalarUDF>>;
/// Getter for a UDAF description
Expand Down
14 changes: 7 additions & 7 deletions datafusion/sql/src/expr/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
) -> Result<Expr> {
// try extension planers
let mut binary_expr = datafusion_expr::planner::RawBinaryExpr { op, left, right };
for planner in self.planners.iter() {
for planner in self.context_provider.get_expr_planners() {
match planner.plan_binary_op(binary_expr, schema)? {
PlannerResult::Planned(expr) => {
return Ok(expr);
Expand Down Expand Up @@ -184,7 +184,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
self.sql_expr_to_logical_expr(*expr, schema, planner_context)?,
];

for planner in self.planners.iter() {
for planner in self.context_provider.get_expr_planners() {
match planner.plan_extract(extract_args)? {
PlannerResult::Planned(expr) => return Ok(expr),
PlannerResult::Original(args) => {
Expand Down Expand Up @@ -283,7 +283,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
};

let mut field_access_expr = RawFieldAccessExpr { expr, field_access };
for planner in self.planners.iter() {
for planner in self.context_provider.get_expr_planners() {
match planner.plan_field_access(field_access_expr, schema)? {
PlannerResult::Planned(expr) => return Ok(expr),
PlannerResult::Original(expr) => {
Expand Down Expand Up @@ -653,7 +653,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
self.create_struct_expr(values, schema, planner_context)?
};

for planner in self.planners.iter() {
for planner in self.context_provider.get_expr_planners() {
match planner.plan_struct_literal(create_struct_args, is_named_struct)? {
PlannerResult::Planned(expr) => return Ok(expr),
PlannerResult::Original(args) => create_struct_args = args,
Expand All @@ -673,7 +673,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
self.sql_expr_to_logical_expr(substr_expr, schema, planner_context)?;
let fullstr = self.sql_expr_to_logical_expr(str_expr, schema, planner_context)?;
let mut position_args = vec![fullstr, substr];
for planner in self.planners.iter() {
for planner in self.context_provider.get_expr_planners() {
match planner.plan_position(position_args)? {
PlannerResult::Planned(expr) => return Ok(expr),
PlannerResult::Original(args) => {
Expand Down Expand Up @@ -703,7 +703,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {

let mut raw_expr = RawDictionaryExpr { keys, values };

for planner in self.planners.iter() {
for planner in self.context_provider.get_expr_planners() {
match planner.plan_dictionary_literal(raw_expr, schema)? {
PlannerResult::Planned(expr) => {
return Ok(expr);
Expand Down Expand Up @@ -927,7 +927,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
}
None => vec![arg, what_arg, from_arg],
};
for planner in self.planners.iter() {
for planner in self.context_provider.get_expr_planners() {
match planner.plan_overlay(overlay_args)? {
PlannerResult::Planned(expr) => return Ok(expr),
PlannerResult::Original(args) => overlay_args = args,
Expand Down
2 changes: 1 addition & 1 deletion datafusion/sql/src/expr/substring.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
}
};

for planner in self.planners.iter() {
for planner in self.context_provider.get_expr_planners() {
match planner.plan_substring(substring_args)? {
PlannerResult::Planned(expr) => return Ok(expr),
PlannerResult::Original(args) => {
Expand Down
2 changes: 1 addition & 1 deletion datafusion/sql/src/expr/value.rs
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
schema: &DFSchema,
) -> Result<Expr> {
let mut exprs = values;
for planner in self.planners.iter() {
for planner in self.context_provider.get_expr_planners() {
match planner.plan_array_literal(exprs, schema)? {
PlannerResult::Planned(expr) => {
return Ok(expr);
Expand Down
10 changes: 0 additions & 10 deletions datafusion/sql/src/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ use arrow_schema::*;
use datafusion_common::{
field_not_found, internal_err, plan_datafusion_err, DFSchemaRef, SchemaError,
};
use datafusion_expr::planner::ExprPlanner;
use sqlparser::ast::TimezoneInfo;
use sqlparser::ast::{ArrayElemTypeDef, ExactNumberInfo};
use sqlparser::ast::{ColumnDef as SQLColumnDef, ColumnOption};
Expand Down Expand Up @@ -186,8 +185,6 @@ pub struct SqlToRel<'a, S: ContextProvider> {
pub(crate) context_provider: &'a S,
pub(crate) options: ParserOptions,
pub(crate) normalizer: IdentNormalizer,
/// user defined planner extensions
pub(crate) planners: Vec<Arc<dyn ExprPlanner>>,
}

impl<'a, S: ContextProvider> SqlToRel<'a, S> {
Expand All @@ -196,12 +193,6 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
Self::new_with_options(context_provider, ParserOptions::default())
}

/// add an user defined planner
pub fn with_user_defined_planner(mut self, planner: Arc<dyn ExprPlanner>) -> Self {
self.planners.push(planner);
self
}

/// Create a new query planner
pub fn new_with_options(context_provider: &'a S, options: ParserOptions) -> Self {
let normalize = options.enable_ident_normalization;
Expand All @@ -210,7 +201,6 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
context_provider,
options,
normalizer: IdentNormalizer::new(normalize),
planners: vec![],
}
}

Expand Down

0 comments on commit de0765a

Please sign in to comment.