11use std:: sync:: Arc ;
22
3+ use datafusion:: arrow:: datatypes:: { DataType , Field } ;
34use datafusion:: {
45 arrow:: datatypes:: { Schema , SchemaRef } ,
56 datasource:: empty:: EmptyTable ,
67 error:: { DataFusionError , Result } ,
7- logical_plan:: { plan:: Projection , DFSchema , Expr , LogicalPlan , TableScan } ,
8+ logical_plan:: { plan:: Projection , DFSchema , DFSchemaRef , Expr , LogicalPlan , TableScan } ,
9+ prelude:: Column ,
810} ;
911
1012use substrait:: protobuf:: {
1113 expression:: {
14+ field_reference:: { ReferenceType , ReferenceType :: MaskedReference } ,
1215 mask_expression:: { StructItem , StructSelect } ,
1316 FieldReference , MaskExpression , RexType ,
1417 } ,
1518 read_rel:: { NamedTable , ReadType } ,
1619 rel:: RelType ,
17- Expression , ProjectRel , ReadRel , Rel ,
20+ Expression , NamedStruct , ProjectRel , ReadRel , Rel ,
1821} ;
19- //
20- // pub fn to_substrait_rex(expr: &Expr) -> Result<Box<Expression>> {
21- // match expr {
22- // Expr::Column(col) => {
23- // Ok(Box::new(Expression {
24- // rex_type: Some(RexType::Selection(Box::new(FieldReference {
25- // reference_type: None,
26- // root_type: None,
27- // })))
28- // }))
29- // }
30- // _ => Err(DataFusionError::NotImplemented(
31- // "Unsupported logical plan expression".to_string(),
32- // )),
33- // }
34- // }
3522
23+ /// Convert DataFusion Expr to Substrait Rex
24+ pub fn to_substrait_rex ( expr : & Expr , schema : & DFSchemaRef ) -> Result < Expression > {
25+ match expr {
26+ Expr :: Column ( col) => Ok ( Expression {
27+ rex_type : Some ( RexType :: Selection ( Box :: new ( FieldReference {
28+ reference_type : Some ( ReferenceType :: MaskedReference ( MaskExpression {
29+ select : Some ( StructSelect {
30+ struct_items : vec ! [ StructItem {
31+ field: schema. index_of( & col. name) ? as i32 ,
32+ child: None ,
33+ } ] ,
34+ } ) ,
35+ maintain_singular_struct : false ,
36+ } ) ) ,
37+ root_type : None ,
38+ } ) ) ) ,
39+ } ) ,
40+ _ => Err ( DataFusionError :: NotImplemented (
41+ "Unsupported logical plan expression" . to_string ( ) ,
42+ ) ) ,
43+ }
44+ }
45+
46+ /// Convert DataFusion LogicalPlan to Substrait Rel
3647pub fn to_substrait_rel ( plan : & LogicalPlan ) -> Result < Box < Rel > > {
3748 match plan {
3849 LogicalPlan :: TableScan ( scan) => {
@@ -48,7 +59,15 @@ pub fn to_substrait_rel(plan: &LogicalPlan) -> Result<Box<Rel>> {
4859 Ok ( Box :: new ( Rel {
4960 rel_type : Some ( RelType :: Read ( Box :: new ( ReadRel {
5061 common : None ,
51- base_schema : None ,
62+ base_schema : Some ( NamedStruct {
63+ names : scan
64+ . projected_schema
65+ . fields ( )
66+ . iter ( )
67+ . map ( |f| f. name ( ) . to_owned ( ) )
68+ . collect ( ) ,
69+ r#struct : None , // TODO
70+ } ) ,
5271 filter : None ,
5372 projection : Some ( MaskExpression {
5473 select : Some ( StructSelect {
@@ -64,70 +83,134 @@ pub fn to_substrait_rel(plan: &LogicalPlan) -> Result<Box<Rel>> {
6483 } ) ) ) ,
6584 } ) )
6685 }
67- LogicalPlan :: Projection ( p) => Ok ( Box :: new ( Rel {
68- rel_type : Some ( RelType :: Project ( Box :: new ( ProjectRel {
69- common : None ,
70- input : Some ( to_substrait_rel ( p. input . as_ref ( ) ) ?) ,
71- expressions : vec ! [ ] ,
72- advanced_extension : None ,
73- } ) ) ) ,
74- } ) ) ,
86+ LogicalPlan :: Projection ( p) => {
87+ let expressions = p
88+ . expr
89+ . iter ( )
90+ . map ( |e| to_substrait_rex ( e, p. input . schema ( ) ) )
91+ . collect :: < Result < Vec < _ > > > ( ) ?;
92+ Ok ( Box :: new ( Rel {
93+ rel_type : Some ( RelType :: Project ( Box :: new ( ProjectRel {
94+ common : None ,
95+ input : Some ( to_substrait_rel ( p. input . as_ref ( ) ) ?) ,
96+ expressions,
97+ advanced_extension : None ,
98+ } ) ) ) ,
99+ } ) )
100+ }
75101 _ => Err ( DataFusionError :: NotImplemented (
76102 "Unsupported logical plan operator" . to_string ( ) ,
77103 ) ) ,
78104 }
79105}
80106
81- pub fn from_substrait ( proto : & Rel ) -> Result < LogicalPlan > {
82- match & proto. rel_type {
83- Some ( RelType :: Project ( p) ) => Ok ( LogicalPlan :: Projection ( Projection {
84- expr : vec ! [ ] ,
85- input : Arc :: new ( from_substrait ( p. input . as_ref ( ) . unwrap ( ) ) ?) ,
86- schema : Arc :: new ( DFSchema :: empty ( ) ) ,
87- alias : None ,
88- } ) ) ,
107+ /// Convert Substrait Rex to DataFusion Expr
108+ // pub fn from_substrait_rex(rex: &Expression) -> Result<Expr> {
109+ // }
110+
111+ /// Convert Substrait Rel to DataFusion LogicalPlan
112+ pub fn from_substrait_rel ( rel : & Rel ) -> Result < LogicalPlan > {
113+ match & rel. rel_type {
114+ Some ( RelType :: Project ( p) ) => {
115+ let input = from_substrait_rel ( p. input . as_ref ( ) . unwrap ( ) ) ?;
116+ let z: Vec < Expr > = p
117+ . expressions
118+ . iter ( )
119+ . map ( |e| {
120+ match & e. rex_type {
121+ Some ( RexType :: Selection ( field_ref) ) => {
122+ match & field_ref. reference_type {
123+ Some ( MaskedReference ( mask) ) => {
124+ //TODO remove unwrap
125+ let xx = & mask. select . as_ref ( ) . unwrap ( ) . struct_items ;
126+ assert ! ( xx. len( ) == 1 ) ;
127+ Ok ( Expr :: Column ( Column {
128+ relation : Some ( "data" . to_string ( ) ) , // TODO remove hard-coded relation name
129+ name : input
130+ . schema ( )
131+ . field ( xx[ 0 ] . field as usize )
132+ . name ( )
133+ . to_string ( ) ,
134+ } ) )
135+ }
136+ _ => Err ( DataFusionError :: NotImplemented (
137+ "unsupported field ref type" . to_string ( ) ,
138+ ) ) ,
139+ }
140+ }
141+ _ => Err ( DataFusionError :: NotImplemented (
142+ "unsupported rex_type in projection" . to_string ( ) ,
143+ ) ) ,
144+ }
145+ } )
146+ . collect :: < Result < Vec < _ > > > ( ) ?;
147+
148+ Ok ( LogicalPlan :: Projection ( Projection {
149+ expr : z,
150+ input : Arc :: new ( input) ,
151+ schema : Arc :: new ( DFSchema :: empty ( ) ) ,
152+ alias : None ,
153+ } ) )
154+ }
89155 Some ( RelType :: Read ( read) ) => {
90156 let projection = & read. projection . as_ref ( ) . map ( |mask| match & mask. select {
91157 Some ( x) => x. struct_items . iter ( ) . map ( |i| i. field as usize ) . collect ( ) ,
92158 None => unimplemented ! ( ) ,
93159 } ) ;
94160
161+ let schema = match & read. base_schema {
162+ Some ( named_struct) => Schema :: new (
163+ named_struct
164+ . names
165+ . iter ( )
166+ . map ( |n| Field :: new ( n, DataType :: Utf8 , false ) )
167+ . collect ( ) ,
168+ ) ,
169+ _ => unimplemented ! ( ) ,
170+ } ;
171+
95172 Ok ( LogicalPlan :: TableScan ( TableScan {
96- table_name : "" . to_string ( ) ,
97- source : Arc :: new ( EmptyTable :: new ( SchemaRef :: new ( Schema :: empty ( ) ) ) ) ,
173+ table_name : match & read. as_ref ( ) . read_type {
174+ Some ( ReadType :: NamedTable ( nt) ) => nt. names [ 0 ] . to_owned ( ) ,
175+ _ => unimplemented ! ( ) ,
176+ } ,
177+ source : Arc :: new ( EmptyTable :: new ( SchemaRef :: new ( schema. clone ( ) ) ) ) ,
98178 projection : projection. to_owned ( ) ,
99- projected_schema : Arc :: new ( DFSchema :: empty ( ) ) ,
179+ projected_schema : Arc :: new ( DFSchema :: try_from ( schema ) ? ) ,
100180 filters : vec ! [ ] ,
101181 limit : None ,
102182 } ) )
103183 }
104184 _ => Err ( DataFusionError :: NotImplemented ( format ! (
105185 "{:?}" ,
106- proto . rel_type
186+ rel . rel_type
107187 ) ) ) ,
108188 }
109189}
110190
111191#[ cfg( test) ]
112192mod tests {
113193
114- use crate :: { from_substrait , to_substrait_rel} ;
194+ use crate :: { from_substrait_rel , to_substrait_rel} ;
115195 use datafusion:: error:: Result ;
116196 use datafusion:: prelude:: * ;
117197
118198 #[ tokio:: test]
119- async fn it_works ( ) -> Result < ( ) > {
199+ async fn simple_select ( ) -> Result < ( ) > {
200+ roundtrip ( "SELECT a, b FROM data" ) . await
201+ }
202+
203+ async fn roundtrip ( sql : & str ) -> Result < ( ) > {
120204 let mut ctx = ExecutionContext :: new ( ) ;
121205 ctx. register_csv ( "data" , "testdata/data.csv" , CsvReadOptions :: new ( ) )
122206 . await ?;
123- let df = ctx. sql ( "SELECT a, b FROM data" ) . await ?;
207+ let df = ctx. sql ( sql ) . await ?;
124208 let plan = df. to_logical_plan ( ) ;
125209 let proto = to_substrait_rel ( & plan) ?;
126- let plan2 = from_substrait ( & proto) ?;
210+ let plan2 = from_substrait_rel ( & proto) ?;
127211 let plan1str = format ! ( "{:?}" , plan) ;
128212 let plan2str = format ! ( "{:?}" , plan2) ;
129213 assert_eq ! ( plan1str, plan2str) ;
130-
131214 Ok ( ( ) )
132215 }
133216}
0 commit comments