Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions datafusion/catalog/src/table.rs
Original file line number Diff line number Diff line change
Expand Up @@ -489,6 +489,12 @@ pub trait TableProviderFactory: Debug + Sync + Send {
pub trait TableFunctionImpl: Debug + Sync + Send {
/// Create a table provider
fn call(&self, args: &[Expr]) -> Result<Arc<dyn TableProvider>>;

/// Returns true if the arguments should be coerced and simplified.
/// Defaults to true for backward compatibility.
fn coerce_arguments(&self) -> bool {
true
}
}

/// A table that uses a function to generate data
Expand Down Expand Up @@ -520,4 +526,9 @@ impl TableFunction {
pub fn create_table_provider(&self, args: &[Expr]) -> Result<Arc<dyn TableProvider>> {
self.fun.call(args)
}

/// Returns true if the arguments should be coerced and simplified
pub fn coerce_arguments(&self) -> bool {
self.fun.coerce_arguments()
}
}
183 changes: 175 additions & 8 deletions datafusion/core/src/execution/session_state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1842,14 +1842,17 @@ impl ContextProvider for SessionContextProvider<'_> {
);
let simplifier = ExprSimplifier::new(simplify_context);
let schema = DFSchema::empty();
let args = args
.into_iter()
.map(|arg| {
simplifier
.coerce(arg, &schema)
.and_then(|e| simplifier.simplify(e))
})
.collect::<datafusion_common::Result<Vec<_>>>()?;
let args = if tbl_func.coerce_arguments() {
args.into_iter()
.map(|arg| {
simplifier
.coerce(arg, &schema)
.and_then(|e| simplifier.simplify(e))
})
.collect::<datafusion_common::Result<Vec<_>>>()?
} else {
args
};
let provider = tbl_func.create_table_provider(&args)?;

Ok(provider_as_source(provider))
Expand Down Expand Up @@ -2509,3 +2512,167 @@ mod tests {
}
}
}

#[cfg(test)]
mod udtf_tests {
use super::*;
use arrow::datatypes::{DataType, Field, Schema, SchemaRef};
use async_trait::async_trait;
use datafusion_catalog::Session;
use datafusion_catalog::{TableFunction, TableFunctionImpl, TableProvider};
use datafusion_common::{Result, plan_err};
use datafusion_expr::{Expr, TableType};
use datafusion_physical_plan::ExecutionPlan;
use std::any::Any;
use std::sync::Arc;

#[derive(Debug)]
struct MockTableProvider {
schema: SchemaRef,
}

#[async_trait]
impl TableProvider for MockTableProvider {
fn as_any(&self) -> &dyn Any {
self
}
fn schema(&self) -> SchemaRef {
self.schema.clone()
}
fn table_type(&self) -> TableType {
TableType::Base
}
async fn scan(
&self,
_state: &dyn Session,
_projection: Option<&Vec<usize>>,
_filters: &[Expr],
_limit: Option<usize>,
) -> Result<Arc<dyn ExecutionPlan>> {
let schema = self.schema.clone();
Ok(Arc::new(datafusion_physical_plan::empty::EmptyExec::new(
schema,
)))
}
}

#[derive(Debug)]
struct NoCoerceUDTF {}

impl TableFunctionImpl for NoCoerceUDTF {
fn call(&self, args: &[Expr]) -> Result<Arc<dyn TableProvider>> {
// Verify that the argument 'index' (which is technically a column reference in SQL)
// survives as an identifier instead of failing coercion because it's missing from the empty schema.
match &args[0] {
Expr::BinaryExpr(be) => {
match be.left.as_ref() {
Expr::Column(c) if c.name == "index" => {
// Success!
}
_ => {
return plan_err!(
"Expected Column('index') on left side, got {:?}",
be.left
);
}
}
}
_ => return plan_err!("Expected BinaryExpr, got {:?}", args[0]),
}

let schema =
Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)]));
Ok(Arc::new(MockTableProvider { schema }))
}

fn coerce_arguments(&self) -> bool {
false
}
}

#[test]
fn test_udtf_no_coercion() -> Result<()> {
let udtf = Arc::new(TableFunction::new(
"scan_with".to_string(),
Arc::new(NoCoerceUDTF {}),
));

let state = SessionStateBuilder::new()
.with_default_features()
.with_table_function_list(vec![udtf])
.build();

let provider = SessionContextProvider {
state: &state,
tables: HashMap::new(),
};

// SQL: SELECT * FROM scan_with(index=1)
let args = vec![Expr::BinaryExpr(datafusion_expr::BinaryExpr {
left: Box::new(Expr::Column(datafusion_common::Column::from_name("index"))),
op: datafusion_expr::Operator::Eq,
right: Box::new(Expr::Literal(
datafusion_common::ScalarValue::Int32(Some(1)),
None,
)),
})];

let source = provider.get_table_function_source("scan_with", args)?;
assert_eq!(source.schema().fields().len(), 1);
assert_eq!(source.schema().field(0).name(), "a");

Ok(())
}

#[test]
fn test_udtf_default_coercion() -> Result<()> {
#[derive(Debug)]
struct CoerceUDTF {}
impl TableFunctionImpl for CoerceUDTF {
fn call(&self, _args: &[Expr]) -> Result<Arc<dyn TableProvider>> {
let schema =
Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)]));
Ok(Arc::new(MockTableProvider { schema }))
}
}

let udtf = Arc::new(TableFunction::new(
"scan_with".to_string(),
Arc::new(CoerceUDTF {}),
));

let state = SessionStateBuilder::new()
.with_default_features()
.with_table_function_list(vec![udtf])
.build();

let provider = SessionContextProvider {
state: &state,
tables: HashMap::new(),
};

// In SQL: SELECT * FROM scan_with(unknown_col=1)
let args = vec![Expr::BinaryExpr(datafusion_expr::BinaryExpr {
left: Box::new(Expr::Column(datafusion_common::Column::from_name(
"unknown_col",
))),
op: datafusion_expr::Operator::Eq,
right: Box::new(Expr::Literal(
datafusion_common::ScalarValue::Int32(Some(1)),
None,
)),
})];

// Should fail because coercion is ON and "unknown_col" is not in the empty schema.
let res = provider.get_table_function_source("scan_with", args);
match res {
Ok(_) => panic!("Expected error, but got success"),
Err(e) => assert!(
e.to_string()
.contains("Schema error: No field named unknown_col")
),
}

Ok(())
}
}