Skip to content

Commit

Permalink
Move average unit tests to slt (#10401)
Browse files Browse the repository at this point in the history
  • Loading branch information
lewiszlw authored May 7, 2024
1 parent b5cc6b9 commit ede3de8
Show file tree
Hide file tree
Showing 3 changed files with 109 additions and 149 deletions.
107 changes: 0 additions & 107 deletions datafusion/physical-expr/src/aggregate/average.rs
Original file line number Diff line number Diff line change
Expand Up @@ -567,110 +567,3 @@ where
+ self.sums.capacity() * std::mem::size_of::<T>()
}
}

#[cfg(test)]
mod tests {
use super::*;
use crate::expressions::tests::assert_aggregate;
use arrow::array::*;
use datafusion_expr::AggregateFunction;

#[test]
fn avg_decimal() {
// test agg
let array: ArrayRef = Arc::new(
(1..7)
.map(Some)
.collect::<Decimal128Array>()
.with_precision_and_scale(10, 0)
.unwrap(),
);

assert_aggregate(
array,
AggregateFunction::Avg,
false,
ScalarValue::Decimal128(Some(35000), 14, 4),
);
}

#[test]
fn avg_decimal_with_nulls() {
let array: ArrayRef = Arc::new(
(1..6)
.map(|i| if i == 2 { None } else { Some(i) })
.collect::<Decimal128Array>()
.with_precision_and_scale(10, 0)
.unwrap(),
);
assert_aggregate(
array,
AggregateFunction::Avg,
false,
ScalarValue::Decimal128(Some(32500), 14, 4),
);
}

#[test]
fn avg_decimal_all_nulls() {
// test agg
let array: ArrayRef = Arc::new(
std::iter::repeat::<Option<i128>>(None)
.take(6)
.collect::<Decimal128Array>()
.with_precision_and_scale(10, 0)
.unwrap(),
);
assert_aggregate(
array,
AggregateFunction::Avg,
false,
ScalarValue::Decimal128(None, 14, 4),
);
}

#[test]
fn avg_i32() {
let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5]));
assert_aggregate(a, AggregateFunction::Avg, false, ScalarValue::from(3_f64));
}

#[test]
fn avg_i32_with_nulls() {
let a: ArrayRef = Arc::new(Int32Array::from(vec![
Some(1),
None,
Some(3),
Some(4),
Some(5),
]));
assert_aggregate(a, AggregateFunction::Avg, false, ScalarValue::from(3.25f64));
}

#[test]
fn avg_i32_all_nulls() {
let a: ArrayRef = Arc::new(Int32Array::from(vec![None, None]));
assert_aggregate(a, AggregateFunction::Avg, false, ScalarValue::Float64(None));
}

#[test]
fn avg_u32() {
let a: ArrayRef =
Arc::new(UInt32Array::from(vec![1_u32, 2_u32, 3_u32, 4_u32, 5_u32]));
assert_aggregate(a, AggregateFunction::Avg, false, ScalarValue::from(3.0f64));
}

#[test]
fn avg_f32() {
let a: ArrayRef =
Arc::new(Float32Array::from(vec![1_f32, 2_f32, 3_f32, 4_f32, 5_f32]));
assert_aggregate(a, AggregateFunction::Avg, false, ScalarValue::from(3_f64));
}

#[test]
fn avg_f64() {
let a: ArrayRef =
Arc::new(Float64Array::from(vec![1_f64, 2_f64, 3_f64, 4_f64, 5_f64]));
assert_aggregate(a, AggregateFunction::Avg, false, ScalarValue::from(3_f64));
}
}
43 changes: 1 addition & 42 deletions datafusion/physical-expr/src/expressions/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -99,14 +99,11 @@ pub use try_cast::{try_cast, TryCastExpr};
pub(crate) mod tests {
use std::sync::Arc;

use crate::expressions::{col, create_aggregate_expr, try_cast};
use crate::AggregateExpr;
use arrow::record_batch::RecordBatch;
use arrow_array::ArrayRef;
use arrow_schema::{Field, Schema};
use datafusion_common::{Result, ScalarValue};
use datafusion_expr::type_coercion::aggregates::coerce_types;
use datafusion_expr::{AggregateFunction, EmitTo};
use datafusion_expr::EmitTo;

/// macro to perform an aggregation using [`datafusion_expr::Accumulator`] and verify the
/// result.
Expand Down Expand Up @@ -201,44 +198,6 @@ pub(crate) mod tests {
}};
}

/// Assert `function(array) == expected` performing any necessary type coercion
pub fn assert_aggregate(
array: ArrayRef,
function: AggregateFunction,
distinct: bool,
expected: ScalarValue,
) {
let data_type = array.data_type();
let sig = function.signature();
let coerced = coerce_types(&function, &[data_type.clone()], &sig).unwrap();

let input_schema = Schema::new(vec![Field::new("a", data_type.clone(), true)]);
let batch =
RecordBatch::try_new(Arc::new(input_schema.clone()), vec![array]).unwrap();

let input = try_cast(
col("a", &input_schema).unwrap(),
&input_schema,
coerced[0].clone(),
)
.unwrap();

let schema = Schema::new(vec![Field::new("a", coerced[0].clone(), true)]);
let agg = create_aggregate_expr(
&function,
distinct,
&[input],
&[],
&schema,
"agg",
false,
)
.unwrap();

let result = aggregate(&batch, agg).unwrap();
assert_eq!(expected, result);
}

/// macro to perform an aggregation with two inputs and verify the result.
#[macro_export]
macro_rules! generic_test_op2 {
Expand Down
108 changes: 108 additions & 0 deletions datafusion/sqllogictest/test_files/aggregate.slt
Original file line number Diff line number Diff line change
Expand Up @@ -1704,6 +1704,114 @@ select avg(c1) from test
----
1.75

# avg_decimal
statement ok
create table t (c1 decimal(10, 0)) as values (1), (2), (3), (4), (5), (6);

query RT
select avg(c1), arrow_typeof(avg(c1)) from t;
----
3.5 Decimal128(14, 4)

statement ok
drop table t;

# avg_decimal_with_nulls
statement ok
create table t (c1 decimal(10, 0)) as values (1), (NULL), (3), (4), (5);

query RT
select avg(c1), arrow_typeof(avg(c1)) from t;
----
3.25 Decimal128(14, 4)

statement ok
drop table t;

# avg_decimal_all_nulls
statement ok
create table t (c1 decimal(10, 0)) as values (NULL), (NULL), (NULL), (NULL), (NULL), (NULL);

query RT
select avg(c1), arrow_typeof(avg(c1)) from t;
----
NULL Decimal128(14, 4)

statement ok
drop table t;

# avg_i32
statement ok
create table t (c1 int) as values (1), (2), (3), (4), (5);

query RT
select avg(c1), arrow_typeof(avg(c1)) from t;
----
3 Float64

statement ok
drop table t;

# avg_i32_with_nulls
statement ok
create table t (c1 int) as values (1), (NULL), (3), (4), (5);

query RT
select avg(c1), arrow_typeof(avg(c1)) from t;
----
3.25 Float64

statement ok
drop table t;

# avg_i32_all_nulls
statement ok
create table t (c1 int) as values (NULL), (NULL);

query RT
select avg(c1), arrow_typeof(avg(c1)) from t;
----
NULL Float64

statement ok
drop table t;

# avg_u32
statement ok
create table t (c1 int unsigned) as values (1), (2), (3), (4), (5);

query RT
select avg(c1), arrow_typeof(avg(c1)) from t;
----
3 Float64

statement ok
drop table t;

# avg_f32
statement ok
create table t (c1 float) as values (1), (2), (3), (4), (5);

query RT
select avg(c1), arrow_typeof(avg(c1)) from t;
----
3 Float64

statement ok
drop table t;

# avg_f64
statement ok
create table t (c1 double) as values (1), (2), (3), (4), (5);

query RT
select avg(c1), arrow_typeof(avg(c1)) from t;
----
3 Float64

statement ok
drop table t;

# simple_mean
query R
select mean(c1) from test
Expand Down

0 comments on commit ede3de8

Please sign in to comment.