Skip to content

Commit

Permalink
support check_overflow (facebookincubator#7)
Browse files Browse the repository at this point in the history
  • Loading branch information
jinchengchenghh authored and liujiayi771 committed Mar 19, 2023
1 parent 640628d commit 284c5c2
Show file tree
Hide file tree
Showing 5 changed files with 146 additions and 7 deletions.
1 change: 1 addition & 0 deletions velox/functions/sparksql/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ add_library(
velox_functions_spark
ArraySort.cpp
Bitwise.cpp
CheckOverflow.cpp
CompareFunctionsNullSafe.cpp
Hash.cpp
In.cpp
Expand Down
103 changes: 103 additions & 0 deletions velox/functions/sparksql/CheckOverflow.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "velox/functions/sparksql/CheckOverflow.h"

#include "velox/expression/DecodedArgs.h"
#include "velox/vector/FlatVector.h"

namespace facebook::velox::functions::sparksql {
namespace {
template <typename TInput, typename TOutput>
class CheckOverflowFunction final : public exec::VectorFunction {
void apply(
const SelectivityVector& rows,
std::vector<VectorPtr>& args, // Not using const ref so we can reuse args
const TypePtr& outputType,
exec::EvalCtx& context,
VectorPtr& resultRef) const final {
VELOX_CHECK_EQ(args.size(), 3);
auto fromType = args[0]->type();
auto toType = args[1]->type();
context.ensureWritable(rows, toType, resultRef);

auto result =
resultRef->asUnchecked<FlatVector<TOutput>>()->mutableRawValues();
exec::DecodedArgs decodedArgs(rows, args, context);
auto decimalValue = decodedArgs.at(0);
VELOX_CHECK(decodedArgs.at(2)->isConstantMapping());
auto nullOnOverflow = decodedArgs.at(2)->valueAt<bool>(0);

const auto& fromPrecisionScale = getDecimalPrecisionScale(*fromType);
const auto& toPrecisionScale = getDecimalPrecisionScale(*toType);
rows.applyToSelected([&](int row) {
auto rescaledValue = DecimalUtil::rescaleWithRoundUp<TInput, TOutput>(
decimalValue->valueAt<TInput>(row),
fromPrecisionScale.first,
fromPrecisionScale.second,
toPrecisionScale.first,
toPrecisionScale.second,
nullOnOverflow);
if (rescaledValue.has_value()) {
result[row] = rescaledValue.value();
} else {
resultRef->setNull(row, true);
}
});
}
};
} // namespace

std::vector<std::shared_ptr<exec::FunctionSignature>>
checkOverflowSignatures() {
return {exec::FunctionSignatureBuilder()
.integerVariable("a_precision")
.integerVariable("a_scale")
.integerVariable("r_precision")
.integerVariable("r_scale")
.returnType("DECIMAL(r_precision, r_scale)")
.argumentType("DECIMAL(a_precision, a_scale)")
.argumentType("DECIMAL(r_precision, r_scale)")
.argumentType("boolean")
.build()};
}

std::shared_ptr<exec::VectorFunction> makeCheckOverflow(
const std::string& name,
const std::vector<exec::VectorFunctionArg>& inputArgs) {
VELOX_CHECK_EQ(inputArgs.size(), 3);
auto fromType = inputArgs[0].type;
auto toType = inputArgs[1].type;
if (toType->kind() == TypeKind::SHORT_DECIMAL) {
if (fromType->kind() == TypeKind::SHORT_DECIMAL) {
std::cout << "from and to is short decimal" << std::endl;
return std::make_shared<
CheckOverflowFunction<UnscaledShortDecimal, UnscaledShortDecimal>>();
} else {
return std::make_shared<
CheckOverflowFunction<UnscaledLongDecimal, UnscaledShortDecimal>>();
}
} else {
if (fromType->kind() == TypeKind::SHORT_DECIMAL) {
return std::make_shared<
CheckOverflowFunction<UnscaledShortDecimal, UnscaledLongDecimal>>();
} else {
return std::make_shared<
CheckOverflowFunction<UnscaledLongDecimal, UnscaledLongDecimal>>();
}
}
}

} // namespace facebook::velox::functions::sparksql
26 changes: 26 additions & 0 deletions velox/functions/sparksql/CheckOverflow.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "velox/expression/VectorFunction.h"

namespace facebook::velox::functions::sparksql {

std::vector<std::shared_ptr<exec::FunctionSignature>> checkOverflowSignatures();

std::shared_ptr<exec::VectorFunction> makeCheckOverflow(
const std::string& name,
const std::vector<exec::VectorFunctionArg>& inputArgs);

} // namespace facebook::velox::functions::sparksql
4 changes: 4 additions & 0 deletions velox/functions/sparksql/Register.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include "velox/functions/prestosql/StringFunctions.h"
#include "velox/functions/sparksql/ArraySort.h"
#include "velox/functions/sparksql/Bitwise.h"
#include "velox/functions/sparksql/CheckOverflow.h"
#include "velox/functions/sparksql/CompareFunctionsNullSafe.h"
#include "velox/functions/sparksql/DateTime.h"
#include "velox/functions/sparksql/DateTimeFunctions.h"
Expand Down Expand Up @@ -164,6 +165,9 @@ void registerFunctions(const std::string& prefix) {
exec::registerStatefulVectorFunction(
prefix + "sort_array", sortArraySignatures(), makeSortArray);

exec::registerStatefulVectorFunction(
prefix + "check_overflow", checkOverflowSignatures(), makeCheckOverflow);

// Register bloom filter function
exec::registerStatefulVectorFunction(
prefix + "might_contain", mightContainSignatures(), makeMightContain);
Expand Down
19 changes: 12 additions & 7 deletions velox/type/DecimalUtil.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ class DecimalUtil {
const int fromPrecision,
const int fromScale,
const int toPrecision,
const int toScale) {
const int toScale,
bool nullOnOverflow = false) {
int128_t rescaledValue = inputValue.unscaledValue();
auto scaleDifference = toScale - fromScale;
bool isOverflow = false;
Expand All @@ -68,12 +69,16 @@ class DecimalUtil {
// Check overflow.
if (rescaledValue < -DecimalUtil::kPowersOfTen[toPrecision] ||
rescaledValue > DecimalUtil::kPowersOfTen[toPrecision] || isOverflow) {
VELOX_USER_FAIL(
"Cannot cast DECIMAL '{}' to DECIMAL({},{})",
DecimalUtil::toString<TInput>(
inputValue, DECIMAL(fromPrecision, fromScale)),
toPrecision,
toScale);
if (nullOnOverflow) {
return std::nullopt;
} else {
VELOX_USER_FAIL(
"Cannot cast DECIMAL '{}' to DECIMAL({},{})",
DecimalUtil::toString<TInput>(
inputValue, DECIMAL(fromPrecision, fromScale)),
toPrecision,
toScale);
}
}
if constexpr (std::is_same_v<TOutput, UnscaledShortDecimal>) {
return UnscaledShortDecimal(static_cast<int64_t>(rescaledValue));
Expand Down

0 comments on commit 284c5c2

Please sign in to comment.