Skip to content

Commit

Permalink
Allow setting config extensions for TaskContext (#5497)
Browse files Browse the repository at this point in the history
* allow setting config extensions for TaskContext

* builder like api for ConfigOptions::with_extensions
  • Loading branch information
mpurins-coralogix authored Mar 8, 2023
1 parent 0ead640 commit 8a1b133
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 7 deletions.
6 changes: 6 additions & 0 deletions datafusion/common/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -397,6 +397,12 @@ impl ConfigOptions {
Self::default()
}

/// Set extensions to provided value
pub fn with_extensions(mut self, extensions: Extensions) -> Self {
self.extensions = extensions;
self
}

/// Set a configuration option
pub fn set(&mut self, key: &str, value: &str) -> Result<()> {
let (prefix, key) = key.split_once('.').ok_or_else(|| {
Expand Down
56 changes: 49 additions & 7 deletions datafusion/core/src/execution/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ use crate::physical_plan::PhysicalPlanner;
use crate::variable::{VarProvider, VarType};
use async_trait::async_trait;
use chrono::{DateTime, Utc};
use datafusion_common::{OwnedTableReference, ScalarValue};
use datafusion_common::{config::Extensions, OwnedTableReference, ScalarValue};
use datafusion_sql::{
parser::DFParser,
planner::{ContextProvider, SqlToRel},
Expand Down Expand Up @@ -2143,27 +2143,28 @@ pub struct TaskContext {

impl TaskContext {
/// Create a new task context instance
pub fn new(
pub fn try_new(
task_id: String,
session_id: String,
task_props: HashMap<String, String>,
scalar_functions: HashMap<String, Arc<ScalarUDF>>,
aggregate_functions: HashMap<String, Arc<AggregateUDF>>,
runtime: Arc<RuntimeEnv>,
) -> Self {
let mut config = ConfigOptions::new();
extensions: Extensions,
) -> Result<Self> {
let mut config = ConfigOptions::new().with_extensions(extensions);
for (k, v) in task_props {
let _ = config.set(&k, &v);
config.set(&k, &v)?;
}

Self {
Ok(Self {
task_id: Some(task_id),
session_id,
session_config: config.into(),
scalar_functions,
aggregate_functions,
runtime,
}
})
}

/// Return the SessionConfig associated with the Task
Expand Down Expand Up @@ -2258,6 +2259,8 @@ mod tests {
use arrow::array::ArrayRef;
use arrow::record_batch::RecordBatch;
use async_trait::async_trait;
use datafusion_common::config::ConfigExtension;
use datafusion_common::extensions_options;
use datafusion_expr::{create_udaf, create_udf, Expr, Volatility};
use datafusion_physical_expr::functions::make_scalar_function;
use std::fs::File;
Expand Down Expand Up @@ -2925,4 +2928,43 @@ mod tests {
.unwrap()
}
}

extensions_options! {
struct TestExtension {
value: usize, default = 42
}
}

impl ConfigExtension for TestExtension {
const PREFIX: &'static str = "test";
}

#[test]
fn task_context_extensions() -> Result<()> {
let runtime = Arc::new(RuntimeEnv::default());
let task_props = HashMap::from([("test.value".to_string(), "24".to_string())]);
let mut extensions = Extensions::default();
extensions.insert(TestExtension::default());

let task_context = TaskContext::try_new(
"task_id".to_string(),
"session_id".to_string(),
task_props,
HashMap::default(),
HashMap::default(),
runtime,
extensions,
)?;

let test = task_context
.session_config()
.config_options()
.extensions
.get::<TestExtension>();
assert!(test.is_some());

assert_eq!(test.unwrap().value, 24);

Ok(())
}
}

0 comments on commit 8a1b133

Please sign in to comment.