diff --git a/native/core/src/execution/datafusion/planner.rs b/native/core/src/execution/datafusion/planner.rs index 6d6102aea4..5bfd76797c 100644 --- a/native/core/src/execution/datafusion/planner.rs +++ b/native/core/src/execution/datafusion/planner.rs @@ -17,9 +17,7 @@ //! Converts Spark physical plan to DataFusion physical plan -use std::{collections::HashMap, sync::Arc}; - -use arrow_schema::{DataType, Field, Schema, TimeUnit}; +use arrow_schema::{DataType, Field, Schema, TimeUnit, DECIMAL128_MAX_PRECISION}; use datafusion::functions_aggregate::bit_and_or_xor::{bit_and_udaf, bit_or_udaf, bit_xor_udaf}; use datafusion::functions_aggregate::count::count_udaf; use datafusion::functions_aggregate::sum::sum_udaf; @@ -62,6 +60,8 @@ use datafusion_physical_expr_common::aggregate::create_aggregate_expr; use itertools::Itertools; use jni::objects::GlobalRef; use num::{BigInt, ToPrimitive}; +use std::cmp::max; +use std::{collections::HashMap, sync::Arc}; use crate::{ errors::ExpressionError, @@ -410,7 +410,7 @@ impl PhysicalPlanner { // Spark Substring's start is 1-based when start > 0 let start = expr.start - i32::from(expr.start > 0); // substring negative len is treated as 0 in Spark - let len = std::cmp::max(expr.len, 0); + let len = max(expr.len, 0); Ok(Arc::new(SubstringExpr::new( child, @@ -664,7 +664,14 @@ impl PhysicalPlanner { | DataFusionOperator::Modulo, Ok(DataType::Decimal128(p1, s1)), Ok(DataType::Decimal128(p2, s2)), - ) => { + ) if ((op == DataFusionOperator::Plus || op == DataFusionOperator::Minus) + && max(s1, s2) as u8 + max(p1 - s1 as u8, p2 - s2 as u8) + >= DECIMAL128_MAX_PRECISION) + || (op == DataFusionOperator::Multiply && p1 + p2 >= DECIMAL128_MAX_PRECISION) + || (op == DataFusionOperator::Modulo + && max(s1, s2) as u8 + max(p1 - s1 as u8, p2 - s2 as u8) + > DECIMAL128_MAX_PRECISION) => + { let data_type = return_type.map(to_arrow_datatype).unwrap(); // For some Decimal128 operations, we need wider internal digits. // Cast left and right to Decimal256 and cast the result back to Decimal128 diff --git a/native/spark-expr/src/scalar_funcs.rs b/native/spark-expr/src/scalar_funcs.rs index c50b98bafe..7cbaf12aa4 100644 --- a/native/spark-expr/src/scalar_funcs.rs +++ b/native/spark-expr/src/scalar_funcs.rs @@ -25,7 +25,7 @@ use arrow::{ datatypes::{validate_decimal_precision, Decimal128Type, Int64Type}, }; use arrow_array::{Array, ArrowNativeTypeOp, BooleanArray, Decimal128Array}; -use arrow_schema::DataType; +use arrow_schema::{DataType, DECIMAL128_MAX_PRECISION}; use datafusion::{functions::math::round::round, physical_plan::ColumnarValue}; use datafusion_common::{ cast::as_generic_string_array, exec_err, internal_err, DataFusionError, @@ -460,27 +460,41 @@ pub fn spark_decimal_div( }; let left = left.as_primitive::(); let right = right.as_primitive::(); - let (_, s1) = get_precision_scale(left.data_type()); - let (_, s2) = get_precision_scale(right.data_type()); + let (p1, s1) = get_precision_scale(left.data_type()); + let (p2, s2) = get_precision_scale(right.data_type()); - let ten = BigInt::from(10); let l_exp = ((s2 + s3 + 1) as u32).saturating_sub(s1 as u32); let r_exp = (s1 as u32).saturating_sub((s2 + s3 + 1) as u32); - let l_mul = ten.pow(l_exp); - let r_mul = ten.pow(r_exp); - let five = BigInt::from(5); - let zero = BigInt::from(0); - let result: Decimal128Array = arrow::compute::kernels::arity::binary(left, right, |l, r| { - let l = BigInt::from(l) * &l_mul; - let r = BigInt::from(r) * &r_mul; - let div = if r.eq(&zero) { zero.clone() } else { &l / &r }; - let res = if div.is_negative() { - div - &five - } else { - div + &five - } / &ten; - res.to_i128().unwrap_or(i128::MAX) - })?; + let result: Decimal128Array = if p1 as u32 + l_exp > DECIMAL128_MAX_PRECISION as u32 + || p2 as u32 + r_exp > DECIMAL128_MAX_PRECISION as u32 + { + let ten = BigInt::from(10); + let l_mul = ten.pow(l_exp); + let r_mul = ten.pow(r_exp); + let five = BigInt::from(5); + let zero = BigInt::from(0); + arrow::compute::kernels::arity::binary(left, right, |l, r| { + let l = BigInt::from(l) * &l_mul; + let r = BigInt::from(r) * &r_mul; + let div = if r.eq(&zero) { zero.clone() } else { &l / &r }; + let res = if div.is_negative() { + div - &five + } else { + div + &five + } / &ten; + res.to_i128().unwrap_or(i128::MAX) + })? + } else { + let l_mul = 10_i128.pow(l_exp); + let r_mul = 10_i128.pow(r_exp); + arrow::compute::kernels::arity::binary(left, right, |l, r| { + let l = l * l_mul; + let r = r * r_mul; + let div = if r == 0 { 0 } else { l / r }; + let res = if div.is_negative() { div - 5 } else { div + 5 } / 10; + res.to_i128().unwrap_or(i128::MAX) + })? + }; let result = result.with_data_type(DataType::Decimal128(p3, s3)); Ok(ColumnarValue::Array(Arc::new(result))) }