From e868406a834cd93106b82bf402c59e476b481a0d Mon Sep 17 00:00:00 2001 From: Jin Shang Date: Thu, 11 Jan 2024 23:38:41 +0800 Subject: [PATCH] GH-39233: [Compute] Add some duration kernels (#39358) ### Rationale for this change Add kernels for durations. ### What changes are included in this PR? In this PR I added the ones that require only registration and unit tests. More complicated ones will be in another PR for readability. ### Are these changes tested? Yes. ### Are there any user-facing changes? No. * Closes: #39233 Authored-by: Jin Shang Signed-off-by: Antoine Pitrou --- .../compute/kernels/scalar_arithmetic.cc | 35 +++++++++ .../arrow/compute/kernels/scalar_compare.cc | 9 +++ .../compute/kernels/scalar_compare_test.cc | 7 +- .../compute/kernels/scalar_temporal_test.cc | 12 +++ .../arrow/compute/kernels/scalar_validity.cc | 6 +- .../compute/kernels/scalar_validity_test.cc | 7 ++ docs/source/cpp/compute.rst | 78 +++++++++---------- 7 files changed, 113 insertions(+), 41 deletions(-) diff --git a/cpp/src/arrow/compute/kernels/scalar_arithmetic.cc b/cpp/src/arrow/compute/kernels/scalar_arithmetic.cc index ad33d7f8951f4..44f5fea79078a 100644 --- a/cpp/src/arrow/compute/kernels/scalar_arithmetic.cc +++ b/cpp/src/arrow/compute/kernels/scalar_arithmetic.cc @@ -1286,12 +1286,27 @@ void RegisterScalarArithmetic(FunctionRegistry* registry) { auto absolute_value = MakeUnaryArithmeticFunction("abs", absolute_value_doc); AddDecimalUnaryKernels(absolute_value.get()); + + // abs(duration) + for (auto unit : TimeUnit::values()) { + auto exec = ArithmeticExecFromOp(duration(unit)); + DCHECK_OK( + absolute_value->AddKernel({duration(unit)}, OutputType(duration(unit)), exec)); + } + DCHECK_OK(registry->AddFunction(std::move(absolute_value))); // ---------------------------------------------------------------------- auto absolute_value_checked = MakeUnaryArithmeticFunctionNotNull( "abs_checked", absolute_value_checked_doc); AddDecimalUnaryKernels(absolute_value_checked.get()); + // abs_checked(duraton) + for (auto unit : TimeUnit::values()) { + auto exec = + ArithmeticExecFromOp(duration(unit)); + DCHECK_OK(absolute_value_checked->AddKernel({duration(unit)}, + OutputType(duration(unit)), exec)); + } DCHECK_OK(registry->AddFunction(std::move(absolute_value_checked))); // ---------------------------------------------------------------------- @@ -1545,12 +1560,27 @@ void RegisterScalarArithmetic(FunctionRegistry* registry) { // ---------------------------------------------------------------------- auto negate = MakeUnaryArithmeticFunction("negate", negate_doc); AddDecimalUnaryKernels(negate.get()); + + // Add neg(duration) -> duration + for (auto unit : TimeUnit::values()) { + auto exec = ArithmeticExecFromOp(duration(unit)); + DCHECK_OK(negate->AddKernel({duration(unit)}, OutputType(duration(unit)), exec)); + } + DCHECK_OK(registry->AddFunction(std::move(negate))); // ---------------------------------------------------------------------- auto negate_checked = MakeUnarySignedArithmeticFunctionNotNull( "negate_checked", negate_checked_doc); AddDecimalUnaryKernels(negate_checked.get()); + + // Add neg_checked(duration) -> duration + for (auto unit : TimeUnit::values()) { + auto exec = ArithmeticExecFromOp(duration(unit)); + DCHECK_OK( + negate_checked->AddKernel({duration(unit)}, OutputType(duration(unit)), exec)); + } + DCHECK_OK(registry->AddFunction(std::move(negate_checked))); // ---------------------------------------------------------------------- @@ -1581,6 +1611,11 @@ void RegisterScalarArithmetic(FunctionRegistry* registry) { // ---------------------------------------------------------------------- auto sign = MakeUnaryArithmeticFunctionWithFixedIntOutType("sign", sign_doc); + // sign(duration) + for (auto unit : TimeUnit::values()) { + auto exec = ScalarUnary::Exec; + DCHECK_OK(sign->AddKernel({duration(unit)}, int8(), std::move(exec))); + } DCHECK_OK(registry->AddFunction(std::move(sign))); // ---------------------------------------------------------------------- diff --git a/cpp/src/arrow/compute/kernels/scalar_compare.cc b/cpp/src/arrow/compute/kernels/scalar_compare.cc index aad648ca275c3..daf8ed76d628d 100644 --- a/cpp/src/arrow/compute/kernels/scalar_compare.cc +++ b/cpp/src/arrow/compute/kernels/scalar_compare.cc @@ -22,6 +22,7 @@ #include "arrow/compute/api_scalar.h" #include "arrow/compute/kernels/common_internal.h" +#include "arrow/type.h" #include "arrow/util/bit_util.h" #include "arrow/util/bitmap_ops.h" @@ -806,6 +807,14 @@ std::shared_ptr MakeScalarMinMax(std::string name, FunctionDoc d kernel.mem_allocation = MemAllocation::type::PREALLOCATE; DCHECK_OK(func->AddKernel(std::move(kernel))); } + for (const auto& ty : DurationTypes()) { + auto exec = GeneratePhysicalNumeric(ty); + ScalarKernel kernel{KernelSignature::Make({ty}, ty, /*is_varargs=*/true), exec, + MinMaxState::Init}; + kernel.null_handling = NullHandling::type::COMPUTED_NO_PREALLOCATE; + kernel.mem_allocation = MemAllocation::type::PREALLOCATE; + DCHECK_OK(func->AddKernel(std::move(kernel))); + } for (const auto& ty : BaseBinaryTypes()) { auto exec = GenerateTypeAgnosticVarBinaryBase(ty); diff --git a/cpp/src/arrow/compute/kernels/scalar_compare_test.cc b/cpp/src/arrow/compute/kernels/scalar_compare_test.cc index 48fa780b03104..8f5952b40500a 100644 --- a/cpp/src/arrow/compute/kernels/scalar_compare_test.cc +++ b/cpp/src/arrow/compute/kernels/scalar_compare_test.cc @@ -1281,7 +1281,7 @@ using CompareNumericBasedTypes = ::testing::Types; using CompareParametricTemporalTypes = - ::testing::Types; + ::testing::Types; using CompareFixedSizeBinaryTypes = ::testing::Types; TYPED_TEST_SUITE(TestVarArgsCompareNumeric, CompareNumericBasedTypes); @@ -2121,6 +2121,11 @@ TEST(TestMaxElementWiseMinElementWise, CommonTemporal) { ScalarFromJSON(date64(), "172800000"), }), ResultWith(ScalarFromJSON(date64(), "86400000"))); + EXPECT_THAT(MinElementWise({ + ScalarFromJSON(duration(TimeUnit::SECOND), "1"), + ScalarFromJSON(duration(TimeUnit::MILLI), "12000"), + }), + ResultWith(ScalarFromJSON(duration(TimeUnit::MILLI), "1000"))); } } // namespace compute diff --git a/cpp/src/arrow/compute/kernels/scalar_temporal_test.cc b/cpp/src/arrow/compute/kernels/scalar_temporal_test.cc index d4482334285bc..8dac6525fe2e6 100644 --- a/cpp/src/arrow/compute/kernels/scalar_temporal_test.cc +++ b/cpp/src/arrow/compute/kernels/scalar_temporal_test.cc @@ -3665,5 +3665,17 @@ TEST_F(ScalarTemporalTest, TestCeilFloorRoundTemporalDate) { CheckScalarUnary("ceil_temporal", arr_ns, arr_ns, &round_to_2_hours); } +TEST_F(ScalarTemporalTest, DurationUnaryArithmetics) { + auto arr = ArrayFromJSON(duration(TimeUnit::SECOND), "[2, -1, null, 3, 0]"); + CheckScalarUnary("negate", arr, + ArrayFromJSON(duration(TimeUnit::SECOND), "[-2, 1, null, -3, 0]")); + CheckScalarUnary("negate_checked", arr, + ArrayFromJSON(duration(TimeUnit::SECOND), "[-2, 1, null, -3, 0]")); + CheckScalarUnary("abs", arr, + ArrayFromJSON(duration(TimeUnit::SECOND), "[2, 1, null, 3, 0]")); + CheckScalarUnary("abs_checked", arr, + ArrayFromJSON(duration(TimeUnit::SECOND), "[2, 1, null, 3, 0]")); + CheckScalarUnary("sign", arr, ArrayFromJSON(int8(), "[1, -1, null, 1, 0]")); +} } // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/compute/kernels/scalar_validity.cc b/cpp/src/arrow/compute/kernels/scalar_validity.cc index 6b1cec0f5ccc6..8505fc4c6e0af 100644 --- a/cpp/src/arrow/compute/kernels/scalar_validity.cc +++ b/cpp/src/arrow/compute/kernels/scalar_validity.cc @@ -169,6 +169,7 @@ std::shared_ptr MakeIsFiniteFunction(std::string name, FunctionD func->AddKernel({InputType(Type::DECIMAL128)}, boolean(), ConstBoolExec)); DCHECK_OK( func->AddKernel({InputType(Type::DECIMAL256)}, boolean(), ConstBoolExec)); + DCHECK_OK(func->AddKernel({InputType(Type::DURATION)}, boolean(), ConstBoolExec)); return func; } @@ -187,7 +188,8 @@ std::shared_ptr MakeIsInfFunction(std::string name, FunctionDoc func->AddKernel({InputType(Type::DECIMAL128)}, boolean(), ConstBoolExec)); DCHECK_OK( func->AddKernel({InputType(Type::DECIMAL256)}, boolean(), ConstBoolExec)); - + DCHECK_OK( + func->AddKernel({InputType(Type::DURATION)}, boolean(), ConstBoolExec)); return func; } @@ -205,6 +207,8 @@ std::shared_ptr MakeIsNanFunction(std::string name, FunctionDoc func->AddKernel({InputType(Type::DECIMAL128)}, boolean(), ConstBoolExec)); DCHECK_OK( func->AddKernel({InputType(Type::DECIMAL256)}, boolean(), ConstBoolExec)); + DCHECK_OK( + func->AddKernel({InputType(Type::DURATION)}, boolean(), ConstBoolExec)); return func; } diff --git a/cpp/src/arrow/compute/kernels/scalar_validity_test.cc b/cpp/src/arrow/compute/kernels/scalar_validity_test.cc index 94d951c838209..d1462838f3be6 100644 --- a/cpp/src/arrow/compute/kernels/scalar_validity_test.cc +++ b/cpp/src/arrow/compute/kernels/scalar_validity_test.cc @@ -103,6 +103,9 @@ TEST(TestValidityKernels, IsFinite) { } CheckScalar("is_finite", {std::make_shared(4)}, ArrayFromJSON(boolean(), "[null, null, null, null]")); + CheckScalar("is_finite", + {ArrayFromJSON(duration(TimeUnit::SECOND), "[0, 1, 42, null]")}, + ArrayFromJSON(boolean(), "[true, true, true, null]")); } TEST(TestValidityKernels, IsInf) { @@ -116,6 +119,8 @@ TEST(TestValidityKernels, IsInf) { } CheckScalar("is_inf", {std::make_shared(4)}, ArrayFromJSON(boolean(), "[null, null, null, null]")); + CheckScalar("is_inf", {ArrayFromJSON(duration(TimeUnit::SECOND), "[0, 1, 42, null]")}, + ArrayFromJSON(boolean(), "[false, false, false, null]")); } TEST(TestValidityKernels, IsNan) { @@ -129,6 +134,8 @@ TEST(TestValidityKernels, IsNan) { } CheckScalar("is_nan", {std::make_shared(4)}, ArrayFromJSON(boolean(), "[null, null, null, null]")); + CheckScalar("is_nan", {ArrayFromJSON(duration(TimeUnit::SECOND), "[0, 1, 42, null]")}, + ArrayFromJSON(boolean(), "[false, false, false, null]")); } TEST(TestValidityKernels, IsValidIsNullNullType) { diff --git a/docs/source/cpp/compute.rst b/docs/source/cpp/compute.rst index 17d003b261dca..e7310d2c0c711 100644 --- a/docs/source/cpp/compute.rst +++ b/docs/source/cpp/compute.rst @@ -458,45 +458,45 @@ floating-point arguments will cast all arguments to floating-point, while mixed decimal and integer arguments will cast all arguments to decimals. Mixed time resolution temporal inputs will be cast to finest input resolution. -+------------------+--------+-------------------------+----------------------+-------+ -| Function name | Arity | Input types | Output type | Notes | -+==================+========+=========================+======================+=======+ -| abs | Unary | Numeric | Numeric | | -+------------------+--------+-------------------------+----------------------+-------+ -| abs_checked | Unary | Numeric | Numeric | | -+------------------+--------+-------------------------+----------------------+-------+ -| add | Binary | Numeric/Temporal | Numeric/Temporal | \(1) | -+------------------+--------+-------------------------+----------------------+-------+ -| add_checked | Binary | Numeric/Temporal | Numeric/Temporal | \(1) | -+------------------+--------+-------------------------+----------------------+-------+ -| divide | Binary | Numeric/Temporal | Numeric/Temporal | \(1) | -+------------------+--------+-------------------------+----------------------+-------+ -| divide_checked | Binary | Numeric/Temporal | Numeric/Temporal | \(1) | -+------------------+--------+-------------------------+----------------------+-------+ -| exp | Unary | Numeric | Float32/Float64 | | -+------------------+--------+-------------------------+----------------------+-------+ -| multiply | Binary | Numeric/Temporal | Numeric/Temporal | \(1) | -+------------------+--------+-------------------------+----------------------+-------+ -| multiply_checked | Binary | Numeric/Temporal | Numeric/Temporal | \(1) | -+------------------+--------+-------------------------+----------------------+-------+ -| negate | Unary | Numeric | Numeric | | -+------------------+--------+-------------------------+----------------------+-------+ -| negate_checked | Unary | Signed Numeric | Signed Numeric | | -+------------------+--------+-------------------------+----------------------+-------+ -| power | Binary | Numeric | Numeric | | -+------------------+--------+-------------------------+----------------------+-------+ -| power_checked | Binary | Numeric | Numeric | | -+------------------+--------+-------------------------+----------------------+-------+ -| sign | Unary | Numeric | Int8/Float32/Float64 | \(2) | -+------------------+--------+-------------------------+----------------------+-------+ -| sqrt | Unary | Numeric | Numeric | | -+------------------+--------+-------------------------+----------------------+-------+ -| sqrt_checked | Unary | Numeric | Numeric | | -+------------------+--------+-------------------------+----------------------+-------+ -| subtract | Binary | Numeric/Temporal | Numeric/Temporal | \(1) | -+------------------+--------+-------------------------+----------------------+-------+ -| subtract_checked | Binary | Numeric/Temporal | Numeric/Temporal | \(1) | -+------------------+--------+-------------------------+----------------------+-------+ ++------------------+--------+-------------------------+---------------------------+-------+ +| Function name | Arity | Input types | Output type | Notes | ++==================+========+=========================+===========================+=======+ +| abs | Unary | Numeric/Duration | Numeric/Duration | | ++------------------+--------+-------------------------+---------------------------+-------+ +| abs_checked | Unary | Numeric/Duration | Numeric/Duration | | ++------------------+--------+-------------------------+---------------------------+-------+ +| add | Binary | Numeric/Temporal | Numeric/Temporal | \(1) | ++------------------+--------+-------------------------+---------------------------+-------+ +| add_checked | Binary | Numeric/Temporal | Numeric/Temporal | \(1) | ++------------------+--------+-------------------------+---------------------------+-------+ +| divide | Binary | Numeric/Temporal | Numeric/Temporal | \(1) | ++------------------+--------+-------------------------+---------------------------+-------+ +| divide_checked | Binary | Numeric/Temporal | Numeric/Temporal | \(1) | ++------------------+--------+-------------------------+---------------------------+-------+ +| exp | Unary | Numeric | Float32/Float64 | | ++------------------+--------+-------------------------+---------------------------+-------+ +| multiply | Binary | Numeric/Temporal | Numeric/Temporal | \(1) | ++------------------+--------+-------------------------+---------------------------+-------+ +| multiply_checked | Binary | Numeric/Temporal | Numeric/Temporal | \(1) | ++------------------+--------+-------------------------+---------------------------+-------+ +| negate | Unary | Numeric/Duration | Numeric/Duration | | ++------------------+--------+-------------------------+---------------------------+-------+ +| negate_checked | Unary | Signed Numeric/Duration | Signed Numeric/Duration | | ++------------------+--------+-------------------------+---------------------------+-------+ +| power | Binary | Numeric | Numeric | | ++------------------+--------+-------------------------+---------------------------+-------+ +| power_checked | Binary | Numeric | Numeric | | ++------------------+--------+-------------------------+---------------------------+-------+ +| sign | Unary | Numeric/Duration | Int8/Float32/Float64 | \(2) | ++------------------+--------+-------------------------+---------------------------+-------+ +| sqrt | Unary | Numeric | Numeric | | ++------------------+--------+-------------------------+---------------------------+-------+ +| sqrt_checked | Unary | Numeric | Numeric | | ++------------------+--------+-------------------------+---------------------------+-------+ +| subtract | Binary | Numeric/Temporal | Numeric/Temporal | \(1) | ++------------------+--------+-------------------------+---------------------------+-------+ +| subtract_checked | Binary | Numeric/Temporal | Numeric/Temporal | \(1) | ++------------------+--------+-------------------------+---------------------------+-------+ * \(1) Precision and scale of computed DECIMAL results