diff --git a/core/src/main/scala/com/intel/oap/execution/ColumnarSortMergeJoinExec.scala b/core/src/main/scala/com/intel/oap/execution/ColumnarSortMergeJoinExec.scala index aaefa6d34..a911f0ec0 100644 --- a/core/src/main/scala/com/intel/oap/execution/ColumnarSortMergeJoinExec.scala +++ b/core/src/main/scala/com/intel/oap/execution/ColumnarSortMergeJoinExec.scala @@ -347,8 +347,6 @@ case class ColumnarSortMergeJoinExec( for (attr <- left.output) { try { ConverterUtils.checkIfTypeSupported(attr.dataType) - if (attr.dataType.isInstanceOf[DecimalType]) - throw new UnsupportedOperationException(s"Unsupported data type: ${attr.dataType}") } catch { case e: UnsupportedOperationException => throw new UnsupportedOperationException( @@ -358,8 +356,6 @@ case class ColumnarSortMergeJoinExec( for (attr <- right.output) { try { ConverterUtils.checkIfTypeSupported(attr.dataType) - if (attr.dataType.isInstanceOf[DecimalType]) - throw new UnsupportedOperationException(s"Unsupported data type: ${attr.dataType}") } catch { case e: UnsupportedOperationException => throw new UnsupportedOperationException( diff --git a/core/src/main/scala/com/intel/oap/expression/ColumnarSortMergeJoin.scala b/core/src/main/scala/com/intel/oap/expression/ColumnarSortMergeJoin.scala index 21fbf837d..cc6071a41 100644 --- a/core/src/main/scala/com/intel/oap/expression/ColumnarSortMergeJoin.scala +++ b/core/src/main/scala/com/intel/oap/expression/ColumnarSortMergeJoin.scala @@ -671,9 +671,6 @@ object ColumnarSortMergeJoin extends Logging { conditionOption: Option[Expression]): TreeNode = { /////// Build side /////// val buildInputFieldList: List[Field] = buildInputAttributes.toList.map(attr => { - if (attr.dataType.isInstanceOf[DecimalType]) - throw new UnsupportedOperationException( - s"Decimal type is not supported in ColumnarShuffledHashJoin.") Field .nullable(s"${attr.name}#${attr.exprId.id}", CodeGeneration.getResultType(attr.dataType)) }) @@ -687,9 +684,6 @@ object ColumnarSortMergeJoin extends Logging { }) /////// Streamed side /////// val streamedInputFieldList: List[Field] = streamedInputAttributes.toList.map(attr => { - if (attr.dataType.isInstanceOf[DecimalType]) - throw new UnsupportedOperationException( - s"Decimal type is not supported in ColumnarShuffledHashJoin.") Field .nullable(s"${attr.name}#${attr.exprId.id}", CodeGeneration.getResultType(attr.dataType)) }) diff --git a/cpp/src/codegen/arrow_compute/ext/merge_join_kernel.cc b/cpp/src/codegen/arrow_compute/ext/merge_join_kernel.cc index 53bdfb0c9..7b9d561f5 100644 --- a/cpp/src/codegen/arrow_compute/ext/merge_join_kernel.cc +++ b/cpp/src/codegen/arrow_compute/ext/merge_join_kernel.cc @@ -234,9 +234,16 @@ class ConditionedJoinArraysKernel::Impl { } std::string GetResultIteratorPrepare() { std::stringstream ss; + if (data_type_->id() == arrow::Type::DECIMAL) { + ss << "builder_" << indice_ << "_ = std::make_shared<" + << GetTypeString(data_type_, "Builder") + << ">(arrow::" << GetArrowTypeDefString(data_type_) + << ", ctx_->memory_pool());" << std::endl; + } else { ss << "builder_" << indice_ << "_ = std::make_shared<" << GetTypeString(data_type_, "Builder") << ">(ctx_->memory_pool());" << std::endl; + } return ss.str(); } std::string GetProcessFinish() { diff --git a/cpp/src/codegen/common/relation.cc b/cpp/src/codegen/common/relation.cc index 11d0f386e..8f3130d1e 100644 --- a/cpp/src/codegen/common/relation.cc +++ b/cpp/src/codegen/common/relation.cc @@ -71,6 +71,7 @@ arrow::Status MakeHashRelationColumn(uint32_t data_type_id, PROCESS(arrow::DoubleType) \ PROCESS(arrow::Date32Type) \ PROCESS(arrow::Date64Type) \ + PROCESS(arrow::Decimal128Type) \ PROCESS(arrow::StringType) arrow::Status MakeRelationColumn(uint32_t data_type_id, std::shared_ptr<RelationColumn>* out) { diff --git a/cpp/src/codegen/common/relation_column.h b/cpp/src/codegen/common/relation_column.h index 1ab289e07..8cafc657f 100644 --- a/cpp/src/codegen/common/relation_column.h +++ b/cpp/src/codegen/common/relation_column.h @@ -23,7 +23,7 @@ #include "precompile/type_traits.h" -using sparkcolumnarplugin::precompile::enable_if_number; +using sparkcolumnarplugin::precompile::enable_if_number_or_decimal; using sparkcolumnarplugin::precompile::enable_if_string_like; using sparkcolumnarplugin::precompile::StringArray; using sparkcolumnarplugin::precompile::TypeTraits; @@ -45,7 +45,7 @@ template <typename T, typename Enable = void> class TypedRelationColumn {}; template <typename DataType> -class TypedRelationColumn<DataType, enable_if_number<DataType>> : public RelationColumn { +class TypedRelationColumn<DataType, enable_if_number_or_decimal<DataType>> : public RelationColumn { public: using T = typename TypeTraits<DataType>::CType; TypedRelationColumn() {}