Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
130 changes: 20 additions & 110 deletions datafusion-examples/examples/json_shredding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,35 +20,29 @@ use std::sync::Arc;

use arrow::array::{RecordBatch, StringArray};
use arrow::datatypes::{DataType, Field, FieldRef, Schema, SchemaRef};
use async_trait::async_trait;

use datafusion::assert_batches_eq;
use datafusion::catalog::memory::DataSourceExec;
use datafusion::catalog::{Session, TableProvider};
use datafusion::common::tree_node::{
Transformed, TransformedResult, TreeNode, TreeNodeRecursion,
};
use datafusion::common::{assert_contains, DFSchema, Result};
use datafusion::datasource::listing::PartitionedFile;
use datafusion::datasource::physical_plan::{FileScanConfigBuilder, ParquetSource};
use datafusion::common::{assert_contains, Result};
use datafusion::datasource::listing::{
ListingTable, ListingTableConfig, ListingTableUrl,
};
use datafusion::execution::context::SessionContext;
use datafusion::execution::object_store::ObjectStoreUrl;
use datafusion::logical_expr::utils::conjunction;
use datafusion::logical_expr::{
ColumnarValue, Expr, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature,
TableProviderFilterPushDown, TableType, Volatility,
ColumnarValue, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature, Volatility,
};
use datafusion::parquet::arrow::ArrowWriter;
use datafusion::parquet::file::properties::WriterProperties;
use datafusion::physical_expr::PhysicalExpr;
use datafusion::physical_expr::{expressions, ScalarFunctionExpr};
use datafusion::physical_plan::ExecutionPlan;
use datafusion::prelude::{lit, SessionConfig};
use datafusion::prelude::SessionConfig;
use datafusion::scalar::ScalarValue;
use datafusion_physical_expr_adapter::{
DefaultPhysicalExprAdapterFactory, PhysicalExprAdapter, PhysicalExprAdapterFactory,
};
use futures::StreamExt;
use object_store::memory::InMemory;
use object_store::path::Path;
use object_store::{ObjectStore, PutPayload};
Expand Down Expand Up @@ -95,23 +89,29 @@ async fn main() -> Result<()> {
let payload = PutPayload::from_bytes(buf.into());
store.put(&path, payload).await?;

// Create a custom table provider that rewrites struct field access
let table_provider = Arc::new(ExampleTableProvider::new(table_schema));

// Set up query execution
let mut cfg = SessionConfig::new();
cfg.options_mut().execution.parquet.pushdown_filters = true;
let ctx = SessionContext::new_with_config(cfg);

// Register our table
ctx.register_table("structs", table_provider)?;
ctx.register_udf(ScalarUDF::new_from_impl(JsonGetStr::default()));

ctx.runtime_env().register_object_store(
ObjectStoreUrl::parse("memory://")?.as_ref(),
Arc::new(store),
);

// Create a custom table provider that rewrites struct field access
let listing_table_config =
ListingTableConfig::new(ListingTableUrl::parse("memory:///example.parquet")?)
.infer_options(&ctx.state())
.await?
.with_schema(table_schema)
.with_expr_adapter_factory(Arc::new(ShreddedJsonRewriterFactory));
let table = ListingTable::try_new(listing_table_config).unwrap();
let table_provider = Arc::new(table);

// Register our table
ctx.register_table("structs", table_provider)?;
ctx.register_udf(ScalarUDF::new_from_impl(JsonGetStr::default()));

println!("\n=== Showing all data ===");
let batches = ctx.sql("SELECT * FROM structs").await?.collect().await?;
arrow::util::pretty::print_batches(&batches)?;
Expand Down Expand Up @@ -191,96 +191,6 @@ fn create_sample_record_batch(file_schema: &Schema) -> RecordBatch {
.unwrap()
}

/// Custom TableProvider that uses a StructFieldRewriter
#[derive(Debug)]
struct ExampleTableProvider {
schema: SchemaRef,
}

impl ExampleTableProvider {
fn new(schema: SchemaRef) -> Self {
Self { schema }
}
}

#[async_trait]
impl TableProvider for ExampleTableProvider {
fn as_any(&self) -> &dyn Any {
self
}

fn schema(&self) -> SchemaRef {
self.schema.clone()
}

fn table_type(&self) -> TableType {
TableType::Base
}

fn supports_filters_pushdown(
&self,
filters: &[&Expr],
) -> Result<Vec<TableProviderFilterPushDown>> {
// Implementers can choose to mark these filters as exact or inexact.
// If marked as exact they cannot have false positives and must always be applied.
// If marked as Inexact they can have false positives and at runtime the rewriter
// can decide to not rewrite / ignore some filters since they will be re-evaluated upstream.
// For the purposes of this example we mark them as Exact to demonstrate the rewriter is working and the filtering is not being re-evaluated upstream.
Ok(vec![TableProviderFilterPushDown::Exact; filters.len()])
}

async fn scan(
&self,
state: &dyn Session,
projection: Option<&Vec<usize>>,
filters: &[Expr],
limit: Option<usize>,
) -> Result<Arc<dyn ExecutionPlan>> {
let schema = self.schema.clone();
let df_schema = DFSchema::try_from(schema.clone())?;
let filter = state.create_physical_expr(
conjunction(filters.iter().cloned()).unwrap_or_else(|| lit(true)),
&df_schema,
)?;

let parquet_source = ParquetSource::default()
.with_predicate(filter)
.with_pushdown_filters(true);

let object_store_url = ObjectStoreUrl::parse("memory://")?;

let store = state.runtime_env().object_store(object_store_url)?;

let mut files = vec![];
let mut listing = store.list(None);
while let Some(file) = listing.next().await {
if let Ok(file) = file {
files.push(file);
}
}

let file_group = files
.iter()
.map(|file| PartitionedFile::new(file.location.clone(), file.size))
.collect();

let file_scan_config = FileScanConfigBuilder::new(
ObjectStoreUrl::parse("memory://")?,
schema,
Arc::new(parquet_source),
)
.with_projection(projection.cloned())
.with_limit(limit)
.with_file_group(file_group)
// if the rewriter needs a reference to the table schema you can bind self.schema() here
.with_expr_adapter(Some(Arc::new(ShreddedJsonRewriterFactory) as _));

Ok(Arc::new(DataSourceExec::new(Arc::new(
file_scan_config.build(),
))))
}
}

/// Scalar UDF that uses serde_json to access json fields
#[derive(Debug, PartialEq, Eq, Hash)]
pub struct JsonGetStr {
Expand Down