Skip to content

Commit

Permalink
Minor: Move TableProviderFactories up out of RuntimeEnv and into Sess…
Browse files Browse the repository at this point in the history
…ionState (#5477)
  • Loading branch information
alamb authored Mar 7, 2023
1 parent 50e9d78 commit deeaa56
Show file tree
Hide file tree
Showing 4 changed files with 84 additions and 63 deletions.
75 changes: 62 additions & 13 deletions datafusion/core/src/execution/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,11 @@
//! SessionContext contains methods for registering data sources and executing queries
use crate::{
catalog::catalog::{CatalogList, MemoryCatalogList},
datasource::listing::{ListingOptions, ListingTable},
datasource::{
datasource::TableProviderFactory,
listing::{ListingOptions, ListingTable},
listing_table_factory::ListingTableFactory,
},
datasource::{MemTable, ViewTable},
logical_expr::{PlanType, ToStringifiedPlan},
optimizer::optimizer::Optimizer,
Expand Down Expand Up @@ -278,6 +282,15 @@ impl SessionContext {
self.session_id.clone()
}

/// Return the [`TableFactoryProvider`] that is registered for the
/// specified file type, if any.
pub fn table_factory(
&self,
file_type: &str,
) -> Option<Arc<dyn TableProviderFactory>> {
self.state.read().table_factories().get(file_type).cloned()
}

/// Return the `enable_ident_normalization` of this Session
pub fn enable_ident_normalization(&self) -> bool {
self.state
Expand Down Expand Up @@ -579,16 +592,16 @@ impl SessionContext {
) -> Result<Arc<dyn TableProvider>> {
let state = self.state.read().clone();
let file_type = cmd.file_type.to_uppercase();
let factory = &state
.runtime_env
.table_factories
.get(file_type.as_str())
.ok_or_else(|| {
DataFusionError::Execution(format!(
"Unable to find factory for {}",
cmd.file_type
))
})?;
let factory =
&state
.table_factories
.get(file_type.as_str())
.ok_or_else(|| {
DataFusionError::Execution(format!(
"Unable to find factory for {}",
cmd.file_type
))
})?;
let table = (*factory).create(&state, cmd).await?;
Ok(table)
}
Expand Down Expand Up @@ -1507,6 +1520,14 @@ pub struct SessionState {
config: SessionConfig,
/// Execution properties
execution_props: ExecutionProps,
/// TableProviderFactories for different file formats.
///
/// Maps strings like "JSON" to an instance of [`TableProviderFactory`]
///
/// This is used to create [`TableProvider`] instances for the
/// `CREATE EXTERNAL TABLE ... STORED AS <FORMAT>` for custom file
/// formats other than those built into DataFusion
table_factories: HashMap<String, Arc<dyn TableProviderFactory>>,
/// Runtime environment
runtime_env: Arc<RuntimeEnv>,
}
Expand Down Expand Up @@ -1540,6 +1561,15 @@ impl SessionState {
) -> Self {
let session_id = Uuid::new_v4().to_string();

// Create table_factories for all default formats
let mut table_factories: HashMap<String, Arc<dyn TableProviderFactory>> =
HashMap::new();
table_factories.insert("PARQUET".into(), Arc::new(ListingTableFactory::new()));
table_factories.insert("CSV".into(), Arc::new(ListingTableFactory::new()));
table_factories.insert("JSON".into(), Arc::new(ListingTableFactory::new()));
table_factories.insert("NDJSON".into(), Arc::new(ListingTableFactory::new()));
table_factories.insert("AVRO".into(), Arc::new(ListingTableFactory::new()));

if config.create_default_catalog_and_schema() {
let default_catalog = MemoryCatalogProvider::new();

Expand All @@ -1550,7 +1580,12 @@ impl SessionState {
)
.expect("memory catalog provider can register schema");

Self::register_default_schema(&config, &runtime, &default_catalog);
Self::register_default_schema(
&config,
&table_factories,
&runtime,
&default_catalog,
);

catalog_list.register_catalog(
config.config_options().catalog.default_catalog.clone(),
Expand Down Expand Up @@ -1619,11 +1654,13 @@ impl SessionState {
config,
execution_props: ExecutionProps::new(),
runtime_env: runtime,
table_factories,
}
}

fn register_default_schema(
config: &SessionConfig,
table_factories: &HashMap<String, Arc<dyn TableProviderFactory>>,
runtime: &Arc<RuntimeEnv>,
default_catalog: &MemoryCatalogProvider,
) {
Expand All @@ -1650,7 +1687,7 @@ impl SessionState {
Ok(store) => store,
_ => return,
};
let factory = match runtime.table_factories.get(format.as_str()) {
let factory = match table_factories.get(format.as_str()) {
Some(factory) => factory,
_ => return,
};
Expand Down Expand Up @@ -1756,6 +1793,18 @@ impl SessionState {
self
}

/// Get the table factories
pub fn table_factories(&self) -> &HashMap<String, Arc<dyn TableProviderFactory>> {
&self.table_factories
}

/// Get the table factories
pub fn table_factories_mut(
&mut self,
) -> &mut HashMap<String, Arc<dyn TableProviderFactory>> {
&mut self.table_factories
}

/// Convert a SQL string into an AST Statement
pub fn sql_to_statement(
&self,
Expand Down
38 changes: 1 addition & 37 deletions datafusion/core/src/execution/runtime_env.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,7 @@ use crate::{
error::Result,
execution::disk_manager::{DiskManager, DiskManagerConfig},
};
use std::collections::HashMap;

use crate::datasource::datasource::TableProviderFactory;
use crate::datasource::listing_table_factory::ListingTableFactory;
use crate::datasource::object_store::ObjectStoreRegistry;
use crate::execution::memory_pool::{GreedyMemoryPool, MemoryPool, UnboundedMemoryPool};
use datafusion_common::DataFusionError;
Expand All @@ -44,8 +41,6 @@ pub struct RuntimeEnv {
pub disk_manager: Arc<DiskManager>,
/// Object Store Registry
pub object_store_registry: Arc<ObjectStoreRegistry>,
/// TableProviderFactories
pub table_factories: HashMap<String, Arc<dyn TableProviderFactory>>,
}

impl Debug for RuntimeEnv {
Expand All @@ -61,7 +56,6 @@ impl RuntimeEnv {
memory_pool,
disk_manager,
object_store_registry,
table_factories,
} = config;

let memory_pool =
Expand All @@ -71,7 +65,6 @@ impl RuntimeEnv {
memory_pool,
disk_manager: DiskManager::try_new(disk_manager)?,
object_store_registry,
table_factories,
})
}

Expand All @@ -94,14 +87,6 @@ impl RuntimeEnv {
.register_store(scheme, host, object_store)
}

/// Registers TableFactories
pub fn register_table_factories(
&mut self,
table_factories: HashMap<String, Arc<dyn TableProviderFactory>>,
) {
self.table_factories.extend(table_factories)
}

/// Retrieves a `ObjectStore` instance for a url by consulting the
/// registery. See [`ObjectStoreRegistry::get_by_url`] for more
/// details.
Expand Down Expand Up @@ -129,24 +114,12 @@ pub struct RuntimeConfig {
pub memory_pool: Option<Arc<dyn MemoryPool>>,
/// ObjectStoreRegistry to get object store based on url
pub object_store_registry: Arc<ObjectStoreRegistry>,
/// Custom table factories for things like deltalake that are not part of core datafusion
pub table_factories: HashMap<String, Arc<dyn TableProviderFactory>>,
}

impl RuntimeConfig {
/// New with default values
pub fn new() -> Self {
let mut table_factories: HashMap<String, Arc<dyn TableProviderFactory>> =
HashMap::new();
table_factories.insert("PARQUET".into(), Arc::new(ListingTableFactory::new()));
table_factories.insert("CSV".into(), Arc::new(ListingTableFactory::new()));
table_factories.insert("JSON".into(), Arc::new(ListingTableFactory::new()));
table_factories.insert("NDJSON".into(), Arc::new(ListingTableFactory::new()));
table_factories.insert("AVRO".into(), Arc::new(ListingTableFactory::new()));
Self {
table_factories,
..Default::default()
}
Default::default()
}

/// Customize disk manager
Expand All @@ -170,15 +143,6 @@ impl RuntimeConfig {
self
}

/// Customize object store registry
pub fn with_table_factories(
mut self,
table_factories: HashMap<String, Arc<dyn TableProviderFactory>>,
) -> Self {
self.table_factories = table_factories;
self
}

/// Specify the total memory to use while running the DataFusion
/// plan to `max_memory * memory_fraction` in bytes.
///
Expand Down
21 changes: 13 additions & 8 deletions datafusion/core/tests/sql/create_drop.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
// specific language governing permissions and limitations
// under the License.

use datafusion::execution::context::SessionState;
use datafusion::execution::runtime_env::{RuntimeConfig, RuntimeEnv};
use datafusion::test_util::TestTableFactory;

Expand Down Expand Up @@ -106,12 +107,14 @@ async fn sql_create_table_exists() -> Result<()> {

#[tokio::test]
async fn create_custom_table() -> Result<()> {
let mut cfg = RuntimeConfig::new();
cfg.table_factories
.insert("DELTATABLE".to_string(), Arc::new(TestTableFactory {}));
let cfg = RuntimeConfig::new();
let env = RuntimeEnv::new(cfg).unwrap();
let ses = SessionConfig::new();
let ctx = SessionContext::with_config_rt(ses, Arc::new(env));
let mut state = SessionState::with_config_rt(ses, Arc::new(env));
state
.table_factories_mut()
.insert("DELTATABLE".to_string(), Arc::new(TestTableFactory {}));
let ctx = SessionContext::with_state(state);

let sql = "CREATE EXTERNAL TABLE dt STORED AS DELTATABLE LOCATION 's3://bucket/schema/table';";
ctx.sql(sql).await.unwrap();
Expand All @@ -126,12 +129,14 @@ async fn create_custom_table() -> Result<()> {

#[tokio::test]
async fn create_external_table_with_ddl() -> Result<()> {
let mut cfg = RuntimeConfig::new();
cfg.table_factories
.insert("MOCKTABLE".to_string(), Arc::new(TestTableFactory {}));
let cfg = RuntimeConfig::new();
let env = RuntimeEnv::new(cfg).unwrap();
let ses = SessionConfig::new();
let ctx = SessionContext::with_config_rt(ses, Arc::new(env));
let mut state = SessionState::with_config_rt(ses, Arc::new(env));
state
.table_factories_mut()
.insert("MOCKTABLE".to_string(), Arc::new(TestTableFactory {}));
let ctx = SessionContext::with_state(state);

let sql = "CREATE EXTERNAL TABLE dt (a_id integer, a_str string, a_bool boolean) STORED AS MOCKTABLE LOCATION 'mockprotocol://path/to/table';";
ctx.sql(sql).await.unwrap();
Expand Down
13 changes: 8 additions & 5 deletions datafusion/proto/src/logical_plan/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -496,10 +496,9 @@ impl AsLogicalPlan for LogicalPlanNode {
};

let file_type = create_extern_table.file_type.as_str();
let env = ctx.runtime_env();
if !env.table_factories.contains_key(file_type) {
if ctx.table_factory(file_type).is_none() {
Err(DataFusionError::Internal(format!(
"No TableProvider for file type: {file_type}"
"No TableProviderFactory for file type: {file_type}"
)))?
}

Expand Down Expand Up @@ -1377,6 +1376,7 @@ mod roundtrip_tests {
};
use datafusion::datasource::datasource::TableProviderFactory;
use datafusion::datasource::TableProvider;
use datafusion::execution::context::SessionState;
use datafusion::execution::runtime_env::{RuntimeConfig, RuntimeEnv};
use datafusion::physical_plan::functions::make_scalar_function;
use datafusion::prelude::{
Expand Down Expand Up @@ -1523,10 +1523,13 @@ mod roundtrip_tests {
let mut table_factories: HashMap<String, Arc<dyn TableProviderFactory>> =
HashMap::new();
table_factories.insert("TESTTABLE".to_string(), Arc::new(TestTableFactory {}));
let cfg = RuntimeConfig::new().with_table_factories(table_factories);
let cfg = RuntimeConfig::new();
let env = RuntimeEnv::new(cfg).unwrap();
let ses = SessionConfig::new();
let ctx = SessionContext::with_config_rt(ses, Arc::new(env));
let mut state = SessionState::with_config_rt(ses, Arc::new(env));
// replace factories
*state.table_factories_mut() = table_factories;
let ctx = SessionContext::with_state(state);

let sql = "CREATE EXTERNAL TABLE t STORED AS testtable LOCATION 's3://bucket/schema/table';";
ctx.sql(sql).await.unwrap();
Expand Down

0 comments on commit deeaa56

Please sign in to comment.