Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 14 additions & 3 deletions datafusion/common/src/cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,10 @@

use crate::{downcast_value, Result};
use arrow::array::{
BinaryViewArray, DurationMicrosecondArray, DurationMillisecondArray,
DurationNanosecondArray, DurationSecondArray, Float16Array, Int16Array, Int8Array,
LargeBinaryArray, LargeStringArray, StringViewArray, UInt16Array,
BinaryViewArray, Decimal32Array, Decimal64Array, DurationMicrosecondArray,
DurationMillisecondArray, DurationNanosecondArray, DurationSecondArray, Float16Array,
Int16Array, Int8Array, LargeBinaryArray, LargeStringArray, StringViewArray,
UInt16Array,
};
use arrow::{
array::{
Expand Down Expand Up @@ -97,6 +98,16 @@ pub fn as_uint64_array(array: &dyn Array) -> Result<&UInt64Array> {
Ok(downcast_value!(array, UInt64Array))
}

// Downcast Array to Decimal32Array
pub fn as_decimal32_array(array: &dyn Array) -> Result<&Decimal32Array> {
Ok(downcast_value!(array, Decimal32Array))
}

// Downcast Array to Decimal64Array
pub fn as_decimal64_array(array: &dyn Array) -> Result<&Decimal64Array> {
Ok(downcast_value!(array, Decimal64Array))
}

// Downcast Array to Decimal128Array
pub fn as_decimal128_array(array: &dyn Array) -> Result<&Decimal128Array> {
Ok(downcast_value!(array, Decimal128Array))
Expand Down
29 changes: 29 additions & 0 deletions datafusion/common/src/dfschema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -798,6 +798,14 @@ impl DFSchema {
.zip(iter2)
.all(|((t1, f1), (t2, f2))| t1 == t2 && Self::field_is_semantically_equal(f1, f2))
}
(
DataType::Decimal32(_l_precision, _l_scale),
DataType::Decimal32(_r_precision, _r_scale),
) => true,
(
DataType::Decimal64(_l_precision, _l_scale),
DataType::Decimal64(_r_precision, _r_scale),
) => true,
(
DataType::Decimal128(_l_precision, _l_scale),
DataType::Decimal128(_r_precision, _r_scale),
Expand Down Expand Up @@ -1596,6 +1604,27 @@ mod tests {
&DataType::Int16
));

// Succeeds if decimal precision and scale are different
assert!(DFSchema::datatype_is_semantically_equal(
&DataType::Decimal32(1, 2),
&DataType::Decimal32(2, 1),
));

assert!(DFSchema::datatype_is_semantically_equal(
&DataType::Decimal64(1, 2),
&DataType::Decimal64(2, 1),
));

assert!(DFSchema::datatype_is_semantically_equal(
&DataType::Decimal128(1, 2),
&DataType::Decimal128(2, 1),
));

assert!(DFSchema::datatype_is_semantically_equal(
&DataType::Decimal256(1, 2),
&DataType::Decimal256(2, 1),
));

// Test lists

// Succeeds if both have the same element type, disregards names and nullability
Expand Down
418 changes: 381 additions & 37 deletions datafusion/common/src/scalar/mod.rs

Large diffs are not rendered by default.

11 changes: 10 additions & 1 deletion datafusion/common/src/types/native.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ use crate::error::{Result, _internal_err};
use arrow::compute::can_cast_types;
use arrow::datatypes::{
DataType, Field, FieldRef, Fields, IntervalUnit, TimeUnit, UnionFields,
DECIMAL128_MAX_PRECISION, DECIMAL32_MAX_PRECISION, DECIMAL64_MAX_PRECISION,
};
use std::{fmt::Display, sync::Arc};

Expand Down Expand Up @@ -228,7 +229,15 @@ impl LogicalType for NativeType {
(Self::Float16, _) => Float16,
(Self::Float32, _) => Float32,
(Self::Float64, _) => Float64,
(Self::Decimal(p, s), _) if p <= &38 => Decimal128(*p, *s),
(Self::Decimal(p, s), _) if *p <= DECIMAL32_MAX_PRECISION => {
Decimal32(*p, *s)
}
(Self::Decimal(p, s), _) if *p <= DECIMAL64_MAX_PRECISION => {
Decimal64(*p, *s)
}
(Self::Decimal(p, s), _) if *p <= DECIMAL128_MAX_PRECISION => {
Decimal128(*p, *s)
}
(Self::Decimal(p, s), _) => Decimal256(*p, *s),
(Self::Timestamp(tu, tz), _) => Timestamp(*tu, tz.clone()),
// If given type is Date, return the same type
Expand Down
55 changes: 48 additions & 7 deletions datafusion/core/tests/fuzz_cases/record_batch_generator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,18 +20,19 @@ use std::sync::Arc;
use arrow::array::{ArrayRef, DictionaryArray, PrimitiveArray, RecordBatch};
use arrow::datatypes::{
ArrowPrimitiveType, BooleanType, DataType, Date32Type, Date64Type, Decimal128Type,
Decimal256Type, DurationMicrosecondType, DurationMillisecondType,
DurationNanosecondType, DurationSecondType, Field, Float32Type, Float64Type,
Int16Type, Int32Type, Int64Type, Int8Type, IntervalDayTimeType,
IntervalMonthDayNanoType, IntervalUnit, IntervalYearMonthType, Schema,
Time32MillisecondType, Time32SecondType, Time64MicrosecondType, Time64NanosecondType,
TimeUnit, TimestampMicrosecondType, TimestampMillisecondType,
Decimal256Type, Decimal32Type, Decimal64Type, DurationMicrosecondType,
DurationMillisecondType, DurationNanosecondType, DurationSecondType, Field,
Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, Int8Type,
IntervalDayTimeType, IntervalMonthDayNanoType, IntervalUnit, IntervalYearMonthType,
Schema, Time32MillisecondType, Time32SecondType, Time64MicrosecondType,
Time64NanosecondType, TimeUnit, TimestampMicrosecondType, TimestampMillisecondType,
TimestampNanosecondType, TimestampSecondType, UInt16Type, UInt32Type, UInt64Type,
UInt8Type,
};
use arrow_schema::{
DECIMAL128_MAX_PRECISION, DECIMAL128_MAX_SCALE, DECIMAL256_MAX_PRECISION,
DECIMAL256_MAX_SCALE,
DECIMAL256_MAX_SCALE, DECIMAL32_MAX_PRECISION, DECIMAL32_MAX_SCALE,
DECIMAL64_MAX_PRECISION, DECIMAL64_MAX_SCALE,
};
use datafusion_common::{arrow_datafusion_err, DataFusionError, Result};
use rand::{rng, rngs::StdRng, Rng, SeedableRng};
Expand Down Expand Up @@ -104,6 +105,20 @@ pub fn get_supported_types_columns(rng_seed: u64) -> Vec<ColumnDescr> {
"duration_nanosecond",
DataType::Duration(TimeUnit::Nanosecond),
),
ColumnDescr::new("decimal32", {
let precision: u8 = rng.random_range(1..=DECIMAL32_MAX_PRECISION);
let scale: i8 = rng.random_range(
i8::MIN..=std::cmp::min(precision as i8, DECIMAL32_MAX_SCALE),
);
DataType::Decimal32(precision, scale)
}),
ColumnDescr::new("decimal64", {
let precision: u8 = rng.random_range(1..=DECIMAL64_MAX_PRECISION);
let scale: i8 = rng.random_range(
i8::MIN..=std::cmp::min(precision as i8, DECIMAL64_MAX_SCALE),
);
DataType::Decimal64(precision, scale)
}),
ColumnDescr::new("decimal128", {
let precision: u8 = rng.random_range(1..=DECIMAL128_MAX_PRECISION);
let scale: i8 = rng.random_range(
Expand Down Expand Up @@ -682,6 +697,32 @@ impl RecordBatchGenerator {
_ => unreachable!(),
}
}
DataType::Decimal32(precision, scale) => {
generate_decimal_array!(
self,
num_rows,
max_num_distinct,
null_pct,
batch_gen_rng,
array_gen_rng,
precision,
scale,
Decimal32Type
)
}
DataType::Decimal64(precision, scale) => {
generate_decimal_array!(
self,
num_rows,
max_num_distinct,
null_pct,
batch_gen_rng,
array_gen_rng,
precision,
scale,
Decimal64Type
)
}
DataType::Decimal128(precision, scale) => {
generate_decimal_array!(
self,
Expand Down
68 changes: 67 additions & 1 deletion datafusion/expr-common/src/casts.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,9 @@ use std::cmp::Ordering;

use arrow::datatypes::{
DataType, TimeUnit, MAX_DECIMAL128_FOR_EACH_PRECISION,
MIN_DECIMAL128_FOR_EACH_PRECISION,
MAX_DECIMAL32_FOR_EACH_PRECISION, MAX_DECIMAL64_FOR_EACH_PRECISION,
MIN_DECIMAL128_FOR_EACH_PRECISION, MIN_DECIMAL32_FOR_EACH_PRECISION,
MIN_DECIMAL64_FOR_EACH_PRECISION,
};
use arrow::temporal_conversions::{MICROSECONDS, MILLISECONDS, NANOSECONDS};
use datafusion_common::ScalarValue;
Expand Down Expand Up @@ -69,6 +71,8 @@ fn is_supported_numeric_type(data_type: &DataType) -> bool {
| DataType::Int16
| DataType::Int32
| DataType::Int64
| DataType::Decimal32(_, _)
| DataType::Decimal64(_, _)
| DataType::Decimal128(_, _)
| DataType::Timestamp(_, _)
)
Expand Down Expand Up @@ -114,6 +118,8 @@ fn try_cast_numeric_literal(
| DataType::Int32
| DataType::Int64 => 1_i128,
DataType::Timestamp(_, _) => 1_i128,
DataType::Decimal32(_, scale) => 10_i128.pow(*scale as u32),
DataType::Decimal64(_, scale) => 10_i128.pow(*scale as u32),
DataType::Decimal128(_, scale) => 10_i128.pow(*scale as u32),
_ => return None,
};
Expand All @@ -127,6 +133,20 @@ fn try_cast_numeric_literal(
DataType::Int32 => (i32::MIN as i128, i32::MAX as i128),
DataType::Int64 => (i64::MIN as i128, i64::MAX as i128),
DataType::Timestamp(_, _) => (i64::MIN as i128, i64::MAX as i128),
DataType::Decimal32(precision, _) => (
// Different precision for decimal32 can store different range of value.
// For example, the precision is 3, the max of value is `999` and the min
// value is `-999`
MIN_DECIMAL32_FOR_EACH_PRECISION[*precision as usize] as i128,
MAX_DECIMAL32_FOR_EACH_PRECISION[*precision as usize] as i128,
),
DataType::Decimal64(precision, _) => (
// Different precision for decimal64 can store different range of value.
// For example, the precision is 3, the max of value is `999` and the min
// value is `-999`
MIN_DECIMAL64_FOR_EACH_PRECISION[*precision as usize] as i128,
MAX_DECIMAL64_FOR_EACH_PRECISION[*precision as usize] as i128,
),
DataType::Decimal128(precision, _) => (
// Different precision for decimal128 can store different range of value.
// For example, the precision is 3, the max of value is `999` and the min
Expand All @@ -149,6 +169,46 @@ fn try_cast_numeric_literal(
ScalarValue::TimestampMillisecond(Some(v), _) => (*v as i128).checked_mul(mul),
ScalarValue::TimestampMicrosecond(Some(v), _) => (*v as i128).checked_mul(mul),
ScalarValue::TimestampNanosecond(Some(v), _) => (*v as i128).checked_mul(mul),
ScalarValue::Decimal32(Some(v), _, scale) => {
let v = *v as i128;
let lit_scale_mul = 10_i128.pow(*scale as u32);
if mul >= lit_scale_mul {
// Example:
// lit is decimal(123,3,2)
// target type is decimal(5,3)
// the lit can be converted to the decimal(1230,5,3)
v.checked_mul(mul / lit_scale_mul)
} else if v % (lit_scale_mul / mul) == 0 {
// Example:
// lit is decimal(123000,10,3)
// target type is int32: the lit can be converted to INT32(123)
// target type is decimal(10,2): the lit can be converted to decimal(12300,10,2)
Some(v / (lit_scale_mul / mul))
} else {
// can't convert the lit decimal to the target data type
None
}
}
ScalarValue::Decimal64(Some(v), _, scale) => {
let v = *v as i128;
let lit_scale_mul = 10_i128.pow(*scale as u32);
if mul >= lit_scale_mul {
// Example:
// lit is decimal(123,3,2)
// target type is decimal(5,3)
// the lit can be converted to the decimal(1230,5,3)
v.checked_mul(mul / lit_scale_mul)
} else if v % (lit_scale_mul / mul) == 0 {
// Example:
// lit is decimal(123000,10,3)
// target type is int32: the lit can be converted to INT32(123)
// target type is decimal(10,2): the lit can be converted to decimal(12300,10,2)
Some(v / (lit_scale_mul / mul))
} else {
// can't convert the lit decimal to the target data type
None
}
}
ScalarValue::Decimal128(Some(v), _, scale) => {
let lit_scale_mul = 10_i128.pow(*scale as u32);
if mul >= lit_scale_mul {
Expand Down Expand Up @@ -218,6 +278,12 @@ fn try_cast_numeric_literal(
);
ScalarValue::TimestampNanosecond(value, tz.clone())
}
DataType::Decimal32(p, s) => {
ScalarValue::Decimal32(Some(value as i32), *p, *s)
}
DataType::Decimal64(p, s) => {
ScalarValue::Decimal64(Some(value as i64), *p, *s)
}
DataType::Decimal128(p, s) => {
ScalarValue::Decimal128(Some(value), *p, *s)
}
Expand Down
45 changes: 42 additions & 3 deletions datafusion/expr-common/src/type_coercion/aggregates.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@
use crate::signature::TypeSignature;
use arrow::datatypes::{
DataType, FieldRef, TimeUnit, DECIMAL128_MAX_PRECISION, DECIMAL128_MAX_SCALE,
DECIMAL256_MAX_PRECISION, DECIMAL256_MAX_SCALE,
DECIMAL256_MAX_PRECISION, DECIMAL256_MAX_SCALE, DECIMAL32_MAX_PRECISION,
DECIMAL32_MAX_SCALE, DECIMAL64_MAX_PRECISION, DECIMAL64_MAX_SCALE,
};

use datafusion_common::{internal_err, plan_err, Result};
Expand Down Expand Up @@ -150,6 +151,18 @@ pub fn sum_return_type(arg_type: &DataType) -> Result<DataType> {
DataType::Int64 => Ok(DataType::Int64),
DataType::UInt64 => Ok(DataType::UInt64),
DataType::Float64 => Ok(DataType::Float64),
DataType::Decimal32(precision, scale) => {
// in the spark, the result type is DECIMAL(min(38,precision+10), s)
// ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala#L66
let new_precision = DECIMAL32_MAX_PRECISION.min(*precision + 10);
Ok(DataType::Decimal32(new_precision, *scale))
}
DataType::Decimal64(precision, scale) => {
// in the spark, the result type is DECIMAL(min(38,precision+10), s)
// ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala#L66
let new_precision = DECIMAL64_MAX_PRECISION.min(*precision + 10);
Ok(DataType::Decimal64(new_precision, *scale))
}
DataType::Decimal128(precision, scale) => {
// In the spark, the result type is DECIMAL(min(38,precision+10), s)
// Ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala#L66
Expand Down Expand Up @@ -196,6 +209,20 @@ pub fn correlation_return_type(arg_type: &DataType) -> Result<DataType> {
/// Function return type of an average
pub fn avg_return_type(func_name: &str, arg_type: &DataType) -> Result<DataType> {
match arg_type {
DataType::Decimal32(precision, scale) => {
// In the spark, the result type is DECIMAL(min(38,precision+4), min(38,scale+4)).
// Ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala#L66
let new_precision = DECIMAL32_MAX_PRECISION.min(*precision + 4);
let new_scale = DECIMAL32_MAX_SCALE.min(*scale + 4);
Ok(DataType::Decimal32(new_precision, new_scale))
}
DataType::Decimal64(precision, scale) => {
// In the spark, the result type is DECIMAL(min(38,precision+4), min(38,scale+4)).
// Ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala#L66
let new_precision = DECIMAL64_MAX_PRECISION.min(*precision + 4);
let new_scale = DECIMAL64_MAX_SCALE.min(*scale + 4);
Ok(DataType::Decimal64(new_precision, new_scale))
}
DataType::Decimal128(precision, scale) => {
// In the spark, the result type is DECIMAL(min(38,precision+4), min(38,scale+4)).
// Ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala#L66
Expand All @@ -222,6 +249,16 @@ pub fn avg_return_type(func_name: &str, arg_type: &DataType) -> Result<DataType>
/// Internal sum type of an average
pub fn avg_sum_type(arg_type: &DataType) -> Result<DataType> {
match arg_type {
DataType::Decimal32(precision, scale) => {
// In the spark, the sum type of avg is DECIMAL(min(38,precision+10), s)
let new_precision = DECIMAL32_MAX_PRECISION.min(*precision + 10);
Ok(DataType::Decimal32(new_precision, *scale))
}
DataType::Decimal64(precision, scale) => {
// In the spark, the sum type of avg is DECIMAL(min(38,precision+10), s)
let new_precision = DECIMAL64_MAX_PRECISION.min(*precision + 10);
Ok(DataType::Decimal64(new_precision, *scale))
}
DataType::Decimal128(precision, scale) => {
// In the spark, the sum type of avg is DECIMAL(min(38,precision+10), s)
let new_precision = DECIMAL128_MAX_PRECISION.min(*precision + 10);
Expand Down Expand Up @@ -249,7 +286,7 @@ pub fn is_sum_support_arg_type(arg_type: &DataType) -> bool {
_ => matches!(
arg_type,
arg_type if NUMERICS.contains(arg_type)
|| matches!(arg_type, DataType::Decimal128(_, _) | DataType::Decimal256(_, _))
|| matches!(arg_type, DataType::Decimal32(_, _) | DataType::Decimal64(_, _) |DataType::Decimal128(_, _) | DataType::Decimal256(_, _))
),
}
}
Expand All @@ -262,7 +299,7 @@ pub fn is_avg_support_arg_type(arg_type: &DataType) -> bool {
_ => matches!(
arg_type,
arg_type if NUMERICS.contains(arg_type)
|| matches!(arg_type, DataType::Decimal128(_, _)| DataType::Decimal256(_, _))
|| matches!(arg_type, DataType::Decimal32(_, _) | DataType::Decimal64(_, _) |DataType::Decimal128(_, _) | DataType::Decimal256(_, _))
),
}
}
Expand Down Expand Up @@ -297,6 +334,8 @@ pub fn coerce_avg_type(func_name: &str, arg_types: &[DataType]) -> Result<Vec<Da
// Refer to https://www.postgresql.org/docs/8.2/functions-aggregate.html doc
fn coerced_type(func_name: &str, data_type: &DataType) -> Result<DataType> {
match &data_type {
DataType::Decimal32(p, s) => Ok(DataType::Decimal32(*p, *s)),
DataType::Decimal64(p, s) => Ok(DataType::Decimal64(*p, *s)),
DataType::Decimal128(p, s) => Ok(DataType::Decimal128(*p, *s)),
DataType::Decimal256(p, s) => Ok(DataType::Decimal256(*p, *s)),
d if d.is_numeric() => Ok(DataType::Float64),
Expand Down
Loading
Loading