diff --git a/datafusion-examples/examples/json_shredding.rs b/datafusion-examples/examples/json_shredding.rs index b7acb5c7b74c..44ea62d04c09 100644 --- a/datafusion-examples/examples/json_shredding.rs +++ b/datafusion-examples/examples/json_shredding.rs @@ -20,35 +20,29 @@ use std::sync::Arc; use arrow::array::{RecordBatch, StringArray}; use arrow::datatypes::{DataType, Field, FieldRef, Schema, SchemaRef}; -use async_trait::async_trait; use datafusion::assert_batches_eq; -use datafusion::catalog::memory::DataSourceExec; -use datafusion::catalog::{Session, TableProvider}; use datafusion::common::tree_node::{ Transformed, TransformedResult, TreeNode, TreeNodeRecursion, }; -use datafusion::common::{assert_contains, DFSchema, Result}; -use datafusion::datasource::listing::PartitionedFile; -use datafusion::datasource::physical_plan::{FileScanConfigBuilder, ParquetSource}; +use datafusion::common::{assert_contains, Result}; +use datafusion::datasource::listing::{ + ListingTable, ListingTableConfig, ListingTableUrl, +}; use datafusion::execution::context::SessionContext; use datafusion::execution::object_store::ObjectStoreUrl; -use datafusion::logical_expr::utils::conjunction; use datafusion::logical_expr::{ - ColumnarValue, Expr, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature, - TableProviderFilterPushDown, TableType, Volatility, + ColumnarValue, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature, Volatility, }; use datafusion::parquet::arrow::ArrowWriter; use datafusion::parquet::file::properties::WriterProperties; use datafusion::physical_expr::PhysicalExpr; use datafusion::physical_expr::{expressions, ScalarFunctionExpr}; -use datafusion::physical_plan::ExecutionPlan; -use datafusion::prelude::{lit, SessionConfig}; +use datafusion::prelude::SessionConfig; use datafusion::scalar::ScalarValue; use datafusion_physical_expr_adapter::{ DefaultPhysicalExprAdapterFactory, PhysicalExprAdapter, PhysicalExprAdapterFactory, }; -use futures::StreamExt; use object_store::memory::InMemory; use object_store::path::Path; use object_store::{ObjectStore, PutPayload}; @@ -95,23 +89,29 @@ async fn main() -> Result<()> { let payload = PutPayload::from_bytes(buf.into()); store.put(&path, payload).await?; - // Create a custom table provider that rewrites struct field access - let table_provider = Arc::new(ExampleTableProvider::new(table_schema)); - // Set up query execution let mut cfg = SessionConfig::new(); cfg.options_mut().execution.parquet.pushdown_filters = true; let ctx = SessionContext::new_with_config(cfg); - - // Register our table - ctx.register_table("structs", table_provider)?; - ctx.register_udf(ScalarUDF::new_from_impl(JsonGetStr::default())); - ctx.runtime_env().register_object_store( ObjectStoreUrl::parse("memory://")?.as_ref(), Arc::new(store), ); + // Create a custom table provider that rewrites struct field access + let listing_table_config = + ListingTableConfig::new(ListingTableUrl::parse("memory:///example.parquet")?) + .infer_options(&ctx.state()) + .await? + .with_schema(table_schema) + .with_expr_adapter_factory(Arc::new(ShreddedJsonRewriterFactory)); + let table = ListingTable::try_new(listing_table_config).unwrap(); + let table_provider = Arc::new(table); + + // Register our table + ctx.register_table("structs", table_provider)?; + ctx.register_udf(ScalarUDF::new_from_impl(JsonGetStr::default())); + println!("\n=== Showing all data ==="); let batches = ctx.sql("SELECT * FROM structs").await?.collect().await?; arrow::util::pretty::print_batches(&batches)?; @@ -191,96 +191,6 @@ fn create_sample_record_batch(file_schema: &Schema) -> RecordBatch { .unwrap() } -/// Custom TableProvider that uses a StructFieldRewriter -#[derive(Debug)] -struct ExampleTableProvider { - schema: SchemaRef, -} - -impl ExampleTableProvider { - fn new(schema: SchemaRef) -> Self { - Self { schema } - } -} - -#[async_trait] -impl TableProvider for ExampleTableProvider { - fn as_any(&self) -> &dyn Any { - self - } - - fn schema(&self) -> SchemaRef { - self.schema.clone() - } - - fn table_type(&self) -> TableType { - TableType::Base - } - - fn supports_filters_pushdown( - &self, - filters: &[&Expr], - ) -> Result> { - // Implementers can choose to mark these filters as exact or inexact. - // If marked as exact they cannot have false positives and must always be applied. - // If marked as Inexact they can have false positives and at runtime the rewriter - // can decide to not rewrite / ignore some filters since they will be re-evaluated upstream. - // For the purposes of this example we mark them as Exact to demonstrate the rewriter is working and the filtering is not being re-evaluated upstream. - Ok(vec![TableProviderFilterPushDown::Exact; filters.len()]) - } - - async fn scan( - &self, - state: &dyn Session, - projection: Option<&Vec>, - filters: &[Expr], - limit: Option, - ) -> Result> { - let schema = self.schema.clone(); - let df_schema = DFSchema::try_from(schema.clone())?; - let filter = state.create_physical_expr( - conjunction(filters.iter().cloned()).unwrap_or_else(|| lit(true)), - &df_schema, - )?; - - let parquet_source = ParquetSource::default() - .with_predicate(filter) - .with_pushdown_filters(true); - - let object_store_url = ObjectStoreUrl::parse("memory://")?; - - let store = state.runtime_env().object_store(object_store_url)?; - - let mut files = vec![]; - let mut listing = store.list(None); - while let Some(file) = listing.next().await { - if let Ok(file) = file { - files.push(file); - } - } - - let file_group = files - .iter() - .map(|file| PartitionedFile::new(file.location.clone(), file.size)) - .collect(); - - let file_scan_config = FileScanConfigBuilder::new( - ObjectStoreUrl::parse("memory://")?, - schema, - Arc::new(parquet_source), - ) - .with_projection(projection.cloned()) - .with_limit(limit) - .with_file_group(file_group) - // if the rewriter needs a reference to the table schema you can bind self.schema() here - .with_expr_adapter(Some(Arc::new(ShreddedJsonRewriterFactory) as _)); - - Ok(Arc::new(DataSourceExec::new(Arc::new( - file_scan_config.build(), - )))) - } -} - /// Scalar UDF that uses serde_json to access json fields #[derive(Debug, PartialEq, Eq, Hash)] pub struct JsonGetStr {