Skip to content

Commit d2e7917

Browse files
committed
Fix bugs, tests, handle more aggregate functions and schema
1 parent 8e6338c commit d2e7917

File tree

6 files changed

+158
-17
lines changed

6 files changed

+158
-17
lines changed

datafusion/common/src/dfschema.rs

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -798,6 +798,14 @@ impl DFSchema {
798798
.zip(iter2)
799799
.all(|((t1, f1), (t2, f2))| t1 == t2 && Self::field_is_semantically_equal(f1, f2))
800800
}
801+
(
802+
DataType::Decimal32(_l_precision, _l_scale),
803+
DataType::Decimal32(_r_precision, _r_scale),
804+
) => true,
805+
(
806+
DataType::Decimal64(_l_precision, _l_scale),
807+
DataType::Decimal64(_r_precision, _r_scale),
808+
) => true,
801809
(
802810
DataType::Decimal128(_l_precision, _l_scale),
803811
DataType::Decimal128(_r_precision, _r_scale),
@@ -1056,6 +1064,12 @@ fn format_simple_data_type(data_type: &DataType) -> String {
10561064
DataType::Dictionary(_, value_type) => {
10571065
format_simple_data_type(value_type.as_ref())
10581066
}
1067+
DataType::Decimal32(precision, scale) => {
1068+
format!("decimal32({precision}, {scale})")
1069+
}
1070+
DataType::Decimal64(precision, scale) => {
1071+
format!("decimal64({precision}, {scale})")
1072+
}
10591073
DataType::Decimal128(precision, scale) => {
10601074
format!("decimal128({precision}, {scale})")
10611075
}
@@ -1794,6 +1808,27 @@ mod tests {
17941808
&DataType::Int16
17951809
));
17961810

1811+
// Succeeds if decimal precision and scale are different
1812+
assert!(DFSchema::datatype_is_semantically_equal(
1813+
&DataType::Decimal32(1, 2),
1814+
&DataType::Decimal32(2, 1),
1815+
));
1816+
1817+
assert!(DFSchema::datatype_is_semantically_equal(
1818+
&DataType::Decimal64(1, 2),
1819+
&DataType::Decimal64(2, 1),
1820+
));
1821+
1822+
assert!(DFSchema::datatype_is_semantically_equal(
1823+
&DataType::Decimal128(1, 2),
1824+
&DataType::Decimal128(2, 1),
1825+
));
1826+
1827+
assert!(DFSchema::datatype_is_semantically_equal(
1828+
&DataType::Decimal256(1, 2),
1829+
&DataType::Decimal256(2, 1),
1830+
));
1831+
17971832
// Test lists
17981833

17991834
// Succeeds if both have the same element type, disregards names and nullability
@@ -2377,6 +2412,8 @@ mod tests {
23772412
),
23782413
false,
23792414
),
2415+
Field::new("decimal32", DataType::Decimal32(9, 4), true),
2416+
Field::new("decimal64", DataType::Decimal64(9, 4), true),
23802417
Field::new("decimal128", DataType::Decimal128(18, 4), true),
23812418
Field::new("decimal256", DataType::Decimal256(38, 10), false),
23822419
Field::new("date32", DataType::Date32, true),
@@ -2408,6 +2445,8 @@ mod tests {
24082445
|-- fixed_size_binary: fixed_size_binary (nullable = true)
24092446
|-- fixed_size_list: fixed size list (nullable = false)
24102447
| |-- item: int32 (nullable = true)
2448+
|-- decimal32: decimal32(9, 4) (nullable = true)
2449+
|-- decimal64: decimal64(9, 4) (nullable = true)
24112450
|-- decimal128: decimal128(18, 4) (nullable = true)
24122451
|-- decimal256: decimal256(38, 10) (nullable = false)
24132452
|-- date32: date32 (nullable = true)

datafusion/common/src/scalar/mod.rs

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1087,6 +1087,12 @@ impl ScalarValue {
10871087
DataType::UInt16 => ScalarValue::UInt16(None),
10881088
DataType::UInt32 => ScalarValue::UInt32(None),
10891089
DataType::UInt64 => ScalarValue::UInt64(None),
1090+
DataType::Decimal32(precision, scale) => {
1091+
ScalarValue::Decimal32(None, *precision, *scale)
1092+
}
1093+
DataType::Decimal64(precision, scale) => {
1094+
ScalarValue::Decimal64(None, *precision, *scale)
1095+
}
10901096
DataType::Decimal128(precision, scale) => {
10911097
ScalarValue::Decimal128(None, *precision, *scale)
10921098
}
@@ -3185,6 +3191,24 @@ impl ScalarValue {
31853191
scale: i8,
31863192
) -> Result<ScalarValue> {
31873193
match array.data_type() {
3194+
DataType::Decimal32(_, _) => {
3195+
let array = as_decimal32_array(array)?;
3196+
if array.is_null(index) {
3197+
Ok(ScalarValue::Decimal32(None, precision, scale))
3198+
} else {
3199+
let value = array.value(index);
3200+
Ok(ScalarValue::Decimal32(Some(value), precision, scale))
3201+
}
3202+
}
3203+
DataType::Decimal64(_, _) => {
3204+
let array = as_decimal64_array(array)?;
3205+
if array.is_null(index) {
3206+
Ok(ScalarValue::Decimal64(None, precision, scale))
3207+
} else {
3208+
let value = array.value(index);
3209+
Ok(ScalarValue::Decimal64(Some(value), precision, scale))
3210+
}
3211+
}
31883212
DataType::Decimal128(_, _) => {
31893213
let array = as_decimal128_array(array)?;
31903214
if array.is_null(index) {
@@ -3203,7 +3227,9 @@ impl ScalarValue {
32033227
Ok(ScalarValue::Decimal256(Some(value), precision, scale))
32043228
}
32053229
}
3206-
_ => _internal_err!("Unsupported decimal type"),
3230+
other => {
3231+
unreachable!("Invalid type isn't decimal: {other:?}")
3232+
}
32073233
}
32083234
}
32093235

@@ -3317,6 +3343,16 @@ impl ScalarValue {
33173343

33183344
Ok(match array.data_type() {
33193345
DataType::Null => ScalarValue::Null,
3346+
DataType::Decimal32(precision, scale) => {
3347+
ScalarValue::get_decimal_value_from_array(
3348+
array, index, *precision, *scale,
3349+
)?
3350+
}
3351+
DataType::Decimal64(precision, scale) => {
3352+
ScalarValue::get_decimal_value_from_array(
3353+
array, index, *precision, *scale,
3354+
)?
3355+
}
33203356
DataType::Decimal128(precision, scale) => {
33213357
ScalarValue::get_decimal_value_from_array(
33223358
array, index, *precision, *scale,

datafusion/functions-aggregate-common/src/min_max.rs

Lines changed: 29 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -19,15 +19,15 @@
1919
2020
use arrow::array::{
2121
ArrayRef, AsArray as _, BinaryArray, BinaryViewArray, BooleanArray, Date32Array,
22-
Date64Array, Decimal128Array, Decimal256Array, DurationMicrosecondArray,
23-
DurationMillisecondArray, DurationNanosecondArray, DurationSecondArray,
24-
FixedSizeBinaryArray, Float16Array, Float32Array, Float64Array, Int16Array,
25-
Int32Array, Int64Array, Int8Array, IntervalDayTimeArray, IntervalMonthDayNanoArray,
26-
IntervalYearMonthArray, LargeBinaryArray, LargeStringArray, StringArray,
27-
StringViewArray, Time32MillisecondArray, Time32SecondArray, Time64MicrosecondArray,
28-
Time64NanosecondArray, TimestampMicrosecondArray, TimestampMillisecondArray,
29-
TimestampNanosecondArray, TimestampSecondArray, UInt16Array, UInt32Array,
30-
UInt64Array, UInt8Array,
22+
Date64Array, Decimal128Array, Decimal256Array, Decimal32Array, Decimal64Array,
23+
DurationMicrosecondArray, DurationMillisecondArray, DurationNanosecondArray,
24+
DurationSecondArray, FixedSizeBinaryArray, Float16Array, Float32Array, Float64Array,
25+
Int16Array, Int32Array, Int64Array, Int8Array, IntervalDayTimeArray,
26+
IntervalMonthDayNanoArray, IntervalYearMonthArray, LargeBinaryArray,
27+
LargeStringArray, StringArray, StringViewArray, Time32MillisecondArray,
28+
Time32SecondArray, Time64MicrosecondArray, Time64NanosecondArray,
29+
TimestampMicrosecondArray, TimestampMillisecondArray, TimestampNanosecondArray,
30+
TimestampSecondArray, UInt16Array, UInt32Array, UInt64Array, UInt8Array,
3131
};
3232
use arrow::compute;
3333
use arrow::datatypes::{DataType, IntervalUnit, TimeUnit};
@@ -69,6 +69,26 @@ macro_rules! min_max_batch {
6969
($VALUES:expr, $OP:ident) => {{
7070
match $VALUES.data_type() {
7171
DataType::Null => ScalarValue::Null,
72+
DataType::Decimal32(precision, scale) => {
73+
typed_min_max_batch!(
74+
$VALUES,
75+
Decimal32Array,
76+
Decimal32,
77+
$OP,
78+
precision,
79+
scale
80+
)
81+
}
82+
DataType::Decimal64(precision, scale) => {
83+
typed_min_max_batch!(
84+
$VALUES,
85+
Decimal64Array,
86+
Decimal64,
87+
$OP,
88+
precision,
89+
scale
90+
)
91+
}
7292
DataType::Decimal128(precision, scale) => {
7393
typed_min_max_batch!(
7494
$VALUES,

datafusion/functions-aggregate/src/median.rs

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,9 @@ use arrow::{
3535

3636
use arrow::array::Array;
3737
use arrow::array::ArrowNativeTypeOp;
38-
use arrow::datatypes::{ArrowNativeType, ArrowPrimitiveType, FieldRef};
38+
use arrow::datatypes::{
39+
ArrowNativeType, ArrowPrimitiveType, Decimal32Type, Decimal64Type, FieldRef,
40+
};
3941

4042
use datafusion_common::{
4143
internal_datafusion_err, internal_err, DataFusionError, HashSet, Result, ScalarValue,
@@ -166,6 +168,8 @@ impl AggregateUDFImpl for Median {
166168
DataType::Float16 => helper!(Float16Type, dt),
167169
DataType::Float32 => helper!(Float32Type, dt),
168170
DataType::Float64 => helper!(Float64Type, dt),
171+
DataType::Decimal32(_, _) => helper!(Decimal32Type, dt),
172+
DataType::Decimal64(_, _) => helper!(Decimal64Type, dt),
169173
DataType::Decimal128(_, _) => helper!(Decimal128Type, dt),
170174
DataType::Decimal256(_, _) => helper!(Decimal256Type, dt),
171175
_ => Err(DataFusionError::NotImplemented(format!(
@@ -205,6 +209,8 @@ impl AggregateUDFImpl for Median {
205209
DataType::Float16 => helper!(Float16Type, dt),
206210
DataType::Float32 => helper!(Float32Type, dt),
207211
DataType::Float64 => helper!(Float64Type, dt),
212+
DataType::Decimal32(_, _) => helper!(Decimal32Type, dt),
213+
DataType::Decimal64(_, _) => helper!(Decimal64Type, dt),
208214
DataType::Decimal128(_, _) => helper!(Decimal128Type, dt),
209215
DataType::Decimal256(_, _) => helper!(Decimal256Type, dt),
210216
_ => Err(DataFusionError::NotImplemented(format!(

datafusion/functions-aggregate/src/min_max.rs

Lines changed: 44 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,10 @@ mod min_max_struct;
2323

2424
use arrow::array::ArrayRef;
2525
use arrow::datatypes::{
26-
DataType, Decimal128Type, Decimal256Type, DurationMicrosecondType,
27-
DurationMillisecondType, DurationNanosecondType, DurationSecondType, Float16Type,
28-
Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, UInt16Type,
29-
UInt32Type, UInt64Type, UInt8Type,
26+
DataType, Decimal128Type, Decimal256Type, Decimal32Type, Decimal64Type,
27+
DurationMicrosecondType, DurationMillisecondType, DurationNanosecondType,
28+
DurationSecondType, Float16Type, Float32Type, Float64Type, Int16Type, Int32Type,
29+
Int64Type, Int8Type, UInt16Type, UInt32Type, UInt64Type, UInt8Type,
3030
};
3131
use datafusion_common::stats::Precision;
3232
use datafusion_common::{
@@ -323,6 +323,12 @@ impl AggregateUDFImpl for Max {
323323
Duration(Nanosecond) => {
324324
primitive_max_accumulator!(data_type, i64, DurationNanosecondType)
325325
}
326+
Decimal32(_, _) => {
327+
primitive_max_accumulator!(data_type, i32, Decimal32Type)
328+
}
329+
Decimal64(_, _) => {
330+
primitive_max_accumulator!(data_type, i64, Decimal64Type)
331+
}
326332
Decimal128(_, _) => {
327333
primitive_max_accumulator!(data_type, i128, Decimal128Type)
328334
}
@@ -484,6 +490,32 @@ macro_rules! min_max {
484490
($VALUE:expr, $DELTA:expr, $OP:ident) => {{
485491
Ok(match ($VALUE, $DELTA) {
486492
(ScalarValue::Null, ScalarValue::Null) => ScalarValue::Null,
493+
(
494+
lhs @ ScalarValue::Decimal32(lhsv, lhsp, lhss),
495+
rhs @ ScalarValue::Decimal32(rhsv, rhsp, rhss)
496+
) => {
497+
if lhsp.eq(rhsp) && lhss.eq(rhss) {
498+
typed_min_max!(lhsv, rhsv, Decimal32, $OP, lhsp, lhss)
499+
} else {
500+
return internal_err!(
501+
"MIN/MAX is not expected to receive scalars of incompatible types {:?}",
502+
(lhs, rhs)
503+
);
504+
}
505+
}
506+
(
507+
lhs @ ScalarValue::Decimal64(lhsv, lhsp, lhss),
508+
rhs @ ScalarValue::Decimal64(rhsv, rhsp, rhss)
509+
) => {
510+
if lhsp.eq(rhsp) && lhss.eq(rhss) {
511+
typed_min_max!(lhsv, rhsv, Decimal64, $OP, lhsp, lhss)
512+
} else {
513+
return internal_err!(
514+
"MIN/MAX is not expected to receive scalars of incompatible types {:?}",
515+
(lhs, rhs)
516+
);
517+
}
518+
}
487519
(
488520
lhs @ ScalarValue::Decimal128(lhsv, lhsp, lhss),
489521
rhs @ ScalarValue::Decimal128(rhsv, rhsp, rhss)
@@ -919,6 +951,8 @@ impl AggregateUDFImpl for Min {
919951
| Float16
920952
| Float32
921953
| Float64
954+
| Decimal32(_, _)
955+
| Decimal64(_, _)
922956
| Decimal128(_, _)
923957
| Decimal256(_, _)
924958
| Date32
@@ -1000,6 +1034,12 @@ impl AggregateUDFImpl for Min {
10001034
Duration(Nanosecond) => {
10011035
primitive_min_accumulator!(data_type, i64, DurationNanosecondType)
10021036
}
1037+
Decimal32(_, _) => {
1038+
primitive_min_accumulator!(data_type, i32, Decimal32Type)
1039+
}
1040+
Decimal64(_, _) => {
1041+
primitive_min_accumulator!(data_type, i64, Decimal64Type)
1042+
}
10031043
Decimal128(_, _) => {
10041044
primitive_min_accumulator!(data_type, i128, Decimal128Type)
10051045
}

datafusion/functions-aggregate/src/sum.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -176,13 +176,13 @@ impl AggregateUDFImpl for Sum {
176176
// in the spark, the result type is DECIMAL(min(38,precision+10), s)
177177
// ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala#L66
178178
let new_precision = DECIMAL32_MAX_PRECISION.min(*precision + 10);
179-
Ok(DataType::Decimal128(new_precision, *scale))
179+
Ok(DataType::Decimal32(new_precision, *scale))
180180
}
181181
DataType::Decimal64(precision, scale) => {
182182
// in the spark, the result type is DECIMAL(min(38,precision+10), s)
183183
// ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala#L66
184184
let new_precision = DECIMAL64_MAX_PRECISION.min(*precision + 10);
185-
Ok(DataType::Decimal128(new_precision, *scale))
185+
Ok(DataType::Decimal64(new_precision, *scale))
186186
}
187187
DataType::Decimal128(precision, scale) => {
188188
// in the spark, the result type is DECIMAL(min(38,precision+10), s)

0 commit comments

Comments
 (0)