@@ -320,6 +320,16 @@ impl<'a> BinaryTypeCoercer<'a> {
320
320
321
321
// TODO Move the rest inside of BinaryTypeCoercer
322
322
323
+ fn is_decimal ( data_type : & DataType ) -> bool {
324
+ matches ! (
325
+ data_type,
326
+ DataType :: Decimal32 ( ..)
327
+ | DataType :: Decimal64 ( ..)
328
+ | DataType :: Decimal128 ( ..)
329
+ | DataType :: Decimal256 ( ..)
330
+ )
331
+ }
332
+
323
333
/// Coercion rules for mathematics operators between decimal and non-decimal types.
324
334
fn math_decimal_coercion (
325
335
lhs_type : & DataType ,
@@ -350,6 +360,15 @@ fn math_decimal_coercion(
350
360
| ( Decimal256 ( _, _) , Decimal256 ( _, _) ) => {
351
361
Some ( ( lhs_type. clone ( ) , rhs_type. clone ( ) ) )
352
362
}
363
+ // Cross-variant decimal coercion - choose larger variant with appropriate precision/scale
364
+ ( lhs, rhs)
365
+ if is_decimal ( lhs)
366
+ && is_decimal ( rhs)
367
+ && std:: mem:: discriminant ( lhs) != std:: mem:: discriminant ( rhs) =>
368
+ {
369
+ let coerced_type = get_wider_decimal_type_cross_variant ( lhs_type, rhs_type) ?;
370
+ Some ( ( coerced_type. clone ( ) , coerced_type) )
371
+ }
353
372
// Unlike with comparison we don't coerce to a decimal in the case of floating point
354
373
// numbers, instead falling back to floating point arithmetic instead
355
374
(
@@ -946,21 +965,92 @@ pub fn binary_numeric_coercion(
946
965
pub fn decimal_coercion ( lhs_type : & DataType , rhs_type : & DataType ) -> Option < DataType > {
947
966
use arrow:: datatypes:: DataType :: * ;
948
967
968
+ // Prefer decimal data type over floating point for comparison operation
949
969
match ( lhs_type, rhs_type) {
950
- // Prefer decimal data type over floating point for comparison operation
951
- ( Decimal128 ( _, _) , Decimal128 ( _, _) ) => {
970
+ // Same decimal types
971
+ ( lhs_type, rhs_type)
972
+ if is_decimal ( lhs_type)
973
+ && is_decimal ( rhs_type)
974
+ && std:: mem:: discriminant ( lhs_type)
975
+ == std:: mem:: discriminant ( rhs_type) =>
976
+ {
952
977
get_wider_decimal_type ( lhs_type, rhs_type)
953
978
}
954
- ( Decimal128 ( _, _) , _) => get_common_decimal_type ( lhs_type, rhs_type) ,
955
- ( _, Decimal128 ( _, _) ) => get_common_decimal_type ( rhs_type, lhs_type) ,
956
- ( Decimal256 ( _, _) , Decimal256 ( _, _) ) => {
957
- get_wider_decimal_type ( lhs_type, rhs_type)
979
+ // Mismatched decimal types
980
+ ( lhs_type, rhs_type)
981
+ if is_decimal ( lhs_type)
982
+ && is_decimal ( rhs_type)
983
+ && std:: mem:: discriminant ( lhs_type)
984
+ != std:: mem:: discriminant ( rhs_type) =>
985
+ {
986
+ get_wider_decimal_type_cross_variant ( lhs_type, rhs_type)
987
+ }
988
+ // Decimal + non-decimal types
989
+ ( Decimal32 ( _, _) | Decimal64 ( _, _) | Decimal128 ( _, _) | Decimal256 ( _, _) , _) => {
990
+ get_common_decimal_type ( lhs_type, rhs_type)
991
+ }
992
+ ( _, Decimal32 ( _, _) | Decimal64 ( _, _) | Decimal128 ( _, _) | Decimal256 ( _, _) ) => {
993
+ get_common_decimal_type ( rhs_type, lhs_type)
958
994
}
959
- ( Decimal256 ( _, _) , _) => get_common_decimal_type ( lhs_type, rhs_type) ,
960
- ( _, Decimal256 ( _, _) ) => get_common_decimal_type ( rhs_type, lhs_type) ,
961
995
( _, _) => None ,
962
996
}
963
997
}
998
+ /// Handle cross-variant decimal widening by choosing the larger variant
999
+ fn get_wider_decimal_type_cross_variant (
1000
+ lhs_type : & DataType ,
1001
+ rhs_type : & DataType ,
1002
+ ) -> Option < DataType > {
1003
+ use arrow:: datatypes:: DataType :: * ;
1004
+
1005
+ let ( p1, s1) = match lhs_type {
1006
+ Decimal32 ( p, s) => ( * p, * s) ,
1007
+ Decimal64 ( p, s) => ( * p, * s) ,
1008
+ Decimal128 ( p, s) => ( * p, * s) ,
1009
+ Decimal256 ( p, s) => ( * p, * s) ,
1010
+ _ => return None ,
1011
+ } ;
1012
+
1013
+ let ( p2, s2) = match rhs_type {
1014
+ Decimal32 ( p, s) => ( * p, * s) ,
1015
+ Decimal64 ( p, s) => ( * p, * s) ,
1016
+ Decimal128 ( p, s) => ( * p, * s) ,
1017
+ Decimal256 ( p, s) => ( * p, * s) ,
1018
+ _ => return None ,
1019
+ } ;
1020
+
1021
+ // max(s1, s2) + max(p1-s1, p2-s2), max(s1, s2)
1022
+ let s = s1. max ( s2) ;
1023
+ let range = ( p1 as i8 - s1) . max ( p2 as i8 - s2) ;
1024
+ let required_precision = ( range + s) as u8 ;
1025
+
1026
+ // Choose the larger variant between the two input types, while making sure we don't overflow the precision.
1027
+ match ( lhs_type, rhs_type) {
1028
+ ( Decimal32 ( _, _) , Decimal64 ( _, _) ) | ( Decimal64 ( _, _) , Decimal32 ( _, _) )
1029
+ if required_precision <= DECIMAL64_MAX_PRECISION =>
1030
+ {
1031
+ Some ( Decimal64 ( required_precision, s) )
1032
+ }
1033
+ ( Decimal32 ( _, _) , Decimal128 ( _, _) )
1034
+ | ( Decimal128 ( _, _) , Decimal32 ( _, _) )
1035
+ | ( Decimal64 ( _, _) , Decimal128 ( _, _) )
1036
+ | ( Decimal128 ( _, _) , Decimal64 ( _, _) )
1037
+ if required_precision <= DECIMAL128_MAX_PRECISION =>
1038
+ {
1039
+ Some ( Decimal128 ( required_precision, s) )
1040
+ }
1041
+ ( Decimal32 ( _, _) , Decimal256 ( _, _) )
1042
+ | ( Decimal256 ( _, _) , Decimal32 ( _, _) )
1043
+ | ( Decimal64 ( _, _) , Decimal256 ( _, _) )
1044
+ | ( Decimal256 ( _, _) , Decimal64 ( _, _) )
1045
+ | ( Decimal128 ( _, _) , Decimal256 ( _, _) )
1046
+ | ( Decimal256 ( _, _) , Decimal128 ( _, _) )
1047
+ if required_precision <= DECIMAL256_MAX_PRECISION =>
1048
+ {
1049
+ Some ( Decimal256 ( required_precision, s) )
1050
+ }
1051
+ _ => None ,
1052
+ }
1053
+ }
964
1054
965
1055
/// Coerce `lhs_type` and `rhs_type` to a common type.
966
1056
fn get_common_decimal_type (
@@ -969,7 +1059,15 @@ fn get_common_decimal_type(
969
1059
) -> Option < DataType > {
970
1060
use arrow:: datatypes:: DataType :: * ;
971
1061
match decimal_type {
972
- Decimal32 ( _, _) | Decimal64 ( _, _) | Decimal128 ( _, _) => {
1062
+ Decimal32 ( _, _) => {
1063
+ let other_decimal_type = coerce_numeric_type_to_decimal32 ( other_type) ?;
1064
+ get_wider_decimal_type ( decimal_type, & other_decimal_type)
1065
+ }
1066
+ Decimal64 ( _, _) => {
1067
+ let other_decimal_type = coerce_numeric_type_to_decimal64 ( other_type) ?;
1068
+ get_wider_decimal_type ( decimal_type, & other_decimal_type)
1069
+ }
1070
+ Decimal128 ( _, _) => {
973
1071
let other_decimal_type = coerce_numeric_type_to_decimal128 ( other_type) ?;
974
1072
get_wider_decimal_type ( decimal_type, & other_decimal_type)
975
1073
}
@@ -981,7 +1079,7 @@ fn get_common_decimal_type(
981
1079
}
982
1080
}
983
1081
984
- /// Returns a `DataType::Decimal128` that can store any value from either
1082
+ /// Returns a decimal [ `DataType`] variant that can store any value from either
985
1083
/// `lhs_decimal_type` and `rhs_decimal_type`
986
1084
///
987
1085
/// The result decimal type is `(max(s1, s2) + max(p1-s1, p2-s2), max(s1, s2))`.
@@ -1202,14 +1300,14 @@ fn numerical_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<DataTy
1202
1300
}
1203
1301
1204
1302
fn create_decimal32_type ( precision : u8 , scale : i8 ) -> DataType {
1205
- DataType :: Decimal128 (
1303
+ DataType :: Decimal32 (
1206
1304
DECIMAL32_MAX_PRECISION . min ( precision) ,
1207
1305
DECIMAL32_MAX_SCALE . min ( scale) ,
1208
1306
)
1209
1307
}
1210
1308
1211
1309
fn create_decimal64_type ( precision : u8 , scale : i8 ) -> DataType {
1212
- DataType :: Decimal128 (
1310
+ DataType :: Decimal64 (
1213
1311
DECIMAL64_MAX_PRECISION . min ( precision) ,
1214
1312
DECIMAL64_MAX_SCALE . min ( scale) ,
1215
1313
)
0 commit comments