Skip to content

Commit

Permalink
add make_decimal (facebookincubator#9)
Browse files Browse the repository at this point in the history
  • Loading branch information
jinchengchenghh authored Mar 21, 2023
1 parent daf2f17 commit c4cf0c7
Show file tree
Hide file tree
Showing 6 changed files with 221 additions and 133 deletions.
2 changes: 1 addition & 1 deletion velox/functions/sparksql/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ add_library(
velox_functions_spark
ArraySort.cpp
Bitwise.cpp
CheckOverflow.cpp
CompareFunctionsNullSafe.cpp
Decimal.cpp
Hash.cpp
In.cpp
LeastGreatest.cpp
Expand Down
103 changes: 0 additions & 103 deletions velox/functions/sparksql/CheckOverflow.cpp

This file was deleted.

26 changes: 0 additions & 26 deletions velox/functions/sparksql/CheckOverflow.h

This file was deleted.

202 changes: 202 additions & 0 deletions velox/functions/sparksql/Decimal.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,202 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "velox/functions/sparksql/Decimal.h"

#include "velox/expression/DecodedArgs.h"
#include "velox/vector/FlatVector.h"

namespace facebook::velox::functions::sparksql {
namespace {
template <typename TInput, typename TOutput>
class CheckOverflowFunction final : public exec::VectorFunction {
void apply(
const SelectivityVector& rows,
std::vector<VectorPtr>& args, // Not using const ref so we can reuse args
const TypePtr& outputType,
exec::EvalCtx& context,
VectorPtr& resultRef) const final {
VELOX_CHECK_EQ(args.size(), 3);
auto fromType = args[0]->type();
auto toType = args[1]->type();
context.ensureWritable(rows, toType, resultRef);

auto result =
resultRef->asUnchecked<FlatVector<TOutput>>()->mutableRawValues();
exec::DecodedArgs decodedArgs(rows, args, context);
auto decimalValue = decodedArgs.at(0);
VELOX_CHECK(decodedArgs.at(2)->isConstantMapping());
auto nullOnOverflow = decodedArgs.at(2)->valueAt<bool>(0);

const auto& fromPrecisionScale = getDecimalPrecisionScale(*fromType);
const auto& toPrecisionScale = getDecimalPrecisionScale(*toType);
rows.applyToSelected([&](int row) {
auto rescaledValue = DecimalUtil::rescaleWithRoundUp<TInput, TOutput>(
decimalValue->valueAt<TInput>(row),
fromPrecisionScale.first,
fromPrecisionScale.second,
toPrecisionScale.first,
toPrecisionScale.second,
nullOnOverflow);
if (rescaledValue.has_value()) {
result[row] = rescaledValue.value();
} else {
resultRef->setNull(row, true);
}
});
}
};

template <typename TInput>
class MakeDecimalFunction final : public exec::VectorFunction {
void apply(
const SelectivityVector& rows,
std::vector<VectorPtr>& args, // Not using const ref so we can reuse args
const TypePtr& outputType,
exec::EvalCtx& context,
VectorPtr& resultRef) const final {
VELOX_CHECK_EQ(args.size(), 3);
auto fromType = args[0]->type();
auto toType = args[1]->type();
exec::DecodedArgs decodedArgs(rows, args, context);
auto unscaledVec = decodedArgs.at(0);
VELOX_CHECK(decodedArgs.at(1)->isConstantMapping());
VELOX_CHECK(decodedArgs.at(2)->isConstantMapping());
auto nullOnOverflow = decodedArgs.at(2)->valueAt<bool>(0);
const auto& toPrecisionScale = getDecimalPrecisionScale(*toType);
auto precision = toPrecisionScale.first;
auto scale = toPrecisionScale.second;
if (precision <= 18) {
context.ensureWritable(
rows,
SHORT_DECIMAL(
static_cast<uint8_t>(precision), static_cast<uint8_t>(scale)),
resultRef);
auto result = resultRef->asUnchecked<FlatVector<UnscaledShortDecimal>>()
->mutableRawValues();
rows.applyToSelected([&](int row) {
auto unscaled = unscaledVec->valueAt<int64_t>(row);

if (UnscaledShortDecimal::valueInRange(unscaled)) {
result[row] = UnscaledShortDecimal(unscaled);
} else {
if (nullOnOverflow) {
resultRef->setNull(row, true);
} else {
VELOX_USER_FAIL("Unscaled value overflow for precision");
}
}
});

} else {
context.ensureWritable(
rows,
LONG_DECIMAL(
static_cast<uint8_t>(precision), static_cast<uint8_t>(scale)),
resultRef);
auto result = resultRef->asUnchecked<FlatVector<UnscaledShortDecimal>>()
->mutableRawValues();
rows.applyToSelected([&](int row) {
auto unscaled = unscaledVec->valueAt<int64_t>(row);
if (UnscaledLongDecimal::valueInRange(unscaled)) {
result[row] = unscaled;
} else {
if (nullOnOverflow) {
resultRef->setNull(row, true);
} else {
VELOX_USER_FAIL("Unscaled value overflow for precision");
}
}
});
}
}
};
} // namespace

std::vector<std::shared_ptr<exec::FunctionSignature>>
checkOverflowSignatures() {
return {exec::FunctionSignatureBuilder()
.integerVariable("a_precision")
.integerVariable("a_scale")
.integerVariable("b_precision")
.integerVariable("b_scale")
.integerVariable("r_precision", "min(38, b_precision)")
.integerVariable("r_scale", "min(38, b_scale)")
.returnType("DECIMAL(r_precision, r_scale)")
.argumentType("DECIMAL(a_precision, a_scale)")
.argumentType("DECIMAL(b_precision, b_scale)")
.argumentType("boolean")
.build()};
}

std::vector<std::shared_ptr<exec::FunctionSignature>> makeDecimalSignatures() {
return {exec::FunctionSignatureBuilder()
.integerVariable("a_precision")
.integerVariable("a_scale")
.integerVariable("r_precision", "min(38, a_precision)")
.integerVariable("r_scale", "min(38, a_scale)")
.returnType("DECIMAL(r_precision, r_scale)")
.argumentType("bigint")
.argumentType("DECIMAL(a_precision, a_scale)")
.argumentType("boolean")
.build()};
}

std::shared_ptr<exec::VectorFunction> makeCheckOverflow(
const std::string& name,
const std::vector<exec::VectorFunctionArg>& inputArgs) {
VELOX_CHECK_EQ(inputArgs.size(), 3);
auto fromType = inputArgs[0].type;
auto toType = inputArgs[1].type;
if (toType->kind() == TypeKind::SHORT_DECIMAL) {
if (fromType->kind() == TypeKind::SHORT_DECIMAL) {
return std::make_shared<
CheckOverflowFunction<UnscaledShortDecimal, UnscaledShortDecimal>>();
} else {
return std::make_shared<
CheckOverflowFunction<UnscaledLongDecimal, UnscaledShortDecimal>>();
}
} else {
if (fromType->kind() == TypeKind::SHORT_DECIMAL) {
return std::make_shared<
CheckOverflowFunction<UnscaledShortDecimal, UnscaledLongDecimal>>();
} else {
return std::make_shared<
CheckOverflowFunction<UnscaledLongDecimal, UnscaledLongDecimal>>();
}
}
}

std::shared_ptr<exec::VectorFunction> makeMakeDecimal(
const std::string& name,
const std::vector<exec::VectorFunctionArg>& inputArgs) {
VELOX_CHECK_EQ(inputArgs.size(), 3);
auto fromType = inputArgs[0].type;
switch (fromType->kind()) {
case TypeKind::SHORT_DECIMAL:
return std::make_shared<MakeDecimalFunction<UnscaledShortDecimal>>();
case TypeKind::LONG_DECIMAL:
return std::make_shared<MakeDecimalFunction<UnscaledLongDecimal>>();
case TypeKind::INTEGER:
return std::make_shared<MakeDecimalFunction<int32_t>>();
case TypeKind::BIGINT:
return std::make_shared<MakeDecimalFunction<int64_t>>();
default:
VELOX_FAIL(
"Not support this type {} in make_decimal", fromType->kindName())
}
}

} // namespace facebook::velox::functions::sparksql
18 changes: 16 additions & 2 deletions velox/functions/sparksql/Decimal.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,11 @@
* limitations under the License.
*/
#include <velox/type/UnscaledShortDecimal.h>
#include "velox/expression/VectorFunction.h"
#include "velox/functions/Macros.h"
#include "velox/type/Type.h"

namespace facebook::velox::functions {
namespace facebook::velox::functions::sparksql {

template <typename T>
struct UnscaledValueFunction {
Expand All @@ -29,4 +30,17 @@ struct UnscaledValueFunction {
result = shortDecimal.unscaledValue();
}
};
} // namespace facebook::velox::functions

std::vector<std::shared_ptr<exec::FunctionSignature>> checkOverflowSignatures();

std::shared_ptr<exec::VectorFunction> makeCheckOverflow(
const std::string& name,
const std::vector<exec::VectorFunctionArg>& inputArgs);

std::vector<std::shared_ptr<exec::FunctionSignature>> makeDecimalSignatures();

std::shared_ptr<exec::VectorFunction> makeMakeDecimal(
const std::string& name,
const std::vector<exec::VectorFunctionArg>& inputArgs);

} // namespace facebook::velox::functions::sparksql
3 changes: 2 additions & 1 deletion velox/functions/sparksql/Register.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
#include "velox/functions/prestosql/StringFunctions.h"
#include "velox/functions/sparksql/ArraySort.h"
#include "velox/functions/sparksql/Bitwise.h"
#include "velox/functions/sparksql/CheckOverflow.h"
#include "velox/functions/sparksql/CompareFunctionsNullSafe.h"
#include "velox/functions/sparksql/DateTime.h"
#include "velox/functions/sparksql/DateTimeFunctions.h"
Expand Down Expand Up @@ -167,6 +166,8 @@ void registerFunctions(const std::string& prefix) {

exec::registerStatefulVectorFunction(
prefix + "check_overflow", checkOverflowSignatures(), makeCheckOverflow);
exec::registerStatefulVectorFunction(
prefix + "make_decimal", makeDecimalSignatures(), makeMakeDecimal);

// Register bloom filter function
exec::registerStatefulVectorFunction(
Expand Down

0 comments on commit c4cf0c7

Please sign in to comment.