Skip to content
This repository has been archived by the owner on Sep 18, 2023. It is now read-only.

[NSE-31]decimal support for SMJ #146

Merged
merged 2 commits into from
Mar 10, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -347,8 +347,6 @@ case class ColumnarSortMergeJoinExec(
for (attr <- left.output) {
try {
ConverterUtils.checkIfTypeSupported(attr.dataType)
if (attr.dataType.isInstanceOf[DecimalType])
throw new UnsupportedOperationException(s"Unsupported data type: ${attr.dataType}")
} catch {
case e: UnsupportedOperationException =>
throw new UnsupportedOperationException(
Expand All @@ -358,8 +356,6 @@ case class ColumnarSortMergeJoinExec(
for (attr <- right.output) {
try {
ConverterUtils.checkIfTypeSupported(attr.dataType)
if (attr.dataType.isInstanceOf[DecimalType])
throw new UnsupportedOperationException(s"Unsupported data type: ${attr.dataType}")
} catch {
case e: UnsupportedOperationException =>
throw new UnsupportedOperationException(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -671,9 +671,6 @@ object ColumnarSortMergeJoin extends Logging {
conditionOption: Option[Expression]): TreeNode = {
/////// Build side ///////
val buildInputFieldList: List[Field] = buildInputAttributes.toList.map(attr => {
if (attr.dataType.isInstanceOf[DecimalType])
throw new UnsupportedOperationException(
s"Decimal type is not supported in ColumnarShuffledHashJoin.")
Field
.nullable(s"${attr.name}#${attr.exprId.id}", CodeGeneration.getResultType(attr.dataType))
})
Expand All @@ -687,9 +684,6 @@ object ColumnarSortMergeJoin extends Logging {
})
/////// Streamed side ///////
val streamedInputFieldList: List[Field] = streamedInputAttributes.toList.map(attr => {
if (attr.dataType.isInstanceOf[DecimalType])
throw new UnsupportedOperationException(
s"Decimal type is not supported in ColumnarShuffledHashJoin.")
Field
.nullable(s"${attr.name}#${attr.exprId.id}", CodeGeneration.getResultType(attr.dataType))
})
Expand Down
77 changes: 61 additions & 16 deletions cpp/src/codegen/arrow_compute/ext/codegen_node_visitor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -358,23 +358,42 @@ arrow::Status CodeGenNodeVisitor::Visit(const gandiva::FunctionNode& node) {
prepare_str_ += prepare_ss.str();
check_str_ = validity;
} else if (func_name.compare("divide") == 0) {
codes_str_ = "divide_" + std::to_string(cur_func_id);
auto validity = "divide_validity_" + std::to_string(cur_func_id);
std::stringstream prepare_ss;
prepare_ss << GetCTypeString(node.return_type()) << " " << codes_str_ << ";"
<< std::endl;
prepare_ss << "bool " << validity << " = (" << child_visitor_list[0]->GetPreCheck()
<< " && " << child_visitor_list[1]->GetPreCheck() << ");" << std::endl;
prepare_ss << "if (" << validity << ") {" << std::endl;
prepare_ss << codes_str_ << " = " << child_visitor_list[0]->GetResult() << " * 1.0 / "
<< child_visitor_list[1]->GetResult() << ";" << std::endl;
prepare_ss << "}" << std::endl;
codes_str_ = "divide_" + std::to_string(cur_func_id);
auto validity = codes_str_ + "_validity";
std::stringstream fix_ss;
if (node.return_type()->id() != arrow::Type::DECIMAL) {
fix_ss << child_visitor_list[0]->GetResult() << " / "
<< child_visitor_list[1]->GetResult();
} else {
auto leftNode = node.children().at(0);
auto rightNode = node.children().at(1);
auto leftType =
std::dynamic_pointer_cast<arrow::Decimal128Type>(leftNode->return_type());
auto rightType =
std::dynamic_pointer_cast<arrow::Decimal128Type>(rightNode->return_type());
auto resType = std::dynamic_pointer_cast<arrow::Decimal128Type>(node.return_type());
fix_ss << "divide(" << child_visitor_list[0]->GetResult() << ", "
<< leftType->precision() << ", " << leftType->scale() << ", "
<< child_visitor_list[1]->GetResult() << ", " << rightType->precision()
<< ", " << rightType->scale() << ", " << resType->precision() << ", "
<< resType->scale() << ")";
}
std::stringstream prepare_ss;
prepare_ss << GetCTypeString(node.return_type()) << " " << codes_str_ << ";"
<< std::endl;
prepare_ss << "bool " << validity << " = ("
<< CombineValidity({child_visitor_list[0]->GetPreCheck(),
child_visitor_list[1]->GetPreCheck()})
<< ");" << std::endl;
prepare_ss << "if (" << validity << ") {" << std::endl;
prepare_ss << codes_str_ << " = " << fix_ss.str() << ";" << std::endl;
prepare_ss << "}" << std::endl;

for (int i = 0; i < 2; i++) {
RETURN_NOT_OK(AppendProjectList(child_visitor_list, i));
}
prepare_str_ += prepare_ss.str();
check_str_ = validity;
for (int i = 0; i < 2; i++) {
prepare_str_ += child_visitor_list[i]->GetPrepare();
}
prepare_str_ += prepare_ss.str();
check_str_ = validity;
} else {
RETURN_NOT_OK(ProduceGandivaFunction());
}
Expand Down Expand Up @@ -531,6 +550,13 @@ arrow::Status CodeGenNodeVisitor::Visit(const gandiva::LiteralNode& node) {
prepare_ss << "auto literal_" << cur_func_id << R"( = ")"
<< gandiva::ToString(node.holder()) << R"(";)" << std::endl;

} else if (node.return_type()->id() == arrow::Type::DECIMAL) {
auto scalar = arrow::util::get<gandiva::DecimalScalar128>(node.holder());
auto decimal = arrow::Decimal128(scalar.value());
prepare_ss << "auto literal_" << cur_func_id << " = "
<< "arrow::Decimal128(\"" << decimal.ToString(scalar.scale()) << "\");"
<< std::endl;
decimal_scale_ = std::to_string(scalar.scale());
} else {
prepare_ss << "auto literal_" << cur_func_id << " = "
<< gandiva::ToString(node.holder()) << ";" << std::endl;
Expand Down Expand Up @@ -729,6 +755,25 @@ std::string CodeGenNodeVisitor::GetNaNCheckStr(std::string left, std::string rig
return ss.str();
}

std::string CodeGenNodeVisitor::CombineValidity(
std::vector<std::string> validity_list) {
bool first = true;
std::stringstream out;
for (auto validity : validity_list) {
if (first) {
if (validity.compare("true") != 0) {
out << validity;
first = false;
}
} else {
if (validity.compare("true") != 0) {
out << " && " << validity;
}
}
}
return out.str();
}

} // namespace extra
} // namespace arrowcompute
} // namespace codegen
Expand Down
3 changes: 2 additions & 1 deletion cpp/src/codegen/arrow_compute/ext/codegen_node_visitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ class CodeGenNodeVisitor : public VisitorBase {
action_impl_->ProduceCodes(action_codegen);
return arrow::Status::OK();
}

std::string CombineValidity(std::vector<std::string> validity_list);
std::string GetInput();
std::string GetResult();
std::string GetResultValidity();
Expand Down Expand Up @@ -107,6 +107,7 @@ class CodeGenNodeVisitor : public VisitorBase {
std::string input_codes_str_;
std::string check_str_;
gandiva::ExpressionPtr project_;
std::string decimal_scale_;
std::vector<int>* left_indices_ = nullptr;
std::vector<std::shared_ptr<arrow::Field>>* left_field_ = nullptr;
std::vector<int>* right_indices_ = nullptr;
Expand Down
9 changes: 9 additions & 0 deletions cpp/src/codegen/arrow_compute/ext/merge_join_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -234,9 +234,16 @@ class ConditionedJoinArraysKernel::Impl {
}
std::string GetResultIteratorPrepare() {
std::stringstream ss;
if (data_type_->id() == arrow::Type::DECIMAL) {
ss << "builder_" << indice_ << "_ = std::make_shared<"
<< GetTypeString(data_type_, "Builder")
<< ">(arrow::" << GetArrowTypeDefString(data_type_)
<< ", ctx_->memory_pool());" << std::endl;
} else {
ss << "builder_" << indice_ << "_ = std::make_shared<"
<< GetTypeString(data_type_, "Builder") << ">(ctx_->memory_pool());"
<< std::endl;
}
return ss.str();
}
std::string GetProcessFinish() {
Expand Down Expand Up @@ -1232,6 +1239,7 @@ typedef )" + item_content_str + " item_content;";
// TODO: fix multi columns case
std::string condition_check_str;
if (func_node) {
//TODO: move to use new API
condition_check_str =
GetConditionCheckFunc(func_node, left_field_list, right_field_list,
&left_cond_index_list, &right_cond_index_list);
Expand Down Expand Up @@ -1294,6 +1302,7 @@ typedef )" + item_content_str + " item_content;";
return BaseCodes() + R"(
#include "codegen/arrow_compute/ext/array_item_index.h"
#include "precompile/builder.h"
#include "precompile/gandiva.h"
#include <numeric>
using namespace sparkcolumnarplugin::precompile;
)" + hash_map_include_str +
Expand Down
1 change: 1 addition & 0 deletions cpp/src/codegen/common/relation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ arrow::Status MakeHashRelationColumn(uint32_t data_type_id,
PROCESS(arrow::DoubleType) \
PROCESS(arrow::Date32Type) \
PROCESS(arrow::Date64Type) \
PROCESS(arrow::Decimal128Type) \
PROCESS(arrow::StringType)
arrow::Status MakeRelationColumn(uint32_t data_type_id,
std::shared_ptr<RelationColumn>* out) {
Expand Down
4 changes: 2 additions & 2 deletions cpp/src/codegen/common/relation_column.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

#include "precompile/type_traits.h"

using sparkcolumnarplugin::precompile::enable_if_number;
using sparkcolumnarplugin::precompile::enable_if_number_or_decimal;
using sparkcolumnarplugin::precompile::enable_if_string_like;
using sparkcolumnarplugin::precompile::StringArray;
using sparkcolumnarplugin::precompile::TypeTraits;
Expand All @@ -45,7 +45,7 @@ template <typename T, typename Enable = void>
class TypedRelationColumn {};

template <typename DataType>
class TypedRelationColumn<DataType, enable_if_number<DataType>> : public RelationColumn {
class TypedRelationColumn<DataType, enable_if_number_or_decimal<DataType>> : public RelationColumn {
public:
using T = typename TypeTraits<DataType>::CType;
TypedRelationColumn() {}
Expand Down