Skip to content

Commit

Permalink
[GLUTEN-6975][CH] Rewrite decimal arithmetic (#7196)
Browse files Browse the repository at this point in the history
* enable tpchq1
  • Loading branch information
loneylee authored Sep 23, 2024
1 parent 57ad4e6 commit 55671cb
Show file tree
Hide file tree
Showing 9 changed files with 1,042 additions and 56 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import org.apache.gluten.extension.ExpressionExtensionTrait
import org.apache.gluten.jni.JniLibLoader
import org.apache.gluten.vectorized.CHNativeExpressionEvaluator

import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.{SPARK_VERSION, SparkConf, SparkContext}
import org.apache.spark.api.plugin.PluginContext
import org.apache.spark.internal.Logging
import org.apache.spark.listener.CHGlutenSQLAppStatusListener
Expand Down Expand Up @@ -85,7 +85,7 @@ class CHListenerApi extends ListenerApi with Logging {
conf.setCHConfig(
"timezone" -> conf.get("spark.sql.session.timeZone", TimeZone.getDefault.getID),
"local_engine.settings.log_processors_profiles" -> "true")

conf.setCHSettings("spark_version", SPARK_VERSION)
// add memory limit for external sort
val externalSortKey = CHConf.runtimeSettings("max_bytes_before_external_sort")
if (conf.getLong(externalSortKey, -1) < 0) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ class GlutenClickHouseDecimalSuite
(DecimalType.apply(18, 8), Seq()),
// 3/10: all value is null and compare with limit
// 1 Spark 3.5
(DecimalType.apply(38, 19), if (isSparkVersionLE("3.3")) Seq(3, 10) else Seq(1, 3, 10))
(DecimalType.apply(38, 19), if (isSparkVersionLE("3.3")) Seq(3, 10) else Seq(3, 10))
)

private def createDecimalTables(dataType: DecimalType): Unit = {
Expand Down Expand Up @@ -309,7 +309,10 @@ class GlutenClickHouseDecimalSuite
"insert into decimals_test values(1, 100.0, 999.0)" +
", (2, 12345.123, 12345.123)" +
", (3, 0.1234567891011, 1234.1)" +
", (4, 123456789123456789.0, 1.123456789123456789)"
", (4, 123456789123456789.0, 1.123456789123456789)" +
", (5, 0, 0)" +
", (6, 0, 1.23)" +
", (7, 1.23, 0)"
spark.sql(createSql)

try {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

#include <Common/CHUtil.h>
#include <Common/GlutenDecimalUtils.h>
#include <Common/GlutenSettings.h>

namespace DB
{
Expand All @@ -47,7 +48,7 @@ DataTypePtr getSparkAvgReturnType(const DataTypePtr & arg_type)
return createDecimal<DataTypeDecimal>(precision_value, scale_value);
}

template <typename T>
template <typename T, bool SPARK35>
requires is_decimal<T>
class AggregateFunctionSparkAvg final : public AggregateFunctionAvg<T>
{
Expand All @@ -61,7 +62,7 @@ class AggregateFunctionSparkAvg final : public AggregateFunctionAvg<T>
{
}

DataTypePtr createResultType(const DataTypes & argument_types_, UInt32 num_scale_, UInt32 round_scale_)
DataTypePtr createResultType(const DataTypes & argument_types_, UInt32 num_scale_, UInt32 /*round_scale_*/)
{
const DataTypePtr & data_type = argument_types_[0];
const UInt32 precision_value = std::min<size_t>(getDecimalPrecision(*data_type) + 4, DecimalUtils::max_precision<Decimal128>);
Expand All @@ -82,7 +83,7 @@ class AggregateFunctionSparkAvg final : public AggregateFunctionAvg<T>
else if (which.isDecimal64())
{
assert_cast<ColumnDecimal<Decimal64> &>(to).getData().push_back(
divideDecimalAndUInt(this->data(place), num_scale, result_scale, round_scale));
divideDecimalAndUInt(this->data(place), num_scale, result_scale, round_scale));
}
else if (which.isDecimal128())
{
Expand Down Expand Up @@ -116,6 +117,9 @@ class AggregateFunctionSparkAvg final : public AggregateFunctionAvg<T>

auto result = value / avg.denominator;

if constexpr (SPARK35)
return result;

if (round_scale > result_scale)
return result;

Expand All @@ -128,8 +132,21 @@ class AggregateFunctionSparkAvg final : public AggregateFunctionAvg<T>
UInt32 round_scale;
};

AggregateFunctionPtr
createAggregateFunctionSparkAvg(const std::string & name, const DataTypes & argument_types, const Array & parameters, const Settings * settings)
template <bool Data, typename... TArgs>
static IAggregateFunction * createWithDecimalType(const IDataType & argument_type, TArgs && ... args)
{
WhichDataType which(argument_type);
if (which.idx == TypeIndex::Decimal32) return new AggregateFunctionSparkAvg<Decimal32, Data>(args...);
if (which.idx == TypeIndex::Decimal64) return new AggregateFunctionSparkAvg<Decimal64, Data>(args...);
if (which.idx == TypeIndex::Decimal128) return new AggregateFunctionSparkAvg<Decimal128, Data>(args...);
if (which.idx == TypeIndex::Decimal256) return new AggregateFunctionSparkAvg<Decimal256, Data>(args...);
if constexpr (AggregateFunctionSparkAvg<DateTime64, Data>::DateTime64Supported)
if (which.idx == TypeIndex::DateTime64) return new AggregateFunctionSparkAvg<DateTime64, Data>(args...);
return nullptr;
}

AggregateFunctionPtr createAggregateFunctionSparkAvg(
const std::string & name, const DataTypes & argument_types, const Array & parameters, const Settings * settings)
{
assertNoParameters(name, parameters);
assertUnary(name, argument_types);
Expand All @@ -140,13 +157,20 @@ createAggregateFunctionSparkAvg(const std::string & name, const DataTypes & argu
throw Exception(
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Illegal type {} of argument for aggregate function {}", data_type->getName(), name);

std::string version;
if (tryGetString(*settings, "spark_version", version) && version.starts_with("3.5"))
{
res.reset(createWithDecimalType<true>(*data_type, argument_types, getDecimalScale(*data_type), 0));
return res;
}

bool allowPrecisionLoss = settings->get(DECIMAL_OPERATIONS_ALLOW_PREC_LOSS).safeGet<bool>();
const UInt32 p1 = DB::getDecimalPrecision(*data_type);
const UInt32 s1 = DB::getDecimalScale(*data_type);
auto [p2, s2] = GlutenDecimalUtils::LONG_DECIMAL;
auto [_, round_scale] = GlutenDecimalUtils::dividePrecisionScale(p1, s1, p2, s2, allowPrecisionLoss);

res.reset(createWithDecimalType<AggregateFunctionSparkAvg>(*data_type, argument_types, getDecimalScale(*data_type), round_scale));
res.reset(createWithDecimalType<false>(*data_type, argument_types, getDecimalScale(*data_type), round_scale));
return res;
}

Expand Down
7 changes: 0 additions & 7 deletions cpp-ch/local-engine/Common/GlutenDecimalUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -95,13 +95,6 @@ class GlutenDecimalUtils
}
}

static std::tuple<size_t, size_t> widerDecimalType(const size_t p1, const size_t s1, const size_t p2, const size_t s2)
{
// max(s1, s2) + max(p1-s1, p2-s2), max(s1, s2)
auto scale = std::max(s1, s2);
auto range = std::max(p1 - s1, p2 - s2);
return std::tuple(range + scale, scale);
}

};

Expand Down
Loading

0 comments on commit 55671cb

Please sign in to comment.