Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Projection Pushdown through user defined LogicalPlan nodes. #9690

Merged
43 changes: 43 additions & 0 deletions datafusion/expr/src/logical_plan/extension.rs
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,24 @@ pub trait UserDefinedLogicalNode: fmt::Debug + Send + Sync {
inputs: &[LogicalPlan],
) -> Arc<dyn UserDefinedLogicalNode>;

/// Returns the necessary input columns for this node required to compute
/// the columns in the output schema
///
/// This is used for projection push-down when DataFusion has determined that
/// only a subset of the output columns of this node are needed by its parents.
/// This API is used to tell DataFusion which, if any, of the input columns are no longer
/// needed.
///
/// Return `None`, the default, if this information can not be determined.
/// Returns `Some(_)` with the column indices for each child of this node that are
/// needed to compute `output_columns`
fn necessary_children_exprs(
&self,
_output_columns: &[usize],
) -> Option<Vec<Vec<usize>>> {
None
}

/// Update the hash `state` with this node requirements from
/// [`Hash`].
///
Expand Down Expand Up @@ -243,6 +261,24 @@ pub trait UserDefinedLogicalNodeCore:
// but the doc comments have not been updated.
#[allow(clippy::wrong_self_convention)]
fn from_template(&self, exprs: &[Expr], inputs: &[LogicalPlan]) -> Self;

/// Returns the necessary input columns for this node required to compute
/// the columns in the output schema
///
/// This is used for projection push-down when DataFusion has determined that
/// only a subset of the output columns of this node are needed by its parents.
/// This API is used to tell DataFusion which, if any, of the input columns are no longer
/// needed.
///
/// Return `None`, the default, if this information can not be determined.
/// Returns `Some(_)` with the column indices for each child of this node that are
/// needed to compute `output_columns`
fn necessary_children_exprs(
&self,
_output_columns: &[usize],
) -> Option<Vec<Vec<usize>>> {
None
}
}

/// Automatically derive UserDefinedLogicalNode to `UserDefinedLogicalNode`
Expand Down Expand Up @@ -284,6 +320,13 @@ impl<T: UserDefinedLogicalNodeCore> UserDefinedLogicalNode for T {
Arc::new(self.from_template(exprs, inputs))
}

fn necessary_children_exprs(
&self,
output_columns: &[usize],
) -> Option<Vec<Vec<usize>>> {
self.necessary_children_exprs(output_columns)
}

fn dyn_hash(&self, state: &mut dyn Hasher) {
let mut s = state;
self.hash(&mut s);
Expand Down
289 changes: 282 additions & 7 deletions datafusion/optimizer/src/optimize_projections.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ use crate::{OptimizerConfig, OptimizerRule};

use arrow::datatypes::SchemaRef;
use datafusion_common::{
get_required_group_by_exprs_indices, Column, DFSchema, DFSchemaRef, JoinType, Result,
get_required_group_by_exprs_indices, internal_err, Column, DFSchema, DFSchemaRef,
JoinType, Result,
};
use datafusion_expr::expr::{Alias, ScalarFunction};
use datafusion_expr::{
Expand Down Expand Up @@ -162,14 +163,40 @@ fn optimize_projections(
.map(|input| ((0..input.schema().fields().len()).collect_vec(), false))
.collect::<Vec<_>>()
}
LogicalPlan::Extension(extension) => {
let necessary_children_indices = if let Some(necessary_children_indices) =
extension.node.necessary_children_exprs(indices)
{
necessary_children_indices
} else {
// Requirements from parent cannot be routed down to user defined logical plan safely
return Ok(None);
};
let children = extension.node.inputs();
if children.len() != necessary_children_indices.len() {
return internal_err!("Inconsistent length between children and necessary children indices. \
Make sure `.necessary_children_exprs` implementation of the `UserDefinedLogicalNode` is \
consistent with actual children length for the node.");
}
// Expressions used by node.
let exprs = plan.expressions();
children
.into_iter()
.zip(necessary_children_indices)
.map(|(child, necessary_indices)| {
let child_schema = child.schema();
let child_req_indices =
indices_referred_by_exprs(child_schema, exprs.iter())?;
Ok((merge_slices(&necessary_indices, &child_req_indices), false))
})
.collect::<Result<Vec<_>>>()?
}
LogicalPlan::EmptyRelation(_)
| LogicalPlan::RecursiveQuery(_)
| LogicalPlan::Statement(_)
| LogicalPlan::Values(_)
| LogicalPlan::Extension(_)
| LogicalPlan::DescribeTable(_) => {
// These operators have no inputs, so stop the optimization process.
// TODO: Add support for `LogicalPlan::Extension`.
return Ok(None);
}
LogicalPlan::Projection(proj) => {
Expand Down Expand Up @@ -899,21 +926,161 @@ fn is_projection_unnecessary(input: &LogicalPlan, proj_exprs: &[Expr]) -> Result

#[cfg(test)]
mod tests {
use std::fmt::Formatter;
use std::sync::Arc;

use crate::optimize_projections::OptimizeProjections;
use crate::test::{assert_optimized_plan_eq, test_table_scan};
use crate::test::{
assert_optimized_plan_eq, test_table_scan, test_table_scan_with_name,
};
use arrow::datatypes::{DataType, Field, Schema};
use datafusion_common::{Result, TableReference};
use datafusion_common::{Column, DFSchemaRef, JoinType, Result, TableReference};
use datafusion_expr::{
binary_expr, col, count, lit, logical_plan::builder::LogicalPlanBuilder, not,
table_scan, try_cast, when, Expr, Like, LogicalPlan, Operator,
binary_expr, build_join_schema, col, count, lit,
logical_plan::builder::LogicalPlanBuilder, not, table_scan, try_cast, when,
BinaryExpr, Expr, Extension, Like, LogicalPlan, Operator,
UserDefinedLogicalNodeCore,
};

fn assert_optimized_plan_equal(plan: &LogicalPlan, expected: &str) -> Result<()> {
assert_optimized_plan_eq(Arc::new(OptimizeProjections::new()), plan, expected)
}

#[derive(Debug, Hash, PartialEq, Eq)]
struct NoOpUserDefined {
exprs: Vec<Expr>,
schema: DFSchemaRef,
input: Arc<LogicalPlan>,
}

impl NoOpUserDefined {
fn new(schema: DFSchemaRef, input: Arc<LogicalPlan>) -> Self {
Self {
exprs: vec![],
schema,
input,
}
}

fn with_exprs(mut self, exprs: Vec<Expr>) -> Self {
self.exprs = exprs;
self
}
}

impl UserDefinedLogicalNodeCore for NoOpUserDefined {
fn name(&self) -> &str {
"NoOpUserDefined"
}

fn inputs(&self) -> Vec<&LogicalPlan> {
vec![&self.input]
}

fn schema(&self) -> &DFSchemaRef {
&self.schema
}

fn expressions(&self) -> Vec<Expr> {
self.exprs.clone()
}

fn fmt_for_explain(&self, f: &mut Formatter) -> std::fmt::Result {
write!(f, "NoOpUserDefined")
}

fn from_template(&self, exprs: &[Expr], inputs: &[LogicalPlan]) -> Self {
Self {
exprs: exprs.to_vec(),
input: Arc::new(inputs[0].clone()),
schema: self.schema.clone(),
}
}

fn necessary_children_exprs(
&self,
output_columns: &[usize],
) -> Option<Vec<Vec<usize>>> {
// Since schema is same. Output columns requires their corresponding version in the input columns.
Some(vec![output_columns.to_vec()])
}
}

#[derive(Debug, Hash, PartialEq, Eq)]
struct UserDefinedCrossJoin {
exprs: Vec<Expr>,
schema: DFSchemaRef,
left_child: Arc<LogicalPlan>,
right_child: Arc<LogicalPlan>,
}

impl UserDefinedCrossJoin {
fn new(left_child: Arc<LogicalPlan>, right_child: Arc<LogicalPlan>) -> Self {
let left_schema = left_child.schema();
let right_schema = right_child.schema();
let schema = Arc::new(
build_join_schema(left_schema, right_schema, &JoinType::Inner).unwrap(),
);
Self {
exprs: vec![],
schema,
left_child,
right_child,
}
}
}

impl UserDefinedLogicalNodeCore for UserDefinedCrossJoin {
fn name(&self) -> &str {
"UserDefinedCrossJoin"
}

fn inputs(&self) -> Vec<&LogicalPlan> {
vec![&self.left_child, &self.right_child]
}

fn schema(&self) -> &DFSchemaRef {
&self.schema
}

fn expressions(&self) -> Vec<Expr> {
self.exprs.clone()
}

fn fmt_for_explain(&self, f: &mut Formatter) -> std::fmt::Result {
write!(f, "UserDefinedCrossJoin")
}

fn from_template(&self, exprs: &[Expr], inputs: &[LogicalPlan]) -> Self {
assert_eq!(inputs.len(), 2);
Self {
exprs: exprs.to_vec(),
left_child: Arc::new(inputs[0].clone()),
right_child: Arc::new(inputs[1].clone()),
schema: self.schema.clone(),
}
}

fn necessary_children_exprs(
&self,
output_columns: &[usize],
) -> Option<Vec<Vec<usize>>> {
let left_child_len = self.left_child.schema().fields().len();
let mut left_reqs = vec![];
let mut right_reqs = vec![];
for &out_idx in output_columns {
if out_idx < left_child_len {
left_reqs.push(out_idx);
} else {
// Output indices further than the left_child_len
// comes from right children
right_reqs.push(out_idx - left_child_len)
}
}
Some(vec![left_reqs, right_reqs])
}
}

#[test]
fn merge_two_projection() -> Result<()> {
let table_scan = test_table_scan()?;
Expand Down Expand Up @@ -1192,4 +1359,112 @@ mod tests {
\n TableScan: test projection=[a]";
assert_optimized_plan_equal(&plan, expected)
}

// Since only column `a` is referred at the output. Scan should only contain projection=[a].
// User defined node should be able to propagate necessary expressions by its parent to its child.
#[test]
fn test_user_defined_logical_plan_node() -> Result<()> {
mustafasrepo marked this conversation as resolved.
Show resolved Hide resolved
let table_scan = test_table_scan()?;
let custom_plan = LogicalPlan::Extension(Extension {
node: Arc::new(NoOpUserDefined::new(
table_scan.schema().clone(),
Arc::new(table_scan.clone()),
)),
});
let plan = LogicalPlanBuilder::from(custom_plan)
.project(vec![col("a"), lit(0).alias("d")])?
.build()?;

let expected = "Projection: test.a, Int32(0) AS d\
\n NoOpUserDefined\
\n TableScan: test projection=[a]";
assert_optimized_plan_equal(&plan, expected)
}

// Only column `a` is referred at the output. However, User defined node itself uses column `b`
// during its operation. Hence, scan should contain projection=[a, b].
// User defined node should be able to propagate necessary expressions by its parent, as well as its own
// required expressions.
#[test]
fn test_user_defined_logical_plan_node2() -> Result<()> {
let table_scan = test_table_scan()?;
let exprs = vec![Expr::Column(Column::from_qualified_name("b"))];
let custom_plan = LogicalPlan::Extension(Extension {
node: Arc::new(
NoOpUserDefined::new(
table_scan.schema().clone(),
Arc::new(table_scan.clone()),
)
.with_exprs(exprs),
),
});
let plan = LogicalPlanBuilder::from(custom_plan)
.project(vec![col("a"), lit(0).alias("d")])?
.build()?;

let expected = "Projection: test.a, Int32(0) AS d\
\n NoOpUserDefined\
\n TableScan: test projection=[a, b]";
assert_optimized_plan_equal(&plan, expected)
}

// Only column `a` is referred at the output. However, User defined node itself uses expression `b+c`
// during its operation. Hence, scan should contain projection=[a, b, c].
// User defined node should be able to propagate necessary expressions by its parent, as well as its own
// required expressions. Expressions doesn't have to be just column. Requirements from complex expressions
// should be propagated also.
#[test]
fn test_user_defined_logical_plan_node3() -> Result<()> {
let table_scan = test_table_scan()?;
let left_expr = Expr::Column(Column::from_qualified_name("b"));
let right_expr = Expr::Column(Column::from_qualified_name("c"));
let binary_expr = Expr::BinaryExpr(BinaryExpr::new(
Box::new(left_expr),
Operator::Plus,
Box::new(right_expr),
));
let exprs = vec![binary_expr];
let custom_plan = LogicalPlan::Extension(Extension {
node: Arc::new(
NoOpUserDefined::new(
table_scan.schema().clone(),
Arc::new(table_scan.clone()),
)
.with_exprs(exprs),
),
});
let plan = LogicalPlanBuilder::from(custom_plan)
.project(vec![col("a"), lit(0).alias("d")])?
.build()?;

let expected = "Projection: test.a, Int32(0) AS d\
\n NoOpUserDefined\
\n TableScan: test projection=[a, b, c]";
assert_optimized_plan_equal(&plan, expected)
}

// Columns `l.a`, `l.c`, `r.a` is referred at the output.
// User defined node should be able to propagate necessary expressions by its parent, to its children.
// Even if it has multiple children.
// left child should have `projection=[a, c]`, and right side should have `projection=[a]`.
#[test]
fn test_user_defined_logical_plan_node4() -> Result<()> {
let left_table = test_table_scan_with_name("l")?;
let right_table = test_table_scan_with_name("r")?;
let custom_plan = LogicalPlan::Extension(Extension {
node: Arc::new(UserDefinedCrossJoin::new(
Arc::new(left_table.clone()),
Arc::new(right_table.clone()),
)),
});
let plan = LogicalPlanBuilder::from(custom_plan)
.project(vec![col("l.a"), col("l.c"), col("r.a"), lit(0).alias("d")])?
.build()?;

let expected = "Projection: l.a, l.c, r.a, Int32(0) AS d\
\n UserDefinedCrossJoin\
\n TableScan: l projection=[a, c]\
\n TableScan: r projection=[a]";
assert_optimized_plan_equal(&plan, expected)
}
}
Loading