Skip to content
This repository has been archived by the owner on Sep 18, 2023. It is now read-only.

[NSE-183] add Timestamp in native side #184

Closed
wants to merge 3 commits into from
Closed
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
add timestamp in native side
  • Loading branch information
rui-mo committed Mar 23, 2021

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature.
commit 107b07bf230f25cf81281521dfc87aacc6d8990d
Original file line number Diff line number Diff line change
@@ -3670,6 +3670,11 @@ arrow::Status MakeUniqueAction(
std::make_shared<UniqueAction<arrow::Decimal128Type, arrow::Decimal128>>(ctx,
type);
*out = std::dynamic_pointer_cast<ActionBase>(action_ptr);
} break;
case arrow::TimestampType::type_id: {
auto action_ptr =
std::make_shared<UniqueAction<arrow::TimestampType, int64_t>>(ctx, type);
*out = std::dynamic_pointer_cast<ActionBase>(action_ptr);
} break;
default: {
std::cout << "Not Found " << type->ToString() << ", type id is " << type->id()
Original file line number Diff line number Diff line change
@@ -87,6 +87,9 @@ using is_number_or_date = std::integral_constant<bool, arrow::is_number_type<T>:
template <typename DataType, typename R = void>
using enable_if_number_or_date = std::enable_if_t<is_number_or_date<DataType>::value, R>;

template <typename DataType, typename R = void>
using enable_if_timestamp = std::enable_if_t<arrow::is_timestamp_type<DataType>::value, R>;

template <typename DataType>
class ArrayAppender<DataType, enable_if_number_or_date<DataType>> : public AppenderBase {
public:
@@ -428,6 +431,91 @@ class ArrayAppender<DataType, enable_if_decimal<DataType>> : public AppenderBase
bool has_null_ = false;
};

template <typename DataType>
class ArrayAppender<DataType, enable_if_timestamp<DataType>> : public AppenderBase {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it looks like we could merge this with number, date, decimal?

public:
ArrayAppender(arrow::compute::ExecContext* ctx, AppenderType type = left)
: ctx_(ctx), type_(type) {
std::unique_ptr<arrow::ArrayBuilder> array_builder;
arrow::MakeBuilder(ctx_->memory_pool(), arrow::int64(), &array_builder);
Copy link
Collaborator Author

@rui-mo rui-mo Mar 25, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zhouyuan this is because TimestampType does not contain "arrow::TypeTraits::type_singleton()", so it is separated out.

builder_.reset(arrow::internal::checked_cast<BuilderType_*>(array_builder.release()));
}
~ArrayAppender() {}

AppenderType GetType() override { return type_; }
arrow::Status AddArray(const std::shared_ptr<arrow::Array>& arr) override {
auto typed_arr_ = std::dynamic_pointer_cast<ArrayType_>(arr);
cached_arr_.emplace_back(typed_arr_);
if (typed_arr_->null_count() > 0) has_null_ = true;
return arrow::Status::OK();
}

arrow::Status PopArray() override {
cached_arr_.pop_back();
has_null_ = false;
return arrow::Status::OK();
}

arrow::Status Append(const uint16_t& array_id, const uint16_t& item_id) override {
if (has_null_ && cached_arr_[array_id]->null_count() > 0 &&
cached_arr_[array_id]->IsNull(item_id)) {
RETURN_NOT_OK(builder_->AppendNull());
} else {
RETURN_NOT_OK(builder_->Append(cached_arr_[array_id]->GetView(item_id)));
}
return arrow::Status::OK();
}

arrow::Status Append(const uint16_t& array_id, const uint16_t& item_id,
int repeated) override {
if (repeated == 0) return arrow::Status::OK();
if (has_null_ && cached_arr_[array_id]->null_count() > 0 &&
cached_arr_[array_id]->IsNull(item_id)) {
RETURN_NOT_OK(builder_->AppendNulls(repeated));
} else {
auto val = cached_arr_[array_id]->GetView(item_id);
std::vector<CType> values;
values.resize(repeated, val);
RETURN_NOT_OK(builder_->AppendValues(values.data(), repeated));
}
return arrow::Status::OK();
}

arrow::Status Append(const std::vector<ArrayItemIndex>& index_list) {
for (auto tmp : index_list) {
if (has_null_ && cached_arr_[tmp.array_id]->null_count() > 0 &&
cached_arr_[tmp.array_id]->IsNull(tmp.id)) {
RETURN_NOT_OK(builder_->AppendNull());
} else {
RETURN_NOT_OK(builder_->Append(cached_arr_[tmp.array_id]->GetView(tmp.id)));
}
}
return arrow::Status::OK();
}

arrow::Status AppendNull() override { return builder_->AppendNull(); }
rui-mo marked this conversation as resolved.
Show resolved Hide resolved

arrow::Status Finish(std::shared_ptr<arrow::Array>* out_) override {
auto status = builder_->Finish(out_);
return status;
}

arrow::Status Reset() override {
builder_->Reset();
return arrow::Status::OK();
}

private:
using BuilderType_ = typename arrow::TypeTraits<DataType>::BuilderType;
using ArrayType_ = typename arrow::TypeTraits<DataType>::ArrayType;
using CType = typename arrow::TypeTraits<DataType>::CType;
std::unique_ptr<BuilderType_> builder_;
std::vector<std::shared_ptr<ArrayType_>> cached_arr_;
arrow::compute::ExecContext* ctx_;
AppenderType type_;
bool has_null_ = false;
};

#define PROCESS_SUPPORTED_TYPES(PROCESS) \
PROCESS(arrow::BooleanType) \
PROCESS(arrow::UInt8Type) \
@@ -442,7 +530,8 @@ class ArrayAppender<DataType, enable_if_decimal<DataType>> : public AppenderBase
PROCESS(arrow::DoubleType) \
PROCESS(arrow::Date32Type) \
PROCESS(arrow::Date64Type) \
PROCESS(arrow::StringType)
PROCESS(arrow::StringType) \
PROCESS(arrow::TimestampType)
static arrow::Status MakeAppender(arrow::compute::ExecContext* ctx,
std::shared_ptr<arrow::DataType> type,
AppenderBase::AppenderType appender_type,
Original file line number Diff line number Diff line change
@@ -733,7 +733,8 @@ class DecimalComparator {
PROCESS(arrow::UInt64Type) \
PROCESS(arrow::Int64Type) \
PROCESS(arrow::Date32Type) \
PROCESS(arrow::Date64Type)
PROCESS(arrow::Date64Type) \
PROCESS(arrow::TimestampType)
static arrow::Status MakeCmpFunction(
const std::vector<arrow::ArrayVector>& array_vectors,
const std::vector<std::shared_ptr<arrow::Field>>& key_field_list,
Original file line number Diff line number Diff line change
@@ -79,6 +79,8 @@ std::string GetArrowTypeDefString(std::shared_ptr<arrow::DataType> type) {
return "boolean()";
case arrow::Decimal128Type::type_id:
return type->ToString();
case arrow::TimestampType::type_id:
return "timestamp(arrow::TimeUnit::MILLI)";
default:
std::cout << "GetArrowTypeString can't convert " << type->ToString() << std::endl;
throw;
@@ -116,6 +118,8 @@ std::string GetCTypeString(std::shared_ptr<arrow::DataType> type) {
return "bool";
case arrow::Decimal128Type::type_id:
return "arrow::Decimal128";
case arrow::TimestampType::type_id:
return "int64_t";
default:
std::cout << "GetCTypeString can't convert " << type->ToString() << std::endl;
throw;
@@ -153,6 +157,8 @@ std::string GetTypeString(std::shared_ptr<arrow::DataType> type, std::string tai
return "Boolean" + tail;
case arrow::Decimal128Type::type_id:
return "Decimal128" + tail;
case arrow::TimestampType::type_id:
return "Timestamp" + tail;
default:
std::cout << "GetTypeString can't convert " << type->ToString() << std::endl;
throw;
@@ -219,7 +225,7 @@ std::string GetTemplateString(std::shared_ptr<arrow::DataType> type,
return template_name + "<" + prefix + "Date32" + tail + ">";
case arrow::Date64Type::type_id:
if (tail.empty())
return template_name + "<uint64_t>";
return template_name + "<int64_t>";
else
return template_name + "<" + prefix + "Date64" + tail + ">";
case arrow::StringType::type_id:
@@ -237,6 +243,11 @@ std::string GetTemplateString(std::shared_ptr<arrow::DataType> type,
return template_name + "<arrow::Decimal128>";
else
return template_name + "<" + prefix + "Decimal128" + tail + ">";
case arrow::TimestampType::type_id:
if (tail.empty())
return template_name + "<uint64_t>";
else
return template_name + "<" + prefix + "Timestamp" + tail + ">";
default:
std::cout << "GetTemplateString can't convert " << type->ToString() << std::endl;
throw;
Original file line number Diff line number Diff line change
@@ -406,7 +406,8 @@ class ConditionedProbeKernel::Impl {
PROCESS(arrow::DoubleType) \
PROCESS(arrow::Date32Type) \
PROCESS(arrow::Date64Type) \
PROCESS(arrow::StringType)
PROCESS(arrow::StringType) \
PROCESS(arrow::TimestampType)
arrow::Status SetDependencies(
const std::vector<std::shared_ptr<ResultIteratorBase>>& dependent_iter_list) {
auto iter = dependent_iter_list[0];
@@ -627,7 +628,8 @@ class ConditionedProbeKernel::Impl {
PROCESS(arrow::DoubleType) \
PROCESS(arrow::Date32Type) \
PROCESS(arrow::Date64Type) \
PROCESS(arrow::Decimal128Type)
PROCESS(arrow::Decimal128Type) \
PROCESS(arrow::TimestampType)
class UnsafeInnerProbeFunction : public ProbeFunctionBase {
public:
UnsafeInnerProbeFunction(std::shared_ptr<HashRelation> hash_relation,
@@ -747,7 +749,8 @@ class ConditionedProbeKernel::Impl {
PROCESS(arrow::DoubleType) \
PROCESS(arrow::Date32Type) \
PROCESS(arrow::Date64Type) \
PROCESS(arrow::Decimal128Type)
PROCESS(arrow::Decimal128Type) \
PROCESS(arrow::TimestampType)
class UnsafeOuterProbeFunction : public ProbeFunctionBase {
public:
UnsafeOuterProbeFunction(std::shared_ptr<HashRelation> hash_relation,
@@ -874,7 +877,8 @@ class ConditionedProbeKernel::Impl {
PROCESS(arrow::DoubleType) \
PROCESS(arrow::Date32Type) \
PROCESS(arrow::Date64Type) \
PROCESS(arrow::Decimal128Type)
PROCESS(arrow::Decimal128Type) \
PROCESS(arrow::TimestampType)
class UnsafeAntiProbeFunction : public ProbeFunctionBase {
public:
UnsafeAntiProbeFunction(std::shared_ptr<HashRelation> hash_relation,
@@ -997,7 +1001,8 @@ class ConditionedProbeKernel::Impl {
PROCESS(arrow::DoubleType) \
PROCESS(arrow::Date32Type) \
PROCESS(arrow::Date64Type) \
PROCESS(arrow::Decimal128Type)
PROCESS(arrow::Decimal128Type) \
PROCESS(arrow::TimestampType)
uint64_t Evaluate(std::shared_ptr<arrow::Array> key_array,
const arrow::ArrayVector& key_payloads) override {
auto typed_key_array = std::dynamic_pointer_cast<arrow::Int32Array>(key_array);
@@ -1112,7 +1117,8 @@ class ConditionedProbeKernel::Impl {
PROCESS(arrow::DoubleType) \
PROCESS(arrow::Date32Type) \
PROCESS(arrow::Date64Type) \
PROCESS(arrow::Decimal128Type)
PROCESS(arrow::Decimal128Type) \
PROCESS(arrow::TimestampType)
class UnsafeExistenceProbeFunction : public ProbeFunctionBase {
public:
UnsafeExistenceProbeFunction(
Original file line number Diff line number Diff line change
@@ -368,7 +368,8 @@ arrow::Status ExpressionCodegenVisitor::Visit(const gandiva::FunctionNode& node)
node.return_type()->id() == arrow::Type::DATE32) {
typed_func_name += "32";
} else if (node.return_type()->id() == arrow::Type::INT64 ||
node.return_type()->id() == arrow::Type::DATE64) {
node.return_type()->id() == arrow::Type::DATE64 ||
node.return_type()->id() == arrow::Type::TIMESTAMP) {
typed_func_name += "64";
} else {
return arrow::Status::NotImplemented("castDATE doesn't support ",
Original file line number Diff line number Diff line change
@@ -45,6 +45,7 @@ namespace arrowcompute {
namespace extra {
using ArrayList = std::vector<std::shared_ptr<arrow::Array>>;
using precompile::StringHashMap;
using precompile::enable_if_number_or_timestamp;

/////////////// SortArraysToIndices ////////////////
class HashAggregateKernel::Impl {
@@ -183,7 +184,8 @@ class HashAggregateKernel::Impl {
PROCESS(arrow::DoubleType) \
PROCESS(arrow::Date32Type) \
PROCESS(arrow::Date64Type) \
PROCESS(arrow::Decimal128Type)
PROCESS(arrow::Decimal128Type) \
PROCESS(arrow::TimestampType)
switch (type->id()) {
#define PROCESS(InType) \
case TypeTraits<InType>::type_id: { \
@@ -632,7 +634,7 @@ class HashAggregateKernel::Impl {
class HashAggregateResultIterator {};

template <typename DataType>
class HashAggregateResultIterator<DataType, enable_if_number<DataType>>
class HashAggregateResultIterator<DataType, enable_if_number_or_timestamp<DataType>>
: public ResultIterator<arrow::RecordBatch> {
public:
using T = typename arrow::TypeTraits<DataType>::CType;
Original file line number Diff line number Diff line change
@@ -218,7 +218,8 @@ class HashRelationKernel::Impl {
PROCESS(arrow::Date32Type) \
PROCESS(arrow::Date64Type) \
PROCESS(arrow::StringType) \
PROCESS(arrow::Decimal128Type)
PROCESS(arrow::Decimal128Type) \
PROCESS(arrow::TimestampType)
if (project_outputs.size() == 1) {
switch (project_outputs[0]->type_id()) {
#define PROCESS(InType) \
Original file line number Diff line number Diff line change
@@ -76,10 +76,9 @@ using ArrayList = std::vector<std::shared_ptr<arrow::Array>>;
using namespace sparkcolumnarplugin::precompile;

template <typename T>
using is_number_bool_date =
std::integral_constant<bool, arrow::is_number_type<T>::value ||
arrow::is_boolean_type<T>::value ||
arrow::is_date_type<T>::value>;
using is_number_bool_date = std::integral_constant<bool,
arrow::is_number_type<T>::value || arrow::is_boolean_type<T>::value ||
arrow::is_date_type<T>::value || arrow::is_timestamp_type<T>::value>;

/////////////// SortArraysToIndices ////////////////
class SortArraysToIndicesKernel::Impl {
@@ -1708,7 +1707,8 @@ arrow::Status SortArraysToIndicesKernel::Make(
PROCESS(arrow::FloatType) \
PROCESS(arrow::DoubleType) \
PROCESS(arrow::Date32Type) \
PROCESS(arrow::Date64Type)
PROCESS(arrow::Date64Type) \
PROCESS(arrow::TimestampType)
SortArraysToIndicesKernel::SortArraysToIndicesKernel(
arrow::compute::ExecContext* ctx, std::shared_ptr<arrow::Schema> result_schema,
gandiva::NodeVector sort_key_node,
4 changes: 2 additions & 2 deletions native-sql-engine/cpp/src/codegen/common/hash_relation.h
Original file line number Diff line number Diff line change
@@ -30,7 +30,7 @@

using sparkcolumnarplugin::codegen::arrowcompute::extra::ArrayItemIndex;
using sparkcolumnarplugin::precompile::enable_if_number;
using sparkcolumnarplugin::precompile::enable_if_number_or_decimal;
using sparkcolumnarplugin::precompile::enable_if_number_decimal_or_timestamp;
using sparkcolumnarplugin::precompile::enable_if_string_like;
using sparkcolumnarplugin::precompile::StringArray;
using sparkcolumnarplugin::precompile::TypeTraits;
@@ -54,7 +54,7 @@ template <typename T, typename Enable = void>
class TypedHashRelationColumn {};

template <typename DataType>
class TypedHashRelationColumn<DataType, enable_if_number_or_decimal<DataType>>
class TypedHashRelationColumn<DataType, enable_if_number_decimal_or_timestamp<DataType>>
: public HashRelationColumn {
public:
using T = typename TypeTraits<DataType>::CType;
Original file line number Diff line number Diff line change
@@ -20,13 +20,13 @@
#include "codegen/common/hash_relation.h"
#include "precompile/sparse_hash_map.h"
using sparkcolumnarplugin::codegen::arrowcompute::extra::ArrayItemIndex;
using sparkcolumnarplugin::precompile::enable_if_number;
using sparkcolumnarplugin::precompile::enable_if_number_or_timestamp;
using sparkcolumnarplugin::precompile::TypeTraits;

/////////////////////////////////////////////////////////////////////////

template <typename DataType>
class TypedHashRelation<DataType, enable_if_number<DataType>> : public HashRelation {
class TypedHashRelation<DataType, enable_if_number_or_timestamp<DataType>> : public HashRelation {
public:
using T = typename TypeTraits<DataType>::CType;
TypedHashRelation(
9 changes: 6 additions & 3 deletions native-sql-engine/cpp/src/codegen/common/relation.cc
Original file line number Diff line number Diff line change
@@ -35,7 +35,8 @@
PROCESS(arrow::Date32Type) \
PROCESS(arrow::Date64Type) \
PROCESS(arrow::StringType) \
PROCESS(arrow::Decimal128Type)
PROCESS(arrow::Decimal128Type) \
PROCESS(arrow::TimestampType)
arrow::Status MakeHashRelationColumn(uint32_t data_type_id,
std::shared_ptr<HashRelationColumn>* out) {
switch (data_type_id) {
@@ -72,7 +73,8 @@ arrow::Status MakeHashRelationColumn(uint32_t data_type_id,
PROCESS(arrow::Date32Type) \
PROCESS(arrow::Date64Type) \
PROCESS(arrow::Decimal128Type) \
PROCESS(arrow::StringType)
PROCESS(arrow::StringType) \
PROCESS(arrow::TimestampType)
arrow::Status MakeRelationColumn(uint32_t data_type_id,
std::shared_ptr<RelationColumn>* out) {
switch (data_type_id) {
@@ -109,7 +111,8 @@ arrow::Status MakeRelationColumn(uint32_t data_type_id,
PROCESS(arrow::DoubleType) \
PROCESS(arrow::Date32Type) \
PROCESS(arrow::Date64Type) \
PROCESS(arrow::StringType)
PROCESS(arrow::StringType) \
PROCESS(arrow::TimestampType)
arrow::Status MakeHashRelation(
uint32_t key_type_id, arrow::compute::ExecContext* ctx,
const std::vector<std::shared_ptr<HashRelationColumn>>& hash_relation_column,
4 changes: 2 additions & 2 deletions native-sql-engine/cpp/src/codegen/common/relation_column.h
Original file line number Diff line number Diff line change
@@ -23,7 +23,7 @@

#include "precompile/type_traits.h"

using sparkcolumnarplugin::precompile::enable_if_number_or_decimal;
using sparkcolumnarplugin::precompile::enable_if_number_decimal_or_timestamp;
using sparkcolumnarplugin::precompile::enable_if_string_like;
using sparkcolumnarplugin::precompile::StringArray;
using sparkcolumnarplugin::precompile::TypeTraits;
@@ -45,7 +45,7 @@ template <typename T, typename Enable = void>
class TypedRelationColumn {};

template <typename DataType>
class TypedRelationColumn<DataType, enable_if_number_or_decimal<DataType>>
class TypedRelationColumn<DataType, enable_if_number_decimal_or_timestamp<DataType>>
: public RelationColumn {
public:
using T = typename TypeTraits<DataType>::CType;
Loading