Skip to content

Commit

Permalink
Change FileFormat to depends on TaskContext rather than `SessionS…
Browse files Browse the repository at this point in the history
…tate`
  • Loading branch information
alamb committed Apr 4, 2023
1 parent 33af59e commit 5b79b5e
Show file tree
Hide file tree
Showing 23 changed files with 187 additions and 133 deletions.
2 changes: 1 addition & 1 deletion benchmarks/src/bin/tpch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -387,7 +387,7 @@ async fn get_table(
let config = ListingTableConfig::new(table_path).with_listing_options(options);

let config = match table_format {
"parquet" => config.infer_schema(&state).await?,
"parquet" => config.infer_schema(&state.task_ctx()).await?,
"tbl" => config.with_schema(Arc::new(get_tbl_tpch_table_schema(table))),
"csv" => config.with_schema(Arc::new(get_tpch_table_schema(table))),
_ => unreachable!(),
Expand Down
46 changes: 23 additions & 23 deletions datafusion-cli/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions datafusion-cli/src/catalog.rs
Original file line number Diff line number Diff line change
Expand Up @@ -147,9 +147,9 @@ impl SchemaProvider for DynamicFileSchemaProvider {

// if the inner schema provider didn't have a table by
// that name, try to treat it as a listing table
let state = self.state.upgrade()?.read().clone();
let task_ctx = self.state.upgrade()?.read().task_ctx();
let config = ListingTableConfig::new(ListingTableUrl::parse(name).ok()?)
.infer(&state)
.infer(&task_ctx)
.await
.ok()?;
Some(Arc::new(ListingTable::try_new(config).ok()?))
Expand Down
14 changes: 9 additions & 5 deletions datafusion-examples/examples/catalog.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ use datafusion::{
TableProvider,
},
error::Result,
execution::context::SessionState,
execution::context::TaskContext,
prelude::SessionContext,
};
use std::sync::RwLock;
Expand All @@ -53,6 +53,7 @@ async fn main() -> Result<()> {
.unwrap();
let mut ctx = SessionContext::new();
let state = ctx.state();
let task_ctx = state.task_ctx();
let catlist = Arc::new(CustomCatalogList::new());
// use our custom catalog list for context. each context has a single catalog list.
// context will by default have MemoryCatalogList
Expand All @@ -61,7 +62,7 @@ async fn main() -> Result<()> {
// intitialize our catalog and schemas
let catalog = DirCatalog::new();
let parquet_schema = DirSchema::create(
&state,
&task_ctx,
DirSchemaOpts {
format: Arc::new(ParquetFormat::default()),
dir: &repo_dir.join("parquet-testing").join("data"),
Expand All @@ -70,7 +71,7 @@ async fn main() -> Result<()> {
)
.await?;
let csv_schema = DirSchema::create(
&state,
&task_ctx,
DirSchemaOpts {
format: Arc::new(CsvFormat::default()),
dir: &repo_dir.join("testing").join("data").join("csv"),
Expand Down Expand Up @@ -138,7 +139,10 @@ struct DirSchema {
tables: RwLock<HashMap<String, Arc<dyn TableProvider>>>,
}
impl DirSchema {
async fn create(state: &SessionState, opts: DirSchemaOpts<'_>) -> Result<Arc<Self>> {
async fn create(
task_ctx: &TaskContext,
opts: DirSchemaOpts<'_>,
) -> Result<Arc<Self>> {
let DirSchemaOpts { ext, dir, format } = opts;
let mut tables = HashMap::new();
let listdir = std::fs::read_dir(dir).unwrap();
Expand All @@ -153,7 +157,7 @@ impl DirSchema {
let opts = ListingOptions::new(format.clone());
let conf = ListingTableConfig::new(table_path)
.with_listing_options(opts)
.infer_schema(state)
.infer_schema(task_ctx)
.await?;
let table = ListingTable::try_new(conf)?;
tables.insert(filename, Arc::new(table) as Arc<dyn TableProvider>);
Expand Down
2 changes: 1 addition & 1 deletion datafusion-examples/examples/flight_server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ impl FlightService for FlightServiceImpl {

let ctx = SessionContext::new();
let schema = listing_options
.infer_schema(&ctx.state(), &table_path)
.infer_schema(&ctx.task_ctx(), &table_path)
.await
.unwrap();

Expand Down
2 changes: 1 addition & 1 deletion datafusion/core/src/catalog/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ mod tests {
let ctx = SessionContext::new();

let config = ListingTableConfig::new(table_path)
.infer(&ctx.state())
.infer(&ctx.task_ctx())
.await
.unwrap();
let table = ListingTable::try_new(config).unwrap();
Expand Down
23 changes: 17 additions & 6 deletions datafusion/core/src/datasource/file_format/avro.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,13 @@ use std::sync::Arc;
use arrow::datatypes::Schema;
use arrow::{self, datatypes::SchemaRef};
use async_trait::async_trait;
use datafusion_execution::TaskContext;
use datafusion_physical_expr::PhysicalExpr;
use object_store::{GetResult, ObjectMeta, ObjectStore};

use super::FileFormat;
use crate::avro_to_arrow::read_avro_schema_from_reader;
use crate::error::Result;
use crate::execution::context::SessionState;
use crate::physical_plan::file_format::{AvroExec, FileScanConfig};
use crate::physical_plan::ExecutionPlan;
use crate::physical_plan::Statistics;
Expand All @@ -48,7 +48,7 @@ impl FileFormat for AvroFormat {

async fn infer_schema(
&self,
_state: &SessionState,
_task_ctx: &TaskContext,
store: &Arc<dyn ObjectStore>,
objects: &[ObjectMeta],
) -> Result<SchemaRef> {
Expand All @@ -70,7 +70,7 @@ impl FileFormat for AvroFormat {

async fn infer_stats(
&self,
_state: &SessionState,
_task_ctx: &TaskContext,
_store: &Arc<dyn ObjectStore>,
_table_schema: SchemaRef,
_object: &ObjectMeta,
Expand All @@ -80,7 +80,7 @@ impl FileFormat for AvroFormat {

async fn create_physical_plan(
&self,
_state: &SessionState,
_task_ctx: &TaskContext,
conf: FileScanConfig,
_filters: Option<&Arc<dyn PhysicalExpr>>,
) -> Result<Arc<dyn ExecutionPlan>> {
Expand All @@ -94,6 +94,7 @@ impl FileFormat for AvroFormat {
mod tests {
use super::*;
use crate::datasource::file_format::test_util::scan_format;
use crate::execution::context::SessionState;
use crate::physical_plan::collect;
use crate::prelude::{SessionConfig, SessionContext};
use datafusion_common::cast::{
Expand Down Expand Up @@ -359,7 +360,15 @@ mod tests {
let testdata = crate::test_util::arrow_test_data();
let store_root = format!("{testdata}/avro");
let format = AvroFormat {};
scan_format(state, &format, &store_root, file_name, projection, limit).await
scan_format(
&state.task_ctx(),
&format,
&store_root,
file_name,
projection,
limit,
)
.await
}
}

Expand All @@ -379,7 +388,9 @@ mod tests {
let format = AvroFormat {};
let testdata = crate::test_util::arrow_test_data();
let filename = "avro/alltypes_plain.avro";
let result = scan_format(&state, &format, &testdata, filename, None, None).await;
let result =
scan_format(&state.task_ctx(), &format, &testdata, filename, None, None)
.await;
assert!(matches!(
result,
Err(DataFusionError::NotImplemented(msg))
Expand Down
14 changes: 8 additions & 6 deletions datafusion/core/src/datasource/file_format/csv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,10 @@ use super::FileFormat;
use crate::datasource::file_format::file_type::FileCompressionType;
use crate::datasource::file_format::DEFAULT_SCHEMA_INFER_MAX_RECORD;
use crate::error::Result;
use crate::execution::context::SessionState;
use crate::physical_plan::file_format::{CsvExec, FileScanConfig};
use crate::physical_plan::ExecutionPlan;
use crate::physical_plan::Statistics;
use datafusion_execution::TaskContext;

/// The default file extension of csv files
pub const DEFAULT_CSV_EXTENSION: &str = ".csv";
Expand Down Expand Up @@ -115,7 +115,7 @@ impl FileFormat for CsvFormat {

async fn infer_schema(
&self,
_state: &SessionState,
_task_ctx: &TaskContext,
store: &Arc<dyn ObjectStore>,
objects: &[ObjectMeta],
) -> Result<SchemaRef> {
Expand All @@ -142,7 +142,7 @@ impl FileFormat for CsvFormat {

async fn infer_stats(
&self,
_state: &SessionState,
_task_ctx: &TaskContext,
_store: &Arc<dyn ObjectStore>,
_table_schema: SchemaRef,
_object: &ObjectMeta,
Expand All @@ -152,7 +152,7 @@ impl FileFormat for CsvFormat {

async fn create_physical_plan(
&self,
_state: &SessionState,
_task_ctx: &TaskContext,
conf: FileScanConfig,
_filters: Option<&Arc<dyn PhysicalExpr>>,
) -> Result<Arc<dyn ExecutionPlan>> {
Expand Down Expand Up @@ -296,6 +296,7 @@ mod tests {
use super::super::test_util::scan_format;
use super::*;
use crate::datasource::file_format::test_util::VariableStream;
use crate::execution::context::SessionState;
use crate::physical_plan::collect;
use crate::prelude::{SessionConfig, SessionContext};
use bytes::Bytes;
Expand Down Expand Up @@ -430,7 +431,7 @@ mod tests {
};
let inferred_schema = csv_format
.infer_schema(
&state,
&state.task_ctx(),
&(variable_object_store.clone() as Arc<dyn ObjectStore>),
&[object_meta],
)
Expand Down Expand Up @@ -467,8 +468,9 @@ mod tests {
projection: Option<Vec<usize>>,
limit: Option<usize>,
) -> Result<Arc<dyn ExecutionPlan>> {
let task_ctx = state.task_ctx();
let root = format!("{}/csv", crate::test_util::arrow_test_data());
let format = CsvFormat::default();
scan_format(state, &format, &root, file_name, projection, limit).await
scan_format(&task_ctx, &format, &root, file_name, projection, limit).await
}
}
Loading

0 comments on commit 5b79b5e

Please sign in to comment.