Skip to content

Commit 4242ed0

Browse files
committed
Fill out more parts in expr,common and expr-common
1 parent 16f8e50 commit 4242ed0

File tree

5 files changed

+104
-14
lines changed

5 files changed

+104
-14
lines changed

datafusion/common/src/types/native.rs

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ use crate::error::{Result, _internal_err};
2323
use arrow::compute::can_cast_types;
2424
use arrow::datatypes::{
2525
DataType, Field, FieldRef, Fields, IntervalUnit, TimeUnit, UnionFields,
26+
DECIMAL128_MAX_PRECISION, DECIMAL32_MAX_PRECISION, DECIMAL64_MAX_PRECISION,
2627
};
2728
use std::{fmt::Display, sync::Arc};
2829

@@ -228,7 +229,15 @@ impl LogicalType for NativeType {
228229
(Self::Float16, _) => Float16,
229230
(Self::Float32, _) => Float32,
230231
(Self::Float64, _) => Float64,
231-
(Self::Decimal(p, s), _) if p <= &38 => Decimal128(*p, *s),
232+
(Self::Decimal(p, s), _) if *p <= DECIMAL32_MAX_PRECISION => {
233+
Decimal32(*p, *s)
234+
}
235+
(Self::Decimal(p, s), _) if *p <= DECIMAL64_MAX_PRECISION => {
236+
Decimal64(*p, *s)
237+
}
238+
(Self::Decimal(p, s), _) if *p <= DECIMAL128_MAX_PRECISION => {
239+
Decimal128(*p, *s)
240+
}
232241
(Self::Decimal(p, s), _) => Decimal256(*p, *s),
233242
(Self::Timestamp(tu, tz), _) => Timestamp(*tu, tz.clone()),
234243
// If given type is Date, return the same type

datafusion/expr-common/src/casts.rs

Lines changed: 67 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,9 @@ use std::cmp::Ordering;
2525

2626
use arrow::datatypes::{
2727
DataType, TimeUnit, MAX_DECIMAL128_FOR_EACH_PRECISION,
28-
MIN_DECIMAL128_FOR_EACH_PRECISION,
28+
MAX_DECIMAL32_FOR_EACH_PRECISION, MAX_DECIMAL64_FOR_EACH_PRECISION,
29+
MIN_DECIMAL128_FOR_EACH_PRECISION, MIN_DECIMAL32_FOR_EACH_PRECISION,
30+
MIN_DECIMAL64_FOR_EACH_PRECISION,
2931
};
3032
use arrow::temporal_conversions::{MICROSECONDS, MILLISECONDS, NANOSECONDS};
3133
use datafusion_common::ScalarValue;
@@ -69,6 +71,8 @@ fn is_supported_numeric_type(data_type: &DataType) -> bool {
6971
| DataType::Int16
7072
| DataType::Int32
7173
| DataType::Int64
74+
| DataType::Decimal32(_, _)
75+
| DataType::Decimal64(_, _)
7276
| DataType::Decimal128(_, _)
7377
| DataType::Timestamp(_, _)
7478
)
@@ -114,6 +118,8 @@ fn try_cast_numeric_literal(
114118
| DataType::Int32
115119
| DataType::Int64 => 1_i128,
116120
DataType::Timestamp(_, _) => 1_i128,
121+
DataType::Decimal32(_, scale) => 10_i128.pow(*scale as u32),
122+
DataType::Decimal64(_, scale) => 10_i128.pow(*scale as u32),
117123
DataType::Decimal128(_, scale) => 10_i128.pow(*scale as u32),
118124
_ => return None,
119125
};
@@ -127,6 +133,20 @@ fn try_cast_numeric_literal(
127133
DataType::Int32 => (i32::MIN as i128, i32::MAX as i128),
128134
DataType::Int64 => (i64::MIN as i128, i64::MAX as i128),
129135
DataType::Timestamp(_, _) => (i64::MIN as i128, i64::MAX as i128),
136+
DataType::Decimal32(precision, _) => (
137+
// Different precision for decimal32 can store different range of value.
138+
// For example, the precision is 3, the max of value is `999` and the min
139+
// value is `-999`
140+
MIN_DECIMAL32_FOR_EACH_PRECISION[*precision as usize] as i128,
141+
MAX_DECIMAL32_FOR_EACH_PRECISION[*precision as usize] as i128,
142+
),
143+
DataType::Decimal64(precision, _) => (
144+
// Different precision for decimal64 can store different range of value.
145+
// For example, the precision is 3, the max of value is `999` and the min
146+
// value is `-999`
147+
MIN_DECIMAL64_FOR_EACH_PRECISION[*precision as usize] as i128,
148+
MAX_DECIMAL64_FOR_EACH_PRECISION[*precision as usize] as i128,
149+
),
130150
DataType::Decimal128(precision, _) => (
131151
// Different precision for decimal128 can store different range of value.
132152
// For example, the precision is 3, the max of value is `999` and the min
@@ -149,6 +169,46 @@ fn try_cast_numeric_literal(
149169
ScalarValue::TimestampMillisecond(Some(v), _) => (*v as i128).checked_mul(mul),
150170
ScalarValue::TimestampMicrosecond(Some(v), _) => (*v as i128).checked_mul(mul),
151171
ScalarValue::TimestampNanosecond(Some(v), _) => (*v as i128).checked_mul(mul),
172+
ScalarValue::Decimal32(Some(v), _, scale) => {
173+
let v = *v as i128;
174+
let lit_scale_mul = 10_i128.pow(*scale as u32);
175+
if mul >= lit_scale_mul {
176+
// Example:
177+
// lit is decimal(123,3,2)
178+
// target type is decimal(5,3)
179+
// the lit can be converted to the decimal(1230,5,3)
180+
v.checked_mul(mul / lit_scale_mul)
181+
} else if v % (lit_scale_mul / mul) == 0 {
182+
// Example:
183+
// lit is decimal(123000,10,3)
184+
// target type is int32: the lit can be converted to INT32(123)
185+
// target type is decimal(10,2): the lit can be converted to decimal(12300,10,2)
186+
Some(v / (lit_scale_mul / mul))
187+
} else {
188+
// can't convert the lit decimal to the target data type
189+
None
190+
}
191+
}
192+
ScalarValue::Decimal64(Some(v), _, scale) => {
193+
let v = *v as i128;
194+
let lit_scale_mul = 10_i128.pow(*scale as u32);
195+
if mul >= lit_scale_mul {
196+
// Example:
197+
// lit is decimal(123,3,2)
198+
// target type is decimal(5,3)
199+
// the lit can be converted to the decimal(1230,5,3)
200+
v.checked_mul(mul / lit_scale_mul)
201+
} else if v % (lit_scale_mul / mul) == 0 {
202+
// Example:
203+
// lit is decimal(123000,10,3)
204+
// target type is int32: the lit can be converted to INT32(123)
205+
// target type is decimal(10,2): the lit can be converted to decimal(12300,10,2)
206+
Some(v / (lit_scale_mul / mul))
207+
} else {
208+
// can't convert the lit decimal to the target data type
209+
None
210+
}
211+
}
152212
ScalarValue::Decimal128(Some(v), _, scale) => {
153213
let lit_scale_mul = 10_i128.pow(*scale as u32);
154214
if mul >= lit_scale_mul {
@@ -218,6 +278,12 @@ fn try_cast_numeric_literal(
218278
);
219279
ScalarValue::TimestampNanosecond(value, tz.clone())
220280
}
281+
DataType::Decimal32(p, s) => {
282+
ScalarValue::Decimal32(Some(value as i32), *p, *s)
283+
}
284+
DataType::Decimal64(p, s) => {
285+
ScalarValue::Decimal64(Some(value as i64), *p, *s)
286+
}
221287
DataType::Decimal128(p, s) => {
222288
ScalarValue::Decimal128(Some(value), *p, *s)
223289
}

datafusion/expr/src/logical_plan/builder.rs

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -281,15 +281,14 @@ impl LogicalPlanBuilder {
281281
let value = &row[j];
282282
let data_type = value.get_type(schema)?;
283283

284-
if !data_type.equals_datatype(field_type) {
285-
if can_cast_types(&data_type, field_type) {
286-
} else {
287-
return exec_err!(
288-
"type mismatch and can't cast to got {} and {}",
289-
data_type,
290-
field_type
291-
);
292-
}
284+
if !data_type.equals_datatype(field_type)
285+
&& !can_cast_types(&data_type, field_type)
286+
{
287+
return exec_err!(
288+
"type mismatch and can't cast to got {} and {}",
289+
data_type,
290+
field_type
291+
);
293292
}
294293
}
295294
fields.push(field_type.to_owned(), field_nullable);

datafusion/expr/src/test/function_stub.rs

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ use std::any::Any;
2323

2424
use arrow::datatypes::{
2525
DataType, FieldRef, DECIMAL128_MAX_PRECISION, DECIMAL256_MAX_PRECISION,
26+
DECIMAL32_MAX_PRECISION, DECIMAL64_MAX_PRECISION,
2627
};
2728

2829
use datafusion_common::{exec_err, not_impl_err, utils::take_function_args, Result};
@@ -135,9 +136,10 @@ impl AggregateUDFImpl for Sum {
135136
DataType::Dictionary(_, v) => coerced_type(v),
136137
// in the spark, the result type is DECIMAL(min(38,precision+10), s)
137138
// ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala#L66
138-
DataType::Decimal128(_, _) | DataType::Decimal256(_, _) => {
139-
Ok(data_type.clone())
140-
}
139+
DataType::Decimal32(_, _)
140+
| DataType::Decimal64(_, _)
141+
| DataType::Decimal128(_, _)
142+
| DataType::Decimal256(_, _) => Ok(data_type.clone()),
141143
dt if dt.is_signed_integer() => Ok(DataType::Int64),
142144
dt if dt.is_unsigned_integer() => Ok(DataType::UInt64),
143145
dt if dt.is_floating() => Ok(DataType::Float64),
@@ -153,6 +155,18 @@ impl AggregateUDFImpl for Sum {
153155
DataType::Int64 => Ok(DataType::Int64),
154156
DataType::UInt64 => Ok(DataType::UInt64),
155157
DataType::Float64 => Ok(DataType::Float64),
158+
DataType::Decimal32(precision, scale) => {
159+
// in the spark, the result type is DECIMAL(min(38,precision+10), s)
160+
// ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala#L66
161+
let new_precision = DECIMAL32_MAX_PRECISION.min(*precision + 10);
162+
Ok(DataType::Decimal32(new_precision, *scale))
163+
}
164+
DataType::Decimal64(precision, scale) => {
165+
// in the spark, the result type is DECIMAL(min(38,precision+10), s)
166+
// ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala#L66
167+
let new_precision = DECIMAL64_MAX_PRECISION.min(*precision + 10);
168+
Ok(DataType::Decimal64(new_precision, *scale))
169+
}
156170
DataType::Decimal128(precision, scale) => {
157171
// in the spark, the result type is DECIMAL(min(38,precision+10), s)
158172
// ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala#L66

datafusion/expr/src/type_coercion/functions.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -877,6 +877,8 @@ fn coerced_from<'a>(
877877
| UInt64
878878
| Float32
879879
| Float64
880+
| Decimal32(_, _)
881+
| Decimal64(_, _)
880882
| Decimal128(_, _),
881883
) => Some(type_into.clone()),
882884
(

0 commit comments

Comments
 (0)