diff --git a/Cargo.lock b/Cargo.lock index 25bcaf68cb84..01e909e4c2b1 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2177,7 +2177,7 @@ dependencies = [ "datafusion-datasource", "datafusion-execution", "datafusion-expr", - "datafusion-functions-aggregate", + "datafusion-functions-aggregate-common", "datafusion-physical-expr", "datafusion-physical-expr-adapter", "datafusion-physical-expr-common", diff --git a/datafusion/datasource-parquet/Cargo.toml b/datafusion/datasource-parquet/Cargo.toml index ae67f9118486..ae3371234d59 100644 --- a/datafusion/datasource-parquet/Cargo.toml +++ b/datafusion/datasource-parquet/Cargo.toml @@ -40,7 +40,7 @@ datafusion-common-runtime = { workspace = true } datafusion-datasource = { workspace = true, features = ["parquet"] } datafusion-execution = { workspace = true } datafusion-expr = { workspace = true } -datafusion-functions-aggregate = { workspace = true } +datafusion-functions-aggregate-common = { workspace = true } datafusion-physical-expr = { workspace = true } datafusion-physical-expr-adapter = { workspace = true } datafusion-physical-expr-common = { workspace = true } diff --git a/datafusion/datasource-parquet/src/metadata.rs b/datafusion/datasource-parquet/src/metadata.rs index 71c81a25001b..81d5511d6974 100644 --- a/datafusion/datasource-parquet/src/metadata.rs +++ b/datafusion/datasource-parquet/src/metadata.rs @@ -32,7 +32,7 @@ use datafusion_common::{ ColumnStatistics, DataFusionError, Result, ScalarValue, Statistics, }; use datafusion_execution::cache::cache_manager::{FileMetadata, FileMetadataCache}; -use datafusion_functions_aggregate::min_max::{MaxAccumulator, MinAccumulator}; +use datafusion_functions_aggregate_common::min_max::{MaxAccumulator, MinAccumulator}; use datafusion_physical_plan::Accumulator; use log::debug; use object_store::path::Path; diff --git a/datafusion/functions-aggregate-common/src/min_max.rs b/datafusion/functions-aggregate-common/src/min_max.rs index b02001753215..806071dd2f58 100644 --- a/datafusion/functions-aggregate-common/src/min_max.rs +++ b/datafusion/functions-aggregate-common/src/min_max.rs @@ -31,8 +31,452 @@ use arrow::array::{ }; use arrow::compute; use arrow::datatypes::{DataType, IntervalUnit, TimeUnit}; -use datafusion_common::{downcast_value, Result, ScalarValue}; -use std::cmp::Ordering; +use datafusion_common::{ + downcast_value, internal_err, DataFusionError, Result, ScalarValue, +}; +use datafusion_expr_common::accumulator::Accumulator; +use std::{cmp::Ordering, mem::size_of_val}; + +// min/max of two non-string scalar values. +macro_rules! typed_min_max { + ($VALUE:expr, $DELTA:expr, $SCALAR:ident, $OP:ident $(, $EXTRA_ARGS:ident)*) => {{ + ScalarValue::$SCALAR( + match ($VALUE, $DELTA) { + (None, None) => None, + (Some(a), None) => Some(*a), + (None, Some(b)) => Some(*b), + (Some(a), Some(b)) => Some((*a).$OP(*b)), + }, + $($EXTRA_ARGS.clone()),* + ) + }}; +} + +macro_rules! typed_min_max_float { + ($VALUE:expr, $DELTA:expr, $SCALAR:ident, $OP:ident) => {{ + ScalarValue::$SCALAR(match ($VALUE, $DELTA) { + (None, None) => None, + (Some(a), None) => Some(*a), + (None, Some(b)) => Some(*b), + (Some(a), Some(b)) => match a.total_cmp(b) { + choose_min_max!($OP) => Some(*b), + _ => Some(*a), + }, + }) + }}; +} + +// min/max of two scalar string values. +macro_rules! typed_min_max_string { + ($VALUE:expr, $DELTA:expr, $SCALAR:ident, $OP:ident) => {{ + ScalarValue::$SCALAR(match ($VALUE, $DELTA) { + (None, None) => None, + (Some(a), None) => Some(a.clone()), + (None, Some(b)) => Some(b.clone()), + (Some(a), Some(b)) => Some((a).$OP(b).clone()), + }) + }}; +} + +// min/max of two scalar string values with a prefix argument. +macro_rules! typed_min_max_string_arg { + ($VALUE:expr, $DELTA:expr, $SCALAR:ident, $OP:ident, $ARG:expr) => {{ + ScalarValue::$SCALAR( + $ARG, + match ($VALUE, $DELTA) { + (None, None) => None, + (Some(a), None) => Some(a.clone()), + (None, Some(b)) => Some(b.clone()), + (Some(a), Some(b)) => Some((a).$OP(b).clone()), + }, + ) + }}; +} + +macro_rules! choose_min_max { + (min) => { + std::cmp::Ordering::Greater + }; + (max) => { + std::cmp::Ordering::Less + }; +} + +macro_rules! interval_min_max { + ($OP:tt, $LHS:expr, $RHS:expr) => {{ + match $LHS.partial_cmp(&$RHS) { + Some(choose_min_max!($OP)) => $RHS.clone(), + Some(_) => $LHS.clone(), + None => { + return internal_err!("Comparison error while computing interval min/max") + } + } + }}; +} + +macro_rules! min_max_generic { + ($VALUE:expr, $DELTA:expr, $OP:ident) => {{ + if $VALUE.is_null() { + let mut delta_copy = $DELTA.clone(); + // When the new value won we want to compact it to + // avoid storing the entire input + delta_copy.compact(); + delta_copy + } else if $DELTA.is_null() { + $VALUE.clone() + } else { + match $VALUE.partial_cmp(&$DELTA) { + Some(choose_min_max!($OP)) => { + // When the new value won we want to compact it to + // avoid storing the entire input + let mut delta_copy = $DELTA.clone(); + delta_copy.compact(); + delta_copy + } + _ => $VALUE.clone(), + } + } + }}; +} + +// min/max of two scalar values of the same type +macro_rules! min_max { + ($VALUE:expr, $DELTA:expr, $OP:ident) => {{ + Ok(match ($VALUE, $DELTA) { + (ScalarValue::Null, ScalarValue::Null) => ScalarValue::Null, + ( + lhs @ ScalarValue::Decimal128(lhsv, lhsp, lhss), + rhs @ ScalarValue::Decimal128(rhsv, rhsp, rhss) + ) => { + if lhsp.eq(rhsp) && lhss.eq(rhss) { + typed_min_max!(lhsv, rhsv, Decimal128, $OP, lhsp, lhss) + } else { + return internal_err!( + "MIN/MAX is not expected to receive scalars of incompatible types {:?}", + (lhs, rhs) + ); + } + } + ( + lhs @ ScalarValue::Decimal256(lhsv, lhsp, lhss), + rhs @ ScalarValue::Decimal256(rhsv, rhsp, rhss) + ) => { + if lhsp.eq(rhsp) && lhss.eq(rhss) { + typed_min_max!(lhsv, rhsv, Decimal256, $OP, lhsp, lhss) + } else { + return internal_err!( + "MIN/MAX is not expected to receive scalars of incompatible types {:?}", + (lhs, rhs) + ); + } + } + (ScalarValue::Boolean(lhs), ScalarValue::Boolean(rhs)) => { + typed_min_max!(lhs, rhs, Boolean, $OP) + } + (ScalarValue::Float64(lhs), ScalarValue::Float64(rhs)) => { + typed_min_max_float!(lhs, rhs, Float64, $OP) + } + (ScalarValue::Float32(lhs), ScalarValue::Float32(rhs)) => { + typed_min_max_float!(lhs, rhs, Float32, $OP) + } + (ScalarValue::Float16(lhs), ScalarValue::Float16(rhs)) => { + typed_min_max_float!(lhs, rhs, Float16, $OP) + } + (ScalarValue::UInt64(lhs), ScalarValue::UInt64(rhs)) => { + typed_min_max!(lhs, rhs, UInt64, $OP) + } + (ScalarValue::UInt32(lhs), ScalarValue::UInt32(rhs)) => { + typed_min_max!(lhs, rhs, UInt32, $OP) + } + (ScalarValue::UInt16(lhs), ScalarValue::UInt16(rhs)) => { + typed_min_max!(lhs, rhs, UInt16, $OP) + } + (ScalarValue::UInt8(lhs), ScalarValue::UInt8(rhs)) => { + typed_min_max!(lhs, rhs, UInt8, $OP) + } + (ScalarValue::Int64(lhs), ScalarValue::Int64(rhs)) => { + typed_min_max!(lhs, rhs, Int64, $OP) + } + (ScalarValue::Int32(lhs), ScalarValue::Int32(rhs)) => { + typed_min_max!(lhs, rhs, Int32, $OP) + } + (ScalarValue::Int16(lhs), ScalarValue::Int16(rhs)) => { + typed_min_max!(lhs, rhs, Int16, $OP) + } + (ScalarValue::Int8(lhs), ScalarValue::Int8(rhs)) => { + typed_min_max!(lhs, rhs, Int8, $OP) + } + (ScalarValue::Utf8(lhs), ScalarValue::Utf8(rhs)) => { + typed_min_max_string!(lhs, rhs, Utf8, $OP) + } + (ScalarValue::LargeUtf8(lhs), ScalarValue::LargeUtf8(rhs)) => { + typed_min_max_string!(lhs, rhs, LargeUtf8, $OP) + } + (ScalarValue::Utf8View(lhs), ScalarValue::Utf8View(rhs)) => { + typed_min_max_string!(lhs, rhs, Utf8View, $OP) + } + (ScalarValue::Binary(lhs), ScalarValue::Binary(rhs)) => { + typed_min_max_string!(lhs, rhs, Binary, $OP) + } + (ScalarValue::LargeBinary(lhs), ScalarValue::LargeBinary(rhs)) => { + typed_min_max_string!(lhs, rhs, LargeBinary, $OP) + } + (ScalarValue::FixedSizeBinary(lsize, lhs), ScalarValue::FixedSizeBinary(rsize, rhs)) => { + if lsize == rsize { + typed_min_max_string_arg!(lhs, rhs, FixedSizeBinary, $OP, *lsize) + } + else { + return internal_err!( + "MIN/MAX is not expected to receive FixedSizeBinary of incompatible sizes {:?}", + (lsize, rsize)) + } + } + (ScalarValue::BinaryView(lhs), ScalarValue::BinaryView(rhs)) => { + typed_min_max_string!(lhs, rhs, BinaryView, $OP) + } + (ScalarValue::TimestampSecond(lhs, l_tz), ScalarValue::TimestampSecond(rhs, _)) => { + typed_min_max!(lhs, rhs, TimestampSecond, $OP, l_tz) + } + ( + ScalarValue::TimestampMillisecond(lhs, l_tz), + ScalarValue::TimestampMillisecond(rhs, _), + ) => { + typed_min_max!(lhs, rhs, TimestampMillisecond, $OP, l_tz) + } + ( + ScalarValue::TimestampMicrosecond(lhs, l_tz), + ScalarValue::TimestampMicrosecond(rhs, _), + ) => { + typed_min_max!(lhs, rhs, TimestampMicrosecond, $OP, l_tz) + } + ( + ScalarValue::TimestampNanosecond(lhs, l_tz), + ScalarValue::TimestampNanosecond(rhs, _), + ) => { + typed_min_max!(lhs, rhs, TimestampNanosecond, $OP, l_tz) + } + ( + ScalarValue::Date32(lhs), + ScalarValue::Date32(rhs), + ) => { + typed_min_max!(lhs, rhs, Date32, $OP) + } + ( + ScalarValue::Date64(lhs), + ScalarValue::Date64(rhs), + ) => { + typed_min_max!(lhs, rhs, Date64, $OP) + } + ( + ScalarValue::Time32Second(lhs), + ScalarValue::Time32Second(rhs), + ) => { + typed_min_max!(lhs, rhs, Time32Second, $OP) + } + ( + ScalarValue::Time32Millisecond(lhs), + ScalarValue::Time32Millisecond(rhs), + ) => { + typed_min_max!(lhs, rhs, Time32Millisecond, $OP) + } + ( + ScalarValue::Time64Microsecond(lhs), + ScalarValue::Time64Microsecond(rhs), + ) => { + typed_min_max!(lhs, rhs, Time64Microsecond, $OP) + } + ( + ScalarValue::Time64Nanosecond(lhs), + ScalarValue::Time64Nanosecond(rhs), + ) => { + typed_min_max!(lhs, rhs, Time64Nanosecond, $OP) + } + ( + ScalarValue::IntervalYearMonth(lhs), + ScalarValue::IntervalYearMonth(rhs), + ) => { + typed_min_max!(lhs, rhs, IntervalYearMonth, $OP) + } + ( + ScalarValue::IntervalMonthDayNano(lhs), + ScalarValue::IntervalMonthDayNano(rhs), + ) => { + typed_min_max!(lhs, rhs, IntervalMonthDayNano, $OP) + } + ( + ScalarValue::IntervalDayTime(lhs), + ScalarValue::IntervalDayTime(rhs), + ) => { + typed_min_max!(lhs, rhs, IntervalDayTime, $OP) + } + ( + ScalarValue::IntervalYearMonth(_), + ScalarValue::IntervalMonthDayNano(_), + ) | ( + ScalarValue::IntervalYearMonth(_), + ScalarValue::IntervalDayTime(_), + ) | ( + ScalarValue::IntervalMonthDayNano(_), + ScalarValue::IntervalDayTime(_), + ) | ( + ScalarValue::IntervalMonthDayNano(_), + ScalarValue::IntervalYearMonth(_), + ) | ( + ScalarValue::IntervalDayTime(_), + ScalarValue::IntervalYearMonth(_), + ) | ( + ScalarValue::IntervalDayTime(_), + ScalarValue::IntervalMonthDayNano(_), + ) => { + interval_min_max!($OP, $VALUE, $DELTA) + } + ( + ScalarValue::DurationSecond(lhs), + ScalarValue::DurationSecond(rhs), + ) => { + typed_min_max!(lhs, rhs, DurationSecond, $OP) + } + ( + ScalarValue::DurationMillisecond(lhs), + ScalarValue::DurationMillisecond(rhs), + ) => { + typed_min_max!(lhs, rhs, DurationMillisecond, $OP) + } + ( + ScalarValue::DurationMicrosecond(lhs), + ScalarValue::DurationMicrosecond(rhs), + ) => { + typed_min_max!(lhs, rhs, DurationMicrosecond, $OP) + } + ( + ScalarValue::DurationNanosecond(lhs), + ScalarValue::DurationNanosecond(rhs), + ) => { + typed_min_max!(lhs, rhs, DurationNanosecond, $OP) + } + + ( + lhs @ ScalarValue::Struct(_), + rhs @ ScalarValue::Struct(_), + ) => { + min_max_generic!(lhs, rhs, $OP) + } + + ( + lhs @ ScalarValue::List(_), + rhs @ ScalarValue::List(_), + ) => { + min_max_generic!(lhs, rhs, $OP) + } + + + ( + lhs @ ScalarValue::LargeList(_), + rhs @ ScalarValue::LargeList(_), + ) => { + min_max_generic!(lhs, rhs, $OP) + } + + + ( + lhs @ ScalarValue::FixedSizeList(_), + rhs @ ScalarValue::FixedSizeList(_), + ) => { + min_max_generic!(lhs, rhs, $OP) + } + + e => { + return internal_err!( + "MIN/MAX is not expected to receive scalars of incompatible types {:?}", + e + ) + } + }) + }}; +} + +/// An accumulator to compute the maximum value +#[derive(Debug, Clone)] +pub struct MaxAccumulator { + max: ScalarValue, +} + +impl MaxAccumulator { + /// new max accumulator + pub fn try_new(datatype: &DataType) -> Result { + Ok(Self { + max: ScalarValue::try_from(datatype)?, + }) + } +} + +impl Accumulator for MaxAccumulator { + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + let values = &values[0]; + let delta = &max_batch(values)?; + let new_max: Result = + min_max!(&self.max, delta, max); + self.max = new_max?; + Ok(()) + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + self.update_batch(states) + } + + fn state(&mut self) -> Result> { + Ok(vec![self.evaluate()?]) + } + fn evaluate(&mut self) -> Result { + Ok(self.max.clone()) + } + + fn size(&self) -> usize { + size_of_val(self) - size_of_val(&self.max) + self.max.size() + } +} + +/// An accumulator to compute the minimum value +#[derive(Debug, Clone)] +pub struct MinAccumulator { + min: ScalarValue, +} + +impl MinAccumulator { + /// new min accumulator + pub fn try_new(datatype: &DataType) -> Result { + Ok(Self { + min: ScalarValue::try_from(datatype)?, + }) + } +} + +impl Accumulator for MinAccumulator { + fn state(&mut self) -> Result> { + Ok(vec![self.evaluate()?]) + } + + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + let values = &values[0]; + let delta = &min_batch(values)?; + let new_min: Result = + min_max!(&self.min, delta, min); + self.min = new_min?; + Ok(()) + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + self.update_batch(states) + } + + fn evaluate(&mut self) -> Result { + Ok(self.min.clone()) + } + + fn size(&self) -> usize { + size_of_val(self) - size_of_val(&self.min) + self.min.size() + } +} // Statically-typed version of min/max(array) -> ScalarValue for string types macro_rules! typed_min_max_batch_string { diff --git a/datafusion/functions-aggregate/src/min_max.rs b/datafusion/functions-aggregate/src/min_max.rs index 1edf10dfee30..639c08706bc0 100644 --- a/datafusion/functions-aggregate/src/min_max.rs +++ b/datafusion/functions-aggregate/src/min_max.rs @@ -29,11 +29,8 @@ use arrow::datatypes::{ UInt32Type, UInt64Type, UInt8Type, }; use datafusion_common::stats::Precision; -use datafusion_common::{ - exec_err, internal_err, ColumnStatistics, DataFusionError, Result, -}; +use datafusion_common::{exec_err, internal_err, ColumnStatistics, Result}; use datafusion_functions_aggregate_common::aggregate::groups_accumulator::prim_op::PrimitiveGroupsAccumulator; -use datafusion_functions_aggregate_common::min_max::{max_batch, min_batch}; use datafusion_physical_expr::expressions; use std::cmp::Ordering; use std::fmt::Debug; @@ -378,404 +375,6 @@ impl AggregateUDFImpl for Max { } } -macro_rules! min_max_generic { - ($VALUE:expr, $DELTA:expr, $OP:ident) => {{ - if $VALUE.is_null() { - let mut delta_copy = $DELTA.clone(); - // When the new value won we want to compact it to - // avoid storing the entire input - delta_copy.compact(); - delta_copy - } else if $DELTA.is_null() { - $VALUE.clone() - } else { - match $VALUE.partial_cmp(&$DELTA) { - Some(choose_min_max!($OP)) => { - // When the new value won we want to compact it to - // avoid storing the entire input - let mut delta_copy = $DELTA.clone(); - delta_copy.compact(); - delta_copy - } - _ => $VALUE.clone(), - } - } - }}; -} - -// min/max of two non-string scalar values. -macro_rules! typed_min_max { - ($VALUE:expr, $DELTA:expr, $SCALAR:ident, $OP:ident $(, $EXTRA_ARGS:ident)*) => {{ - ScalarValue::$SCALAR( - match ($VALUE, $DELTA) { - (None, None) => None, - (Some(a), None) => Some(*a), - (None, Some(b)) => Some(*b), - (Some(a), Some(b)) => Some((*a).$OP(*b)), - }, - $($EXTRA_ARGS.clone()),* - ) - }}; -} -macro_rules! typed_min_max_float { - ($VALUE:expr, $DELTA:expr, $SCALAR:ident, $OP:ident) => {{ - ScalarValue::$SCALAR(match ($VALUE, $DELTA) { - (None, None) => None, - (Some(a), None) => Some(*a), - (None, Some(b)) => Some(*b), - (Some(a), Some(b)) => match a.total_cmp(b) { - choose_min_max!($OP) => Some(*b), - _ => Some(*a), - }, - }) - }}; -} - -// min/max of two scalar string values. -macro_rules! typed_min_max_string { - ($VALUE:expr, $DELTA:expr, $SCALAR:ident, $OP:ident) => {{ - ScalarValue::$SCALAR(match ($VALUE, $DELTA) { - (None, None) => None, - (Some(a), None) => Some(a.clone()), - (None, Some(b)) => Some(b.clone()), - (Some(a), Some(b)) => Some((a).$OP(b).clone()), - }) - }}; -} - -// min/max of two scalar string values with a prefix argument. -macro_rules! typed_min_max_string_arg { - ($VALUE:expr, $DELTA:expr, $SCALAR:ident, $OP:ident, $ARG:expr) => {{ - ScalarValue::$SCALAR( - $ARG, - match ($VALUE, $DELTA) { - (None, None) => None, - (Some(a), None) => Some(a.clone()), - (None, Some(b)) => Some(b.clone()), - (Some(a), Some(b)) => Some((a).$OP(b).clone()), - }, - ) - }}; -} - -macro_rules! choose_min_max { - (min) => { - std::cmp::Ordering::Greater - }; - (max) => { - std::cmp::Ordering::Less - }; -} - -macro_rules! interval_min_max { - ($OP:tt, $LHS:expr, $RHS:expr) => {{ - match $LHS.partial_cmp(&$RHS) { - Some(choose_min_max!($OP)) => $RHS.clone(), - Some(_) => $LHS.clone(), - None => { - return internal_err!("Comparison error while computing interval min/max") - } - } - }}; -} - -// min/max of two scalar values of the same type -macro_rules! min_max { - ($VALUE:expr, $DELTA:expr, $OP:ident) => {{ - Ok(match ($VALUE, $DELTA) { - (ScalarValue::Null, ScalarValue::Null) => ScalarValue::Null, - ( - lhs @ ScalarValue::Decimal128(lhsv, lhsp, lhss), - rhs @ ScalarValue::Decimal128(rhsv, rhsp, rhss) - ) => { - if lhsp.eq(rhsp) && lhss.eq(rhss) { - typed_min_max!(lhsv, rhsv, Decimal128, $OP, lhsp, lhss) - } else { - return internal_err!( - "MIN/MAX is not expected to receive scalars of incompatible types {:?}", - (lhs, rhs) - ); - } - } - ( - lhs @ ScalarValue::Decimal256(lhsv, lhsp, lhss), - rhs @ ScalarValue::Decimal256(rhsv, rhsp, rhss) - ) => { - if lhsp.eq(rhsp) && lhss.eq(rhss) { - typed_min_max!(lhsv, rhsv, Decimal256, $OP, lhsp, lhss) - } else { - return internal_err!( - "MIN/MAX is not expected to receive scalars of incompatible types {:?}", - (lhs, rhs) - ); - } - } - (ScalarValue::Boolean(lhs), ScalarValue::Boolean(rhs)) => { - typed_min_max!(lhs, rhs, Boolean, $OP) - } - (ScalarValue::Float64(lhs), ScalarValue::Float64(rhs)) => { - typed_min_max_float!(lhs, rhs, Float64, $OP) - } - (ScalarValue::Float32(lhs), ScalarValue::Float32(rhs)) => { - typed_min_max_float!(lhs, rhs, Float32, $OP) - } - (ScalarValue::Float16(lhs), ScalarValue::Float16(rhs)) => { - typed_min_max_float!(lhs, rhs, Float16, $OP) - } - (ScalarValue::UInt64(lhs), ScalarValue::UInt64(rhs)) => { - typed_min_max!(lhs, rhs, UInt64, $OP) - } - (ScalarValue::UInt32(lhs), ScalarValue::UInt32(rhs)) => { - typed_min_max!(lhs, rhs, UInt32, $OP) - } - (ScalarValue::UInt16(lhs), ScalarValue::UInt16(rhs)) => { - typed_min_max!(lhs, rhs, UInt16, $OP) - } - (ScalarValue::UInt8(lhs), ScalarValue::UInt8(rhs)) => { - typed_min_max!(lhs, rhs, UInt8, $OP) - } - (ScalarValue::Int64(lhs), ScalarValue::Int64(rhs)) => { - typed_min_max!(lhs, rhs, Int64, $OP) - } - (ScalarValue::Int32(lhs), ScalarValue::Int32(rhs)) => { - typed_min_max!(lhs, rhs, Int32, $OP) - } - (ScalarValue::Int16(lhs), ScalarValue::Int16(rhs)) => { - typed_min_max!(lhs, rhs, Int16, $OP) - } - (ScalarValue::Int8(lhs), ScalarValue::Int8(rhs)) => { - typed_min_max!(lhs, rhs, Int8, $OP) - } - (ScalarValue::Utf8(lhs), ScalarValue::Utf8(rhs)) => { - typed_min_max_string!(lhs, rhs, Utf8, $OP) - } - (ScalarValue::LargeUtf8(lhs), ScalarValue::LargeUtf8(rhs)) => { - typed_min_max_string!(lhs, rhs, LargeUtf8, $OP) - } - (ScalarValue::Utf8View(lhs), ScalarValue::Utf8View(rhs)) => { - typed_min_max_string!(lhs, rhs, Utf8View, $OP) - } - (ScalarValue::Binary(lhs), ScalarValue::Binary(rhs)) => { - typed_min_max_string!(lhs, rhs, Binary, $OP) - } - (ScalarValue::LargeBinary(lhs), ScalarValue::LargeBinary(rhs)) => { - typed_min_max_string!(lhs, rhs, LargeBinary, $OP) - } - (ScalarValue::FixedSizeBinary(lsize, lhs), ScalarValue::FixedSizeBinary(rsize, rhs)) => { - if lsize == rsize { - typed_min_max_string_arg!(lhs, rhs, FixedSizeBinary, $OP, *lsize) - } - else { - return internal_err!( - "MIN/MAX is not expected to receive FixedSizeBinary of incompatible sizes {:?}", - (lsize, rsize)) - } - } - (ScalarValue::BinaryView(lhs), ScalarValue::BinaryView(rhs)) => { - typed_min_max_string!(lhs, rhs, BinaryView, $OP) - } - (ScalarValue::TimestampSecond(lhs, l_tz), ScalarValue::TimestampSecond(rhs, _)) => { - typed_min_max!(lhs, rhs, TimestampSecond, $OP, l_tz) - } - ( - ScalarValue::TimestampMillisecond(lhs, l_tz), - ScalarValue::TimestampMillisecond(rhs, _), - ) => { - typed_min_max!(lhs, rhs, TimestampMillisecond, $OP, l_tz) - } - ( - ScalarValue::TimestampMicrosecond(lhs, l_tz), - ScalarValue::TimestampMicrosecond(rhs, _), - ) => { - typed_min_max!(lhs, rhs, TimestampMicrosecond, $OP, l_tz) - } - ( - ScalarValue::TimestampNanosecond(lhs, l_tz), - ScalarValue::TimestampNanosecond(rhs, _), - ) => { - typed_min_max!(lhs, rhs, TimestampNanosecond, $OP, l_tz) - } - ( - ScalarValue::Date32(lhs), - ScalarValue::Date32(rhs), - ) => { - typed_min_max!(lhs, rhs, Date32, $OP) - } - ( - ScalarValue::Date64(lhs), - ScalarValue::Date64(rhs), - ) => { - typed_min_max!(lhs, rhs, Date64, $OP) - } - ( - ScalarValue::Time32Second(lhs), - ScalarValue::Time32Second(rhs), - ) => { - typed_min_max!(lhs, rhs, Time32Second, $OP) - } - ( - ScalarValue::Time32Millisecond(lhs), - ScalarValue::Time32Millisecond(rhs), - ) => { - typed_min_max!(lhs, rhs, Time32Millisecond, $OP) - } - ( - ScalarValue::Time64Microsecond(lhs), - ScalarValue::Time64Microsecond(rhs), - ) => { - typed_min_max!(lhs, rhs, Time64Microsecond, $OP) - } - ( - ScalarValue::Time64Nanosecond(lhs), - ScalarValue::Time64Nanosecond(rhs), - ) => { - typed_min_max!(lhs, rhs, Time64Nanosecond, $OP) - } - ( - ScalarValue::IntervalYearMonth(lhs), - ScalarValue::IntervalYearMonth(rhs), - ) => { - typed_min_max!(lhs, rhs, IntervalYearMonth, $OP) - } - ( - ScalarValue::IntervalMonthDayNano(lhs), - ScalarValue::IntervalMonthDayNano(rhs), - ) => { - typed_min_max!(lhs, rhs, IntervalMonthDayNano, $OP) - } - ( - ScalarValue::IntervalDayTime(lhs), - ScalarValue::IntervalDayTime(rhs), - ) => { - typed_min_max!(lhs, rhs, IntervalDayTime, $OP) - } - ( - ScalarValue::IntervalYearMonth(_), - ScalarValue::IntervalMonthDayNano(_), - ) | ( - ScalarValue::IntervalYearMonth(_), - ScalarValue::IntervalDayTime(_), - ) | ( - ScalarValue::IntervalMonthDayNano(_), - ScalarValue::IntervalDayTime(_), - ) | ( - ScalarValue::IntervalMonthDayNano(_), - ScalarValue::IntervalYearMonth(_), - ) | ( - ScalarValue::IntervalDayTime(_), - ScalarValue::IntervalYearMonth(_), - ) | ( - ScalarValue::IntervalDayTime(_), - ScalarValue::IntervalMonthDayNano(_), - ) => { - interval_min_max!($OP, $VALUE, $DELTA) - } - ( - ScalarValue::DurationSecond(lhs), - ScalarValue::DurationSecond(rhs), - ) => { - typed_min_max!(lhs, rhs, DurationSecond, $OP) - } - ( - ScalarValue::DurationMillisecond(lhs), - ScalarValue::DurationMillisecond(rhs), - ) => { - typed_min_max!(lhs, rhs, DurationMillisecond, $OP) - } - ( - ScalarValue::DurationMicrosecond(lhs), - ScalarValue::DurationMicrosecond(rhs), - ) => { - typed_min_max!(lhs, rhs, DurationMicrosecond, $OP) - } - ( - ScalarValue::DurationNanosecond(lhs), - ScalarValue::DurationNanosecond(rhs), - ) => { - typed_min_max!(lhs, rhs, DurationNanosecond, $OP) - } - - ( - lhs @ ScalarValue::Struct(_), - rhs @ ScalarValue::Struct(_), - ) => { - min_max_generic!(lhs, rhs, $OP) - } - - ( - lhs @ ScalarValue::List(_), - rhs @ ScalarValue::List(_), - ) => { - min_max_generic!(lhs, rhs, $OP) - } - - - ( - lhs @ ScalarValue::LargeList(_), - rhs @ ScalarValue::LargeList(_), - ) => { - min_max_generic!(lhs, rhs, $OP) - } - - - ( - lhs @ ScalarValue::FixedSizeList(_), - rhs @ ScalarValue::FixedSizeList(_), - ) => { - min_max_generic!(lhs, rhs, $OP) - } - - e => { - return internal_err!( - "MIN/MAX is not expected to receive scalars of incompatible types {:?}", - e - ) - } - }) - }}; -} - -/// An accumulator to compute the maximum value -#[derive(Debug, Clone)] -pub struct MaxAccumulator { - max: ScalarValue, -} - -impl MaxAccumulator { - /// new max accumulator - pub fn try_new(datatype: &DataType) -> Result { - Ok(Self { - max: ScalarValue::try_from(datatype)?, - }) - } -} - -impl Accumulator for MaxAccumulator { - fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { - let values = &values[0]; - let delta = &max_batch(values)?; - let new_max: Result = - min_max!(&self.max, delta, max); - self.max = new_max?; - Ok(()) - } - - fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { - self.update_batch(states) - } - - fn state(&mut self) -> Result> { - Ok(vec![self.evaluate()?]) - } - fn evaluate(&mut self) -> Result { - Ok(self.max.clone()) - } - - fn size(&self) -> usize { - size_of_val(self) - size_of_val(&self.max) + self.max.size() - } -} - #[derive(Debug)] pub struct SlidingMaxAccumulator { max: ScalarValue, @@ -1056,48 +655,6 @@ impl AggregateUDFImpl for Min { } } -/// An accumulator to compute the minimum value -#[derive(Debug, Clone)] -pub struct MinAccumulator { - min: ScalarValue, -} - -impl MinAccumulator { - /// new min accumulator - pub fn try_new(datatype: &DataType) -> Result { - Ok(Self { - min: ScalarValue::try_from(datatype)?, - }) - } -} - -impl Accumulator for MinAccumulator { - fn state(&mut self) -> Result> { - Ok(vec![self.evaluate()?]) - } - - fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { - let values = &values[0]; - let delta = &min_batch(values)?; - let new_min: Result = - min_max!(&self.min, delta, min); - self.min = new_min?; - Ok(()) - } - - fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { - self.update_batch(states) - } - - fn evaluate(&mut self) -> Result { - Ok(self.min.clone()) - } - - fn size(&self) -> usize { - size_of_val(self) - size_of_val(&self.min) + self.min.size() - } -} - #[derive(Debug)] pub struct SlidingMinAccumulator { min: ScalarValue, @@ -1429,6 +986,11 @@ make_udaf_expr_and_func!( min_udaf ); +// Re-export accumulators from the common module for backwards compatibility +pub use datafusion_functions_aggregate_common::min_max::{ + MaxAccumulator, MinAccumulator, +}; + #[cfg(test)] mod tests { use super::*;