Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(substrait): set ProjectRel output_mapping in producer #12495

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 53 additions & 26 deletions datafusion/substrait/src/logical_plan/producer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,9 @@ use substrait::proto::expression::literal::{
use substrait::proto::expression::subquery::InPredicate;
use substrait::proto::expression::window_function::BoundsType;
use substrait::proto::read_rel::VirtualTable;
use substrait::proto::{CrossRel, ExchangeRel};
use substrait::proto::rel_common::EmitKind;
use substrait::proto::rel_common::EmitKind::Emit;
use substrait::proto::{rel_common, CrossRel, ExchangeRel, RelCommon};
use substrait::{
proto::{
aggregate_function::AggregationInvocation,
Expand Down Expand Up @@ -219,9 +221,20 @@ pub fn to_substrait_rel(
.iter()
.map(|e| to_substrait_rex(ctx, e, p.input.schema(), 0, extensions))
.collect::<Result<Vec<_>>>()?;

let emit_kind = create_project_remapping(
expressions.len(),
p.input.as_ref().schema().fields().len(),
);
let common = RelCommon {
emit_kind: Some(emit_kind),
hint: None,
advanced_extension: None,
};

Ok(Box::new(Rel {
rel_type: Some(RelType::Project(Box::new(ProjectRel {
common: None,
common: Some(common),
input: Some(to_substrait_rel(p.input.as_ref(), ctx, extensions)?),
expressions,
advanced_extension: None,
Expand Down Expand Up @@ -432,38 +445,39 @@ pub fn to_substrait_rel(
}
LogicalPlan::Window(window) => {
let input = to_substrait_rel(window.input.as_ref(), ctx, extensions)?;
// If the input is a Project relation, we can just append the WindowFunction expressions
// before returning
// Otherwise, wrap the input in a Project relation before appending the WindowFunction
// expressions
let mut project_rel: Box<ProjectRel> = match &input.as_ref().rel_type {
Some(RelType::Project(p)) => Box::new(*p.clone()),
_ => {
// Create Projection with field referencing all output fields in the input relation
let expressions = (0..window.input.schema().fields().len())
.map(substrait_field_ref)
.collect::<Result<Vec<_>>>()?;
Box::new(ProjectRel {
common: None,
input: Some(input),
expressions,
advanced_extension: None,
})
}
};
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is potentially a useful optimization, however it becomes a bit more complicated with the introduction of the output_mapping because you a need to modify it along with the expressions. I've opted to simplify this for now to favour simplicity and correctness.

As well, I think this is better handled in when consuming plans and/or by the optimizer.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, the datafusion optimizer already handles projection pushdown quite well -- so keeping the substrait producer simpler makes the most sense to me

// Parse WindowFunction expression
let mut window_exprs = vec![];

// create a field reference for each input field
let mut expressions = (0..window.input.schema().fields().len())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess you could not add these ref expressions, and then use a direct emit (or remap starting from 0) instead, would that not be the same? That said it'd break roundtrip until you've gotten the consumer part done as well 😅

.map(substrait_field_ref)
.collect::<Result<Vec<_>>>()?;

// process and add each window function expression
for expr in &window.window_expr {
window_exprs.push(to_substrait_rex(
expressions.push(to_substrait_rex(
ctx,
expr,
window.input.schema(),
0,
extensions,
)?);
}
// Append parsed WindowFunction expressions
project_rel.expressions.extend(window_exprs);

let emit_kind = create_project_remapping(
expressions.len(),
window.input.schema().fields().len(),
);
let common = RelCommon {
emit_kind: Some(emit_kind),
hint: None,
advanced_extension: None,
};
let project_rel = Box::new(ProjectRel {
common: Some(common),
input: Some(input),
expressions,
advanced_extension: None,
});

Ok(Box::new(Rel {
rel_type: Some(RelType::Project(project_rel)),
}))
Expand Down Expand Up @@ -553,6 +567,19 @@ pub fn to_substrait_rel(
}
}

/// By default, a Substrait Project outputs all input fields followed by all expressions.
/// A DataFusion Projection only outputs expressions. In order to keep the Substrait
/// plan consistent with DataFusion, we must apply an output mapping that skips the input
/// fields so that the Substrait Project will only output the expression fields.
fn create_project_remapping(expr_count: usize, input_field_count: usize) -> EmitKind {
let expression_field_start = input_field_count;
let expression_field_end = expression_field_start + expr_count;
let output_mapping = (expression_field_start..expression_field_end)
.map(|i| i as i32)
.collect();
Emit(rel_common::Emit { output_mapping })
}

fn to_substrait_named_struct(
schema: &DFSchemaRef,
extensions: &mut Extensions,
Expand Down
101 changes: 101 additions & 0 deletions datafusion/substrait/tests/cases/serialize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,11 @@ mod tests {
use datafusion::error::Result;
use datafusion::prelude::*;

use datafusion_substrait::logical_plan::producer::to_substrait_plan;
use std::fs;
use substrait::proto::plan_rel::RelType;
use substrait::proto::rel_common::{Emit, EmitKind};
use substrait::proto::{rel, RelCommon};

#[tokio::test]
async fn serialize_simple_select() -> Result<()> {
Expand Down Expand Up @@ -63,6 +67,103 @@ mod tests {
Ok(())
}

#[tokio::test]
async fn include_remaps_for_projects() -> Result<()> {
let ctx = create_context().await?;
let df = ctx.sql("SELECT b, a + a, a FROM data").await?;
let datafusion_plan = df.into_optimized_plan()?;

assert_eq!(
format!("{}", datafusion_plan),
"Projection: data.b, data.a + data.a, data.a\
\n TableScan: data projection=[a, b]",
);

let plan = to_substrait_plan(&datafusion_plan, &ctx)?.as_ref().clone();

let relation = plan.relations.first().unwrap().rel_type.as_ref();
let root_rel = match relation {
Some(RelType::Root(root)) => root.input.as_ref().unwrap(),
_ => panic!("expected Root"),
};
if let Some(rel::RelType::Project(p)) = root_rel.rel_type.as_ref() {
// The input has 2 columns [a, b], the Projection has 3 expressions [b, a + a, a]
// The required output mapping is [2,3,4], which skips the 2 input columns.
assert_emit(p.common.as_ref(), vec![2, 3, 4]);

if let Some(rel::RelType::Read(r)) =
p.input.as_ref().unwrap().rel_type.as_ref()
{
let mask_expression = r.projection.as_ref().unwrap();
let select = mask_expression.select.as_ref().unwrap();
assert_eq!(
2,
select.struct_items.len(),
"Read outputs two columns: a, b"
);
return Ok(());
}
}
panic!("plan did not match expected structure")
}

#[tokio::test]
async fn include_remaps_for_windows() -> Result<()> {
let ctx = create_context().await?;
// let df = ctx.sql("SELECT a, b, lead(b) OVER (PARTITION BY a) FROM data").await?;
let df = ctx
.sql("SELECT b, RANK() OVER (PARTITION BY a), c FROM data;")
.await?;
let datafusion_plan = df.into_optimized_plan()?;
assert_eq!(
format!("{}", datafusion_plan),
"Projection: data.b, RANK() PARTITION BY [data.a] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING, data.c\
\n WindowAggr: windowExpr=[[RANK() PARTITION BY [data.a] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING]]\
\n TableScan: data projection=[a, b, c]",
);

let plan = to_substrait_plan(&datafusion_plan, &ctx)?.as_ref().clone();

let relation = plan.relations.first().unwrap().rel_type.as_ref();
let root_rel = match relation {
Some(RelType::Root(root)) => root.input.as_ref().unwrap(),
_ => panic!("expected Root"),
};

if let Some(rel::RelType::Project(p1)) = root_rel.rel_type.as_ref() {
// The WindowAggr outputs 4 columns, the Projection has 4 columns
Copy link
Contributor

@Blizzara Blizzara Sep 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think both are 3 columns?

Or well the window outputs 4 but the project outputs only 3.

I guess that's what you're saying here anyways 😅

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, that is a typo. It should have been

The WindowAggr outputs 4 columns, the Projection has 3 columns
                                                    ^^^

which is what the assert below is actually (correctly) checking.

I'll make a note to myself to update this when I make the consumer changes.

assert_emit(p1.common.as_ref(), vec![4, 5, 6]);

if let Some(rel::RelType::Project(p2)) =
p1.input.as_ref().unwrap().rel_type.as_ref()
{
// The input has 3 columns, the WindowAggr has 4 expression
assert_emit(p2.common.as_ref(), vec![3, 4, 5, 6]);

if let Some(rel::RelType::Read(r)) =
p2.input.as_ref().unwrap().rel_type.as_ref()
{
let mask_expression = r.projection.as_ref().unwrap();
let select = mask_expression.select.as_ref().unwrap();
assert_eq!(
3,
select.struct_items.len(),
"Read outputs three columns: a, b, c"
);
return Ok(());
}
}
}
panic!("plan did not match expected structure")
}

fn assert_emit(rel_common: Option<&RelCommon>, output_mapping: Vec<i32>) {
assert_eq!(
rel_common.unwrap().emit_kind.clone(),
Some(EmitKind::Emit(Emit { output_mapping }))
);
}

async fn create_context() -> Result<SessionContext> {
let ctx = SessionContext::new();
ctx.register_csv("data", "tests/testdata/data.csv", CsvReadOptions::new())
Expand Down