diff --git a/datafusion-examples/examples/simple_udtf.rs b/datafusion-examples/examples/simple_udtf.rs index 260c6cbee710..05ad231949a8 100644 --- a/datafusion-examples/examples/simple_udtf.rs +++ b/datafusion-examples/examples/simple_udtf.rs @@ -15,7 +15,6 @@ // specific language governing permissions and limitations // under the License. -use arrow::array::Int64Array; use arrow::csv::reader::Format; use arrow::csv::ReaderBuilder; use async_trait::async_trait; @@ -24,22 +23,22 @@ use datafusion::arrow::record_batch::RecordBatch; use datafusion::datasource::function::TableFunctionImpl; use datafusion::datasource::TableProvider; use datafusion::error::Result; -use datafusion::execution::context::SessionState; -use datafusion::execution::TaskContext; +use datafusion::execution::context::{ExecutionProps, SessionState}; use datafusion::physical_plan::memory::MemoryExec; -use datafusion::physical_plan::{collect, ExecutionPlan}; +use datafusion::physical_plan::ExecutionPlan; use datafusion::prelude::SessionContext; -use datafusion_common::{DFSchema, ScalarValue}; -use datafusion_expr::{EmptyRelation, Expr, LogicalPlan, Projection, TableType}; +use datafusion_common::{plan_err, DataFusionError, ScalarValue}; +use datafusion_expr::{Expr, TableType}; +use datafusion_optimizer::simplify_expressions::{ExprSimplifier, SimplifyContext}; use std::fs::File; use std::io::Seek; use std::path::Path; use std::sync::Arc; // To define your own table function, you only need to do the following 3 things: -// 1. Implement your own TableProvider -// 2. Implement your own TableFunctionImpl and return your TableProvider -// 3. Register the function using ctx.register_udtf +// 1. Implement your own [`TableProvider`] +// 2. Implement your own [`TableFunctionImpl`] and return your [`TableProvider`] +// 3. Register the function using [`SessionContext::register_udtf`] /// This example demonstrates how to register a TableFunction #[tokio::main] @@ -47,14 +46,15 @@ async fn main() -> Result<()> { // create local execution context let ctx = SessionContext::new(); + // register the table function that will be called in SQL statements by `read_csv` ctx.register_udtf("read_csv", Arc::new(LocalCsvTableFunc {})); let testdata = datafusion::test_util::arrow_test_data(); let csv_file = format!("{testdata}/csv/aggregate_test_100.csv"); - // read csv with at most 2 rows + // Pass 2 arguments, read csv with at most 2 rows (simplify logic makes 1+1 --> 2) let df = ctx - .sql(format!("SELECT * FROM read_csv('{csv_file}', 2);").as_str()) + .sql(format!("SELECT * FROM read_csv('{csv_file}', 1 + 1);").as_str()) .await?; df.show().await?; @@ -67,9 +67,14 @@ async fn main() -> Result<()> { Ok(()) } +/// Table Function that mimics the [`read_csv`] function in DuckDB. +/// +/// Usage: `read_csv(filename, [limit])` +/// +/// [`read_csv`]: https://duckdb.org/docs/data/csv/overview.html struct LocalCsvTable { schema: SchemaRef, - exprs: Vec, + limit: Option, batches: Vec, } @@ -89,13 +94,12 @@ impl TableProvider for LocalCsvTable { async fn scan( &self, - state: &SessionState, + _state: &SessionState, projection: Option<&Vec>, _filters: &[Expr], _limit: Option, ) -> Result> { - let batches = if !self.exprs.is_empty() { - let max_return_lines = self.interpreter_expr(state).await?; + let batches = if let Some(max_return_lines) = self.limit { // get max return rows from self.batches let mut batches = vec![]; let mut lines = 0; @@ -121,56 +125,35 @@ impl TableProvider for LocalCsvTable { )?)) } } - -impl LocalCsvTable { - async fn interpreter_expr(&self, state: &SessionState) -> Result { - use datafusion::logical_expr::expr_rewriter::normalize_col; - use datafusion::logical_expr::utils::columnize_expr; - let plan = LogicalPlan::EmptyRelation(EmptyRelation { - produce_one_row: true, - schema: Arc::new(DFSchema::empty()), - }); - let logical_plan = Projection::try_new( - vec![columnize_expr( - normalize_col(self.exprs[0].clone(), &plan)?, - plan.schema(), - )], - Arc::new(plan), - ) - .map(LogicalPlan::Projection)?; - let rbs = collect( - state.create_physical_plan(&logical_plan).await?, - Arc::new(TaskContext::from(state)), - ) - .await?; - let limit = rbs[0] - .column(0) - .as_any() - .downcast_ref::() - .unwrap() - .value(0); - Ok(limit) - } -} - struct LocalCsvTableFunc {} impl TableFunctionImpl for LocalCsvTableFunc { fn call(&self, exprs: &[Expr]) -> Result> { - let mut new_exprs = vec![]; - let mut filepath = String::new(); - for expr in exprs { - match expr { - Expr::Literal(ScalarValue::Utf8(Some(ref path))) => { - filepath = path.clone() + let Some(Expr::Literal(ScalarValue::Utf8(Some(ref path)))) = exprs.get(0) else { + return plan_err!("read_csv requires at least one string argument"); + }; + + let limit = exprs + .get(1) + .map(|expr| { + // try to simpify the expression, so 1+2 becomes 3, for example + let execution_props = ExecutionProps::new(); + let info = SimplifyContext::new(&execution_props); + let expr = ExprSimplifier::new(info).simplify(expr.clone())?; + + if let Expr::Literal(ScalarValue::Int64(Some(limit))) = expr { + Ok(limit as usize) + } else { + plan_err!("Limit must be an integer") } - expr => new_exprs.push(expr.clone()), - } - } - let (schema, batches) = read_csv_batches(filepath)?; + }) + .transpose()?; + + let (schema, batches) = read_csv_batches(path)?; + let table = LocalCsvTable { schema, - exprs: new_exprs.clone(), + limit, batches, }; Ok(Arc::new(table))