Skip to content

Commit

Permalink
apacheGH-39233: [Compute] Add some duration kernels (apache#39358)
Browse files Browse the repository at this point in the history
### Rationale for this change

Add kernels for durations.
### What changes are included in this PR?

In this PR I added the ones that require only registration and unit tests. More complicated ones will be in another PR for readability.
 

### Are these changes tested?
Yes.

### Are there any user-facing changes?
No.

* Closes: apache#39233

Authored-by: Jin Shang <shangjin1997@gmail.com>
Signed-off-by: Antoine Pitrou <antoine@python.org>
  • Loading branch information
js8544 authored and zanmato1984 committed Feb 28, 2024
1 parent df28b07 commit 3712c54
Show file tree
Hide file tree
Showing 7 changed files with 113 additions and 41 deletions.
35 changes: 35 additions & 0 deletions cpp/src/arrow/compute/kernels/scalar_arithmetic.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1286,12 +1286,27 @@ void RegisterScalarArithmetic(FunctionRegistry* registry) {
auto absolute_value =
MakeUnaryArithmeticFunction<AbsoluteValue>("abs", absolute_value_doc);
AddDecimalUnaryKernels<AbsoluteValue>(absolute_value.get());

// abs(duration)
for (auto unit : TimeUnit::values()) {
auto exec = ArithmeticExecFromOp<ScalarUnary, AbsoluteValue>(duration(unit));
DCHECK_OK(
absolute_value->AddKernel({duration(unit)}, OutputType(duration(unit)), exec));
}

DCHECK_OK(registry->AddFunction(std::move(absolute_value)));

// ----------------------------------------------------------------------
auto absolute_value_checked = MakeUnaryArithmeticFunctionNotNull<AbsoluteValueChecked>(
"abs_checked", absolute_value_checked_doc);
AddDecimalUnaryKernels<AbsoluteValueChecked>(absolute_value_checked.get());
// abs_checked(duraton)
for (auto unit : TimeUnit::values()) {
auto exec =
ArithmeticExecFromOp<ScalarUnaryNotNull, AbsoluteValueChecked>(duration(unit));
DCHECK_OK(absolute_value_checked->AddKernel({duration(unit)},
OutputType(duration(unit)), exec));
}
DCHECK_OK(registry->AddFunction(std::move(absolute_value_checked)));

// ----------------------------------------------------------------------
Expand Down Expand Up @@ -1545,12 +1560,27 @@ void RegisterScalarArithmetic(FunctionRegistry* registry) {
// ----------------------------------------------------------------------
auto negate = MakeUnaryArithmeticFunction<Negate>("negate", negate_doc);
AddDecimalUnaryKernels<Negate>(negate.get());

// Add neg(duration) -> duration
for (auto unit : TimeUnit::values()) {
auto exec = ArithmeticExecFromOp<ScalarUnary, Negate>(duration(unit));
DCHECK_OK(negate->AddKernel({duration(unit)}, OutputType(duration(unit)), exec));
}

DCHECK_OK(registry->AddFunction(std::move(negate)));

// ----------------------------------------------------------------------
auto negate_checked = MakeUnarySignedArithmeticFunctionNotNull<NegateChecked>(
"negate_checked", negate_checked_doc);
AddDecimalUnaryKernels<NegateChecked>(negate_checked.get());

// Add neg_checked(duration) -> duration
for (auto unit : TimeUnit::values()) {
auto exec = ArithmeticExecFromOp<ScalarUnaryNotNull, Negate>(duration(unit));
DCHECK_OK(
negate_checked->AddKernel({duration(unit)}, OutputType(duration(unit)), exec));
}

DCHECK_OK(registry->AddFunction(std::move(negate_checked)));

// ----------------------------------------------------------------------
Expand Down Expand Up @@ -1581,6 +1611,11 @@ void RegisterScalarArithmetic(FunctionRegistry* registry) {
// ----------------------------------------------------------------------
auto sign =
MakeUnaryArithmeticFunctionWithFixedIntOutType<Sign, Int8Type>("sign", sign_doc);
// sign(duration)
for (auto unit : TimeUnit::values()) {
auto exec = ScalarUnary<Int8Type, Int64Type, Sign>::Exec;
DCHECK_OK(sign->AddKernel({duration(unit)}, int8(), std::move(exec)));
}
DCHECK_OK(registry->AddFunction(std::move(sign)));

// ----------------------------------------------------------------------
Expand Down
9 changes: 9 additions & 0 deletions cpp/src/arrow/compute/kernels/scalar_compare.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

#include "arrow/compute/api_scalar.h"
#include "arrow/compute/kernels/common_internal.h"
#include "arrow/type.h"
#include "arrow/util/bit_util.h"
#include "arrow/util/bitmap_ops.h"

Expand Down Expand Up @@ -806,6 +807,14 @@ std::shared_ptr<ScalarFunction> MakeScalarMinMax(std::string name, FunctionDoc d
kernel.mem_allocation = MemAllocation::type::PREALLOCATE;
DCHECK_OK(func->AddKernel(std::move(kernel)));
}
for (const auto& ty : DurationTypes()) {
auto exec = GeneratePhysicalNumeric<ScalarMinMax, Op>(ty);
ScalarKernel kernel{KernelSignature::Make({ty}, ty, /*is_varargs=*/true), exec,
MinMaxState::Init};
kernel.null_handling = NullHandling::type::COMPUTED_NO_PREALLOCATE;
kernel.mem_allocation = MemAllocation::type::PREALLOCATE;
DCHECK_OK(func->AddKernel(std::move(kernel)));
}
for (const auto& ty : BaseBinaryTypes()) {
auto exec =
GenerateTypeAgnosticVarBinaryBase<BinaryScalarMinMax, ArrayKernelExec, Op>(ty);
Expand Down
7 changes: 6 additions & 1 deletion cpp/src/arrow/compute/kernels/scalar_compare_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1281,7 +1281,7 @@ using CompareNumericBasedTypes =
::testing::Types<UInt8Type, UInt16Type, UInt32Type, UInt64Type, Int8Type, Int16Type,
Int32Type, Int64Type, FloatType, DoubleType, Date32Type, Date64Type>;
using CompareParametricTemporalTypes =
::testing::Types<TimestampType, Time32Type, Time64Type>;
::testing::Types<TimestampType, Time32Type, Time64Type, DurationType>;
using CompareFixedSizeBinaryTypes = ::testing::Types<FixedSizeBinaryType>;

TYPED_TEST_SUITE(TestVarArgsCompareNumeric, CompareNumericBasedTypes);
Expand Down Expand Up @@ -2121,6 +2121,11 @@ TEST(TestMaxElementWiseMinElementWise, CommonTemporal) {
ScalarFromJSON(date64(), "172800000"),
}),
ResultWith(ScalarFromJSON(date64(), "86400000")));
EXPECT_THAT(MinElementWise({
ScalarFromJSON(duration(TimeUnit::SECOND), "1"),
ScalarFromJSON(duration(TimeUnit::MILLI), "12000"),
}),
ResultWith(ScalarFromJSON(duration(TimeUnit::MILLI), "1000")));
}

} // namespace compute
Expand Down
12 changes: 12 additions & 0 deletions cpp/src/arrow/compute/kernels/scalar_temporal_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3665,5 +3665,17 @@ TEST_F(ScalarTemporalTest, TestCeilFloorRoundTemporalDate) {
CheckScalarUnary("ceil_temporal", arr_ns, arr_ns, &round_to_2_hours);
}

TEST_F(ScalarTemporalTest, DurationUnaryArithmetics) {
auto arr = ArrayFromJSON(duration(TimeUnit::SECOND), "[2, -1, null, 3, 0]");
CheckScalarUnary("negate", arr,
ArrayFromJSON(duration(TimeUnit::SECOND), "[-2, 1, null, -3, 0]"));
CheckScalarUnary("negate_checked", arr,
ArrayFromJSON(duration(TimeUnit::SECOND), "[-2, 1, null, -3, 0]"));
CheckScalarUnary("abs", arr,
ArrayFromJSON(duration(TimeUnit::SECOND), "[2, 1, null, 3, 0]"));
CheckScalarUnary("abs_checked", arr,
ArrayFromJSON(duration(TimeUnit::SECOND), "[2, 1, null, 3, 0]"));
CheckScalarUnary("sign", arr, ArrayFromJSON(int8(), "[1, -1, null, 1, 0]"));
}
} // namespace compute
} // namespace arrow
6 changes: 5 additions & 1 deletion cpp/src/arrow/compute/kernels/scalar_validity.cc
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,7 @@ std::shared_ptr<ScalarFunction> MakeIsFiniteFunction(std::string name, FunctionD
func->AddKernel({InputType(Type::DECIMAL128)}, boolean(), ConstBoolExec<true>));
DCHECK_OK(
func->AddKernel({InputType(Type::DECIMAL256)}, boolean(), ConstBoolExec<true>));
DCHECK_OK(func->AddKernel({InputType(Type::DURATION)}, boolean(), ConstBoolExec<true>));

return func;
}
Expand All @@ -187,7 +188,8 @@ std::shared_ptr<ScalarFunction> MakeIsInfFunction(std::string name, FunctionDoc
func->AddKernel({InputType(Type::DECIMAL128)}, boolean(), ConstBoolExec<false>));
DCHECK_OK(
func->AddKernel({InputType(Type::DECIMAL256)}, boolean(), ConstBoolExec<false>));

DCHECK_OK(
func->AddKernel({InputType(Type::DURATION)}, boolean(), ConstBoolExec<false>));
return func;
}

Expand All @@ -205,6 +207,8 @@ std::shared_ptr<ScalarFunction> MakeIsNanFunction(std::string name, FunctionDoc
func->AddKernel({InputType(Type::DECIMAL128)}, boolean(), ConstBoolExec<false>));
DCHECK_OK(
func->AddKernel({InputType(Type::DECIMAL256)}, boolean(), ConstBoolExec<false>));
DCHECK_OK(
func->AddKernel({InputType(Type::DURATION)}, boolean(), ConstBoolExec<false>));

return func;
}
Expand Down
7 changes: 7 additions & 0 deletions cpp/src/arrow/compute/kernels/scalar_validity_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,9 @@ TEST(TestValidityKernels, IsFinite) {
}
CheckScalar("is_finite", {std::make_shared<NullArray>(4)},
ArrayFromJSON(boolean(), "[null, null, null, null]"));
CheckScalar("is_finite",
{ArrayFromJSON(duration(TimeUnit::SECOND), "[0, 1, 42, null]")},
ArrayFromJSON(boolean(), "[true, true, true, null]"));
}

TEST(TestValidityKernels, IsInf) {
Expand All @@ -116,6 +119,8 @@ TEST(TestValidityKernels, IsInf) {
}
CheckScalar("is_inf", {std::make_shared<NullArray>(4)},
ArrayFromJSON(boolean(), "[null, null, null, null]"));
CheckScalar("is_inf", {ArrayFromJSON(duration(TimeUnit::SECOND), "[0, 1, 42, null]")},
ArrayFromJSON(boolean(), "[false, false, false, null]"));
}

TEST(TestValidityKernels, IsNan) {
Expand All @@ -129,6 +134,8 @@ TEST(TestValidityKernels, IsNan) {
}
CheckScalar("is_nan", {std::make_shared<NullArray>(4)},
ArrayFromJSON(boolean(), "[null, null, null, null]"));
CheckScalar("is_nan", {ArrayFromJSON(duration(TimeUnit::SECOND), "[0, 1, 42, null]")},
ArrayFromJSON(boolean(), "[false, false, false, null]"));
}

TEST(TestValidityKernels, IsValidIsNullNullType) {
Expand Down
78 changes: 39 additions & 39 deletions docs/source/cpp/compute.rst
Original file line number Diff line number Diff line change
Expand Up @@ -458,45 +458,45 @@ floating-point arguments will cast all arguments to floating-point, while mixed
decimal and integer arguments will cast all arguments to decimals.
Mixed time resolution temporal inputs will be cast to finest input resolution.

+------------------+--------+-------------------------+----------------------+-------+
| Function name | Arity | Input types | Output type | Notes |
+==================+========+=========================+======================+=======+
| abs | Unary | Numeric | Numeric | |
+------------------+--------+-------------------------+----------------------+-------+
| abs_checked | Unary | Numeric | Numeric | |
+------------------+--------+-------------------------+----------------------+-------+
| add | Binary | Numeric/Temporal | Numeric/Temporal | \(1) |
+------------------+--------+-------------------------+----------------------+-------+
| add_checked | Binary | Numeric/Temporal | Numeric/Temporal | \(1) |
+------------------+--------+-------------------------+----------------------+-------+
| divide | Binary | Numeric/Temporal | Numeric/Temporal | \(1) |
+------------------+--------+-------------------------+----------------------+-------+
| divide_checked | Binary | Numeric/Temporal | Numeric/Temporal | \(1) |
+------------------+--------+-------------------------+----------------------+-------+
| exp | Unary | Numeric | Float32/Float64 | |
+------------------+--------+-------------------------+----------------------+-------+
| multiply | Binary | Numeric/Temporal | Numeric/Temporal | \(1) |
+------------------+--------+-------------------------+----------------------+-------+
| multiply_checked | Binary | Numeric/Temporal | Numeric/Temporal | \(1) |
+------------------+--------+-------------------------+----------------------+-------+
| negate | Unary | Numeric | Numeric | |
+------------------+--------+-------------------------+----------------------+-------+
| negate_checked | Unary | Signed Numeric | Signed Numeric | |
+------------------+--------+-------------------------+----------------------+-------+
| power | Binary | Numeric | Numeric | |
+------------------+--------+-------------------------+----------------------+-------+
| power_checked | Binary | Numeric | Numeric | |
+------------------+--------+-------------------------+----------------------+-------+
| sign | Unary | Numeric | Int8/Float32/Float64 | \(2) |
+------------------+--------+-------------------------+----------------------+-------+
| sqrt | Unary | Numeric | Numeric | |
+------------------+--------+-------------------------+----------------------+-------+
| sqrt_checked | Unary | Numeric | Numeric | |
+------------------+--------+-------------------------+----------------------+-------+
| subtract | Binary | Numeric/Temporal | Numeric/Temporal | \(1) |
+------------------+--------+-------------------------+----------------------+-------+
| subtract_checked | Binary | Numeric/Temporal | Numeric/Temporal | \(1) |
+------------------+--------+-------------------------+----------------------+-------+
+------------------+--------+-------------------------+---------------------------+-------+
| Function name | Arity | Input types | Output type | Notes |
+==================+========+=========================+===========================+=======+
| abs | Unary | Numeric/Duration | Numeric/Duration | |
+------------------+--------+-------------------------+---------------------------+-------+
| abs_checked | Unary | Numeric/Duration | Numeric/Duration | |
+------------------+--------+-------------------------+---------------------------+-------+
| add | Binary | Numeric/Temporal | Numeric/Temporal | \(1) |
+------------------+--------+-------------------------+---------------------------+-------+
| add_checked | Binary | Numeric/Temporal | Numeric/Temporal | \(1) |
+------------------+--------+-------------------------+---------------------------+-------+
| divide | Binary | Numeric/Temporal | Numeric/Temporal | \(1) |
+------------------+--------+-------------------------+---------------------------+-------+
| divide_checked | Binary | Numeric/Temporal | Numeric/Temporal | \(1) |
+------------------+--------+-------------------------+---------------------------+-------+
| exp | Unary | Numeric | Float32/Float64 | |
+------------------+--------+-------------------------+---------------------------+-------+
| multiply | Binary | Numeric/Temporal | Numeric/Temporal | \(1) |
+------------------+--------+-------------------------+---------------------------+-------+
| multiply_checked | Binary | Numeric/Temporal | Numeric/Temporal | \(1) |
+------------------+--------+-------------------------+---------------------------+-------+
| negate | Unary | Numeric/Duration | Numeric/Duration | |
+------------------+--------+-------------------------+---------------------------+-------+
| negate_checked | Unary | Signed Numeric/Duration | Signed Numeric/Duration | |
+------------------+--------+-------------------------+---------------------------+-------+
| power | Binary | Numeric | Numeric | |
+------------------+--------+-------------------------+---------------------------+-------+
| power_checked | Binary | Numeric | Numeric | |
+------------------+--------+-------------------------+---------------------------+-------+
| sign | Unary | Numeric/Duration | Int8/Float32/Float64 | \(2) |
+------------------+--------+-------------------------+---------------------------+-------+
| sqrt | Unary | Numeric | Numeric | |
+------------------+--------+-------------------------+---------------------------+-------+
| sqrt_checked | Unary | Numeric | Numeric | |
+------------------+--------+-------------------------+---------------------------+-------+
| subtract | Binary | Numeric/Temporal | Numeric/Temporal | \(1) |
+------------------+--------+-------------------------+---------------------------+-------+
| subtract_checked | Binary | Numeric/Temporal | Numeric/Temporal | \(1) |
+------------------+--------+-------------------------+---------------------------+-------+

* \(1) Precision and scale of computed DECIMAL results

Expand Down

0 comments on commit 3712c54

Please sign in to comment.