-
Notifications
You must be signed in to change notification settings - Fork 1.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat(substrait): set ProjectRel output_mapping in producer #12495
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -61,7 +61,9 @@ use substrait::proto::expression::literal::{ | |
use substrait::proto::expression::subquery::InPredicate; | ||
use substrait::proto::expression::window_function::BoundsType; | ||
use substrait::proto::read_rel::VirtualTable; | ||
use substrait::proto::{CrossRel, ExchangeRel}; | ||
use substrait::proto::rel_common::EmitKind; | ||
use substrait::proto::rel_common::EmitKind::Emit; | ||
use substrait::proto::{rel_common, CrossRel, ExchangeRel, RelCommon}; | ||
use substrait::{ | ||
proto::{ | ||
aggregate_function::AggregationInvocation, | ||
|
@@ -219,9 +221,20 @@ pub fn to_substrait_rel( | |
.iter() | ||
.map(|e| to_substrait_rex(ctx, e, p.input.schema(), 0, extensions)) | ||
.collect::<Result<Vec<_>>>()?; | ||
|
||
let emit_kind = create_project_remapping( | ||
expressions.len(), | ||
p.input.as_ref().schema().fields().len(), | ||
); | ||
let common = RelCommon { | ||
emit_kind: Some(emit_kind), | ||
hint: None, | ||
advanced_extension: None, | ||
}; | ||
|
||
Ok(Box::new(Rel { | ||
rel_type: Some(RelType::Project(Box::new(ProjectRel { | ||
common: None, | ||
common: Some(common), | ||
input: Some(to_substrait_rel(p.input.as_ref(), ctx, extensions)?), | ||
expressions, | ||
advanced_extension: None, | ||
|
@@ -432,38 +445,39 @@ pub fn to_substrait_rel( | |
} | ||
LogicalPlan::Window(window) => { | ||
let input = to_substrait_rel(window.input.as_ref(), ctx, extensions)?; | ||
// If the input is a Project relation, we can just append the WindowFunction expressions | ||
// before returning | ||
// Otherwise, wrap the input in a Project relation before appending the WindowFunction | ||
// expressions | ||
let mut project_rel: Box<ProjectRel> = match &input.as_ref().rel_type { | ||
Some(RelType::Project(p)) => Box::new(*p.clone()), | ||
_ => { | ||
// Create Projection with field referencing all output fields in the input relation | ||
let expressions = (0..window.input.schema().fields().len()) | ||
.map(substrait_field_ref) | ||
.collect::<Result<Vec<_>>>()?; | ||
Box::new(ProjectRel { | ||
common: None, | ||
input: Some(input), | ||
expressions, | ||
advanced_extension: None, | ||
}) | ||
} | ||
}; | ||
// Parse WindowFunction expression | ||
let mut window_exprs = vec![]; | ||
|
||
// create a field reference for each input field | ||
let mut expressions = (0..window.input.schema().fields().len()) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I guess you could not add these ref expressions, and then use a direct emit (or remap starting from 0) instead, would that not be the same? That said it'd break roundtrip until you've gotten the consumer part done as well 😅 |
||
.map(substrait_field_ref) | ||
.collect::<Result<Vec<_>>>()?; | ||
|
||
// process and add each window function expression | ||
for expr in &window.window_expr { | ||
window_exprs.push(to_substrait_rex( | ||
expressions.push(to_substrait_rex( | ||
ctx, | ||
expr, | ||
window.input.schema(), | ||
0, | ||
extensions, | ||
)?); | ||
} | ||
// Append parsed WindowFunction expressions | ||
project_rel.expressions.extend(window_exprs); | ||
|
||
let emit_kind = create_project_remapping( | ||
expressions.len(), | ||
window.input.schema().fields().len(), | ||
); | ||
let common = RelCommon { | ||
emit_kind: Some(emit_kind), | ||
hint: None, | ||
advanced_extension: None, | ||
}; | ||
let project_rel = Box::new(ProjectRel { | ||
common: Some(common), | ||
input: Some(input), | ||
expressions, | ||
advanced_extension: None, | ||
}); | ||
|
||
Ok(Box::new(Rel { | ||
rel_type: Some(RelType::Project(project_rel)), | ||
})) | ||
|
@@ -553,6 +567,19 @@ pub fn to_substrait_rel( | |
} | ||
} | ||
|
||
/// By default, a Substrait Project outputs all input fields followed by all expressions. | ||
/// A DataFusion Projection only outputs expressions. In order to keep the Substrait | ||
/// plan consistent with DataFusion, we must apply an output mapping that skips the input | ||
/// fields so that the Substrait Project will only output the expression fields. | ||
fn create_project_remapping(expr_count: usize, input_field_count: usize) -> EmitKind { | ||
let expression_field_start = input_field_count; | ||
let expression_field_end = expression_field_start + expr_count; | ||
let output_mapping = (expression_field_start..expression_field_end) | ||
.map(|i| i as i32) | ||
.collect(); | ||
Emit(rel_common::Emit { output_mapping }) | ||
} | ||
|
||
fn to_substrait_named_struct( | ||
schema: &DFSchemaRef, | ||
extensions: &mut Extensions, | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -26,7 +26,11 @@ mod tests { | |
use datafusion::error::Result; | ||
use datafusion::prelude::*; | ||
|
||
use datafusion_substrait::logical_plan::producer::to_substrait_plan; | ||
use std::fs; | ||
use substrait::proto::plan_rel::RelType; | ||
use substrait::proto::rel_common::{Emit, EmitKind}; | ||
use substrait::proto::{rel, RelCommon}; | ||
|
||
#[tokio::test] | ||
async fn serialize_simple_select() -> Result<()> { | ||
|
@@ -63,6 +67,103 @@ mod tests { | |
Ok(()) | ||
} | ||
|
||
#[tokio::test] | ||
async fn include_remaps_for_projects() -> Result<()> { | ||
let ctx = create_context().await?; | ||
let df = ctx.sql("SELECT b, a + a, a FROM data").await?; | ||
let datafusion_plan = df.into_optimized_plan()?; | ||
|
||
assert_eq!( | ||
format!("{}", datafusion_plan), | ||
"Projection: data.b, data.a + data.a, data.a\ | ||
\n TableScan: data projection=[a, b]", | ||
); | ||
|
||
let plan = to_substrait_plan(&datafusion_plan, &ctx)?.as_ref().clone(); | ||
|
||
let relation = plan.relations.first().unwrap().rel_type.as_ref(); | ||
let root_rel = match relation { | ||
Some(RelType::Root(root)) => root.input.as_ref().unwrap(), | ||
_ => panic!("expected Root"), | ||
}; | ||
if let Some(rel::RelType::Project(p)) = root_rel.rel_type.as_ref() { | ||
// The input has 2 columns [a, b], the Projection has 3 expressions [b, a + a, a] | ||
// The required output mapping is [2,3,4], which skips the 2 input columns. | ||
assert_emit(p.common.as_ref(), vec![2, 3, 4]); | ||
|
||
if let Some(rel::RelType::Read(r)) = | ||
p.input.as_ref().unwrap().rel_type.as_ref() | ||
{ | ||
let mask_expression = r.projection.as_ref().unwrap(); | ||
let select = mask_expression.select.as_ref().unwrap(); | ||
assert_eq!( | ||
2, | ||
select.struct_items.len(), | ||
"Read outputs two columns: a, b" | ||
); | ||
return Ok(()); | ||
} | ||
} | ||
panic!("plan did not match expected structure") | ||
} | ||
|
||
#[tokio::test] | ||
async fn include_remaps_for_windows() -> Result<()> { | ||
let ctx = create_context().await?; | ||
// let df = ctx.sql("SELECT a, b, lead(b) OVER (PARTITION BY a) FROM data").await?; | ||
let df = ctx | ||
.sql("SELECT b, RANK() OVER (PARTITION BY a), c FROM data;") | ||
.await?; | ||
let datafusion_plan = df.into_optimized_plan()?; | ||
assert_eq!( | ||
format!("{}", datafusion_plan), | ||
"Projection: data.b, RANK() PARTITION BY [data.a] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING, data.c\ | ||
\n WindowAggr: windowExpr=[[RANK() PARTITION BY [data.a] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING]]\ | ||
\n TableScan: data projection=[a, b, c]", | ||
); | ||
|
||
let plan = to_substrait_plan(&datafusion_plan, &ctx)?.as_ref().clone(); | ||
|
||
let relation = plan.relations.first().unwrap().rel_type.as_ref(); | ||
let root_rel = match relation { | ||
Some(RelType::Root(root)) => root.input.as_ref().unwrap(), | ||
_ => panic!("expected Root"), | ||
}; | ||
|
||
if let Some(rel::RelType::Project(p1)) = root_rel.rel_type.as_ref() { | ||
// The WindowAggr outputs 4 columns, the Projection has 4 columns | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
I guess that's what you're saying here anyways 😅 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah, that is a typo. It should have been
which is what the assert below is actually (correctly) checking. I'll make a note to myself to update this when I make the consumer changes. |
||
assert_emit(p1.common.as_ref(), vec![4, 5, 6]); | ||
|
||
if let Some(rel::RelType::Project(p2)) = | ||
p1.input.as_ref().unwrap().rel_type.as_ref() | ||
{ | ||
// The input has 3 columns, the WindowAggr has 4 expression | ||
assert_emit(p2.common.as_ref(), vec![3, 4, 5, 6]); | ||
|
||
if let Some(rel::RelType::Read(r)) = | ||
p2.input.as_ref().unwrap().rel_type.as_ref() | ||
{ | ||
let mask_expression = r.projection.as_ref().unwrap(); | ||
let select = mask_expression.select.as_ref().unwrap(); | ||
assert_eq!( | ||
3, | ||
select.struct_items.len(), | ||
"Read outputs three columns: a, b, c" | ||
); | ||
return Ok(()); | ||
} | ||
} | ||
} | ||
panic!("plan did not match expected structure") | ||
} | ||
|
||
fn assert_emit(rel_common: Option<&RelCommon>, output_mapping: Vec<i32>) { | ||
assert_eq!( | ||
rel_common.unwrap().emit_kind.clone(), | ||
Some(EmitKind::Emit(Emit { output_mapping })) | ||
); | ||
} | ||
|
||
async fn create_context() -> Result<SessionContext> { | ||
let ctx = SessionContext::new(); | ||
ctx.register_csv("data", "tests/testdata/data.csv", CsvReadOptions::new()) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is potentially a useful optimization, however it becomes a bit more complicated with the introduction of the
output_mapping
because you a need to modify it along with the expressions. I've opted to simplify this for now to favour simplicity and correctness.As well, I think this is better handled in when consuming plans and/or by the optimizer.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, the datafusion optimizer already handles projection pushdown quite well -- so keeping the substrait producer simpler makes the most sense to me