diff --git a/Cargo.toml b/Cargo.toml index 60ff770d0d13..234d76707a58 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -34,6 +34,7 @@ members = [ "datafusion-examples", "test-utils", "benchmarks", + "extension/scalar-function/test-func", ] resolver = "2" diff --git a/datafusion-examples/Cargo.toml b/datafusion-examples/Cargo.toml index e5146c7fd94e..66db736fdc2e 100644 --- a/datafusion-examples/Cargo.toml +++ b/datafusion-examples/Cargo.toml @@ -41,6 +41,7 @@ datafusion-common = { path = "../datafusion/common" } datafusion-expr = { path = "../datafusion/expr" } datafusion-optimizer = { path = "../datafusion/optimizer" } datafusion-sql = { path = "../datafusion/sql" } +datafusion-extension-test-scalar-func = {path = "../extension/scalar-function/test-func"} env_logger = "0.10" futures = "0.3" log = "0.4" diff --git a/datafusion-examples/examples/external_function_package.rs b/datafusion-examples/examples/external_function_package.rs new file mode 100644 index 000000000000..31ea4413ef6c --- /dev/null +++ b/datafusion-examples/examples/external_function_package.rs @@ -0,0 +1,43 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use datafusion::error::Result; +use datafusion::prelude::*; +use datafusion_extension_test_scalar_func::TestFunctionPackage; + +/// This example demonstrates executing a simple query against an Arrow data source (CSV) and +/// fetching results +#[tokio::main] +async fn main() -> Result<()> { + let ctx = SessionContext::new(); + let testdata = datafusion::test_util::arrow_test_data(); + ctx.register_csv( + "aggregate_test_100", + &format!("{testdata}/csv/aggregate_test_100.csv"), + CsvReadOptions::new(), + ) + .await?; + + // Register add_one(x), multiply_two(x) function from `TestFunctionPackage` + ctx.register_scalar_function_package(Box::new(TestFunctionPackage)); + + let df = ctx + .sql("select add_one(1), multiply_two(c3), add_one(multiply_two(c4)) from aggregate_test_100 limit 5").await?; + df.show().await?; + + Ok(()) +} diff --git a/datafusion/core/src/execution/context.rs b/datafusion/core/src/execution/context.rs index ca6da6cfa047..b9c9d5c483f7 100644 --- a/datafusion/core/src/execution/context.rs +++ b/datafusion/core/src/execution/context.rs @@ -36,7 +36,8 @@ use datafusion_common::{ use datafusion_execution::registry::SerializerRegistry; use datafusion_expr::{ logical_plan::{DdlStatement, Statement}, - StringifiedPlan, UserDefinedLogicalNode, WindowUDF, + ScalarFunctionDef, ScalarFunctionPackage, StringifiedPlan, UserDefinedLogicalNode, + WindowUDF, }; pub use datafusion_physical_expr::execution_props::ExecutionProps; use datafusion_physical_expr::var_provider::is_system_variables; @@ -79,6 +80,7 @@ use sqlparser::dialect::dialect_from_str; use crate::config::ConfigOptions; use crate::datasource::physical_plan::{plan_to_csv, plan_to_json, plan_to_parquet}; use crate::execution::{runtime_env::RuntimeEnv, FunctionRegistry}; +use crate::physical_plan::functions::make_scalar_function; use crate::physical_plan::udaf::AggregateUDF; use crate::physical_plan::udf::ScalarUDF; use crate::physical_plan::ExecutionPlan; @@ -792,6 +794,30 @@ impl SessionContext { .add_var_provider(variable_type, provider); } + /// Register a function package into this context + pub fn register_scalar_function_package( + &self, + func_pkg: Box, + ) { + // Make a `dyn ScalarFunctionDef` into a internal struct for scalar functions, then it can be + // registered into context + pub fn to_scalar_function(func: Box) -> ScalarUDF { + let name = func.name().to_string(); + let signature = func.signature(); + let return_type = func.return_type(); + let func_impl = make_scalar_function(move |args| func.execute(args)); + + ScalarUDF::new(&name, &signature, &return_type, &func_impl) + } + + for func in func_pkg.functions() { + self.state + .write() + .scalar_functions + .insert(func.name().to_string(), Arc::new(to_scalar_function(func))); + } + } + /// Registers a scalar UDF within this context. /// /// Note in SQL queries, function names are looked up using diff --git a/datafusion/expr/src/lib.rs b/datafusion/expr/src/lib.rs index d35233bc39d2..7a99852e5802 100644 --- a/datafusion/expr/src/lib.rs +++ b/datafusion/expr/src/lib.rs @@ -77,7 +77,7 @@ pub use partition_evaluator::PartitionEvaluator; pub use signature::{Signature, TypeSignature, Volatility}; pub use table_source::{TableProviderFilterPushDown, TableSource, TableType}; pub use udaf::AggregateUDF; -pub use udf::ScalarUDF; +pub use udf::{ScalarFunctionDef, ScalarFunctionPackage, ScalarUDF}; pub use udwf::WindowUDF; pub use window_frame::{WindowFrame, WindowFrameBound, WindowFrameUnits}; pub use window_function::{BuiltInWindowFunction, WindowFunction}; diff --git a/datafusion/expr/src/udf.rs b/datafusion/expr/src/udf.rs index be6c90aa5985..0498515dd190 100644 --- a/datafusion/expr/src/udf.rs +++ b/datafusion/expr/src/udf.rs @@ -18,11 +18,30 @@ //! Udf module contains foundational types that are used to represent UDFs in DataFusion. use crate::{Expr, ReturnTypeFunction, ScalarFunctionImplementation, Signature}; +use arrow::array::ArrayRef; +use datafusion_common::Result; use std::fmt; use std::fmt::Debug; use std::fmt::Formatter; use std::sync::Arc; +pub trait ScalarFunctionDef: Sync + Send + std::fmt::Debug { + // TODO: support alias + fn name(&self) -> &str; + + fn signature(&self) -> Signature; + + // TODO: ReturnTypeFunction -> a ENUM + // most function's return type is either the same as 1st arg or a fixed type + fn return_type(&self) -> ReturnTypeFunction; + + fn execute(&self, args: &[ArrayRef]) -> Result; +} + +pub trait ScalarFunctionPackage { + fn functions(&self) -> Vec>; +} + /// Logical representation of a UDF. #[derive(Clone)] pub struct ScalarUDF { diff --git a/extension/scalar-function/test-func/Cargo.toml b/extension/scalar-function/test-func/Cargo.toml new file mode 100644 index 000000000000..89eae6c9222a --- /dev/null +++ b/extension/scalar-function/test-func/Cargo.toml @@ -0,0 +1,20 @@ +[package] +name = "datafusion-extension-test-scalar-func" +authors.workspace = true +edition.workspace = true +homepage.workspace = true +license.workspace = true +readme.workspace = true +repository.workspace = true +rust-version.workspace = true +version.workspace = true + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +datafusion = { path = "../../../datafusion/core" } +datafusion-common = { path = "../../../datafusion/common" } +datafusion-expr = { path = "../../../datafusion/expr" } +arrow = { workspace = true } +#arrow-flight = { workspace = true } +#arrow-schema = { workspace = true } diff --git a/extension/scalar-function/test-func/src/lib.rs b/extension/scalar-function/test-func/src/lib.rs new file mode 100644 index 000000000000..b3181cab492d --- /dev/null +++ b/extension/scalar-function/test-func/src/lib.rs @@ -0,0 +1,79 @@ +use arrow::array::{ArrayRef, Float64Array}; +use arrow::datatypes::DataType; +use datafusion::error::Result; +use datafusion::logical_expr::Volatility; +use datafusion_common::cast::as_float64_array; +use datafusion_expr::{ReturnTypeFunction, Signature}; +use datafusion_expr::{ScalarFunctionDef, ScalarFunctionPackage}; +use std::sync::Arc; + +#[derive(Debug)] +pub struct AddOneFunction; + +impl ScalarFunctionDef for AddOneFunction { + fn name(&self) -> &str { + "add_one" + } + + fn signature(&self) -> Signature { + Signature::exact(vec![DataType::Float64], Volatility::Immutable) + } + + fn return_type(&self) -> ReturnTypeFunction { + let return_type = Arc::new(DataType::Float64); + Arc::new(move |_| Ok(return_type.clone())) + } + + fn execute(&self, args: &[ArrayRef]) -> Result { + assert_eq!(args.len(), 1); + let input = as_float64_array(&args[0]).expect("cast failed"); + let array = input + .iter() + .map(|value| match value { + Some(value) => Some(value + 1.0), + _ => None, + }) + .collect::(); + Ok(Arc::new(array) as ArrayRef) + } +} + +#[derive(Debug)] +pub struct MultiplyTwoFunction; + +impl ScalarFunctionDef for MultiplyTwoFunction { + fn name(&self) -> &str { + "multiply_two" + } + + fn signature(&self) -> Signature { + Signature::exact(vec![DataType::Float64], Volatility::Immutable) + } + + fn return_type(&self) -> ReturnTypeFunction { + let return_type = Arc::new(DataType::Float64); + Arc::new(move |_| Ok(return_type.clone())) + } + + fn execute(&self, args: &[ArrayRef]) -> Result { + assert_eq!(args.len(), 1); + let input = as_float64_array(&args[0]).expect("cast failed"); + let array = input + .iter() + .map(|value| match value { + Some(value) => Some(value * 2.0), + _ => None, + }) + .collect::(); + Ok(Arc::new(array) as ArrayRef) + } +} + +// Function package declaration +pub struct TestFunctionPackage; + +impl ScalarFunctionPackage for TestFunctionPackage { + fn functions(&self) -> Vec> { + vec![Box::new(AddOneFunction), Box::new(MultiplyTwoFunction)] + } +}