Skip to content

Commit e19c669

Browse files
Veeupupalamb
andauthored
Support User Defined Table Function (#8306)
* Support User Defined Table Function Signed-off-by: veeupup <code@tanweime.com> * fix comments Signed-off-by: veeupup <code@tanweime.com> * add udtf test Signed-off-by: veeupup <code@tanweime.com> * add file header * Simply table function example, add some comments * Simplfy exprs * make clippy happy * Update datafusion/core/tests/user_defined/user_defined_table_functions.rs --------- Signed-off-by: veeupup <code@tanweime.com> Co-authored-by: Andrew Lamb <andrew@nerdnetworks.org>
1 parent 5c02664 commit e19c669

File tree

8 files changed

+550
-21
lines changed

8 files changed

+550
-21
lines changed
Lines changed: 177 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,177 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
use arrow::csv::reader::Format;
19+
use arrow::csv::ReaderBuilder;
20+
use async_trait::async_trait;
21+
use datafusion::arrow::datatypes::SchemaRef;
22+
use datafusion::arrow::record_batch::RecordBatch;
23+
use datafusion::datasource::function::TableFunctionImpl;
24+
use datafusion::datasource::TableProvider;
25+
use datafusion::error::Result;
26+
use datafusion::execution::context::{ExecutionProps, SessionState};
27+
use datafusion::physical_plan::memory::MemoryExec;
28+
use datafusion::physical_plan::ExecutionPlan;
29+
use datafusion::prelude::SessionContext;
30+
use datafusion_common::{plan_err, DataFusionError, ScalarValue};
31+
use datafusion_expr::{Expr, TableType};
32+
use datafusion_optimizer::simplify_expressions::{ExprSimplifier, SimplifyContext};
33+
use std::fs::File;
34+
use std::io::Seek;
35+
use std::path::Path;
36+
use std::sync::Arc;
37+
38+
// To define your own table function, you only need to do the following 3 things:
39+
// 1. Implement your own [`TableProvider`]
40+
// 2. Implement your own [`TableFunctionImpl`] and return your [`TableProvider`]
41+
// 3. Register the function using [`SessionContext::register_udtf`]
42+
43+
/// This example demonstrates how to register a TableFunction
44+
#[tokio::main]
45+
async fn main() -> Result<()> {
46+
// create local execution context
47+
let ctx = SessionContext::new();
48+
49+
// register the table function that will be called in SQL statements by `read_csv`
50+
ctx.register_udtf("read_csv", Arc::new(LocalCsvTableFunc {}));
51+
52+
let testdata = datafusion::test_util::arrow_test_data();
53+
let csv_file = format!("{testdata}/csv/aggregate_test_100.csv");
54+
55+
// Pass 2 arguments, read csv with at most 2 rows (simplify logic makes 1+1 --> 2)
56+
let df = ctx
57+
.sql(format!("SELECT * FROM read_csv('{csv_file}', 1 + 1);").as_str())
58+
.await?;
59+
df.show().await?;
60+
61+
// just run, return all rows
62+
let df = ctx
63+
.sql(format!("SELECT * FROM read_csv('{csv_file}');").as_str())
64+
.await?;
65+
df.show().await?;
66+
67+
Ok(())
68+
}
69+
70+
/// Table Function that mimics the [`read_csv`] function in DuckDB.
71+
///
72+
/// Usage: `read_csv(filename, [limit])`
73+
///
74+
/// [`read_csv`]: https://duckdb.org/docs/data/csv/overview.html
75+
struct LocalCsvTable {
76+
schema: SchemaRef,
77+
limit: Option<usize>,
78+
batches: Vec<RecordBatch>,
79+
}
80+
81+
#[async_trait]
82+
impl TableProvider for LocalCsvTable {
83+
fn as_any(&self) -> &dyn std::any::Any {
84+
self
85+
}
86+
87+
fn schema(&self) -> SchemaRef {
88+
self.schema.clone()
89+
}
90+
91+
fn table_type(&self) -> TableType {
92+
TableType::Base
93+
}
94+
95+
async fn scan(
96+
&self,
97+
_state: &SessionState,
98+
projection: Option<&Vec<usize>>,
99+
_filters: &[Expr],
100+
_limit: Option<usize>,
101+
) -> Result<Arc<dyn ExecutionPlan>> {
102+
let batches = if let Some(max_return_lines) = self.limit {
103+
// get max return rows from self.batches
104+
let mut batches = vec![];
105+
let mut lines = 0;
106+
for batch in &self.batches {
107+
let batch_lines = batch.num_rows();
108+
if lines + batch_lines > max_return_lines {
109+
let batch_lines = max_return_lines - lines;
110+
batches.push(batch.slice(0, batch_lines));
111+
break;
112+
} else {
113+
batches.push(batch.clone());
114+
lines += batch_lines;
115+
}
116+
}
117+
batches
118+
} else {
119+
self.batches.clone()
120+
};
121+
Ok(Arc::new(MemoryExec::try_new(
122+
&[batches],
123+
TableProvider::schema(self),
124+
projection.cloned(),
125+
)?))
126+
}
127+
}
128+
struct LocalCsvTableFunc {}
129+
130+
impl TableFunctionImpl for LocalCsvTableFunc {
131+
fn call(&self, exprs: &[Expr]) -> Result<Arc<dyn TableProvider>> {
132+
let Some(Expr::Literal(ScalarValue::Utf8(Some(ref path)))) = exprs.get(0) else {
133+
return plan_err!("read_csv requires at least one string argument");
134+
};
135+
136+
let limit = exprs
137+
.get(1)
138+
.map(|expr| {
139+
// try to simpify the expression, so 1+2 becomes 3, for example
140+
let execution_props = ExecutionProps::new();
141+
let info = SimplifyContext::new(&execution_props);
142+
let expr = ExprSimplifier::new(info).simplify(expr.clone())?;
143+
144+
if let Expr::Literal(ScalarValue::Int64(Some(limit))) = expr {
145+
Ok(limit as usize)
146+
} else {
147+
plan_err!("Limit must be an integer")
148+
}
149+
})
150+
.transpose()?;
151+
152+
let (schema, batches) = read_csv_batches(path)?;
153+
154+
let table = LocalCsvTable {
155+
schema,
156+
limit,
157+
batches,
158+
};
159+
Ok(Arc::new(table))
160+
}
161+
}
162+
163+
fn read_csv_batches(csv_path: impl AsRef<Path>) -> Result<(SchemaRef, Vec<RecordBatch>)> {
164+
let mut file = File::open(csv_path)?;
165+
let (schema, _) = Format::default().infer_schema(&mut file, None)?;
166+
file.rewind()?;
167+
168+
let reader = ReaderBuilder::new(Arc::new(schema.clone()))
169+
.with_header(true)
170+
.build(file)?;
171+
let mut batches = vec![];
172+
for bacth in reader {
173+
batches.push(bacth?);
174+
}
175+
let schema = Arc::new(schema);
176+
Ok((schema, batches))
177+
}
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
//! A table that uses a function to generate data
19+
20+
use super::TableProvider;
21+
22+
use datafusion_common::Result;
23+
use datafusion_expr::Expr;
24+
25+
use std::sync::Arc;
26+
27+
/// A trait for table function implementations
28+
pub trait TableFunctionImpl: Sync + Send {
29+
/// Create a table provider
30+
fn call(&self, args: &[Expr]) -> Result<Arc<dyn TableProvider>>;
31+
}
32+
33+
/// A table that uses a function to generate data
34+
pub struct TableFunction {
35+
/// Name of the table function
36+
name: String,
37+
/// Function implementation
38+
fun: Arc<dyn TableFunctionImpl>,
39+
}
40+
41+
impl TableFunction {
42+
/// Create a new table function
43+
pub fn new(name: String, fun: Arc<dyn TableFunctionImpl>) -> Self {
44+
Self { name, fun }
45+
}
46+
47+
/// Get the name of the table function
48+
pub fn name(&self) -> &str {
49+
&self.name
50+
}
51+
52+
/// Get the function implementation and generate a table
53+
pub fn create_table_provider(&self, args: &[Expr]) -> Result<Arc<dyn TableProvider>> {
54+
self.fun.call(args)
55+
}
56+
}

datafusion/core/src/datasource/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ pub mod avro_to_arrow;
2323
pub mod default_table_source;
2424
pub mod empty;
2525
pub mod file_format;
26+
pub mod function;
2627
pub mod listing;
2728
pub mod listing_table_factory;
2829
pub mod memory;

datafusion/core/src/execution/context/mod.rs

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ mod parquet;
2626
use crate::{
2727
catalog::{CatalogList, MemoryCatalogList},
2828
datasource::{
29+
function::{TableFunction, TableFunctionImpl},
2930
listing::{ListingOptions, ListingTable},
3031
provider::TableProviderFactory,
3132
},
@@ -42,7 +43,7 @@ use datafusion_common::{
4243
use datafusion_execution::registry::SerializerRegistry;
4344
use datafusion_expr::{
4445
logical_plan::{DdlStatement, Statement},
45-
StringifiedPlan, UserDefinedLogicalNode, WindowUDF,
46+
Expr, StringifiedPlan, UserDefinedLogicalNode, WindowUDF,
4647
};
4748
pub use datafusion_physical_expr::execution_props::ExecutionProps;
4849
use datafusion_physical_expr::var_provider::is_system_variables;
@@ -803,6 +804,14 @@ impl SessionContext {
803804
.add_var_provider(variable_type, provider);
804805
}
805806

807+
/// Register a table UDF with this context
808+
pub fn register_udtf(&self, name: &str, fun: Arc<dyn TableFunctionImpl>) {
809+
self.state.write().table_functions.insert(
810+
name.to_owned(),
811+
Arc::new(TableFunction::new(name.to_owned(), fun)),
812+
);
813+
}
814+
806815
/// Registers a scalar UDF within this context.
807816
///
808817
/// Note in SQL queries, function names are looked up using
@@ -1241,6 +1250,8 @@ pub struct SessionState {
12411250
query_planner: Arc<dyn QueryPlanner + Send + Sync>,
12421251
/// Collection of catalogs containing schemas and ultimately TableProviders
12431252
catalog_list: Arc<dyn CatalogList>,
1253+
/// Table Functions
1254+
table_functions: HashMap<String, Arc<TableFunction>>,
12441255
/// Scalar functions that are registered with the context
12451256
scalar_functions: HashMap<String, Arc<ScalarUDF>>,
12461257
/// Aggregate functions registered in the context
@@ -1339,6 +1350,7 @@ impl SessionState {
13391350
physical_optimizers: PhysicalOptimizer::new(),
13401351
query_planner: Arc::new(DefaultQueryPlanner {}),
13411352
catalog_list,
1353+
table_functions: HashMap::new(),
13421354
scalar_functions: HashMap::new(),
13431355
aggregate_functions: HashMap::new(),
13441356
window_functions: HashMap::new(),
@@ -1877,6 +1889,22 @@ impl<'a> ContextProvider for SessionContextProvider<'a> {
18771889
.ok_or_else(|| plan_datafusion_err!("table '{name}' not found"))
18781890
}
18791891

1892+
fn get_table_function_source(
1893+
&self,
1894+
name: &str,
1895+
args: Vec<Expr>,
1896+
) -> Result<Arc<dyn TableSource>> {
1897+
let tbl_func = self
1898+
.state
1899+
.table_functions
1900+
.get(name)
1901+
.cloned()
1902+
.ok_or_else(|| plan_datafusion_err!("table function '{name}' not found"))?;
1903+
let provider = tbl_func.create_table_provider(&args)?;
1904+
1905+
Ok(provider_as_source(provider))
1906+
}
1907+
18801908
fn get_function_meta(&self, name: &str) -> Option<Arc<ScalarUDF>> {
18811909
self.state.scalar_functions().get(name).cloned()
18821910
}

datafusion/core/tests/user_defined/mod.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,3 +26,6 @@ mod user_defined_plan;
2626

2727
/// Tests for User Defined Window Functions
2828
mod user_defined_window_functions;
29+
30+
/// Tests for User Defined Table Functions
31+
mod user_defined_table_functions;

0 commit comments

Comments
 (0)