Skip to content

Commit 76f95b1

Browse files
authored
Merge pull request #1 from andygrove/roundtrip
Get first test passing
2 parents b17155a + 8dbc69c commit 76f95b1

File tree

1 file changed

+127
-44
lines changed

1 file changed

+127
-44
lines changed

src/lib.rs

Lines changed: 127 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1,38 +1,49 @@
11
use std::sync::Arc;
22

3+
use datafusion::arrow::datatypes::{DataType, Field};
34
use datafusion::{
45
arrow::datatypes::{Schema, SchemaRef},
56
datasource::empty::EmptyTable,
67
error::{DataFusionError, Result},
7-
logical_plan::{plan::Projection, DFSchema, Expr, LogicalPlan, TableScan},
8+
logical_plan::{plan::Projection, DFSchema, DFSchemaRef, Expr, LogicalPlan, TableScan},
9+
prelude::Column,
810
};
911

1012
use substrait::protobuf::{
1113
expression::{
14+
field_reference::{ReferenceType, ReferenceType::MaskedReference},
1215
mask_expression::{StructItem, StructSelect},
1316
FieldReference, MaskExpression, RexType,
1417
},
1518
read_rel::{NamedTable, ReadType},
1619
rel::RelType,
17-
Expression, ProjectRel, ReadRel, Rel,
20+
Expression, NamedStruct, ProjectRel, ReadRel, Rel,
1821
};
19-
//
20-
// pub fn to_substrait_rex(expr: &Expr) -> Result<Box<Expression>> {
21-
// match expr {
22-
// Expr::Column(col) => {
23-
// Ok(Box::new(Expression {
24-
// rex_type: Some(RexType::Selection(Box::new(FieldReference {
25-
// reference_type: None,
26-
// root_type: None,
27-
// })))
28-
// }))
29-
// }
30-
// _ => Err(DataFusionError::NotImplemented(
31-
// "Unsupported logical plan expression".to_string(),
32-
// )),
33-
// }
34-
// }
3522

23+
/// Convert DataFusion Expr to Substrait Rex
24+
pub fn to_substrait_rex(expr: &Expr, schema: &DFSchemaRef) -> Result<Expression> {
25+
match expr {
26+
Expr::Column(col) => Ok(Expression {
27+
rex_type: Some(RexType::Selection(Box::new(FieldReference {
28+
reference_type: Some(ReferenceType::MaskedReference(MaskExpression {
29+
select: Some(StructSelect {
30+
struct_items: vec![StructItem {
31+
field: schema.index_of(&col.name)? as i32,
32+
child: None,
33+
}],
34+
}),
35+
maintain_singular_struct: false,
36+
})),
37+
root_type: None,
38+
}))),
39+
}),
40+
_ => Err(DataFusionError::NotImplemented(
41+
"Unsupported logical plan expression".to_string(),
42+
)),
43+
}
44+
}
45+
46+
/// Convert DataFusion LogicalPlan to Substrait Rel
3647
pub fn to_substrait_rel(plan: &LogicalPlan) -> Result<Box<Rel>> {
3748
match plan {
3849
LogicalPlan::TableScan(scan) => {
@@ -48,7 +59,15 @@ pub fn to_substrait_rel(plan: &LogicalPlan) -> Result<Box<Rel>> {
4859
Ok(Box::new(Rel {
4960
rel_type: Some(RelType::Read(Box::new(ReadRel {
5061
common: None,
51-
base_schema: None,
62+
base_schema: Some(NamedStruct {
63+
names: scan
64+
.projected_schema
65+
.fields()
66+
.iter()
67+
.map(|f| f.name().to_owned())
68+
.collect(),
69+
r#struct: None, // TODO
70+
}),
5271
filter: None,
5372
projection: Some(MaskExpression {
5473
select: Some(StructSelect {
@@ -64,70 +83,134 @@ pub fn to_substrait_rel(plan: &LogicalPlan) -> Result<Box<Rel>> {
6483
}))),
6584
}))
6685
}
67-
LogicalPlan::Projection(p) => Ok(Box::new(Rel {
68-
rel_type: Some(RelType::Project(Box::new(ProjectRel {
69-
common: None,
70-
input: Some(to_substrait_rel(p.input.as_ref())?),
71-
expressions: vec![],
72-
advanced_extension: None,
73-
}))),
74-
})),
86+
LogicalPlan::Projection(p) => {
87+
let expressions = p
88+
.expr
89+
.iter()
90+
.map(|e| to_substrait_rex(e, p.input.schema()))
91+
.collect::<Result<Vec<_>>>()?;
92+
Ok(Box::new(Rel {
93+
rel_type: Some(RelType::Project(Box::new(ProjectRel {
94+
common: None,
95+
input: Some(to_substrait_rel(p.input.as_ref())?),
96+
expressions,
97+
advanced_extension: None,
98+
}))),
99+
}))
100+
}
75101
_ => Err(DataFusionError::NotImplemented(
76102
"Unsupported logical plan operator".to_string(),
77103
)),
78104
}
79105
}
80106

81-
pub fn from_substrait(proto: &Rel) -> Result<LogicalPlan> {
82-
match &proto.rel_type {
83-
Some(RelType::Project(p)) => Ok(LogicalPlan::Projection(Projection {
84-
expr: vec![],
85-
input: Arc::new(from_substrait(p.input.as_ref().unwrap())?),
86-
schema: Arc::new(DFSchema::empty()),
87-
alias: None,
88-
})),
107+
/// Convert Substrait Rex to DataFusion Expr
108+
// pub fn from_substrait_rex(rex: &Expression) -> Result<Expr> {
109+
// }
110+
111+
/// Convert Substrait Rel to DataFusion LogicalPlan
112+
pub fn from_substrait_rel(rel: &Rel) -> Result<LogicalPlan> {
113+
match &rel.rel_type {
114+
Some(RelType::Project(p)) => {
115+
let input = from_substrait_rel(p.input.as_ref().unwrap())?;
116+
let z: Vec<Expr> = p
117+
.expressions
118+
.iter()
119+
.map(|e| {
120+
match &e.rex_type {
121+
Some(RexType::Selection(field_ref)) => {
122+
match &field_ref.reference_type {
123+
Some(MaskedReference(mask)) => {
124+
//TODO remove unwrap
125+
let xx = &mask.select.as_ref().unwrap().struct_items;
126+
assert!(xx.len() == 1);
127+
Ok(Expr::Column(Column {
128+
relation: Some("data".to_string()), // TODO remove hard-coded relation name
129+
name: input
130+
.schema()
131+
.field(xx[0].field as usize)
132+
.name()
133+
.to_string(),
134+
}))
135+
}
136+
_ => Err(DataFusionError::NotImplemented(
137+
"unsupported field ref type".to_string(),
138+
)),
139+
}
140+
}
141+
_ => Err(DataFusionError::NotImplemented(
142+
"unsupported rex_type in projection".to_string(),
143+
)),
144+
}
145+
})
146+
.collect::<Result<Vec<_>>>()?;
147+
148+
Ok(LogicalPlan::Projection(Projection {
149+
expr: z,
150+
input: Arc::new(input),
151+
schema: Arc::new(DFSchema::empty()),
152+
alias: None,
153+
}))
154+
}
89155
Some(RelType::Read(read)) => {
90156
let projection = &read.projection.as_ref().map(|mask| match &mask.select {
91157
Some(x) => x.struct_items.iter().map(|i| i.field as usize).collect(),
92158
None => unimplemented!(),
93159
});
94160

161+
let schema = match &read.base_schema {
162+
Some(named_struct) => Schema::new(
163+
named_struct
164+
.names
165+
.iter()
166+
.map(|n| Field::new(n, DataType::Utf8, false))
167+
.collect(),
168+
),
169+
_ => unimplemented!(),
170+
};
171+
95172
Ok(LogicalPlan::TableScan(TableScan {
96-
table_name: "".to_string(),
97-
source: Arc::new(EmptyTable::new(SchemaRef::new(Schema::empty()))),
173+
table_name: match &read.as_ref().read_type {
174+
Some(ReadType::NamedTable(nt)) => nt.names[0].to_owned(),
175+
_ => unimplemented!(),
176+
},
177+
source: Arc::new(EmptyTable::new(SchemaRef::new(schema.clone()))),
98178
projection: projection.to_owned(),
99-
projected_schema: Arc::new(DFSchema::empty()),
179+
projected_schema: Arc::new(DFSchema::try_from(schema)?),
100180
filters: vec![],
101181
limit: None,
102182
}))
103183
}
104184
_ => Err(DataFusionError::NotImplemented(format!(
105185
"{:?}",
106-
proto.rel_type
186+
rel.rel_type
107187
))),
108188
}
109189
}
110190

111191
#[cfg(test)]
112192
mod tests {
113193

114-
use crate::{from_substrait, to_substrait_rel};
194+
use crate::{from_substrait_rel, to_substrait_rel};
115195
use datafusion::error::Result;
116196
use datafusion::prelude::*;
117197

118198
#[tokio::test]
119-
async fn it_works() -> Result<()> {
199+
async fn simple_select() -> Result<()> {
200+
roundtrip("SELECT a, b FROM data").await
201+
}
202+
203+
async fn roundtrip(sql: &str) -> Result<()> {
120204
let mut ctx = ExecutionContext::new();
121205
ctx.register_csv("data", "testdata/data.csv", CsvReadOptions::new())
122206
.await?;
123-
let df = ctx.sql("SELECT a, b FROM data").await?;
207+
let df = ctx.sql(sql).await?;
124208
let plan = df.to_logical_plan();
125209
let proto = to_substrait_rel(&plan)?;
126-
let plan2 = from_substrait(&proto)?;
210+
let plan2 = from_substrait_rel(&proto)?;
127211
let plan1str = format!("{:?}", plan);
128212
let plan2str = format!("{:?}", plan2);
129213
assert_eq!(plan1str, plan2str);
130-
131214
Ok(())
132215
}
133216
}

0 commit comments

Comments
 (0)