diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHListenerApi.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHListenerApi.scala index 58acda88fba5..8065d35c8565 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHListenerApi.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHListenerApi.scala @@ -25,7 +25,7 @@ import org.apache.gluten.extension.ExpressionExtensionTrait import org.apache.gluten.jni.JniLibLoader import org.apache.gluten.vectorized.CHNativeExpressionEvaluator -import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.{SPARK_VERSION, SparkConf, SparkContext} import org.apache.spark.api.plugin.PluginContext import org.apache.spark.internal.Logging import org.apache.spark.listener.CHGlutenSQLAppStatusListener @@ -85,7 +85,7 @@ class CHListenerApi extends ListenerApi with Logging { conf.setCHConfig( "timezone" -> conf.get("spark.sql.session.timeZone", TimeZone.getDefault.getID), "local_engine.settings.log_processors_profiles" -> "true") - + conf.setCHSettings("spark_version", SPARK_VERSION) // add memory limit for external sort val externalSortKey = CHConf.runtimeSettings("max_bytes_before_external_sort") if (conf.getLong(externalSortKey, -1) < 0) { diff --git a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseDecimalSuite.scala b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseDecimalSuite.scala index 40f442bc2948..f772c909c8f2 100644 --- a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseDecimalSuite.scala +++ b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseDecimalSuite.scala @@ -69,7 +69,7 @@ class GlutenClickHouseDecimalSuite (DecimalType.apply(18, 8), Seq()), // 3/10: all value is null and compare with limit // 1 Spark 3.5 - (DecimalType.apply(38, 19), if (isSparkVersionLE("3.3")) Seq(3, 10) else Seq(1, 3, 10)) + (DecimalType.apply(38, 19), if (isSparkVersionLE("3.3")) Seq(3, 10) else Seq(3, 10)) ) private def createDecimalTables(dataType: DecimalType): Unit = { @@ -309,7 +309,10 @@ class GlutenClickHouseDecimalSuite "insert into decimals_test values(1, 100.0, 999.0)" + ", (2, 12345.123, 12345.123)" + ", (3, 0.1234567891011, 1234.1)" + - ", (4, 123456789123456789.0, 1.123456789123456789)" + ", (4, 123456789123456789.0, 1.123456789123456789)" + + ", (5, 0, 0)" + + ", (6, 0, 1.23)" + + ", (7, 1.23, 0)" spark.sql(createSql) try { diff --git a/cpp-ch/local-engine/AggregateFunctions/AggregateFunctionSparkAvg.cpp b/cpp-ch/local-engine/AggregateFunctions/AggregateFunctionSparkAvg.cpp index 0aa233145728..524a39d3a155 100644 --- a/cpp-ch/local-engine/AggregateFunctions/AggregateFunctionSparkAvg.cpp +++ b/cpp-ch/local-engine/AggregateFunctions/AggregateFunctionSparkAvg.cpp @@ -24,6 +24,7 @@ #include #include +#include namespace DB { @@ -47,7 +48,7 @@ DataTypePtr getSparkAvgReturnType(const DataTypePtr & arg_type) return createDecimal(precision_value, scale_value); } -template +template requires is_decimal class AggregateFunctionSparkAvg final : public AggregateFunctionAvg { @@ -61,7 +62,7 @@ class AggregateFunctionSparkAvg final : public AggregateFunctionAvg { } - DataTypePtr createResultType(const DataTypes & argument_types_, UInt32 num_scale_, UInt32 round_scale_) + DataTypePtr createResultType(const DataTypes & argument_types_, UInt32 num_scale_, UInt32 /*round_scale_*/) { const DataTypePtr & data_type = argument_types_[0]; const UInt32 precision_value = std::min(getDecimalPrecision(*data_type) + 4, DecimalUtils::max_precision); @@ -82,7 +83,7 @@ class AggregateFunctionSparkAvg final : public AggregateFunctionAvg else if (which.isDecimal64()) { assert_cast &>(to).getData().push_back( - divideDecimalAndUInt(this->data(place), num_scale, result_scale, round_scale)); + divideDecimalAndUInt(this->data(place), num_scale, result_scale, round_scale)); } else if (which.isDecimal128()) { @@ -116,6 +117,9 @@ class AggregateFunctionSparkAvg final : public AggregateFunctionAvg auto result = value / avg.denominator; + if constexpr (SPARK35) + return result; + if (round_scale > result_scale) return result; @@ -128,8 +132,21 @@ class AggregateFunctionSparkAvg final : public AggregateFunctionAvg UInt32 round_scale; }; -AggregateFunctionPtr -createAggregateFunctionSparkAvg(const std::string & name, const DataTypes & argument_types, const Array & parameters, const Settings * settings) +template +static IAggregateFunction * createWithDecimalType(const IDataType & argument_type, TArgs && ... args) +{ + WhichDataType which(argument_type); + if (which.idx == TypeIndex::Decimal32) return new AggregateFunctionSparkAvg(args...); + if (which.idx == TypeIndex::Decimal64) return new AggregateFunctionSparkAvg(args...); + if (which.idx == TypeIndex::Decimal128) return new AggregateFunctionSparkAvg(args...); + if (which.idx == TypeIndex::Decimal256) return new AggregateFunctionSparkAvg(args...); + if constexpr (AggregateFunctionSparkAvg::DateTime64Supported) + if (which.idx == TypeIndex::DateTime64) return new AggregateFunctionSparkAvg(args...); + return nullptr; +} + +AggregateFunctionPtr createAggregateFunctionSparkAvg( + const std::string & name, const DataTypes & argument_types, const Array & parameters, const Settings * settings) { assertNoParameters(name, parameters); assertUnary(name, argument_types); @@ -140,13 +157,20 @@ createAggregateFunctionSparkAvg(const std::string & name, const DataTypes & argu throw Exception( ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Illegal type {} of argument for aggregate function {}", data_type->getName(), name); + std::string version; + if (tryGetString(*settings, "spark_version", version) && version.starts_with("3.5")) + { + res.reset(createWithDecimalType(*data_type, argument_types, getDecimalScale(*data_type), 0)); + return res; + } + bool allowPrecisionLoss = settings->get(DECIMAL_OPERATIONS_ALLOW_PREC_LOSS).safeGet(); const UInt32 p1 = DB::getDecimalPrecision(*data_type); const UInt32 s1 = DB::getDecimalScale(*data_type); auto [p2, s2] = GlutenDecimalUtils::LONG_DECIMAL; auto [_, round_scale] = GlutenDecimalUtils::dividePrecisionScale(p1, s1, p2, s2, allowPrecisionLoss); - res.reset(createWithDecimalType(*data_type, argument_types, getDecimalScale(*data_type), round_scale)); + res.reset(createWithDecimalType(*data_type, argument_types, getDecimalScale(*data_type), round_scale)); return res; } diff --git a/cpp-ch/local-engine/Common/GlutenDecimalUtils.h b/cpp-ch/local-engine/Common/GlutenDecimalUtils.h index 32af66ec590e..cf600a5cc9ef 100644 --- a/cpp-ch/local-engine/Common/GlutenDecimalUtils.h +++ b/cpp-ch/local-engine/Common/GlutenDecimalUtils.h @@ -95,13 +95,6 @@ class GlutenDecimalUtils } } - static std::tuple widerDecimalType(const size_t p1, const size_t s1, const size_t p2, const size_t s2) - { - // max(s1, s2) + max(p1-s1, p2-s2), max(s1, s2) - auto scale = std::max(s1, s2); - auto range = std::max(p1 - s1, p2 - s2); - return std::tuple(range + scale, scale); - } }; diff --git a/cpp-ch/local-engine/Functions/SparkFunctionDecimalBinaryArithmetic.cpp b/cpp-ch/local-engine/Functions/SparkFunctionDecimalBinaryArithmetic.cpp new file mode 100644 index 000000000000..26d8e0deb3a9 --- /dev/null +++ b/cpp-ch/local-engine/Functions/SparkFunctionDecimalBinaryArithmetic.cpp @@ -0,0 +1,591 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "SparkFunctionDecimalBinaryArithmetic.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace DB +{ +namespace ErrorCodes +{ +extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH; +extern const int ILLEGAL_TYPE_OF_ARGUMENT; +extern const int ILLEGAL_COLUMN; +extern const int TYPE_MISMATCH; +extern const int LOGICAL_ERROR; +} + +} + +namespace local_engine +{ +using namespace DB; + +namespace +{ +enum class OpCase : uint8_t +{ + Vector, + LeftConstant, + RightConstant +}; + +enum class OpMode : uint8_t +{ + Default, + Effect +}; + +template +bool calculateWith256(const IDataType & left, const IDataType & right) +{ + const size_t p1 = getDecimalPrecision(left); + const size_t s1 = getDecimalScale(left); + const size_t p2 = getDecimalPrecision(right); + const size_t s2 = getDecimalScale(right); + + size_t precision; + if constexpr (is_plus_minus) + precision = std::max(s1, s2) + std::max(p1 - s1, p2 - s2) + 1; + else if constexpr (is_multiply) + precision = p1 + p2 + 1; + else if constexpr (is_division) + precision = p1 - s1 + s2 + std::max(static_cast(6), s1 + p2 + 1); + else if constexpr (is_modulo) + precision = std::min(p1 - s1, p2 - s2) + std::max(s1, s2); + else + throw Exception(ErrorCodes::LOGICAL_ERROR, "Not supported."); + + return precision > DataTypeDecimal128::maxPrecision(); +} + +template +struct SparkDecimalBinaryOperation +{ +private: + static constexpr bool is_plus_minus = SparkIsOperation::plus || SparkIsOperation::minus; + static constexpr bool is_multiply = SparkIsOperation::multiply; + static constexpr bool is_division = SparkIsOperation::division; + static constexpr bool is_modulo = SparkIsOperation::modulo; + +public: + template + static ColumnPtr executeDecimal(const ColumnsWithTypeAndName & arguments, const A & left, const B & right, const R & result) + { + using LeftDataType = std::decay_t; // e.g. DataTypeDecimal + using RightDataType = std::decay_t; // e.g. DataTypeDecimal + using ResultDataType = std::decay_t; // e.g. DataTypeDecimal + + using ColVecLeft = ColumnVectorOrDecimal; + using ColVecRight = ColumnVectorOrDecimal; + + const ColumnPtr left_col = arguments[0].column; + const ColumnPtr right_col = arguments[1].column; + + const auto * const col_left_raw = left_col.get(); + const auto * const col_right_raw = right_col.get(); + + const size_t col_left_size = col_left_raw->size(); + + const ColumnConst * const col_left_const = checkAndGetColumnConst(col_left_raw); + const ColumnConst * const col_right_const = checkAndGetColumnConst(col_right_raw); + + const ColVecLeft * const col_left = checkAndGetColumn(col_left_raw); + const ColVecRight * const col_right = checkAndGetColumn(col_right_raw); + + if constexpr (Mode == OpMode::Effect) + { + return executeDecimalImpl( + left, right, col_left_const, col_right_const, col_left, col_right, col_left_size, result); + } + + if (calculateWith256(*arguments[0].type.get(), *arguments[1].type.get())) + { + return executeDecimalImpl( + left, right, col_left_const, col_right_const, col_left, col_right, col_left_size, result); + } + + return executeDecimalImpl( + left, right, col_left_const, col_right_const, col_left, col_right, col_left_size, result); + } + +private: + // ResultDataType e.g. DataTypeDecimal + template + static ColumnPtr executeDecimalImpl( + const auto & left, + const auto & right, + const ColumnConst * const col_left_const, + const ColumnConst * const col_right_const, + const auto * const col_left, + const auto * const col_right, + size_t col_left_size, + const ResultDataType & resultDataType) + { + using LeftFieldType = typename LeftDataType::FieldType; + using RightFieldType = typename RightDataType::FieldType; + using ResultFieldType = typename ResultDataType::FieldType; + + using NativeResultType = NativeType; + using ColVecResult = ColumnVectorOrDecimal; + + size_t max_scale; + if constexpr (is_multiply) + max_scale = left.getScale() + right.getScale(); + else + max_scale = std::max(resultDataType.getScale(), std::max(left.getScale(), right.getScale())); + + NativeResultType scale_left = [&] + { + if constexpr (is_multiply) + return NativeResultType{1}; + + // cast scale same to left + auto diff_scale = max_scale - left.getScale(); + if constexpr (is_division) + return DecimalUtils::scaleMultiplier(diff_scale + max_scale); + else + return DecimalUtils::scaleMultiplier(diff_scale); + }(); + + const NativeResultType scale_right = [&] + { + if constexpr (is_multiply) + return NativeResultType{1}; + else + return DecimalUtils::scaleMultiplier(max_scale - right.getScale()); + }(); + + + bool calculate_with_256 = false; + if constexpr (CalculateWith256) + calculate_with_256 = true; + else + { + auto p1 = left.getPrecision(); + auto p2 = right.getPrecision(); + if (DataTypeDecimal::maxPrecision() < p1 + max_scale - left.getScale() + || DataTypeDecimal::maxPrecision() < p2 + max_scale - right.getScale()) + calculate_with_256 = true; + } + + ColumnUInt8::MutablePtr col_null_map_to = ColumnUInt8::create(col_left_size, false); + ColumnUInt8::Container * vec_null_map_to = &col_null_map_to->getData(); + + typename ColVecResult::MutablePtr col_res = ColVecResult::create(0, resultDataType.getScale()); + auto & vec_res = col_res->getData(); + vec_res.resize(col_left_size); + + if (col_left && col_right) + { + if (calculate_with_256) + { + process( + col_left->getData(), + col_right->getData(), + vec_res, + scale_left, + scale_right, + *vec_null_map_to, + resultDataType, + max_scale); + } + else + { + process( + col_left->getData(), + col_right->getData(), + vec_res, + scale_left, + scale_right, + *vec_null_map_to, + resultDataType, + max_scale); + } + } + else if (col_left_const && col_right) + { + LeftFieldType const_left = col_left_const->getValue(); + + if (calculate_with_256) + { + process( + const_left, col_right->getData(), vec_res, scale_left, scale_right, *vec_null_map_to, resultDataType, max_scale); + } + else + { + process( + const_left, col_right->getData(), vec_res, scale_left, scale_right, *vec_null_map_to, resultDataType, max_scale); + } + } + else if (col_left && col_right_const) + { + RightFieldType const_right = col_right_const->getValue(); + if (calculate_with_256) + { + process( + col_left->getData(), const_right, vec_res, scale_left, scale_right, *vec_null_map_to, resultDataType, max_scale); + } + else + { + process( + col_left->getData(), const_right, vec_res, scale_left, scale_right, *vec_null_map_to, resultDataType, max_scale); + } + } + else + { + throw Exception(ErrorCodes::LOGICAL_ERROR, "Not supported."); + } + + return ColumnNullable::create(std::move(col_res), std::move(col_null_map_to)); + } + + template + static static void NO_INLINE process( + const auto & a, + const auto & b, + ResultContainerType & result_container, + const NativeResultType & scale_a, + const NativeResultType & scale_b, + ColumnUInt8::Container & vec_null_map_to, + const ResultDataType & resultDataType, + const size_t & max_scale) + { + size_t size; + if constexpr (op_case == OpCase::LeftConstant) + size = b.size(); + else + size = a.size(); + + if constexpr (op_case == OpCase::Vector) + { + for (size_t i = 0; i < size; ++i) + { + NativeResultType res; + if (calculate( + unwrap(a, i), + unwrap(b, i), + scale_a, + scale_b, + res, + resultDataType, + max_scale)) + result_container[i] = res; + else + vec_null_map_to[i] = static_cast(1); + } + } + else if constexpr (op_case == OpCase::LeftConstant) + { + auto scaled_a = applyScaled(unwrap(a, 0), scale_a); + for (size_t i = 0; i < size; ++i) + { + NativeResultType res; + if (calculate( + scaled_a, + unwrap(b, i), + static_cast(0), + scale_b, + res, + resultDataType, + max_scale)) + result_container[i] = res; + else + vec_null_map_to[i] = static_cast(1); + } + } + else if constexpr (op_case == OpCase::RightConstant) + { + auto scaled_b = applyScaled(unwrap(b, 0), scale_b); + + for (size_t i = 0; i < size; ++i) + { + NativeResultType res; + if (calculate( + unwrap(a, i), + scaled_b, + scale_a, + static_cast(0), + res, + resultDataType, + max_scale)) + result_container[i] = res; + else + vec_null_map_to[i] = static_cast(1); + } + } + } + + // ResultNativeType = Int32/64/128/256 + template + static NO_SANITIZE_UNDEFINED bool calculate( + const LeftNativeType l, + const RightNativeType r, + const NativeResultType & scale_left, + const NativeResultType & scale_right, + NativeResultType & res, + const ResultDataType & resultDataType, + const size_t & max_scale) + { + if constexpr (CalculateWith256) + return calculateImpl(l, r, scale_left, scale_right, res, resultDataType, max_scale); + else if (is_division) + return calculateImpl(l, r, scale_left, scale_right, res, resultDataType, max_scale); + else + return calculateImpl(l, r, scale_left, scale_right, res, resultDataType, max_scale); + } + + template + static NO_SANITIZE_UNDEFINED bool calculateImpl( + const LeftNativeType & l, + const RightNativeType & r, + const NativeResultType & scale_left, + const NativeResultType & scale_right, + NativeResultType & res, + const ResultDataType & resultDataType, + const size_t & max_scale) + { + CalcType scaled_l = applyScaled(static_cast(l), static_cast(scale_left)); + CalcType scaled_r = applyScaled(static_cast(r), static_cast(scale_right)); + + CalcType c_res = 0; + auto success = Operation::template apply(scaled_l, scaled_r, c_res); + if (!success) + return false; + + auto result_scale = resultDataType.getScale(); + auto scale_diff = max_scale - result_scale; + chassert(scale_diff >= 0); + if (scale_diff) + { + auto scaled_diff = DecimalUtils::scaleMultiplier(scale_diff); + DecimalDivideImpl::apply(c_res, scaled_diff, c_res); + } + + // check overflow + if constexpr (std::is_same_v || is_division) + { + auto max_value = intExp10OfSize(resultDataType.getPrecision()); + if (c_res <= -max_value || c_res >= max_value) + return false; + } + + res = static_cast(c_res); + + return true; + } + + template + static auto unwrap(const E & elem, size_t i) + { + if constexpr (op_case == target) + return elem.value; + else + return elem[i].value; + } + + template + static ResultNativeType applyScaled(const NativeType & l, const ResultNativeType & scale) + { + if (scale > 1) + return common::mulIgnoreOverflow(l, scale); + + return static_cast(l); + } +}; + + +template +class SparkFunctionDecimalBinaryArithmetic final : public IFunction +{ + static constexpr bool is_plus_minus = SparkIsOperation::plus || SparkIsOperation::minus; + static constexpr bool is_multiply = SparkIsOperation::multiply; + static constexpr bool is_division = SparkIsOperation::division; + static constexpr bool is_modulo = SparkIsOperation::modulo; + +public: + static constexpr auto name = Name::name; + + static FunctionPtr create(ContextPtr context_) { return std::make_shared(context_); } + + explicit SparkFunctionDecimalBinaryArithmetic(ContextPtr context_) : context(context_) { } + + String getName() const override { return name; } + size_t getNumberOfArguments() const override { return 3; } + bool isSuitableForShortCircuitArgumentsExecution(const DataTypesWithConstInfo & /*arguments*/) const override { return false; } + bool useDefaultImplementationForConstants() const override { return true; } + ColumnNumbers getArgumentsThatAreAlwaysConstant() const override { return {2}; } + + DataTypePtr getReturnTypeImpl(const ColumnsWithTypeAndName & arguments) const override + { + if (arguments.size() != 3) + throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Function '{}' expects 3 arguments", getName()); + + if (!isDecimal(arguments[0].type) || !isDecimal(arguments[1].type) || !isDecimal(arguments[2].type)) + throw Exception( + ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, + "Illegal type {} {} {} of argument of function {}", + arguments[0].type->getName(), + arguments[1].type->getName(), + arguments[2].type->getName(), + getName()); + + return std::make_shared(arguments[2].type); + } + + // executeImpl2 + ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr &, size_t) const override + { + const auto & left_argument = arguments[0]; + const auto & right_argument = arguments[1]; + + const auto * left_generic = left_argument.type.get(); + const auto * right_generic = right_argument.type.get(); + + ColumnPtr res; + const bool valid = castBothTypes( + left_generic, + right_generic, + removeNullable(arguments[2].type).get(), + [&](const auto & left, const auto & right, const auto & result) { + return (res = SparkDecimalBinaryOperation::template executeDecimal(arguments, left, right, result)) + != nullptr; + }); + + if (!valid) + { + // This is a logical error, because the types should have been checked + // by getReturnTypeImpl(). + throw Exception( + ErrorCodes::LOGICAL_ERROR, + "Arguments of '{}' have incorrect data types: '{}' of type '{}'," + " '{}' of type '{}'", + getName(), + left_argument.name, + left_argument.type->getName(), + right_argument.name, + right_argument.type->getName()); + } + + return res; + } + +private: + template + static bool castBothTypes(const IDataType * left, const IDataType * right, const IDataType * result, F && f) + { + return castType( + left, + [&](const auto & left_) + { + return castType( + right, + [&](const auto & right_) { return castType(result, [&](const auto & result_) { return f(left_, right_, result_); }); }); + }); + } + + static bool castType(const IDataType * type, auto && f) + { + using Types = TypeList; + return castTypeToEither(Types{}, type, std::forward(f)); + } + + ContextPtr context; +}; + +struct NameSparkDecimalPlus +{ + static constexpr auto name = "sparkDecimalPlus"; +}; +struct NameSparkDecimalPlusEffect +{ + static constexpr auto name = "sparkDecimalPlusEffect"; +}; +struct NameSparkDecimalMinus +{ + static constexpr auto name = "sparkDecimalMinus"; +}; +struct NameSparkDecimalMinusEffect +{ + static constexpr auto name = "sparkDecimalMinusEffect"; +}; +struct NameSparkDecimalMultiply +{ + static constexpr auto name = "sparkDecimalMultiply"; +}; +struct NameSparkDecimalMultiplyEffect +{ + static constexpr auto name = "sparkDecimalMultiplyEffect"; +}; +struct NameSparkDecimalDivide +{ + static constexpr auto name = "sparkDecimalDivide"; +}; +struct NameSparkDecimalDivideEffect +{ + static constexpr auto name = "sparkDecimalDivideEffect"; +}; +struct NameSparkDecimalModulo +{ + static constexpr auto name = "NameSparkDecimalModulo"; +}; +struct NameSparkDecimalModuloEffect +{ + static constexpr auto name = "NameSparkDecimalModuloEffect"; +}; + + +using DecimalPlus = SparkFunctionDecimalBinaryArithmetic; +using DecimalMinus = SparkFunctionDecimalBinaryArithmetic; +using DecimalMultiply = SparkFunctionDecimalBinaryArithmetic; +using DecimalDivide = SparkFunctionDecimalBinaryArithmetic; +using DecimalModulo = SparkFunctionDecimalBinaryArithmetic; + +using DecimalPlusEffect = SparkFunctionDecimalBinaryArithmetic; +using DecimalMinusEffect = SparkFunctionDecimalBinaryArithmetic; +using DecimalMultiplyEffect = SparkFunctionDecimalBinaryArithmetic; +using DecimalDivideEffect = SparkFunctionDecimalBinaryArithmetic; +using DecimalModuloEffect = SparkFunctionDecimalBinaryArithmetic; +} + +REGISTER_FUNCTION(SparkDecimalFunctionArithmetic) +{ + factory.registerFunction(); + factory.registerFunction(); + factory.registerFunction(); + factory.registerFunction(); + factory.registerFunction(); + + factory.registerFunction(); + factory.registerFunction(); + factory.registerFunction(); + factory.registerFunction(); + factory.registerFunction(); +} +} diff --git a/cpp-ch/local-engine/Functions/SparkFunctionDecimalBinaryArithmetic.h b/cpp-ch/local-engine/Functions/SparkFunctionDecimalBinaryArithmetic.h new file mode 100644 index 000000000000..05e5f4ff9f07 --- /dev/null +++ b/cpp-ch/local-engine/Functions/SparkFunctionDecimalBinaryArithmetic.h @@ -0,0 +1,276 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include + +namespace local_engine +{ + +static bool canCastLower(const Int256 & a, const Int256 & b) +{ + if (a.items[2] == 0 && a.items[3] == 0 && b.items[2] == 0 && b.items[3] == 0) + return true; + + return false; +} + +static bool canCastLower(const Int128 & a, const Int128 & b) +{ + if (a.items[1] == 0 && b.items[1] == 0) + return true; + + return false; +} + +struct DecimalPlusImpl +{ + template + static bool apply(A & a, A & b, A & r) + { + return !common::addOverflow(a, b, r); + } + + template <> + static bool apply(Int128 & a, Int128 & b, Int128 & r) + { + if (canCastLower(a, b)) + { + UInt64 low_result; + if (common::addOverflow(static_cast(a), static_cast(b), low_result)) + return !common::addOverflow(a, b, r); + + r = static_cast(low_result); + return true; + } + + return !common::addOverflow(a, b, r); + } + + + template <> + static bool apply(Int256 & a, Int256 & b, Int256 & r) + { + if (canCastLower(a, b)) + { + UInt128 low_result; + if (common::addOverflow(static_cast(a), static_cast(b), low_result)) + return !common::addOverflow(a, b, r); + + r = static_cast(low_result); + return true; + } + + return !common::addOverflow(a, b, r); + } + + +#if USE_EMBEDDED_COMPILER + static constexpr bool compilable = true; + + static llvm::Value * compile(llvm::IRBuilder<> & b, llvm::Value * left, llvm::Value * right, bool) + { + return left->getType()->isIntegerTy() ? b.CreateAdd(left, right) : b.CreateFAdd(left, right); + } +#endif +}; + +struct DecimalMinusImpl +{ + /// Apply operation and check overflow. It's used for Deciamal operations. @returns true if overflowed, false otherwise. + template + static bool apply(A & a, A & b, A & r) + { + return !common::subOverflow(a, b, r); + } + + template <> + static bool apply(Int128 & a, Int128 & b, Int128 & r) + { + if (canCastLower(a, b)) + { + UInt64 low_result; + if (common::subOverflow(static_cast(a), static_cast(b), low_result)) + return !common::subOverflow(a, b, r); + + r = static_cast(low_result); + return true; + } + + return !common::subOverflow(a, b, r); + } + + template <> + static bool apply(Int256 & a, Int256 & b, Int256 & r) + { + if (canCastLower(a, b)) + { + UInt128 low_result; + if (common::subOverflow(static_cast(a), static_cast(b), low_result)) + return !common::subOverflow(a, b, r); + + r = static_cast(low_result); + return true; + } + + return !common::subOverflow(a, b, r); + } + +#if USE_EMBEDDED_COMPILER + static constexpr bool compilable = true; + + static llvm::Value * compile(llvm::IRBuilder<> & b, llvm::Value * left, llvm::Value * right, bool) + { + return left->getType()->isIntegerTy() ? b.CreateSub(left, right) : b.CreateFSub(left, right); + } +#endif +}; + + +struct DecimalMultiplyImpl +{ + /// Apply operation and check overflow. It's used for Decimal operations. @returns true if overflowed, false otherwise. + template + static bool apply(A & a, A & b, A & c) + { + return !common::mulOverflow(a, b, c); + } + + template + static bool apply(Int128 & a, Int128 & b, Int128 & r) + { + if (canCastLower(a, b)) + { + UInt64 low_result = 0; + if (common::mulOverflow(static_cast(a), static_cast(b), low_result)) + return !common::mulOverflow(a, b, r); + + r = static_cast(low_result); + return true; + } + + return !common::mulOverflow(a, b, r); + } + +#if USE_EMBEDDED_COMPILER + static constexpr bool compilable = true; + + static llvm::Value * compile(llvm::IRBuilder<> & b, llvm::Value * left, llvm::Value * right, bool) + { + return left->getType()->isIntegerTy() ? b.CreateMul(left, right) : b.CreateFMul(left, right); + } +#endif +}; + +struct DecimalDivideImpl +{ + template + static bool apply(A & a, A & b, A & r) + { + if (b == 0) + return false; + + r = a / b; + return true; + } + + template <> + static bool apply(Int128 & a, Int128 & b, Int128 & r) + { + if (b == 0) + return false; + + if (canCastLower(a, b)) + { + r = static_cast(static_cast(a) / static_cast(b)); + return true; + } + + r = a / b; + return true; + } + + template <> + static bool apply(Int256 & a, Int256 & b, Int256 & r) + { + if (b == 0) + return false; + + if (canCastLower(a, b)) + { + UInt128 low_result = 0; + UInt128 low_a = static_cast(a); + UInt128 low_b = static_cast(b); + apply(low_a, low_b, low_result); + r = static_cast(low_result); + return true; + } + + r = a / b; + return true; + } + +#if USE_EMBEDDED_COMPILER + static constexpr bool compilable = true; + + static llvm::Value * compile(llvm::IRBuilder<> & b, llvm::Value * left, llvm::Value * right, bool) + { + return left->getType()->isIntegerTy() ? b.CreateSub(left, right) : b.CreateFSub(left, right); + } +#endif +}; + + +// ModuloImpl +struct DecimalModuloImpl +{ + template + static bool apply(A & a, A & b, A & r) + { + if (b == 0) + return false; + + r = a % b; + return true; + } + +#if USE_EMBEDDED_COMPILER + static constexpr bool compilable = true; + + static llvm::Value * compile(llvm::IRBuilder<> & b, llvm::Value * left, llvm::Value * right, bool) + { + return left->getType()->isIntegerTy() ? b.CreateSub(left, right) : b.CreateFSub(left, right); + } +#endif +}; + +template +struct IsSameOperation +{ + static constexpr bool value = std::is_same_v; +}; + +template +struct SparkIsOperation +{ + static constexpr bool plus = IsSameOperation::value; + static constexpr bool minus = IsSameOperation::value; + static constexpr bool multiply = IsSameOperation::value; + static constexpr bool division = IsSameOperation::value; + static constexpr bool modulo = IsSameOperation::value; +}; +} diff --git a/cpp-ch/local-engine/Parser/SerializedPlanParser.h b/cpp-ch/local-engine/Parser/SerializedPlanParser.h index c9a48106a9b3..9e1ffc0dd6a5 100644 --- a/cpp-ch/local-engine/Parser/SerializedPlanParser.h +++ b/cpp-ch/local-engine/Parser/SerializedPlanParser.h @@ -135,6 +135,7 @@ class SerializedPlanParser IQueryPlanStep * addRollbackFilterHeaderStep(QueryPlanPtr & query_plan, const Block & input_header); static std::pair parseLiteral(const substrait::Expression_Literal & literal); + ContextPtr getContext() const { return context; } std::vector extra_plan_holder; diff --git a/cpp-ch/local-engine/Parser/scalar_function_parser/arithmetic.cpp b/cpp-ch/local-engine/Parser/scalar_function_parser/arithmetic.cpp index 6aba310bf095..f73305df022e 100644 --- a/cpp-ch/local-engine/Parser/scalar_function_parser/arithmetic.cpp +++ b/cpp-ch/local-engine/Parser/scalar_function_parser/arithmetic.cpp @@ -22,6 +22,7 @@ #include #include #include +#include namespace DB::ErrorCodes { @@ -136,8 +137,11 @@ class FunctionParserBinaryArithmetic : public FunctionParser return toFunctionNode(actions_dag, "checkDecimalOverflowSparkOrNull", overflow_args); } - virtual const DB::ActionsDAG::Node * - createFunctionNode(DB::ActionsDAG & actions_dag, const String & func_name, const DB::ActionsDAG::NodeRawConstPtrs & args) const + virtual const DB::ActionsDAG::Node * createFunctionNode( + DB::ActionsDAG & actions_dag, + const String & func_name, + const DB::ActionsDAG::NodeRawConstPtrs & args, + DataTypePtr result_type) const { return toFunctionNode(actions_dag, func_name, args); } @@ -154,33 +158,8 @@ class FunctionParserBinaryArithmetic : public FunctionParser const auto left_type = DB::removeNullable(parsed_args[0]->result_type); const auto right_type = DB::removeNullable(parsed_args[1]->result_type); - const bool converted = isDecimal(left_type) && isDecimal(right_type); - - if (converted) - { - const DecimalType evalType = getDecimalType(left_type, right_type); - parsed_args = convertBinaryArithmeticFunDecimalArgs(actions_dag, parsed_args, evalType, substrait_func); - } - - const auto * func_node = createFunctionNode(actions_dag, ch_func_name, parsed_args); - - if (converted) - { - const auto parsed_output_type = removeNullable(TypeParser::parseType(substrait_func.output_type())); - assert(isDecimal(parsed_output_type)); - const Int32 parsed_precision = getDecimalPrecision(*parsed_output_type); - const Int32 parsed_scale = getDecimalScale(*parsed_output_type); - func_node = checkDecimalOverflow(actions_dag, func_node, parsed_precision, parsed_scale); -#ifndef NDEBUG - const auto output_type = removeNullable(func_node->result_type); - const Int32 output_precision = getDecimalPrecision(*output_type); - const Int32 output_scale = getDecimalScale(*output_type); - if (output_precision != parsed_precision || output_scale != parsed_scale) - throw Exception(DB::ErrorCodes::BAD_ARGUMENTS, "Function {} has wrong output type", getName()); -#endif - - return func_node; - } + const auto result_type = removeNullable(TypeParser::parseType(substrait_func.output_type())); + const auto * func_node = createFunctionNode(actions_dag, ch_func_name, parsed_args, result_type); return convertNodeTypeIfNeeded(substrait_func, func_node, actions_dag); } }; @@ -199,6 +178,32 @@ class FunctionParserPlus final : public FunctionParserBinaryArithmetic { return DecimalType::evalAddSubstractDecimalType(p1, s1, p2, s2); } + + const DB::ActionsDAG::Node * createFunctionNode( + DB::ActionsDAG & actions_dag, + const String & func_name, + const DB::ActionsDAG::NodeRawConstPtrs & new_args, + DataTypePtr result_type) const override + { + const auto * left_arg = new_args[0]; + const auto * right_arg = new_args[1]; + + if (isDecimal(removeNullable(left_arg->result_type)) && isDecimal(removeNullable(right_arg->result_type))) + { + const ActionsDAG::Node * type_node = &actions_dag.addColumn(ColumnWithTypeAndName( + result_type->createColumnConstWithDefaultValue(1), result_type, getUniqueName(result_type->getName()))); + + const auto & settings = plan_parser->getContext()->getSettingsRef(); + auto function_name + = settings.has("arithmetic.decimal.mode") && settingsEqual(settings, "arithmetic.decimal.mode", "EFFECT") + ? "sparkDecimalPlusEffect" + : "sparkDecimalPlus"; + + return toFunctionNode(actions_dag, function_name, {left_arg, right_arg, type_node}); + } + + return toFunctionNode(actions_dag, "plus", {left_arg, right_arg}); + } }; class FunctionParserMinus final : public FunctionParserBinaryArithmetic @@ -215,6 +220,32 @@ class FunctionParserMinus final : public FunctionParserBinaryArithmetic { return DecimalType::evalAddSubstractDecimalType(p1, s1, p2, s2); } + + const DB::ActionsDAG::Node * createFunctionNode( + DB::ActionsDAG & actions_dag, + const String & func_name, + const DB::ActionsDAG::NodeRawConstPtrs & new_args, + DataTypePtr result_type) const override + { + const auto * left_arg = new_args[0]; + const auto * right_arg = new_args[1]; + + if (isDecimal(removeNullable(left_arg->result_type)) && isDecimal(removeNullable(right_arg->result_type))) + { + const ActionsDAG::Node * type_node = &actions_dag.addColumn(ColumnWithTypeAndName( + result_type->createColumnConstWithDefaultValue(1), result_type, getUniqueName(result_type->getName()))); + + const auto & settings = plan_parser->getContext()->getSettingsRef(); + auto function_name + = settings.has("arithmetic.decimal.mode") && settingsEqual(settings, "arithmetic.decimal.mode", "EFFECT") + ? "sparkDecimalMinusEffect" + : "sparkDecimalMinus"; + + return toFunctionNode(actions_dag, function_name, {left_arg, right_arg, type_node}); + } + + return toFunctionNode(actions_dag, "minus", {left_arg, right_arg}); + } }; class FunctionParserMultiply final : public FunctionParserBinaryArithmetic @@ -230,6 +261,32 @@ class FunctionParserMultiply final : public FunctionParserBinaryArithmetic { return DecimalType::evalMultiplyDecimalType(p1, s1, p2, s2); } + + const DB::ActionsDAG::Node * createFunctionNode( + DB::ActionsDAG & actions_dag, + const String & func_name, + const DB::ActionsDAG::NodeRawConstPtrs & new_args, + DataTypePtr result_type) const override + { + const auto * left_arg = new_args[0]; + const auto * right_arg = new_args[1]; + + if (isDecimal(removeNullable(left_arg->result_type)) && isDecimal(removeNullable(right_arg->result_type))) + { + const ActionsDAG::Node * type_node = &actions_dag.addColumn(ColumnWithTypeAndName( + result_type->createColumnConstWithDefaultValue(1), result_type, getUniqueName(result_type->getName()))); + + const auto & settings = plan_parser->getContext()->getSettingsRef(); + auto function_name + = settings.has("arithmetic.decimal.mode") && settingsEqual(settings, "arithmetic.decimal.mode", "EFFECT") + ? "sparkDecimalMultiplyEffect" + : "sparkDecimalMultiply"; + + return toFunctionNode(actions_dag, function_name, {left_arg, right_arg, type_node}); + } + + return toFunctionNode(actions_dag, "multiply", {left_arg, right_arg}); + } }; class FunctionParserModulo final : public FunctionParserBinaryArithmetic @@ -245,6 +302,33 @@ class FunctionParserModulo final : public FunctionParserBinaryArithmetic { return DecimalType::evalModuloDecimalType(p1, s1, p2, s2); } + + const DB::ActionsDAG::Node * createFunctionNode( + DB::ActionsDAG & actions_dag, + const String & func_name, + const DB::ActionsDAG::NodeRawConstPtrs & new_args, + DataTypePtr result_type) const override + { + assert(func_name == name); + const auto * left_arg = new_args[0]; + const auto * right_arg = new_args[1]; + + if (isDecimal(removeNullable(left_arg->result_type)) || isDecimal(removeNullable(right_arg->result_type))) + { + const ActionsDAG::Node * type_node = &actions_dag.addColumn(ColumnWithTypeAndName( + result_type->createColumnConstWithDefaultValue(1), result_type, getUniqueName(result_type->getName()))); + + const auto & settings = plan_parser->getContext()->getSettingsRef(); + auto function_name + = settings.has("arithmetic.decimal.mode") && settingsEqual(settings, "arithmetic.decimal.mode", "EFFECT") + ? "NameSparkDecimalModuloEffect" + : "NameSparkDecimalModulo"; + ; + return toFunctionNode(actions_dag, function_name, {left_arg, right_arg, type_node}); + } + + return toFunctionNode(actions_dag, "modulo", {left_arg, right_arg}); + } }; class FunctionParserDivide final : public FunctionParserBinaryArithmetic @@ -262,14 +346,28 @@ class FunctionParserDivide final : public FunctionParserBinaryArithmetic } const DB::ActionsDAG::Node * createFunctionNode( - DB::ActionsDAG & actions_dag, const String & func_name, const DB::ActionsDAG::NodeRawConstPtrs & new_args) const override + DB::ActionsDAG & actions_dag, + const String & func_name, + const DB::ActionsDAG::NodeRawConstPtrs & new_args, + DataTypePtr result_type) const override { assert(func_name == name); const auto * left_arg = new_args[0]; const auto * right_arg = new_args[1]; if (isDecimal(removeNullable(left_arg->result_type)) || isDecimal(removeNullable(right_arg->result_type))) - return toFunctionNode(actions_dag, "sparkDivideDecimal", {left_arg, right_arg}); + { + const ActionsDAG::Node * type_node = &actions_dag.addColumn(ColumnWithTypeAndName( + result_type->createColumnConstWithDefaultValue(1), result_type, getUniqueName(result_type->getName()))); + + const auto & settings = plan_parser->getContext()->getSettingsRef(); + auto function_name + = settings.has("arithmetic.decimal.mode") && settingsEqual(settings, "arithmetic.decimal.mode", "EFFECT") + ? "sparkDecimalDivideEffect" + : "sparkDecimalDivide"; + ; + return toFunctionNode(actions_dag, function_name, {left_arg, right_arg, type_node}); + } return toFunctionNode(actions_dag, "sparkDivide", {left_arg, right_arg}); } diff --git a/gluten-substrait/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala b/gluten-substrait/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala index 0aa676158474..d14a59127794 100644 --- a/gluten-substrait/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala +++ b/gluten-substrait/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala @@ -446,6 +446,10 @@ object ExpressionConverter extends SQLConfHelper with Logging { LiteralTransformer(m.nullOnOverflow)), m ) + case PromotePrecision(_ @Cast(child, _: DecimalType, _, _)) + if child.dataType + .isInstanceOf[DecimalType] && !BackendsApiManager.getSettings.transformCheckOverflow => + replaceWithExpressionTransformer0(child, attributeSeq, expressionsMap) case _: NormalizeNaNAndZero | _: PromotePrecision | _: TaggingExpression => ChildTransformer( substraitExprName, @@ -466,16 +470,12 @@ object ExpressionConverter extends SQLConfHelper with Logging { if !BackendsApiManager.getSettings.transformCheckOverflow && DecimalArithmeticUtil.isDecimalArithmetic(b) => DecimalArithmeticUtil.checkAllowDecimalArithmetic() - val leftChild = + val arithmeticExprName = getAndCheckSubstraitName(b, expressionsMap) + val left = replaceWithExpressionTransformer0(b.left, attributeSeq, expressionsMap) - val rightChild = + val right = replaceWithExpressionTransformer0(b.right, attributeSeq, expressionsMap) - DecimalArithmeticExpressionTransformer( - getAndCheckSubstraitName(b, expressionsMap), - leftChild, - rightChild, - decimalType, - b) + DecimalArithmeticExpressionTransformer(arithmeticExprName, left, right, decimalType, b) case c: CheckOverflow => CheckOverflowTransformer( substraitExprName,