Skip to content

Commit

Permalink
Make fields of ScalarUDF , AggregateUDF and WindowUDF non pub (
Browse files Browse the repository at this point in the history
…#8079)

* Make fields of ScalarUDF non pub

* Make fields of `WindowUDF` and `AggregateUDF` non pub.

* fix doc
  • Loading branch information
alamb authored Nov 15, 2023
1 parent 6ecb6cd commit 7f11125
Show file tree
Hide file tree
Showing 21 changed files with 172 additions and 87 deletions.
2 changes: 1 addition & 1 deletion datafusion/core/src/datasource/listing/helpers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ pub fn expr_applicable_for_cols(col_names: &[String], expr: &Expr) -> bool {
}
}
Expr::ScalarUDF(ScalarUDF { fun, .. }) => {
match fun.signature.volatility {
match fun.signature().volatility {
Volatility::Immutable => VisitRecursion::Continue,
// TODO: Stable functions could be `applicable`, but that would require access to the context
Volatility::Stable | Volatility::Volatile => {
Expand Down
6 changes: 3 additions & 3 deletions datafusion/core/src/execution/context/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -806,7 +806,7 @@ impl SessionContext {
self.state
.write()
.scalar_functions
.insert(f.name.clone(), Arc::new(f));
.insert(f.name().to_string(), Arc::new(f));
}

/// Registers an aggregate UDF within this context.
Expand All @@ -820,7 +820,7 @@ impl SessionContext {
self.state
.write()
.aggregate_functions
.insert(f.name.clone(), Arc::new(f));
.insert(f.name().to_string(), Arc::new(f));
}

/// Registers a window UDF within this context.
Expand All @@ -834,7 +834,7 @@ impl SessionContext {
self.state
.write()
.window_functions
.insert(f.name.clone(), Arc::new(f));
.insert(f.name().to_string(), Arc::new(f));
}

/// Creates a [`DataFrame`] for reading a data source.
Expand Down
4 changes: 2 additions & 2 deletions datafusion/core/src/physical_planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ fn create_physical_name(e: &Expr, is_first_expr: bool) -> Result<String> {
create_function_physical_name(&func.fun.to_string(), false, &func.args)
}
Expr::ScalarUDF(ScalarUDF { fun, args }) => {
create_function_physical_name(&fun.name, false, args)
create_function_physical_name(fun.name(), false, args)
}
Expr::WindowFunction(WindowFunction { fun, args, .. }) => {
create_function_physical_name(&fun.to_string(), false, args)
Expand Down Expand Up @@ -250,7 +250,7 @@ fn create_physical_name(e: &Expr, is_first_expr: bool) -> Result<String> {
for e in args {
names.push(create_physical_name(e, false)?);
}
Ok(format!("{}({})", fun.name, names.join(",")))
Ok(format!("{}({})", fun.name(), names.join(",")))
}
Expr::GroupingSet(grouping_set) => match grouping_set {
GroupingSet::Rollup(exprs) => Ok(format!(
Expand Down
14 changes: 8 additions & 6 deletions datafusion/expr/src/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,7 @@ impl Between {
}
}

/// ScalarFunction expression
/// ScalarFunction expression invokes a built-in scalar function
#[derive(Clone, PartialEq, Eq, Hash, Debug)]
pub struct ScalarFunction {
/// The function
Expand All @@ -354,7 +354,9 @@ impl ScalarFunction {
}
}

/// ScalarUDF expression
/// ScalarUDF expression invokes a user-defined scalar function [`ScalarUDF`]
///
/// [`ScalarUDF`]: crate::ScalarUDF
#[derive(Clone, PartialEq, Eq, Hash, Debug)]
pub struct ScalarUDF {
/// The function
Expand Down Expand Up @@ -1200,7 +1202,7 @@ impl fmt::Display for Expr {
fmt_function(f, &func.fun.to_string(), false, &func.args, true)
}
Expr::ScalarUDF(ScalarUDF { fun, args }) => {
fmt_function(f, &fun.name, false, args, true)
fmt_function(f, fun.name(), false, args, true)
}
Expr::WindowFunction(WindowFunction {
fun,
Expand Down Expand Up @@ -1247,7 +1249,7 @@ impl fmt::Display for Expr {
order_by,
..
}) => {
fmt_function(f, &fun.name, false, args, true)?;
fmt_function(f, fun.name(), false, args, true)?;
if let Some(fe) = filter {
write!(f, " FILTER (WHERE {fe})")?;
}
Expand Down Expand Up @@ -1536,7 +1538,7 @@ fn create_name(e: &Expr) -> Result<String> {
create_function_name(&func.fun.to_string(), false, &func.args)
}
Expr::ScalarUDF(ScalarUDF { fun, args }) => {
create_function_name(&fun.name, false, args)
create_function_name(fun.name(), false, args)
}
Expr::WindowFunction(WindowFunction {
fun,
Expand Down Expand Up @@ -1589,7 +1591,7 @@ fn create_name(e: &Expr) -> Result<String> {
if let Some(ob) = order_by {
info += &format!(" ORDER BY ([{}])", expr_vec_fmt!(ob));
}
Ok(format!("{}({}){}", fun.name, names.join(","), info))
Ok(format!("{}({}){}", fun.name(), names.join(","), info))
}
Expr::GroupingSet(grouping_set) => match grouping_set {
GroupingSet::Rollup(exprs) => {
Expand Down
4 changes: 2 additions & 2 deletions datafusion/expr/src/expr_schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ impl ExprSchemable for Expr {
.iter()
.map(|e| e.get_type(schema))
.collect::<Result<Vec<_>>>()?;
Ok((fun.return_type)(&data_types)?.as_ref().clone())
Ok(fun.return_type(&data_types)?)
}
Expr::ScalarFunction(ScalarFunction { fun, args }) => {
let arg_data_types = args
Expand Down Expand Up @@ -128,7 +128,7 @@ impl ExprSchemable for Expr {
.iter()
.map(|e| e.get_type(schema))
.collect::<Result<Vec<_>>>()?;
Ok((fun.return_type)(&data_types)?.as_ref().clone())
fun.return_type(&data_types)
}
Expr::Not(_)
| Expr::IsNull(_)
Expand Down
47 changes: 40 additions & 7 deletions datafusion/expr/src/udaf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,14 @@
// specific language governing permissions and limitations
// under the License.

//! Udaf module contains functions and structs supporting user-defined aggregate functions.
//! [`AggregateUDF`]: User Defined Aggregate Functions

use crate::Expr;
use crate::{Accumulator, Expr};
use crate::{
AccumulatorFactoryFunction, ReturnTypeFunction, Signature, StateTypeFunction,
};
use arrow::datatypes::DataType;
use datafusion_common::Result;
use std::fmt::{self, Debug, Formatter};
use std::sync::Arc;

Expand All @@ -46,15 +48,15 @@ use std::sync::Arc;
#[derive(Clone)]
pub struct AggregateUDF {
/// name
pub name: String,
name: String,
/// Signature (input arguments)
pub signature: Signature,
signature: Signature,
/// Return type
pub return_type: ReturnTypeFunction,
return_type: ReturnTypeFunction,
/// actual implementation
pub accumulator: AccumulatorFactoryFunction,
accumulator: AccumulatorFactoryFunction,
/// the accumulator's state's description as a function of the return type
pub state_type: StateTypeFunction,
state_type: StateTypeFunction,
}

impl Debug for AggregateUDF {
Expand Down Expand Up @@ -112,4 +114,35 @@ impl AggregateUDF {
order_by: None,
})
}

/// Returns this function's name
pub fn name(&self) -> &str {
&self.name
}

/// Returns this function's signature (what input types are accepted)
pub fn signature(&self) -> &Signature {
&self.signature
}

/// Return the type of the function given its input types
pub fn return_type(&self, args: &[DataType]) -> Result<DataType> {
// Old API returns an Arc of the datatype for some reason
let res = (self.return_type)(args)?;
Ok(res.as_ref().clone())
}

/// Return an accumualator the given aggregate, given
/// its return datatype.
pub fn accumulator(&self, return_type: &DataType) -> Result<Box<dyn Accumulator>> {
(self.accumulator)(return_type)
}

/// Return the type of the intermediate state used by this aggregator, given
/// its return datatype. Supports multi-phase aggregations
pub fn state_type(&self, return_type: &DataType) -> Result<Vec<DataType>> {
// old API returns an Arc for some reason, try and unwrap it here
let res = (self.state_type)(return_type)?;
Ok(Arc::try_unwrap(res).unwrap_or_else(|res| res.as_ref().clone()))
}
}
50 changes: 41 additions & 9 deletions datafusion/expr/src/udf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,23 +15,31 @@
// specific language governing permissions and limitations
// under the License.

//! Udf module contains foundational types that are used to represent UDFs in DataFusion.
//! [`ScalarUDF`]: Scalar User Defined Functions

use crate::{Expr, ReturnTypeFunction, ScalarFunctionImplementation, Signature};
use arrow::datatypes::DataType;
use datafusion_common::Result;
use std::fmt;
use std::fmt::Debug;
use std::fmt::Formatter;
use std::sync::Arc;

/// Logical representation of a UDF.
/// Logical representation of a Scalar User Defined Function.
///
/// A scalar function produces a single row output for each row of input.
///
/// This struct contains the information DataFusion needs to plan and invoke
/// functions such name, type signature, return type, and actual implementation.
///
#[derive(Clone)]
pub struct ScalarUDF {
/// name
pub name: String,
/// signature
pub signature: Signature,
/// Return type
pub return_type: ReturnTypeFunction,
/// The name of the function
name: String,
/// The signature (the types of arguments that are supported)
signature: Signature,
/// Function that returns the return type given the argument types
return_type: ReturnTypeFunction,
/// actual implementation
///
/// The fn param is the wrapped function but be aware that the function will
Expand All @@ -40,7 +48,7 @@ pub struct ScalarUDF {
/// will be passed. In that case the single element is a null array to indicate
/// the batch's row count (so that the generative zero-argument function can know
/// the result array size).
pub fun: ScalarFunctionImplementation,
fun: ScalarFunctionImplementation,
}

impl Debug for ScalarUDF {
Expand Down Expand Up @@ -89,4 +97,28 @@ impl ScalarUDF {
pub fn call(&self, args: Vec<Expr>) -> Expr {
Expr::ScalarUDF(crate::expr::ScalarUDF::new(Arc::new(self.clone()), args))
}

/// Returns this function's name
pub fn name(&self) -> &str {
&self.name
}

/// Returns this function's signature (what input types are accepted)
pub fn signature(&self) -> &Signature {
&self.signature
}

/// Return the type of the function given its input types
pub fn return_type(&self, args: &[DataType]) -> Result<DataType> {
// Old API returns an Arc of the datatype for some reason
let res = (self.return_type)(args)?;
Ok(res.as_ref().clone())
}

/// Return the actual implementation
pub fn fun(&self) -> ScalarFunctionImplementation {
self.fun.clone()
}

// TODO maybe add an invoke() method that runs the actual function?
}
44 changes: 34 additions & 10 deletions datafusion/expr/src/udwf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,19 @@
// specific language governing permissions and limitations
// under the License.

//! Support for user-defined window (UDWF) window functions
//! [`WindowUDF`]: User Defined Window Functions

use crate::{
Expr, PartitionEvaluator, PartitionEvaluatorFactory, ReturnTypeFunction, Signature,
WindowFrame,
};
use arrow::datatypes::DataType;
use datafusion_common::Result;
use std::{
fmt::{self, Debug, Display, Formatter},
sync::Arc,
};

use crate::{
Expr, PartitionEvaluatorFactory, ReturnTypeFunction, Signature, WindowFrame,
};

/// Logical representation of a user-defined window function (UDWF)
/// A UDWF is different from a UDF in that it is stateful across batches.
///
Expand All @@ -35,13 +37,13 @@ use crate::{
#[derive(Clone)]
pub struct WindowUDF {
/// name
pub name: String,
name: String,
/// signature
pub signature: Signature,
signature: Signature,
/// Return type
pub return_type: ReturnTypeFunction,
return_type: ReturnTypeFunction,
/// Return the partition evaluator
pub partition_evaluator_factory: PartitionEvaluatorFactory,
partition_evaluator_factory: PartitionEvaluatorFactory,
}

impl Debug for WindowUDF {
Expand Down Expand Up @@ -86,7 +88,7 @@ impl WindowUDF {
partition_evaluator_factory: &PartitionEvaluatorFactory,
) -> Self {
Self {
name: name.to_owned(),
name: name.to_string(),
signature: signature.clone(),
return_type: return_type.clone(),
partition_evaluator_factory: partition_evaluator_factory.clone(),
Expand Down Expand Up @@ -115,4 +117,26 @@ impl WindowUDF {
window_frame,
})
}

/// Returns this function's name
pub fn name(&self) -> &str {
&self.name
}

/// Returns this function's signature (what input types are accepted)
pub fn signature(&self) -> &Signature {
&self.signature
}

/// Return the type of the function given its input types
pub fn return_type(&self, args: &[DataType]) -> Result<DataType> {
// Old API returns an Arc of the datatype for some reason
let res = (self.return_type)(args)?;
Ok(res.as_ref().clone())
}

/// Return a `PartitionEvaluator` for evaluating this window function
pub fn partition_evaluator_factory(&self) -> Result<Box<dyn PartitionEvaluator>> {
(self.partition_evaluator_factory)()
}
}
12 changes: 4 additions & 8 deletions datafusion/expr/src/window_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -171,12 +171,8 @@ impl WindowFunction {
WindowFunction::BuiltInWindowFunction(fun) => {
fun.return_type(input_expr_types)
}
WindowFunction::AggregateUDF(fun) => {
Ok((*(fun.return_type)(input_expr_types)?).clone())
}
WindowFunction::WindowUDF(fun) => {
Ok((*(fun.return_type)(input_expr_types)?).clone())
}
WindowFunction::AggregateUDF(fun) => fun.return_type(input_expr_types),
WindowFunction::WindowUDF(fun) => fun.return_type(input_expr_types),
}
}
}
Expand Down Expand Up @@ -234,8 +230,8 @@ impl WindowFunction {
match self {
WindowFunction::AggregateFunction(fun) => fun.signature(),
WindowFunction::BuiltInWindowFunction(fun) => fun.signature(),
WindowFunction::AggregateUDF(fun) => fun.signature.clone(),
WindowFunction::WindowUDF(fun) => fun.signature.clone(),
WindowFunction::AggregateUDF(fun) => fun.signature().clone(),
WindowFunction::WindowUDF(fun) => fun.signature().clone(),
}
}
}
Expand Down
Loading

0 comments on commit 7f11125

Please sign in to comment.