Skip to content

Commit

Permalink
support type_matcher for combination input types in decimal compare k…
Browse files Browse the repository at this point in the history
…ernel
  • Loading branch information
ZhangHuiGui committed May 17, 2024
1 parent 8ca5f83 commit a7703b9
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 19 deletions.
14 changes: 12 additions & 2 deletions cpp/src/arrow/compute/kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -430,6 +430,11 @@ bool InputType::Matches(const Datum& value) const {
return Matches(*value.type());
}

bool InputType::Matches(const std::vector<TypeHolder>& types) const {
DCHECK_EQ(InputType::USE_TYPE_MATCHER, kind_);
return type_matcher_->Matches(types);
}

const std::shared_ptr<DataType>& InputType::type() const {
DCHECK_EQ(InputType::EXACT_TYPE, kind_);
return type_;
Expand Down Expand Up @@ -505,9 +510,14 @@ bool KernelSignature::Equals(const KernelSignature& other) const {
}

bool KernelSignature::MatchesInputs(const std::vector<TypeHolder>& types) const {
auto is_match_combination_types = [&](const InputType& in_type) {
return in_type.kind() == InputType::USE_TYPE_MATCHER ? in_type.Matches(types) : true;
};

if (is_varargs_) {
for (size_t i = 0; i < types.size(); ++i) {
if (!in_types_[std::min(i, in_types_.size() - 1)].Matches(*types[i])) {
const auto& in_type = in_types_[std::min(i, in_types_.size() - 1)];
if (!in_type.Matches(*types[i]) || !is_match_combination_types(in_type)) {
return false;
}
}
Expand All @@ -516,7 +526,7 @@ bool KernelSignature::MatchesInputs(const std::vector<TypeHolder>& types) const
return false;
}
for (size_t i = 0; i < in_types_.size(); ++i) {
if (!in_types_[i].Matches(*types[i])) {
if (!in_types_[i].Matches(*types[i]) || !is_match_combination_types(in_types_[i])) {
return false;
}
}
Expand Down
7 changes: 7 additions & 0 deletions cpp/src/arrow/compute/kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,9 @@ struct ARROW_EXPORT TypeMatcher {
/// \brief Return true if this matcher accepts the data type.
virtual bool Matches(const DataType& type) const = 0;

/// \brief Return true if this matcher accepts the combination types
virtual bool Matches(const std::vector<TypeHolder>& types) const { return true; }

/// \brief A human-interpretable string representation of what the type
/// matcher checks for, usable when printing KernelSignature or formatting
/// error messages.
Expand Down Expand Up @@ -241,6 +244,10 @@ class ARROW_EXPORT InputType {
/// \brief Return true if the type matches this InputType
bool Matches(const DataType& type) const;

/// \brief Return true if the input combination types matches this
/// InputType's type_matcher matched rules.
bool Matches(const std::vector<TypeHolder>& types) const;

/// \brief The type matching rule that this InputType uses.
Kind kind() const { return kind_; }

Expand Down
65 changes: 48 additions & 17 deletions cpp/src/arrow/compute/kernels/scalar_compare.cc
Original file line number Diff line number Diff line change
Expand Up @@ -385,21 +385,52 @@ struct VarArgsCompareFunction : ScalarFunction {
}
};

Result<TypeHolder> ResolveDecimalCompareOutputType(KernelContext*,
const std::vector<TypeHolder>& types) {
// casted types should be same size decimals
const auto& left_type = checked_cast<const DecimalType&>(*types[0]);
const auto& right_type = checked_cast<const DecimalType&>(*types[1]);
DCHECK_EQ(left_type.id(), right_type.id());

// check the casted decimal scales according kAdd promotion rule
const int32_t s1 = left_type.scale();
const int32_t s2 = right_type.scale();
if (s1 != s2) {
return Status::Invalid("Comparison of two decimal ", "types s1 != s2. (", s1, s2,
").");
}
return boolean();
class DecimalTypesCompareMatcher : public TypeMatcher {
public:
explicit DecimalTypesCompareMatcher(std::shared_ptr<TypeMatcher> decimal_type_matcher)
: decimal_type_matcher(std::move(decimal_type_matcher)) {}

bool Matches(const DataType& type) const override {
return decimal_type_matcher->Matches(type);
}

bool Matches(const std::vector<TypeHolder>& types) const override {
DCHECK_EQ(types.size(), 2);
if (!is_decimal(*types[0]) || !is_decimal(*types[1])) {
return true;
}

// Below match logic should only be executed when types are both decimal
//
const auto& left_type = checked_cast<const DecimalType&>(*types[0]);
const auto& right_type = checked_cast<const DecimalType&>(*types[1]);

// check the decimal types' scales according kAdd promotion rule
const int32_t s1 = left_type.scale();
const int32_t s2 = right_type.scale();
if (s1 != s2) {
return false;
}
return true;
}

bool Equals(const TypeMatcher& other) const override {
if (this == &other) {
return true;
}
const auto* casted = dynamic_cast<const DecimalTypesCompareMatcher*>(&other);
return casted != nullptr &&
decimal_type_matcher->Equals(*casted->decimal_type_matcher);
}

std::string ToString() const override { return "decimal-types-matcher"; }

private:
std::shared_ptr<TypeMatcher> decimal_type_matcher;
};

std::shared_ptr<TypeMatcher> DecimalTypesMatcher(Type::type type_id) {
return std::make_shared<DecimalTypesCompareMatcher>(match::SameTypeId(type_id));
}

template <typename Op>
Expand Down Expand Up @@ -450,9 +481,9 @@ std::shared_ptr<ScalarFunction> MakeCompareFunction(std::string name, FunctionDo
}

for (const auto id : {Type::DECIMAL128, Type::DECIMAL256}) {
OutputType out_type(ResolveDecimalCompareOutputType);
InputType in_type(DecimalTypesMatcher(id));
auto exec = GenerateDecimal<applicator::ScalarBinaryEqualTypes, BooleanType, Op>(id);
DCHECK_OK(func->AddKernel({InputType(id), InputType(id)}, out_type, std::move(exec)));
DCHECK_OK(func->AddKernel({in_type, in_type}, boolean(), std::move(exec)));
}

{
Expand Down

0 comments on commit a7703b9

Please sign in to comment.