Skip to content

Commit a846492

Browse files
committed
better test
1 parent 6bc7fc6 commit a846492

File tree

4 files changed

+316
-35
lines changed

4 files changed

+316
-35
lines changed

datafusion/core/src/datasource/listing/table.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -938,6 +938,8 @@ pub struct ListingTable {
938938
column_defaults: HashMap<String, Expr>,
939939
/// Optional [`SchemaAdapterFactory`] for creating schema adapters
940940
schema_adapter_factory: Option<Arc<dyn SchemaAdapterFactory>>,
941+
/// Optional [`PhysicalExprAdapterFactory`] for creating physical expression adapters
942+
expr_adapter_factory: Option<Arc<dyn PhysicalExprAdapterFactory>>,
941943
}
942944

943945
impl ListingTable {
@@ -979,6 +981,7 @@ impl ListingTable {
979981
constraints: Constraints::default(),
980982
column_defaults: HashMap::new(),
981983
schema_adapter_factory: config.schema_adapter_factory,
984+
expr_adapter_factory: config.physical_expr_adapter_factory,
982985
};
983986

984987
Ok(table)
@@ -1223,6 +1226,7 @@ impl TableProvider for ListingTable {
12231226
.with_limit(limit)
12241227
.with_output_ordering(output_ordering)
12251228
.with_table_partition_cols(table_partition_cols)
1229+
.with_expr_adapter(self.expr_adapter_factory.clone())
12261230
.build(),
12271231
)
12281232
.await

datafusion/core/tests/parquet/schema_adapter.rs

Lines changed: 293 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,16 +17,27 @@
1717

1818
use std::sync::Arc;
1919

20-
use arrow::array::{record_batch, RecordBatch};
21-
use arrow_schema::{DataType, Field, Schema};
20+
use arrow::array::{record_batch, RecordBatch, RecordBatchOptions};
21+
use arrow::compute::{cast_with_options, CastOptions};
22+
use arrow_schema::{DataType, Field, FieldRef, Schema, SchemaRef};
2223
use bytes::{BufMut, BytesMut};
2324
use datafusion::assert_batches_eq;
25+
use datafusion::common::Result;
2426
use datafusion::datasource::listing::{ListingTable, ListingTableConfig};
25-
use datafusion::prelude::SessionContext;
26-
use datafusion_datasource::schema_adapter::DefaultSchemaAdapterFactory;
27+
use datafusion::prelude::{SessionConfig, SessionContext};
28+
use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
29+
use datafusion_common::{ColumnStatistics, ScalarValue};
30+
use datafusion_datasource::schema_adapter::{
31+
DefaultSchemaAdapterFactory, SchemaAdapter, SchemaAdapterFactory, SchemaMapper,
32+
};
2733
use datafusion_datasource::ListingTableUrl;
2834
use datafusion_execution::object_store::ObjectStoreUrl;
29-
use datafusion_physical_expr::schema_rewriter::DefaultPhysicalExprAdapterFactory;
35+
use datafusion_physical_expr::expressions::{self, Column};
36+
use datafusion_physical_expr::schema_rewriter::{
37+
DefaultPhysicalExprAdapterFactory, PhysicalExprAdapter, PhysicalExprAdapterFactory,
38+
};
39+
use datafusion_physical_expr::{DefaultPhysicalExprAdapter, PhysicalExpr};
40+
use itertools::Itertools;
3041
use object_store::{memory::InMemory, path::Path, ObjectStore};
3142
use parquet::arrow::ArrowWriter;
3243

@@ -41,6 +52,180 @@ async fn write_parquet(batch: RecordBatch, store: Arc<dyn ObjectStore>, path: &s
4152
store.put(&Path::from(path), data.into()).await.unwrap();
4253
}
4354

55+
#[derive(Debug)]
56+
struct CustomSchemaAdapterFactory;
57+
58+
impl SchemaAdapterFactory for CustomSchemaAdapterFactory {
59+
fn create(
60+
&self,
61+
projected_table_schema: SchemaRef,
62+
_table_schema: SchemaRef,
63+
) -> Box<dyn SchemaAdapter> {
64+
Box::new(CustomSchemaAdapter {
65+
logical_file_schema: projected_table_schema,
66+
})
67+
}
68+
}
69+
70+
#[derive(Debug)]
71+
struct CustomSchemaAdapter {
72+
logical_file_schema: SchemaRef,
73+
}
74+
75+
impl SchemaAdapter for CustomSchemaAdapter {
76+
fn map_column_index(&self, index: usize, file_schema: &Schema) -> Option<usize> {
77+
for (idx, field) in file_schema.fields().iter().enumerate() {
78+
if field.name() == self.logical_file_schema.field(index).name() {
79+
return Some(idx);
80+
}
81+
}
82+
None
83+
}
84+
85+
fn map_schema(
86+
&self,
87+
file_schema: &Schema,
88+
) -> Result<(Arc<dyn SchemaMapper>, Vec<usize>)> {
89+
let projection = (0..file_schema.fields().len()).collect_vec();
90+
Ok((
91+
Arc::new(CustomSchemaMapper {
92+
logical_file_schema: Arc::clone(&self.logical_file_schema),
93+
}),
94+
projection,
95+
))
96+
}
97+
}
98+
99+
#[derive(Debug)]
100+
struct CustomSchemaMapper {
101+
logical_file_schema: SchemaRef,
102+
}
103+
104+
impl SchemaMapper for CustomSchemaMapper {
105+
fn map_batch(&self, batch: RecordBatch) -> Result<RecordBatch> {
106+
let mut output_columns =
107+
Vec::with_capacity(self.logical_file_schema.fields().len());
108+
for field in self.logical_file_schema.fields() {
109+
if let Some(array) = batch.column_by_name(field.name()) {
110+
output_columns.push(cast_with_options(
111+
array,
112+
field.data_type(),
113+
&CastOptions::default(),
114+
)?);
115+
} else {
116+
// Create a new array with the default value for the field type
117+
let default_value = match field.data_type() {
118+
DataType::Int64 => ScalarValue::Int64(Some(0)),
119+
DataType::Utf8 => ScalarValue::Utf8(Some("a".to_string())),
120+
_ => unimplemented!("Unsupported data type: {:?}", field.data_type()),
121+
};
122+
output_columns
123+
.push(default_value.to_array_of_size(batch.num_rows()).unwrap());
124+
}
125+
}
126+
let batch = RecordBatch::try_new_with_options(
127+
Arc::clone(&self.logical_file_schema),
128+
output_columns,
129+
&RecordBatchOptions::new().with_row_count(Some(batch.num_rows())),
130+
)
131+
.unwrap();
132+
Ok(batch)
133+
}
134+
135+
fn map_column_statistics(
136+
&self,
137+
_file_col_statistics: &[ColumnStatistics],
138+
) -> Result<Vec<ColumnStatistics>> {
139+
Ok(vec![
140+
ColumnStatistics::new_unknown();
141+
self.logical_file_schema.fields().len()
142+
])
143+
}
144+
}
145+
146+
// Implement a custom PhysicalExprAdapterFactory that fills in missing columns with the default value for the field type
147+
#[derive(Debug)]
148+
struct CustomPhysicalExprAdapterFactory;
149+
150+
impl PhysicalExprAdapterFactory for CustomPhysicalExprAdapterFactory {
151+
fn create(
152+
&self,
153+
logical_file_schema: SchemaRef,
154+
physical_file_schema: SchemaRef,
155+
) -> Arc<dyn PhysicalExprAdapter> {
156+
Arc::new(CustomPhysicalExprAdapter {
157+
logical_file_schema: Arc::clone(&logical_file_schema),
158+
physical_file_schema: Arc::clone(&physical_file_schema),
159+
inner: Arc::new(DefaultPhysicalExprAdapter::new(
160+
logical_file_schema,
161+
physical_file_schema,
162+
)),
163+
})
164+
}
165+
}
166+
167+
#[derive(Debug, Clone)]
168+
struct CustomPhysicalExprAdapter {
169+
logical_file_schema: SchemaRef,
170+
physical_file_schema: SchemaRef,
171+
inner: Arc<dyn PhysicalExprAdapter>,
172+
}
173+
174+
impl PhysicalExprAdapter for CustomPhysicalExprAdapter {
175+
fn rewrite(&self, mut expr: Arc<dyn PhysicalExpr>) -> Result<Arc<dyn PhysicalExpr>> {
176+
expr = expr
177+
.transform(|expr| {
178+
if let Some(column) = expr.as_any().downcast_ref::<Column>() {
179+
let field_name = column.name();
180+
if self
181+
.physical_file_schema
182+
.field_with_name(field_name)
183+
.ok()
184+
.is_none()
185+
{
186+
let field = self
187+
.logical_file_schema
188+
.field_with_name(field_name)
189+
.map_err(|_| {
190+
datafusion_common::DataFusionError::Plan(format!(
191+
"Field '{}' not found in logical file schema",
192+
field_name
193+
))
194+
})?;
195+
// If the field does not exist, create a default value expression
196+
// Note that we use slightly different logic here to create a default value so that we can see different behavior in tests
197+
let default_value = match field.data_type() {
198+
DataType::Int64 => ScalarValue::Int64(Some(1)),
199+
DataType::Utf8 => ScalarValue::Utf8(Some("b".to_string())),
200+
_ => unimplemented!(
201+
"Unsupported data type: {:?}",
202+
field.data_type()
203+
),
204+
};
205+
return Ok(Transformed::yes(Arc::new(expressions::Literal::new(
206+
default_value,
207+
))));
208+
}
209+
}
210+
211+
Ok(Transformed::no(expr))
212+
})
213+
.data()?;
214+
self.inner.rewrite(expr)
215+
}
216+
217+
fn with_partition_values(
218+
&self,
219+
partition_values: Vec<(FieldRef, ScalarValue)>,
220+
) -> Arc<dyn PhysicalExprAdapter> {
221+
assert!(
222+
partition_values.is_empty(),
223+
"Partition values are not supported in this test"
224+
);
225+
Arc::new(self.clone())
226+
}
227+
}
228+
44229
#[tokio::test]
45230
async fn single_file() {
46231
let batch =
@@ -56,8 +241,22 @@ async fn single_file() {
56241
Field::new("c2", DataType::Utf8, true),
57242
]));
58243

59-
let ctx = SessionContext::new();
244+
let mut cfg = SessionConfig::new()
245+
// Disable statistics collection for this test otherwise early pruning makes it hard to demonstrate data adaptation
246+
.with_collect_statistics(false)
247+
.with_parquet_pruning(false)
248+
.with_parquet_page_index_pruning(false);
249+
cfg.options_mut().execution.parquet.pushdown_filters = true;
250+
let ctx = SessionContext::new_with_config(cfg);
60251
ctx.register_object_store(store_url.as_ref(), Arc::clone(&store));
252+
assert!(
253+
!ctx.state()
254+
.config_mut()
255+
.options_mut()
256+
.execution
257+
.collect_statistics
258+
);
259+
assert!(!ctx.state().config().collect_statistics());
61260

62261
let listing_table_config =
63262
ListingTableConfig::new(ListingTableUrl::parse("memory:///").unwrap())
@@ -89,4 +288,92 @@ async fn single_file() {
89288
"+----+----+",
90289
];
91290
assert_batches_eq!(expected, &batches);
291+
292+
// Test using a custom schema adapter and no explicit physical expr adapter
293+
// This should use the custom schema adapter both for projections and predicate pushdown
294+
let listing_table_config =
295+
ListingTableConfig::new(ListingTableUrl::parse("memory:///").unwrap())
296+
.infer_options(&ctx.state())
297+
.await
298+
.unwrap()
299+
.with_schema(table_schema.clone())
300+
.with_schema_adapter_factory(Arc::new(CustomSchemaAdapterFactory));
301+
let table = ListingTable::try_new(listing_table_config).unwrap();
302+
ctx.deregister_table("t").unwrap();
303+
ctx.register_table("t", Arc::new(table)).unwrap();
304+
let batches = ctx
305+
.sql("SELECT c2, c1 FROM t WHERE c1 = 2 AND c2 = 'a'")
306+
.await
307+
.unwrap()
308+
.collect()
309+
.await
310+
.unwrap();
311+
let expected = [
312+
"+----+----+",
313+
"| c2 | c1 |",
314+
"+----+----+",
315+
"| a | 2 |",
316+
"+----+----+",
317+
];
318+
assert_batches_eq!(expected, &batches);
319+
320+
// Do the same test but with a custom physical expr adapter
321+
// Now the default schema adapter will be used for projections, but the custom physical expr adapter will be used for predicate pushdown
322+
let listing_table_config =
323+
ListingTableConfig::new(ListingTableUrl::parse("memory:///").unwrap())
324+
.infer_options(&ctx.state())
325+
.await
326+
.unwrap()
327+
.with_schema(table_schema.clone())
328+
.with_physical_expr_adapter_factory(Arc::new(
329+
CustomPhysicalExprAdapterFactory,
330+
));
331+
let table = ListingTable::try_new(listing_table_config).unwrap();
332+
ctx.deregister_table("t").unwrap();
333+
ctx.register_table("t", Arc::new(table)).unwrap();
334+
let batches = ctx
335+
.sql("SELECT c2, c1 FROM t WHERE c1 = 2 AND c2 = 'b'")
336+
.await
337+
.unwrap()
338+
.collect()
339+
.await
340+
.unwrap();
341+
let expected = [
342+
"+----+----+",
343+
"| c2 | c1 |",
344+
"+----+----+",
345+
"| | 2 |",
346+
"+----+----+",
347+
];
348+
assert_batches_eq!(expected, &batches);
349+
350+
// If we use both then the custom physical expr adapter will be used for predicate pushdown and the custom schema adapter will be used for projections
351+
let listing_table_config =
352+
ListingTableConfig::new(ListingTableUrl::parse("memory:///").unwrap())
353+
.infer_options(&ctx.state())
354+
.await
355+
.unwrap()
356+
.with_schema(table_schema.clone())
357+
.with_schema_adapter_factory(Arc::new(CustomSchemaAdapterFactory))
358+
.with_physical_expr_adapter_factory(Arc::new(
359+
CustomPhysicalExprAdapterFactory,
360+
));
361+
let table = ListingTable::try_new(listing_table_config).unwrap();
362+
ctx.deregister_table("t").unwrap();
363+
ctx.register_table("t", Arc::new(table)).unwrap();
364+
let batches = ctx
365+
.sql("SELECT c2, c1 FROM t WHERE c1 = 2 AND c2 = 'b'")
366+
.await
367+
.unwrap()
368+
.collect()
369+
.await
370+
.unwrap();
371+
let expected = [
372+
"+----+----+",
373+
"| c2 | c1 |",
374+
"+----+----+",
375+
"| a | 2 |",
376+
"+----+----+",
377+
];
378+
assert_batches_eq!(expected, &batches);
92379
}

0 commit comments

Comments
 (0)