Skip to content

Commit b4c7400

Browse files
AdamGSJefffrey
andcommitted
CR comments
Co-authored-by: Jeffrey Vo <jeffrey.vo.australia@gmail.com>
1 parent 4145a04 commit b4c7400

File tree

3 files changed

+25
-19
lines changed

3 files changed

+25
-19
lines changed

datafusion/expr-common/src/type_coercion/binary.rs

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -327,6 +327,16 @@ impl<'a> BinaryTypeCoercer<'a> {
327327

328328
// TODO Move the rest inside of BinaryTypeCoercer
329329

330+
fn is_decimal(data_type: &DataType) -> bool {
331+
matches!(
332+
data_type,
333+
DataType::Decimal32(..)
334+
| DataType::Decimal64(..)
335+
| DataType::Decimal128(..)
336+
| DataType::Decimal256(..)
337+
)
338+
}
339+
330340
/// Coercion rules for mathematics operators between decimal and non-decimal types.
331341
fn math_decimal_coercion(
332342
lhs_type: &DataType,
@@ -358,10 +368,11 @@ fn math_decimal_coercion(
358368
Some((lhs_type.clone(), rhs_type.clone()))
359369
}
360370
// Cross-variant decimal coercion - choose larger variant with appropriate precision/scale
361-
(Decimal32(_, _), Decimal64(_, _) | Decimal128(_, _) | Decimal256(_, _))
362-
| (Decimal64(_, _), Decimal32(_, _) | Decimal128(_, _) | Decimal256(_, _))
363-
| (Decimal128(_, _), Decimal32(_, _) | Decimal64(_, _) | Decimal256(_, _))
364-
| (Decimal256(_, _), Decimal32(_, _) | Decimal64(_, _) | Decimal128(_, _)) => {
371+
(lhs, rhs)
372+
if is_decimal(lhs)
373+
&& is_decimal(rhs)
374+
&& std::mem::discriminant(lhs) != std::mem::discriminant(rhs) =>
375+
{
365376
let coerced_type = get_wider_decimal_type_cross_variant(lhs_type, rhs_type)?;
366377
Some((coerced_type.clone(), coerced_type))
367378
}
@@ -1023,6 +1034,11 @@ fn get_wider_decimal_type_cross_variant(
10231034
let range = (p1 as i8 - s1).max(p2 as i8 - s2);
10241035
let required_precision = (range + s) as u8;
10251036

1037+
// We currently don't handle cases where the required percision overflows
1038+
if required_precision > DECIMAL256_MAX_PRECISION {
1039+
return None;
1040+
}
1041+
10261042
// Choose the larger variant between the two input types
10271043
match (lhs_type, rhs_type) {
10281044
(Decimal32(_, _), Decimal64(_, _)) | (Decimal64(_, _), Decimal32(_, _)) => {

datafusion/physical-plan/src/joins/sort_merge_join/stream.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1999,6 +1999,7 @@ fn is_join_arrays_equal(
19991999
DataType::Decimal32(..) => compare_value!(Decimal32Array),
20002000
DataType::Decimal64(..) => compare_value!(Decimal64Array),
20012001
DataType::Decimal128(..) => compare_value!(Decimal128Array),
2002+
DataType::Decimal256(..) => compare_value!(Decimal256Array),
20022003
DataType::Timestamp(time_unit, None) => match time_unit {
20032004
TimeUnit::Second => compare_value!(TimestampSecondArray),
20042005
TimeUnit::Millisecond => compare_value!(TimestampMillisecondArray),

datafusion/spark/src/function/math/width_bucket.rs

Lines changed: 4 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ use datafusion_common::cast::{
3232
};
3333
use datafusion_common::{exec_err, Result};
3434
use datafusion_expr::sort_properties::{ExprProperties, SortProperties};
35+
use datafusion_expr::type_coercion::is_signed_numeric;
3536
use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature};
3637
use datafusion_functions::utils::make_scalar_function;
3738

@@ -93,23 +94,11 @@ impl ScalarUDFImpl for SparkWidthBucket {
9394

9495
let (v, lo, hi, n) = (&types[0], &types[1], &types[2], &types[3]);
9596

96-
let is_num = |t: &DataType| {
97-
matches!(
98-
t,
99-
Int8 | Int16
100-
| Int32
101-
| Int64
102-
| Float32
103-
| Float64
104-
| Decimal32(_, _)
105-
| Decimal64(_, _)
106-
| Decimal128(_, _)
107-
)
108-
};
109-
11097
match (v, lo, hi, n) {
11198
(a, b, c, &(Int8 | Int16 | Int32 | Int64))
112-
if is_num(a) && is_num(b) && is_num(c) =>
99+
if is_signed_numeric(a)
100+
&& is_signed_numeric(b)
101+
&& is_signed_numeric(c) =>
113102
{
114103
Ok(vec![Float64, Float64, Float64, Int32])
115104
}

0 commit comments

Comments
 (0)