1717
1818use std:: sync:: Arc ;
1919
20- use arrow:: array:: { record_batch, RecordBatch } ;
21- use arrow_schema:: { DataType , Field , Schema } ;
20+ use arrow:: array:: { record_batch, RecordBatch , RecordBatchOptions } ;
21+ use arrow:: compute:: { cast_with_options, CastOptions } ;
22+ use arrow_schema:: { DataType , Field , FieldRef , Schema , SchemaRef } ;
2223use bytes:: { BufMut , BytesMut } ;
2324use datafusion:: assert_batches_eq;
25+ use datafusion:: common:: Result ;
2426use datafusion:: datasource:: listing:: { ListingTable , ListingTableConfig } ;
25- use datafusion:: prelude:: SessionContext ;
26- use datafusion_datasource:: schema_adapter:: DefaultSchemaAdapterFactory ;
27+ use datafusion:: prelude:: { SessionConfig , SessionContext } ;
28+ use datafusion_common:: tree_node:: { Transformed , TransformedResult , TreeNode } ;
29+ use datafusion_common:: { ColumnStatistics , ScalarValue } ;
30+ use datafusion_datasource:: schema_adapter:: {
31+ DefaultSchemaAdapterFactory , SchemaAdapter , SchemaAdapterFactory , SchemaMapper ,
32+ } ;
2733use datafusion_datasource:: ListingTableUrl ;
2834use datafusion_execution:: object_store:: ObjectStoreUrl ;
29- use datafusion_physical_expr:: schema_rewriter:: DefaultPhysicalExprAdapterFactory ;
35+ use datafusion_physical_expr:: expressions:: { self , Column } ;
36+ use datafusion_physical_expr:: schema_rewriter:: {
37+ DefaultPhysicalExprAdapterFactory , PhysicalExprAdapter , PhysicalExprAdapterFactory ,
38+ } ;
39+ use datafusion_physical_expr:: { DefaultPhysicalExprAdapter , PhysicalExpr } ;
40+ use itertools:: Itertools ;
3041use object_store:: { memory:: InMemory , path:: Path , ObjectStore } ;
3142use parquet:: arrow:: ArrowWriter ;
3243
@@ -41,6 +52,180 @@ async fn write_parquet(batch: RecordBatch, store: Arc<dyn ObjectStore>, path: &s
4152 store. put ( & Path :: from ( path) , data. into ( ) ) . await . unwrap ( ) ;
4253}
4354
55+ #[ derive( Debug ) ]
56+ struct CustomSchemaAdapterFactory ;
57+
58+ impl SchemaAdapterFactory for CustomSchemaAdapterFactory {
59+ fn create (
60+ & self ,
61+ projected_table_schema : SchemaRef ,
62+ _table_schema : SchemaRef ,
63+ ) -> Box < dyn SchemaAdapter > {
64+ Box :: new ( CustomSchemaAdapter {
65+ logical_file_schema : projected_table_schema,
66+ } )
67+ }
68+ }
69+
70+ #[ derive( Debug ) ]
71+ struct CustomSchemaAdapter {
72+ logical_file_schema : SchemaRef ,
73+ }
74+
75+ impl SchemaAdapter for CustomSchemaAdapter {
76+ fn map_column_index ( & self , index : usize , file_schema : & Schema ) -> Option < usize > {
77+ for ( idx, field) in file_schema. fields ( ) . iter ( ) . enumerate ( ) {
78+ if field. name ( ) == self . logical_file_schema . field ( index) . name ( ) {
79+ return Some ( idx) ;
80+ }
81+ }
82+ None
83+ }
84+
85+ fn map_schema (
86+ & self ,
87+ file_schema : & Schema ,
88+ ) -> Result < ( Arc < dyn SchemaMapper > , Vec < usize > ) > {
89+ let projection = ( 0 ..file_schema. fields ( ) . len ( ) ) . collect_vec ( ) ;
90+ Ok ( (
91+ Arc :: new ( CustomSchemaMapper {
92+ logical_file_schema : Arc :: clone ( & self . logical_file_schema ) ,
93+ } ) ,
94+ projection,
95+ ) )
96+ }
97+ }
98+
99+ #[ derive( Debug ) ]
100+ struct CustomSchemaMapper {
101+ logical_file_schema : SchemaRef ,
102+ }
103+
104+ impl SchemaMapper for CustomSchemaMapper {
105+ fn map_batch ( & self , batch : RecordBatch ) -> Result < RecordBatch > {
106+ let mut output_columns =
107+ Vec :: with_capacity ( self . logical_file_schema . fields ( ) . len ( ) ) ;
108+ for field in self . logical_file_schema . fields ( ) {
109+ if let Some ( array) = batch. column_by_name ( field. name ( ) ) {
110+ output_columns. push ( cast_with_options (
111+ array,
112+ field. data_type ( ) ,
113+ & CastOptions :: default ( ) ,
114+ ) ?) ;
115+ } else {
116+ // Create a new array with the default value for the field type
117+ let default_value = match field. data_type ( ) {
118+ DataType :: Int64 => ScalarValue :: Int64 ( Some ( 0 ) ) ,
119+ DataType :: Utf8 => ScalarValue :: Utf8 ( Some ( "a" . to_string ( ) ) ) ,
120+ _ => unimplemented ! ( "Unsupported data type: {:?}" , field. data_type( ) ) ,
121+ } ;
122+ output_columns
123+ . push ( default_value. to_array_of_size ( batch. num_rows ( ) ) . unwrap ( ) ) ;
124+ }
125+ }
126+ let batch = RecordBatch :: try_new_with_options (
127+ Arc :: clone ( & self . logical_file_schema ) ,
128+ output_columns,
129+ & RecordBatchOptions :: new ( ) . with_row_count ( Some ( batch. num_rows ( ) ) ) ,
130+ )
131+ . unwrap ( ) ;
132+ Ok ( batch)
133+ }
134+
135+ fn map_column_statistics (
136+ & self ,
137+ _file_col_statistics : & [ ColumnStatistics ] ,
138+ ) -> Result < Vec < ColumnStatistics > > {
139+ Ok ( vec ! [
140+ ColumnStatistics :: new_unknown( ) ;
141+ self . logical_file_schema. fields( ) . len( )
142+ ] )
143+ }
144+ }
145+
146+ // Implement a custom PhysicalExprAdapterFactory that fills in missing columns with the default value for the field type
147+ #[ derive( Debug ) ]
148+ struct CustomPhysicalExprAdapterFactory ;
149+
150+ impl PhysicalExprAdapterFactory for CustomPhysicalExprAdapterFactory {
151+ fn create (
152+ & self ,
153+ logical_file_schema : SchemaRef ,
154+ physical_file_schema : SchemaRef ,
155+ ) -> Arc < dyn PhysicalExprAdapter > {
156+ Arc :: new ( CustomPhysicalExprAdapter {
157+ logical_file_schema : Arc :: clone ( & logical_file_schema) ,
158+ physical_file_schema : Arc :: clone ( & physical_file_schema) ,
159+ inner : Arc :: new ( DefaultPhysicalExprAdapter :: new (
160+ logical_file_schema,
161+ physical_file_schema,
162+ ) ) ,
163+ } )
164+ }
165+ }
166+
167+ #[ derive( Debug , Clone ) ]
168+ struct CustomPhysicalExprAdapter {
169+ logical_file_schema : SchemaRef ,
170+ physical_file_schema : SchemaRef ,
171+ inner : Arc < dyn PhysicalExprAdapter > ,
172+ }
173+
174+ impl PhysicalExprAdapter for CustomPhysicalExprAdapter {
175+ fn rewrite ( & self , mut expr : Arc < dyn PhysicalExpr > ) -> Result < Arc < dyn PhysicalExpr > > {
176+ expr = expr
177+ . transform ( |expr| {
178+ if let Some ( column) = expr. as_any ( ) . downcast_ref :: < Column > ( ) {
179+ let field_name = column. name ( ) ;
180+ if self
181+ . physical_file_schema
182+ . field_with_name ( field_name)
183+ . ok ( )
184+ . is_none ( )
185+ {
186+ let field = self
187+ . logical_file_schema
188+ . field_with_name ( field_name)
189+ . map_err ( |_| {
190+ datafusion_common:: DataFusionError :: Plan ( format ! (
191+ "Field '{}' not found in logical file schema" ,
192+ field_name
193+ ) )
194+ } ) ?;
195+ // If the field does not exist, create a default value expression
196+ // Note that we use slightly different logic here to create a default value so that we can see different behavior in tests
197+ let default_value = match field. data_type ( ) {
198+ DataType :: Int64 => ScalarValue :: Int64 ( Some ( 1 ) ) ,
199+ DataType :: Utf8 => ScalarValue :: Utf8 ( Some ( "b" . to_string ( ) ) ) ,
200+ _ => unimplemented ! (
201+ "Unsupported data type: {:?}" ,
202+ field. data_type( )
203+ ) ,
204+ } ;
205+ return Ok ( Transformed :: yes ( Arc :: new ( expressions:: Literal :: new (
206+ default_value,
207+ ) ) ) ) ;
208+ }
209+ }
210+
211+ Ok ( Transformed :: no ( expr) )
212+ } )
213+ . data ( ) ?;
214+ self . inner . rewrite ( expr)
215+ }
216+
217+ fn with_partition_values (
218+ & self ,
219+ partition_values : Vec < ( FieldRef , ScalarValue ) > ,
220+ ) -> Arc < dyn PhysicalExprAdapter > {
221+ assert ! (
222+ partition_values. is_empty( ) ,
223+ "Partition values are not supported in this test"
224+ ) ;
225+ Arc :: new ( self . clone ( ) )
226+ }
227+ }
228+
44229#[ tokio:: test]
45230async fn single_file ( ) {
46231 let batch =
@@ -56,8 +241,22 @@ async fn single_file() {
56241 Field :: new( "c2" , DataType :: Utf8 , true ) ,
57242 ] ) ) ;
58243
59- let ctx = SessionContext :: new ( ) ;
244+ let mut cfg = SessionConfig :: new ( )
245+ // Disable statistics collection for this test otherwise early pruning makes it hard to demonstrate data adaptation
246+ . with_collect_statistics ( false )
247+ . with_parquet_pruning ( false )
248+ . with_parquet_page_index_pruning ( false ) ;
249+ cfg. options_mut ( ) . execution . parquet . pushdown_filters = true ;
250+ let ctx = SessionContext :: new_with_config ( cfg) ;
60251 ctx. register_object_store ( store_url. as_ref ( ) , Arc :: clone ( & store) ) ;
252+ assert ! (
253+ !ctx. state( )
254+ . config_mut( )
255+ . options_mut( )
256+ . execution
257+ . collect_statistics
258+ ) ;
259+ assert ! ( !ctx. state( ) . config( ) . collect_statistics( ) ) ;
61260
62261 let listing_table_config =
63262 ListingTableConfig :: new ( ListingTableUrl :: parse ( "memory:///" ) . unwrap ( ) )
@@ -89,4 +288,92 @@ async fn single_file() {
89288 "+----+----+" ,
90289 ] ;
91290 assert_batches_eq ! ( expected, & batches) ;
291+
292+ // Test using a custom schema adapter and no explicit physical expr adapter
293+ // This should use the custom schema adapter both for projections and predicate pushdown
294+ let listing_table_config =
295+ ListingTableConfig :: new ( ListingTableUrl :: parse ( "memory:///" ) . unwrap ( ) )
296+ . infer_options ( & ctx. state ( ) )
297+ . await
298+ . unwrap ( )
299+ . with_schema ( table_schema. clone ( ) )
300+ . with_schema_adapter_factory ( Arc :: new ( CustomSchemaAdapterFactory ) ) ;
301+ let table = ListingTable :: try_new ( listing_table_config) . unwrap ( ) ;
302+ ctx. deregister_table ( "t" ) . unwrap ( ) ;
303+ ctx. register_table ( "t" , Arc :: new ( table) ) . unwrap ( ) ;
304+ let batches = ctx
305+ . sql ( "SELECT c2, c1 FROM t WHERE c1 = 2 AND c2 = 'a'" )
306+ . await
307+ . unwrap ( )
308+ . collect ( )
309+ . await
310+ . unwrap ( ) ;
311+ let expected = [
312+ "+----+----+" ,
313+ "| c2 | c1 |" ,
314+ "+----+----+" ,
315+ "| a | 2 |" ,
316+ "+----+----+" ,
317+ ] ;
318+ assert_batches_eq ! ( expected, & batches) ;
319+
320+ // Do the same test but with a custom physical expr adapter
321+ // Now the default schema adapter will be used for projections, but the custom physical expr adapter will be used for predicate pushdown
322+ let listing_table_config =
323+ ListingTableConfig :: new ( ListingTableUrl :: parse ( "memory:///" ) . unwrap ( ) )
324+ . infer_options ( & ctx. state ( ) )
325+ . await
326+ . unwrap ( )
327+ . with_schema ( table_schema. clone ( ) )
328+ . with_physical_expr_adapter_factory ( Arc :: new (
329+ CustomPhysicalExprAdapterFactory ,
330+ ) ) ;
331+ let table = ListingTable :: try_new ( listing_table_config) . unwrap ( ) ;
332+ ctx. deregister_table ( "t" ) . unwrap ( ) ;
333+ ctx. register_table ( "t" , Arc :: new ( table) ) . unwrap ( ) ;
334+ let batches = ctx
335+ . sql ( "SELECT c2, c1 FROM t WHERE c1 = 2 AND c2 = 'b'" )
336+ . await
337+ . unwrap ( )
338+ . collect ( )
339+ . await
340+ . unwrap ( ) ;
341+ let expected = [
342+ "+----+----+" ,
343+ "| c2 | c1 |" ,
344+ "+----+----+" ,
345+ "| | 2 |" ,
346+ "+----+----+" ,
347+ ] ;
348+ assert_batches_eq ! ( expected, & batches) ;
349+
350+ // If we use both then the custom physical expr adapter will be used for predicate pushdown and the custom schema adapter will be used for projections
351+ let listing_table_config =
352+ ListingTableConfig :: new ( ListingTableUrl :: parse ( "memory:///" ) . unwrap ( ) )
353+ . infer_options ( & ctx. state ( ) )
354+ . await
355+ . unwrap ( )
356+ . with_schema ( table_schema. clone ( ) )
357+ . with_schema_adapter_factory ( Arc :: new ( CustomSchemaAdapterFactory ) )
358+ . with_physical_expr_adapter_factory ( Arc :: new (
359+ CustomPhysicalExprAdapterFactory ,
360+ ) ) ;
361+ let table = ListingTable :: try_new ( listing_table_config) . unwrap ( ) ;
362+ ctx. deregister_table ( "t" ) . unwrap ( ) ;
363+ ctx. register_table ( "t" , Arc :: new ( table) ) . unwrap ( ) ;
364+ let batches = ctx
365+ . sql ( "SELECT c2, c1 FROM t WHERE c1 = 2 AND c2 = 'b'" )
366+ . await
367+ . unwrap ( )
368+ . collect ( )
369+ . await
370+ . unwrap ( ) ;
371+ let expected = [
372+ "+----+----+" ,
373+ "| c2 | c1 |" ,
374+ "+----+----+" ,
375+ "| a | 2 |" ,
376+ "+----+----+" ,
377+ ] ;
378+ assert_batches_eq ! ( expected, & batches) ;
92379}
0 commit comments