Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Projection Pushdown through user defined LogicalPlan nodes. #9690

Merged
43 changes: 43 additions & 0 deletions datafusion/expr/src/logical_plan/extension.rs
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,24 @@ pub trait UserDefinedLogicalNode: fmt::Debug + Send + Sync {
inputs: &[LogicalPlan],
) -> Arc<dyn UserDefinedLogicalNode>;

/// Returns the necessary input columns for this node required to compute
/// the columns in the output schema
///
/// This is used for projection push-down when DataFusion has determined that
/// only a subset of the output columns of this node are needed by its parents.
/// This API is used to tell DataFusion which, if any, of the input columns are no longer
/// needed.
///
/// Return `None`, the default, if this information can not be determined.
/// Returns `Some(_)` with the column indices for each child of this node that are
/// needed to compute `output_columns`
fn necessary_children_exprs(
&self,
_output_columns: &[usize],
) -> Option<Vec<Vec<usize>>> {
None
}

/// Update the hash `state` with this node requirements from
/// [`Hash`].
///
Expand Down Expand Up @@ -243,6 +261,24 @@ pub trait UserDefinedLogicalNodeCore:
// but the doc comments have not been updated.
#[allow(clippy::wrong_self_convention)]
fn from_template(&self, exprs: &[Expr], inputs: &[LogicalPlan]) -> Self;

/// Returns the necessary input columns for this node required to compute
/// the columns in the output schema
///
/// This is used for projection push-down when DataFusion has determined that
/// only a subset of the output columns of this node are needed by its parents.
/// This API is used to tell DataFusion which, if any, of the input columns are no longer
/// needed.
///
/// Return `None`, the default, if this information can not be determined.
/// Returns `Some(_)` with the column indices for each child of this node that are
/// needed to compute `output_columns`
fn necessary_children_exprs(
&self,
_output_columns: &[usize],
) -> Option<Vec<Vec<usize>>> {
None
}
}

/// Automatically derive UserDefinedLogicalNode to `UserDefinedLogicalNode`
Expand Down Expand Up @@ -284,6 +320,13 @@ impl<T: UserDefinedLogicalNodeCore> UserDefinedLogicalNode for T {
Arc::new(self.from_template(exprs, inputs))
}

fn necessary_children_exprs(
&self,
output_columns: &[usize],
) -> Option<Vec<Vec<usize>>> {
self.necessary_children_exprs(output_columns)
}

fn dyn_hash(&self, state: &mut dyn Hasher) {
let mut s = state;
self.hash(&mut s);
Expand Down
296 changes: 290 additions & 6 deletions datafusion/optimizer/src/optimize_projections.rs
Original file line number Diff line number Diff line change
Expand Up @@ -162,14 +162,36 @@ fn optimize_projections(
.map(|input| ((0..input.schema().fields().len()).collect_vec(), false))
.collect::<Vec<_>>()
}
LogicalPlan::Extension(extension) => {
let necessary_children_indices = if let Some(necessary_children_indices) =
extension.node.necessary_children_exprs(indices)
{
necessary_children_indices
} else {
// Requirements from parent cannot be routed down to user defined logical plan safely
return Ok(None);
};
let children = extension.node.inputs();
assert_eq!(children.len(), necessary_children_indices.len());
mustafasrepo marked this conversation as resolved.
Show resolved Hide resolved
// Expressions used by node.
let exprs = plan.expressions();
children
.into_iter()
.zip(necessary_children_indices)
.map(|(child, necessary_indices)| {
let child_schema = child.schema();
let child_req_indices =
indices_referred_by_exprs(child_schema, exprs.iter())?;
Ok((merge_slices(&necessary_indices, &child_req_indices), false))
})
.collect::<Result<Vec<_>>>()?
}
LogicalPlan::EmptyRelation(_)
| LogicalPlan::RecursiveQuery(_)
| LogicalPlan::Statement(_)
| LogicalPlan::Values(_)
| LogicalPlan::Extension(_)
| LogicalPlan::DescribeTable(_) => {
// These operators have no inputs, so stop the optimization process.
// TODO: Add support for `LogicalPlan::Extension`.
return Ok(None);
}
LogicalPlan::Projection(proj) => {
Expand Down Expand Up @@ -899,21 +921,161 @@ fn is_projection_unnecessary(input: &LogicalPlan, proj_exprs: &[Expr]) -> Result

#[cfg(test)]
mod tests {
use std::fmt::Formatter;
use std::sync::Arc;

use crate::optimize_projections::OptimizeProjections;
use crate::test::{assert_optimized_plan_eq, test_table_scan};
use crate::test::{
assert_optimized_plan_eq, test_table_scan, test_table_scan_with_name,
};
use arrow::datatypes::{DataType, Field, Schema};
use datafusion_common::{Result, TableReference};
use datafusion_common::{Column, DFSchemaRef, JoinType, Result, TableReference};
use datafusion_expr::{
binary_expr, col, count, lit, logical_plan::builder::LogicalPlanBuilder, not,
table_scan, try_cast, when, Expr, Like, LogicalPlan, Operator,
binary_expr, build_join_schema, col, count, lit,
logical_plan::builder::LogicalPlanBuilder, not, table_scan, try_cast, when,
BinaryExpr, Expr, Extension, Like, LogicalPlan, Operator,
UserDefinedLogicalNodeCore,
};

fn assert_optimized_plan_equal(plan: &LogicalPlan, expected: &str) -> Result<()> {
assert_optimized_plan_eq(Arc::new(OptimizeProjections::new()), plan, expected)
}

#[derive(Debug, Hash, PartialEq, Eq)]
struct NoOpUserDefined {
exprs: Vec<Expr>,
schema: DFSchemaRef,
input: Arc<LogicalPlan>,
}

impl NoOpUserDefined {
fn new(schema: DFSchemaRef, input: Arc<LogicalPlan>) -> Self {
Self {
exprs: vec![],
schema,
input,
}
}

fn with_exprs(mut self, exprs: Vec<Expr>) -> Self {
self.exprs = exprs;
self
}
}

impl UserDefinedLogicalNodeCore for NoOpUserDefined {
fn name(&self) -> &str {
"NoOpUserDefined"
}

fn inputs(&self) -> Vec<&LogicalPlan> {
vec![&self.input]
}

fn schema(&self) -> &DFSchemaRef {
&self.schema
}

fn expressions(&self) -> Vec<Expr> {
self.exprs.clone()
}

fn fmt_for_explain(&self, f: &mut Formatter) -> std::fmt::Result {
write!(f, "NoOpUserDefined")
}

fn from_template(&self, exprs: &[Expr], inputs: &[LogicalPlan]) -> Self {
Self {
exprs: exprs.to_vec(),
input: Arc::new(inputs[0].clone()),
schema: self.schema.clone(),
}
}

fn necessary_children_exprs(
&self,
output_columns: &[usize],
) -> Option<Vec<Vec<usize>>> {
// Since schema is same. Output columns requires their corresponding version in the input columns.
Some(vec![output_columns.to_vec()])
}
}

#[derive(Debug, Hash, PartialEq, Eq)]
struct UserDefinedCrossJoin {
exprs: Vec<Expr>,
schema: DFSchemaRef,
left_child: Arc<LogicalPlan>,
right_child: Arc<LogicalPlan>,
}

impl UserDefinedCrossJoin {
fn new(left_child: Arc<LogicalPlan>, right_child: Arc<LogicalPlan>) -> Self {
let left_schema = left_child.schema();
let right_schema = right_child.schema();
let schema = Arc::new(
build_join_schema(left_schema, right_schema, &JoinType::Inner).unwrap(),
);
Self {
exprs: vec![],
schema,
left_child,
right_child,
}
}
}

impl UserDefinedLogicalNodeCore for UserDefinedCrossJoin {
fn name(&self) -> &str {
"UserDefinedCrossJoin"
}

fn inputs(&self) -> Vec<&LogicalPlan> {
vec![&self.left_child, &self.right_child]
}

fn schema(&self) -> &DFSchemaRef {
&self.schema
}

fn expressions(&self) -> Vec<Expr> {
self.exprs.clone()
}

fn fmt_for_explain(&self, f: &mut Formatter) -> std::fmt::Result {
write!(f, "UserDefinedCrossJoin")
}

fn from_template(&self, exprs: &[Expr], inputs: &[LogicalPlan]) -> Self {
assert_eq!(inputs.len(), 2);
Self {
exprs: exprs.to_vec(),
left_child: Arc::new(inputs[0].clone()),
right_child: Arc::new(inputs[1].clone()),
schema: self.schema.clone(),
}
}

fn necessary_children_exprs(
&self,
output_columns: &[usize],
) -> Option<Vec<Vec<usize>>> {
let left_child_len = self.left_child.schema().fields().len();
let mut left_reqs = vec![];
let mut right_reqs = vec![];
for &out_idx in output_columns {
if out_idx < left_child_len {
left_reqs.push(out_idx);
} else {
// Output indices further than the left_child_len
// comes from right children
right_reqs.push(out_idx - left_child_len)
}
}
Some(vec![left_reqs, right_reqs])
}
}

#[test]
fn merge_two_projection() -> Result<()> {
let table_scan = test_table_scan()?;
Expand Down Expand Up @@ -1192,4 +1354,126 @@ mod tests {
\n TableScan: test projection=[a]";
assert_optimized_plan_equal(&plan, expected)
}

// Optimize Projections Rule, pushes down projection through users defined logical plan node.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we please add a few more tests:

  1. The NoOpUserDefined plan itself refers to column a in its expressions
  2. The NoOpUserDefined plan itself refers to column b in its expressions
  3. The NoOpUserDefined plan itself refers to column a + b in its expressions

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added tests for these cases.

#[test]
fn test_user_defined_logical_plan_node() -> Result<()> {
mustafasrepo marked this conversation as resolved.
Show resolved Hide resolved
let table_scan = test_table_scan()?;
let custom_plan = LogicalPlan::Extension(Extension {
node: Arc::new(NoOpUserDefined::new(
table_scan.schema().clone(),
Arc::new(table_scan.clone()),
)),
});
let plan = LogicalPlanBuilder::from(custom_plan)
.project(vec![col("a"), lit(0).alias("d")])?
.build()?;

let expected = "Projection: test.a, Int32(0) AS d\
\n NoOpUserDefined\
\n TableScan: test projection=[a]";
assert_optimized_plan_equal(&plan, expected)
}

// Optimize Projections Rule, pushes down projection through users defined logical plan node.
#[test]
fn test_user_defined_logical_plan_node2() -> Result<()> {
let table_scan = test_table_scan()?;
let exprs = vec![Expr::Column(Column::from_qualified_name("a"))];
let custom_plan = LogicalPlan::Extension(Extension {
node: Arc::new(
NoOpUserDefined::new(
table_scan.schema().clone(),
Arc::new(table_scan.clone()),
)
.with_exprs(exprs),
),
});
let plan = LogicalPlanBuilder::from(custom_plan)
.project(vec![col("a"), lit(0).alias("d")])?
.build()?;

let expected = "Projection: test.a, Int32(0) AS d\
\n NoOpUserDefined\
\n TableScan: test projection=[a]";
assert_optimized_plan_equal(&plan, expected)
}

// Optimize Projections Rule, pushes down projection through users defined logical plan node.
#[test]
fn test_user_defined_logical_plan_node3() -> Result<()> {
let table_scan = test_table_scan()?;
let exprs = vec![Expr::Column(Column::from_qualified_name("b"))];
let custom_plan = LogicalPlan::Extension(Extension {
node: Arc::new(
NoOpUserDefined::new(
table_scan.schema().clone(),
Arc::new(table_scan.clone()),
)
.with_exprs(exprs),
),
});
let plan = LogicalPlanBuilder::from(custom_plan)
.project(vec![col("a"), lit(0).alias("d")])?
.build()?;

let expected = "Projection: test.a, Int32(0) AS d\
\n NoOpUserDefined\
\n TableScan: test projection=[a, b]";
assert_optimized_plan_equal(&plan, expected)
}

// Optimize Projections Rule, pushes down projection through users defined logical plan node.
#[test]
fn test_user_defined_logical_plan_node4() -> Result<()> {
let table_scan = test_table_scan()?;
let left_expr = Expr::Column(Column::from_qualified_name("b"));
let right_expr = Expr::Column(Column::from_qualified_name("c"));
let binary_expr = Expr::BinaryExpr(BinaryExpr::new(
Box::new(left_expr),
Operator::Plus,
Box::new(right_expr),
));
let exprs = vec![binary_expr];
let custom_plan = LogicalPlan::Extension(Extension {
node: Arc::new(
NoOpUserDefined::new(
table_scan.schema().clone(),
Arc::new(table_scan.clone()),
)
.with_exprs(exprs),
),
});
let plan = LogicalPlanBuilder::from(custom_plan)
.project(vec![col("a"), lit(0).alias("d")])?
.build()?;

let expected = "Projection: test.a, Int32(0) AS d\
\n NoOpUserDefined\
\n TableScan: test projection=[a, b, c]";
assert_optimized_plan_equal(&plan, expected)
}

// Optimize Projections Rule, pushes down projection through
// users defined logical plan nodes with more than single child
#[test]
fn test_user_defined_logical_plan_node5() -> Result<()> {
let left_table = test_table_scan_with_name("l")?;
let right_table = test_table_scan_with_name("r")?;
let custom_plan = LogicalPlan::Extension(Extension {
node: Arc::new(UserDefinedCrossJoin::new(
Arc::new(left_table.clone()),
Arc::new(right_table.clone()),
)),
});
let plan = LogicalPlanBuilder::from(custom_plan)
.project(vec![col("l.a"), col("l.c"), col("r.a"), lit(0).alias("d")])?
.build()?;

let expected = "Projection: l.a, l.c, r.a, Int32(0) AS d\
\n UserDefinedCrossJoin\
\n TableScan: l projection=[a, c]\
\n TableScan: r projection=[a]";
assert_optimized_plan_equal(&plan, expected)
}
}
Loading