Skip to content

Commit

Permalink
Update documentation for creating User Defined Aggregates (AggregateU…
Browse files Browse the repository at this point in the history
…DF) (#6729)

* Update documentation for creating User Defined Aggregates (AggregateUDF)

* Fix other references
  • Loading branch information
alamb authored Jun 22, 2023
1 parent 8966990 commit eb290a0
Show file tree
Hide file tree
Showing 10 changed files with 171 additions and 71 deletions.
9 changes: 4 additions & 5 deletions datafusion/core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@
#![warn(missing_docs, clippy::needless_borrow)]

//! [DataFusion] is an extensible query engine written in Rust that
//! uses [Apache Arrow] as its in-memory format. DataFusion's [use
//! cases] include building very fast database and analytic systems,
//! customized to particular workloads.
//! uses [Apache Arrow] as its in-memory format. DataFusion's many [use
//! cases] help developers build very fast and feature rich database
//! and analytic systems, customized to particular workloads.
//!
//! "Out of the box," DataFusion quickly runs complex [SQL] and
//! [`DataFrame`] queries using a sophisticated query planner, a columnar,
Expand Down Expand Up @@ -143,8 +143,7 @@
//! * read from any datasource ([`TableProvider`])
//! * define your own catalogs, schemas, and table lists ([`CatalogProvider`])
//! * build your own query langue or plans using the ([`LogicalPlanBuilder`])
//! * declare and use user-defined scalar functions ([`ScalarUDF`])
//! * declare and use user-defined aggregate functions ([`AggregateUDF`])
//! * declare and use user-defined functions: ([`ScalarUDF`], and [`AggregateUDF`])
//! * add custom optimizer rewrite passes ([`OptimizerRule`] and [`PhysicalOptimizerRule`])
//! * extend the planner to use user-defined logical and physical nodes ([`QueryPlanner`])
//!
Expand Down
6 changes: 3 additions & 3 deletions datafusion/core/tests/user_defined_aggregates.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ use datafusion::{
assert_batches_eq,
error::Result,
logical_expr::{
AccumulatorFunctionImplementation, AggregateUDF, ReturnTypeFunction, Signature,
AccumulatorFactoryFunction, AggregateUDF, ReturnTypeFunction, Signature,
StateTypeFunction, TypeSignature, Volatility,
},
physical_plan::Accumulator,
Expand Down Expand Up @@ -296,7 +296,7 @@ impl TimeSum {
let signature = Signature::exact(vec![timestamp_type], volatility);

let captured_state = Arc::clone(&test_state);
let accumulator: AccumulatorFunctionImplementation =
let accumulator: AccumulatorFactoryFunction =
Arc::new(move |_| Ok(Box::new(Self::new(Arc::clone(&captured_state)))));

let name = "time_sum";
Expand Down Expand Up @@ -396,7 +396,7 @@ impl FirstSelector {
// Possible input signatures
let signatures = vec![TypeSignature::Exact(Self::input_datatypes())];

let accumulator: AccumulatorFunctionImplementation =
let accumulator: AccumulatorFactoryFunction =
Arc::new(|_| Ok(Box::new(Self::new())));

let volatility = Volatility::Immutable;
Expand Down
174 changes: 134 additions & 40 deletions datafusion/expr/src/accumulator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,49 +21,161 @@ use arrow::array::ArrayRef;
use datafusion_common::{DataFusionError, Result, ScalarValue};
use std::fmt::Debug;

/// Accumulates an aggregate's state.
/// Describes an aggregate functions's state.
///
/// `Accumulator`s are stateful objects that lives throughout the
/// `Accumulator`s are stateful objects that live throughout the
/// evaluation of multiple rows and aggregate multiple values together
/// into a final output aggregate.
///
/// An accumulator knows how to:
/// * update its state from inputs via `update_batch`
/// * (optionally) retract an update to its state from given inputs via `retract_batch`
/// * convert its internal state to a vector of aggregate values
/// * update its state from multiple accumulators' states via `merge_batch`
/// * compute the final value from its internal state via `evaluate`
/// * update its state from inputs via [`update_batch`]
///
/// * compute the final value from its internal state via [`evaluate`]
///
/// * retract an update to its state from given inputs via
/// [`retract_batch`] (when used as a window aggregate [window
/// function])
///
/// * convert its internal state to a vector of aggregate values via
/// [`state`] and combine the state from multiple accumulators'
/// via [`merge_batch`], as part of efficient multi-phase grouping.
///
/// [`update_batch`]: Self::update_batch
/// [`retract_batch`]: Self::retract_batch
/// [`state`]: Self::state
/// [`evaluate`]: Self::evaluate
/// [`merge_batch`]: Self::merge_batch
/// [window function]: https://en.wikipedia.org/wiki/Window_function_(SQL)
pub trait Accumulator: Send + Sync + Debug {
/// Returns the partial intermediate state of the accumulator. This
/// partial state is serialied as `Arrays` and then combined with
/// other partial states from different instances of this
/// accumulator (that ran on different partitions, for
/// example).
/// Updates the accumulator's state from its input.
///
/// The state can be and often is a different type than the output
/// type of the [`Accumulator`].
/// `values` contains the arguments to this aggregate function.
///
/// For example, the `SUM` accumulator maintains a running sum,
/// and `update_batch` adds each of the input values to the
/// running sum.
fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()>;

/// Returns the final aggregate value.
///
/// For example, the `SUM` accumulator maintains a running sum,
/// and `evaluate` will produce that running sum as its output.
fn evaluate(&self) -> Result<ScalarValue>;

/// Returns the allocated size required for this accumulator, in
/// bytes, including `Self`.
///
/// See [`Self::merge_batch`] for more details on the merging process.
/// This value is used to calculate the memory used during
/// execution so DataFusion can stay within its allotted limit.
///
/// "Allocated" means that for internal containers such as `Vec`,
/// the `capacity` should be used not the `len`.
fn size(&self) -> usize;

/// Returns the intermediate state of the accumulator.
///
/// Intermediate state is used for "multi-phase" grouping in
/// DataFusion, where an aggregate is computed in parallel with
/// multiple `Accumulator` instances, as illustrated below:
///
/// # MultiPhase Grouping
///
/// ```text
/// ▲
/// │ evaluate() is called to
/// │ produce the final aggregate
/// │ value per group
/// │
/// ┌─────────────────────────┐
/// │GroupBy │
/// │(AggregateMode::Final) │ state() is called for each
/// │ │ group and the resulting
/// └─────────────────────────┘ RecordBatches passed to the
/// ▲
/// │
/// ┌────────────────┴───────────────┐
/// │ │
/// │ │
/// ┌─────────────────────────┐ ┌─────────────────────────┐
/// │ GroubyBy │ │ GroubyBy │
/// │(AggregateMode::Partial) │ │(AggregateMode::Partial) │
/// └─────────────────────────┘ └────────────▲────────────┘
/// ▲ │
/// │ │ update_batch() is called for
/// │ │ each input RecordBatch
/// .─────────. .─────────.
/// ,─' '─. ,─' '─.
/// ; Input : ; Input :
/// : Partition 0 ; : Partition 1 ;
/// ╲ ╱ ╲ ╱
/// '─. ,─' '─. ,─'
/// `───────' `───────'
/// ```
///
/// The partial state is serialied as `Arrays` and then combined
/// with other partial states from different instances of this
/// Accumulator (that ran on different partitions, for example).
///
/// The state can be and often is a different type than the output
/// type of the [`Accumulator`] and needs different merge
/// operations (for example, the partial state for `COUNT` needs
/// to be summed together)
///
/// Some accumulators can return multiple values for their
/// intermediate states. For example average, tracks `sum` and
/// `n`, and this function should return
/// a vector of two values, sum and n.
///
/// `ScalarValue::List` can also be used to pass multiple values
/// if the number of intermediate values is not known at planning
/// time (e.g. median)
/// Note that [`ScalarValue::List`] can be used to pass multiple
/// values if the number of intermediate values is not known at
/// planning time (e.g. for `MEDIAN`)
fn state(&self) -> Result<Vec<ScalarValue>>;

/// Updates the accumulator's state from a vector of arrays.
fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()>;
/// Updates the accumulator's state from an `Array` containing one
/// or more intermediate values.
///
/// For some aggregates (such as `SUM`), merge_batch is the same
/// as `update_batch`, but for some aggregrates (such as `COUNT`)
/// the operations differ. See [`Self::state`] for more details on how
/// state is used and merged.
///
/// The `states` array passed was formed by concatenating the
/// results of calling [`Self::state`] on zero or more other
/// `Accumulator` instances.
fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()>;

/// Retracts an update (caused by the given inputs) to
/// Retracts (removed) an update (caused by the given inputs) to
/// accumulator's state.
///
/// This is the inverse operation of [`Self::update_batch`] and is used
/// to incrementally calculate window aggregates where the OVER
/// to incrementally calculate window aggregates where the `OVER`
/// clause defines a bounded window.
///
/// # Example
///
/// For example, given the following input partition
///
/// ```text
/// │ current │
/// window
/// │ │
/// ┌────┬────┬────┬────┬────┬────┬────┬────┬────┐
/// Input │ A │ B │ C │ D │ E │ F │ G │ H │ I │
/// partition └────┴────┴────┴────┼────┴────┴────┴────┼────┘
///
/// │ next │
/// window
/// ```
///
/// First, [`Self::evaluate`] will be called to produce the output
/// for the current window.
///
/// Then, to advance to the next window:
///
/// First, [`Self::retract_batch`] will be called with the values
/// that are leaving the window, `[B, C, D]` and then
/// [`Self::update_batch`] will be called with the values that are
/// entering the window, `[F, G, H]`.
fn retract_batch(&mut self, _values: &[ArrayRef]) -> Result<()> {
// TODO add retract for all accumulators
Err(DataFusionError::Internal(
Expand All @@ -80,22 +192,4 @@ pub trait Accumulator: Send + Sync + Debug {
fn supports_retract_batch(&self) -> bool {
false
}

/// Updates the accumulator's state from an `Array` containing one
/// or more intermediate values.
///
/// The `states` array passed was formed by concatenating the
/// results of calling `[state]` on zero or more other accumulator
/// instances.
///
/// `states` is an array of the same types as returned by [`Self::state`]
fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()>;

/// Returns the final aggregate value based on its current state.
fn evaluate(&self) -> Result<ScalarValue>;

/// Allocated size required for this accumulator, in bytes, including `Self`.
/// Allocated means that for internal containers such as `Vec`, the `capacity` should be used
/// not the `len`
fn size(&self) -> usize;
}
4 changes: 2 additions & 2 deletions datafusion/expr/src/expr_fn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ use crate::expr::{
};
use crate::{
aggregate_function, built_in_function, conditional_expressions::CaseBuilder,
logical_plan::Subquery, AccumulatorFunctionImplementation, AggregateUDF,
logical_plan::Subquery, AccumulatorFactoryFunction, AggregateUDF,
BuiltinScalarFunction, Expr, LogicalPlan, Operator, ReturnTypeFunction,
ScalarFunctionImplementation, ScalarUDF, Signature, StateTypeFunction, Volatility,
};
Expand Down Expand Up @@ -777,7 +777,7 @@ pub fn create_udaf(
input_type: DataType,
return_type: Arc<DataType>,
volatility: Volatility,
accumulator: AccumulatorFunctionImplementation,
accumulator: AccumulatorFactoryFunction,
state_type: Arc<Vec<DataType>>,
) -> AggregateUDF {
let return_type: ReturnTypeFunction = Arc::new(move |_| Ok(return_type.clone()));
Expand Down
2 changes: 1 addition & 1 deletion datafusion/expr/src/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ pub type ReturnTypeFunction =

/// Factory that returns an accumulator for the given aggregate, given
/// its return datatype.
pub type AccumulatorFunctionImplementation =
pub type AccumulatorFactoryFunction =
Arc<dyn Fn(&DataType) -> Result<Box<dyn Accumulator>> + Send + Sync>;

/// Factory that returns the types used by an aggregator to serialize
Expand Down
2 changes: 1 addition & 1 deletion datafusion/expr/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ pub use expr::{
pub use expr_fn::*;
pub use expr_schema::ExprSchemable;
pub use function::{
AccumulatorFunctionImplementation, ReturnTypeFunction, ScalarFunctionImplementation,
AccumulatorFactoryFunction, ReturnTypeFunction, ScalarFunctionImplementation,
StateTypeFunction,
};
pub use literal::{lit, lit_timestamp_nano, Literal, TimestampLiteral};
Expand Down
26 changes: 18 additions & 8 deletions datafusion/expr/src/udaf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,20 +19,30 @@
use crate::Expr;
use crate::{
AccumulatorFunctionImplementation, ReturnTypeFunction, Signature, StateTypeFunction,
AccumulatorFactoryFunction, ReturnTypeFunction, Signature, StateTypeFunction,
};
use std::fmt::{self, Debug, Formatter};
use std::sync::Arc;

/// Logical representation of a user-defined aggregate function (UDAF).
/// Logical representation of a user-defined [aggregate function] (UDAF).
///
/// A UDAF is different from a user-defined scalar function (UDF) in
/// that it is stateful across batches. UDAFs can be used as normal
/// aggregate functions as well as window functions (the `OVER` clause)
/// An aggregate function combines the values from multiple input rows
/// into a single output "aggregate" (summary) row. It is different
/// from a scalar function because it is stateful across batches. User
/// defined aggregate functions can be used as normal SQL aggregate
/// functions (`GROUP BY` clause) as well as window functions (`OVER`
/// clause).
///
/// For more information, please see [the examples]
/// `AggregateUDF` provides DataFusion the information needed to plan
/// and call aggregate functions, including name, type information,
/// and a factory function to create [`Accumulator`], which peform the
/// actual aggregation.
///
/// For more information, please see [the examples].
///
/// [the examples]: https://github.com/apache/arrow-datafusion/tree/main/datafusion-examples#single-process
/// [aggregate function]: https://en.wikipedia.org/wiki/Aggregate_function
/// [`Accumulator`]: crate::Accumulator
#[derive(Clone)]
pub struct AggregateUDF {
/// name
Expand All @@ -42,7 +52,7 @@ pub struct AggregateUDF {
/// Return type
pub return_type: ReturnTypeFunction,
/// actual implementation
pub accumulator: AccumulatorFunctionImplementation,
pub accumulator: AccumulatorFactoryFunction,
/// the accumulator's state's description as a function of the return type
pub state_type: StateTypeFunction,
}
Expand Down Expand Up @@ -78,7 +88,7 @@ impl AggregateUDF {
name: &str,
signature: &Signature,
return_type: &ReturnTypeFunction,
accumulator: &AccumulatorFunctionImplementation,
accumulator: &AccumulatorFactoryFunction,
state_type: &StateTypeFunction,
) -> Self {
Self {
Expand Down
9 changes: 4 additions & 5 deletions datafusion/optimizer/src/analyzer/type_coercion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -783,10 +783,9 @@ mod test {
use datafusion_common::{DFField, DFSchema, DFSchemaRef, Result, ScalarValue};
use datafusion_expr::expr::{self, InSubquery, Like, ScalarFunction};
use datafusion_expr::{
cast, col, concat, concat_ws, create_udaf, is_true,
AccumulatorFunctionImplementation, AggregateFunction, AggregateUDF, BinaryExpr,
BuiltinScalarFunction, Case, ColumnarValue, ExprSchemable, Filter, Operator,
StateTypeFunction, Subquery,
cast, col, concat, concat_ws, create_udaf, is_true, AccumulatorFactoryFunction,
AggregateFunction, AggregateUDF, BinaryExpr, BuiltinScalarFunction, Case,
ColumnarValue, ExprSchemable, Filter, Operator, StateTypeFunction, Subquery,
};
use datafusion_expr::{
lit,
Expand Down Expand Up @@ -941,7 +940,7 @@ mod test {
Arc::new(move |_| Ok(Arc::new(DataType::Float64)));
let state_type: StateTypeFunction =
Arc::new(move |_| Ok(Arc::new(vec![DataType::UInt64, DataType::Float64])));
let accumulator: AccumulatorFunctionImplementation = Arc::new(|_| {
let accumulator: AccumulatorFactoryFunction = Arc::new(|_| {
Ok(Box::new(AvgAccumulator::try_new(
&DataType::Float64,
&DataType::Float64,
Expand Down
5 changes: 2 additions & 3 deletions datafusion/optimizer/src/common_subexpr_eliminate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -773,7 +773,7 @@ mod test {
avg, col, lit, logical_plan::builder::LogicalPlanBuilder, sum,
};
use datafusion_expr::{
AccumulatorFunctionImplementation, AggregateUDF, ReturnTypeFunction, Signature,
AccumulatorFactoryFunction, AggregateUDF, ReturnTypeFunction, Signature,
StateTypeFunction, Volatility,
};

Expand Down Expand Up @@ -898,8 +898,7 @@ mod test {
assert_eq!(inputs, &[DataType::UInt32]);
Ok(Arc::new(DataType::UInt32))
});
let accumulator: AccumulatorFunctionImplementation =
Arc::new(|_| unimplemented!());
let accumulator: AccumulatorFactoryFunction = Arc::new(|_| unimplemented!());
let state_type: StateTypeFunction = Arc::new(|_| unimplemented!());
let udf_agg = |inner: Expr| {
Expr::AggregateUDF(datafusion_expr::expr::AggregateUDF::new(
Expand Down
Loading

0 comments on commit eb290a0

Please sign in to comment.