-
Notifications
You must be signed in to change notification settings - Fork 1.7k
More decimal 32/64 support - type coercsion and misc gaps #17808
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -327,6 +327,16 @@ impl<'a> BinaryTypeCoercer<'a> { | |
|
||
// TODO Move the rest inside of BinaryTypeCoercer | ||
|
||
fn is_decimal(data_type: &DataType) -> bool { | ||
matches!( | ||
data_type, | ||
DataType::Decimal32(..) | ||
| DataType::Decimal64(..) | ||
| DataType::Decimal128(..) | ||
| DataType::Decimal256(..) | ||
) | ||
} | ||
|
||
/// Coercion rules for mathematics operators between decimal and non-decimal types. | ||
fn math_decimal_coercion( | ||
lhs_type: &DataType, | ||
|
@@ -357,6 +367,15 @@ fn math_decimal_coercion( | |
| (Decimal256(_, _), Decimal256(_, _)) => { | ||
Some((lhs_type.clone(), rhs_type.clone())) | ||
} | ||
// Cross-variant decimal coercion - choose larger variant with appropriate precision/scale | ||
(lhs, rhs) | ||
if is_decimal(lhs) | ||
&& is_decimal(rhs) | ||
&& std::mem::discriminant(lhs) != std::mem::discriminant(rhs) => | ||
{ | ||
let coerced_type = get_wider_decimal_type_cross_variant(lhs_type, rhs_type)?; | ||
Some((coerced_type.clone(), coerced_type)) | ||
} | ||
// Unlike with comparison we don't coerce to a decimal in the case of floating point | ||
// numbers, instead falling back to floating point arithmetic instead | ||
( | ||
|
@@ -953,21 +972,92 @@ pub fn binary_numeric_coercion( | |
pub fn decimal_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<DataType> { | ||
use arrow::datatypes::DataType::*; | ||
|
||
// Prefer decimal data type over floating point for comparison operation | ||
match (lhs_type, rhs_type) { | ||
// Prefer decimal data type over floating point for comparison operation | ||
(Decimal128(_, _), Decimal128(_, _)) => { | ||
// Same decimal types | ||
(lhs_type, rhs_type) | ||
if is_decimal(lhs_type) | ||
&& is_decimal(rhs_type) | ||
&& std::mem::discriminant(lhs_type) | ||
== std::mem::discriminant(rhs_type) => | ||
{ | ||
get_wider_decimal_type(lhs_type, rhs_type) | ||
} | ||
(Decimal128(_, _), _) => get_common_decimal_type(lhs_type, rhs_type), | ||
(_, Decimal128(_, _)) => get_common_decimal_type(rhs_type, lhs_type), | ||
(Decimal256(_, _), Decimal256(_, _)) => { | ||
get_wider_decimal_type(lhs_type, rhs_type) | ||
// Mismatched decimal types | ||
(lhs_type, rhs_type) | ||
if is_decimal(lhs_type) | ||
&& is_decimal(rhs_type) | ||
&& std::mem::discriminant(lhs_type) | ||
!= std::mem::discriminant(rhs_type) => | ||
{ | ||
get_wider_decimal_type_cross_variant(lhs_type, rhs_type) | ||
} | ||
// Decimal + non-decimal types | ||
(Decimal32(_, _) | Decimal64(_, _) | Decimal128(_, _) | Decimal256(_, _), _) => { | ||
get_common_decimal_type(lhs_type, rhs_type) | ||
} | ||
(_, Decimal32(_, _) | Decimal64(_, _) | Decimal128(_, _) | Decimal256(_, _)) => { | ||
get_common_decimal_type(rhs_type, lhs_type) | ||
} | ||
(Decimal256(_, _), _) => get_common_decimal_type(lhs_type, rhs_type), | ||
(_, Decimal256(_, _)) => get_common_decimal_type(rhs_type, lhs_type), | ||
(_, _) => None, | ||
} | ||
} | ||
/// Handle cross-variant decimal widening by choosing the larger variant | ||
fn get_wider_decimal_type_cross_variant( | ||
lhs_type: &DataType, | ||
rhs_type: &DataType, | ||
) -> Option<DataType> { | ||
use arrow::datatypes::DataType::*; | ||
|
||
let (p1, s1) = match lhs_type { | ||
Decimal32(p, s) => (*p, *s), | ||
Decimal64(p, s) => (*p, *s), | ||
Decimal128(p, s) => (*p, *s), | ||
Decimal256(p, s) => (*p, *s), | ||
_ => return None, | ||
}; | ||
|
||
let (p2, s2) = match rhs_type { | ||
Decimal32(p, s) => (*p, *s), | ||
Decimal64(p, s) => (*p, *s), | ||
Decimal128(p, s) => (*p, *s), | ||
Decimal256(p, s) => (*p, *s), | ||
_ => return None, | ||
}; | ||
|
||
// max(s1, s2) + max(p1-s1, p2-s2), max(s1, s2) | ||
let s = s1.max(s2); | ||
let range = (p1 as i8 - s1).max(p2 as i8 - s2); | ||
let required_precision = (range + s) as u8; | ||
Comment on lines
+1029
to
+1031
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What happens if we have: Decimal256 with precision 76 (max) and scale 0, and Decimal128 with precision 38 (max) with scale 1; So Is this a valid case? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think an overflow is valid, I'll have to think about it and maybe look into solutions in other systems. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I looked around a bit, and what I could find is:
I'm not sure what's the desired behavior regarding precision loss (should it be configurable? Is there currently an accepted desired behavior?), I think for this PR it should be fine to just return There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think returning There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done as part of fd1f043 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Cheers; left another minor comment related to the check below. Also would be nice if we had a test for this edge case. |
||
|
||
// Choose the larger variant between the two input types, while making sure we don't overflow the precision. | ||
match (lhs_type, rhs_type) { | ||
(Decimal32(_, _), Decimal64(_, _)) | (Decimal64(_, _), Decimal32(_, _)) | ||
if required_precision <= DECIMAL64_MAX_PRECISION => | ||
{ | ||
Some(Decimal64(required_precision, s)) | ||
} | ||
(Decimal32(_, _), Decimal128(_, _)) | ||
| (Decimal128(_, _), Decimal32(_, _)) | ||
| (Decimal64(_, _), Decimal128(_, _)) | ||
| (Decimal128(_, _), Decimal64(_, _)) | ||
if required_precision <= DECIMAL128_MAX_PRECISION => | ||
{ | ||
Some(Decimal128(required_precision, s)) | ||
} | ||
(Decimal32(_, _), Decimal256(_, _)) | ||
| (Decimal256(_, _), Decimal32(_, _)) | ||
| (Decimal64(_, _), Decimal256(_, _)) | ||
| (Decimal256(_, _), Decimal64(_, _)) | ||
| (Decimal128(_, _), Decimal256(_, _)) | ||
| (Decimal256(_, _), Decimal128(_, _)) | ||
if required_precision <= DECIMAL256_MAX_PRECISION => | ||
{ | ||
Some(Decimal256(required_precision, s)) | ||
} | ||
_ => None, | ||
} | ||
} | ||
|
||
/// Coerce `lhs_type` and `rhs_type` to a common type. | ||
fn get_common_decimal_type( | ||
|
@@ -976,7 +1066,15 @@ fn get_common_decimal_type( | |
) -> Option<DataType> { | ||
use arrow::datatypes::DataType::*; | ||
match decimal_type { | ||
Decimal32(_, _) | Decimal64(_, _) | Decimal128(_, _) => { | ||
Decimal32(_, _) => { | ||
let other_decimal_type = coerce_numeric_type_to_decimal32(other_type)?; | ||
get_wider_decimal_type(decimal_type, &other_decimal_type) | ||
} | ||
Decimal64(_, _) => { | ||
let other_decimal_type = coerce_numeric_type_to_decimal64(other_type)?; | ||
get_wider_decimal_type(decimal_type, &other_decimal_type) | ||
} | ||
Decimal128(_, _) => { | ||
let other_decimal_type = coerce_numeric_type_to_decimal128(other_type)?; | ||
get_wider_decimal_type(decimal_type, &other_decimal_type) | ||
} | ||
|
@@ -988,7 +1086,7 @@ fn get_common_decimal_type( | |
} | ||
} | ||
|
||
/// Returns a `DataType::Decimal128` that can store any value from either | ||
/// Returns a decimal [`DataType`] variant that can store any value from either | ||
/// `lhs_decimal_type` and `rhs_decimal_type` | ||
/// | ||
/// The result decimal type is `(max(s1, s2) + max(p1-s1, p2-s2), max(s1, s2))`. | ||
|
@@ -1209,14 +1307,14 @@ fn numerical_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<DataTy | |
} | ||
|
||
fn create_decimal32_type(precision: u8, scale: i8) -> DataType { | ||
DataType::Decimal128( | ||
DataType::Decimal32( | ||
DECIMAL32_MAX_PRECISION.min(precision), | ||
DECIMAL32_MAX_SCALE.min(scale), | ||
) | ||
} | ||
|
||
fn create_decimal64_type(precision: u8, scale: i8) -> DataType { | ||
DataType::Decimal128( | ||
DataType::Decimal64( | ||
DECIMAL64_MAX_PRECISION.min(precision), | ||
DECIMAL64_MAX_SCALE.min(scale), | ||
) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This might be cleaner like so:
Following what was done above
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oops forgot the
is_decimal()
checks for the first branchThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah I've added them locally :) should have something soon
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done as part of 4145a04