Skip to content

Commit 43aa441

Browse files
committed
udtf: provide session state ref to the call
This patch adds session state arg for the [`TableFunctionImpl::call`] method. It is useful to implement table functions depending on other tables from the state. For example, a table functions that return current list of all views in the state.
1 parent 8821e01 commit 43aa441

File tree

8 files changed

+52
-15
lines changed

8 files changed

+52
-15
lines changed

datafusion-cli/src/functions.rs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -317,7 +317,11 @@ fn fixed_len_byte_array_to_string(val: Option<&FixedLenByteArray>) -> Option<Str
317317
pub struct ParquetMetadataFunc {}
318318

319319
impl TableFunctionImpl for ParquetMetadataFunc {
320-
fn call(&self, exprs: &[Expr]) -> Result<Arc<dyn TableProvider>> {
320+
fn call(
321+
&self,
322+
_state: &SessionState,
323+
exprs: &[Expr],
324+
) -> Result<Arc<dyn TableProvider>> {
321325
let filename = match exprs.first() {
322326
Some(Expr::Literal(ScalarValue::Utf8(Some(s)))) => s, // single quote: parquet_metadata('x.parquet')
323327
Some(Expr::Column(Column { name, .. })) => name, // double quote: parquet_metadata("x.parquet")

datafusion-examples/examples/simple_udtf.rs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ use datafusion::datasource::function::TableFunctionImpl;
2525
use datafusion::datasource::TableProvider;
2626
use datafusion::error::Result;
2727
use datafusion::execution::context::ExecutionProps;
28+
use datafusion::execution::SessionState;
2829
use datafusion::physical_plan::memory::MemoryExec;
2930
use datafusion::physical_plan::ExecutionPlan;
3031
use datafusion::prelude::SessionContext;
@@ -130,7 +131,11 @@ impl TableProvider for LocalCsvTable {
130131
struct LocalCsvTableFunc {}
131132

132133
impl TableFunctionImpl for LocalCsvTableFunc {
133-
fn call(&self, exprs: &[Expr]) -> Result<Arc<dyn TableProvider>> {
134+
fn call(
135+
&self,
136+
_state: &SessionState,
137+
exprs: &[Expr],
138+
) -> Result<Arc<dyn TableProvider>> {
134139
let Some(Expr::Literal(ScalarValue::Utf8(Some(ref path)))) = exprs.first() else {
135140
return plan_err!("read_csv requires at least one string argument");
136141
};

datafusion/core/src/datasource/function.rs

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717

1818
//! A table that uses a function to generate data
1919
20+
use crate::execution::SessionState;
21+
2022
use super::TableProvider;
2123

2224
use datafusion_common::Result;
@@ -27,7 +29,8 @@ use std::sync::Arc;
2729
/// A trait for table function implementations
2830
pub trait TableFunctionImpl: Sync + Send {
2931
/// Create a table provider
30-
fn call(&self, args: &[Expr]) -> Result<Arc<dyn TableProvider>>;
32+
fn call(&self, state: &SessionState, args: &[Expr])
33+
-> Result<Arc<dyn TableProvider>>;
3134
}
3235

3336
/// A table that uses a function to generate data
@@ -55,7 +58,11 @@ impl TableFunction {
5558
}
5659

5760
/// Get the function implementation and generate a table
58-
pub fn create_table_provider(&self, args: &[Expr]) -> Result<Arc<dyn TableProvider>> {
59-
self.fun.call(args)
61+
pub fn create_table_provider(
62+
&self,
63+
state: &SessionState,
64+
args: &[Expr],
65+
) -> Result<Arc<dyn TableProvider>> {
66+
self.fun.call(state, args)
6067
}
6168
}

datafusion/core/src/execution/session_state.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1543,7 +1543,7 @@ impl ContextProvider for SessionContextProvider<'_> {
15431543
.get(name)
15441544
.cloned()
15451545
.ok_or_else(|| plan_datafusion_err!("table function '{name}' not found"))?;
1546-
let provider = tbl_func.create_table_provider(&args)?;
1546+
let provider = tbl_func.create_table_provider(self.state, &args)?;
15471547

15481548
Ok(provider_as_source(provider))
15491549
}

datafusion/core/tests/user_defined/user_defined_table_functions.rs

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ use datafusion::arrow::record_batch::RecordBatch;
2424
use datafusion::datasource::function::TableFunctionImpl;
2525
use datafusion::datasource::TableProvider;
2626
use datafusion::error::Result;
27-
use datafusion::execution::TaskContext;
27+
use datafusion::execution::{SessionState, TaskContext};
2828
use datafusion::physical_plan::memory::MemoryExec;
2929
use datafusion::physical_plan::{collect, ExecutionPlan};
3030
use datafusion::prelude::SessionContext;
@@ -194,7 +194,11 @@ impl SimpleCsvTable {
194194
struct SimpleCsvTableFunc {}
195195

196196
impl TableFunctionImpl for SimpleCsvTableFunc {
197-
fn call(&self, exprs: &[Expr]) -> Result<Arc<dyn TableProvider>> {
197+
fn call(
198+
&self,
199+
_state: &SessionState,
200+
exprs: &[Expr],
201+
) -> Result<Arc<dyn TableProvider>> {
198202
let mut new_exprs = vec![];
199203
let mut filepath = String::new();
200204
for expr in exprs {

datafusion/sql/src/query.rs

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -86,15 +86,21 @@ impl<S: ContextProvider> SqlToRel<'_, S> {
8686
}
8787

8888
let skip = match skip {
89-
Some(skip_expr) => self.get_constant_usize_result(skip_expr.value, input.schema(), "OFFSET"),
89+
Some(skip_expr) => {
90+
self.get_constant_usize_result(skip_expr.value, input.schema(), "OFFSET")
91+
}
9092
_ => Ok(0),
9193
}?;
9294

9395
let fetch = match fetch {
9496
Some(limit_expr)
9597
if limit_expr != sqlparser::ast::Expr::Value(Value::Null) =>
9698
{
97-
Some(self.get_constant_usize_result(limit_expr, input.schema(), "LIMIT")?)
99+
Some(self.get_constant_usize_result(
100+
limit_expr,
101+
input.schema(),
102+
"LIMIT",
103+
)?)
98104
}
99105
_ => None,
100106
};
@@ -159,7 +165,12 @@ impl<S: ContextProvider> SqlToRel<'_, S> {
159165
///
160166
/// * `Result<usize>` - An `Ok` variant containing the constant result if evaluation is successful,
161167
/// or an `Err` variant containing an error message if evaluation fails.
162-
pub(super) fn get_constant_usize_result(&self, expr: SQLExpr, schema: &datafusion_common::DFSchema, arg_name: &str) -> Result<usize> {
168+
pub(super) fn get_constant_usize_result(
169+
&self,
170+
expr: SQLExpr,
171+
schema: &datafusion_common::DFSchema,
172+
arg_name: &str,
173+
) -> Result<usize> {
163174
let expr = self.sql_to_expr(expr, schema, &mut PlannerContext::new())?;
164175
let value = get_constant_result(&expr, arg_name)?;
165176
convert_usize_with_check(value, arg_name)

datafusion/sql/src/statement.rs

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1219,7 +1219,7 @@ impl<S: ContextProvider> SqlToRel<'_, S> {
12191219
&self,
12201220
table_name: ObjectName,
12211221
predicate_expr: Option<SQLExpr>,
1222-
limit: Option<SQLExpr>
1222+
limit: Option<SQLExpr>,
12231223
) -> Result<LogicalPlan> {
12241224
// Do a table lookup to verify the table exists
12251225
let table_ref = self.object_name_to_table_reference(table_name.clone())?;
@@ -1251,9 +1251,11 @@ impl<S: ContextProvider> SqlToRel<'_, S> {
12511251
}
12521252
};
12531253

1254-
if let Some(limit_expr) = limit {
1254+
if let Some(limit_expr) = limit {
12551255
let limit = (limit_expr != SQLExpr::Value(Value::Null))
1256-
.then(|| self.get_constant_usize_result(limit_expr, source.schema(), "LIMIT"))
1256+
.then(|| {
1257+
self.get_constant_usize_result(limit_expr, source.schema(), "LIMIT")
1258+
})
12571259
.transpose()?;
12581260

12591261
source = LogicalPlanBuilder::from(source).limit(0, limit)?.build()?

docs/source/library-user-guide/adding-udfs.md

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -562,14 +562,18 @@ In the `call` method, you parse the input `Expr`s and return a `TableProvider`.
562562
```rust
563563
use datafusion::common::plan_err;
564564
use datafusion::datasource::function::TableFunctionImpl;
565+
use datafusion::execution::SessionState;
565566
// Other imports here
566567

567568
/// A table function that returns a table provider with the value as a single column
568569
#[derive(Default)]
569570
pub struct EchoFunction {}
570571

571572
impl TableFunctionImpl for EchoFunction {
572-
fn call(&self, exprs: &[Expr]) -> Result<Arc<dyn TableProvider>> {
573+
fn call(&self,
574+
_state: &SessionState,
575+
exprs: &[Expr],
576+
) -> Result<Arc<dyn TableProvider>> {
573577
let Some(Expr::Literal(ScalarValue::Int64(Some(value)))) = exprs.get(0) else {
574578
return plan_err!("First argument must be an integer");
575579
};

0 commit comments

Comments
 (0)