Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 14 additions & 6 deletions cpp/src/arrow/compute/kernels/aggregate_basic.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,21 @@ Status AggregateFinalize(KernelContext* ctx, Datum* out) {

void AddAggKernel(std::shared_ptr<KernelSignature> sig, KernelInit init,
ScalarAggregateFunction* func, SimdLevel::type simd_level) {
ScalarAggregateKernel kernel(std::move(sig), init, AggregateConsume, AggregateMerge,
AggregateFinalize);
ScalarAggregateKernel kernel(std::move(sig), std::move(init), AggregateConsume,
AggregateMerge, AggregateFinalize);
// Set the simd level
kernel.simd_level = simd_level;
DCHECK_OK(func->AddKernel(std::move(kernel)));
}

void AddAggKernel(std::shared_ptr<KernelSignature> sig, KernelInit init,
ScalarAggregateFinalize finalize, ScalarAggregateFunction* func,
SimdLevel::type simd_level) {
ScalarAggregateKernel kernel(std::move(sig), std::move(init), AggregateConsume,
AggregateMerge, std::move(finalize));
// Set the simd level
kernel.simd_level = simd_level;
DCHECK_OK(func->AddKernel(kernel));
DCHECK_OK(func->AddKernel(std::move(kernel)));
}

namespace aggregate {
Expand Down Expand Up @@ -314,9 +324,7 @@ void AddMinOrMaxAggKernel(ScalarAggregateFunction* func,

// Note SIMD level is always NONE, but the convenience kernel will
// dispatch to an appropriate implementation
ScalarAggregateKernel kernel(std::move(sig), std::move(init), AggregateConsume,
AggregateMerge, std::move(finalize));
DCHECK_OK(func->AddKernel(kernel));
AddAggKernel(std::move(sig), std::move(init), std::move(finalize), func);
}

// ----------------------------------------------------------------------
Expand Down
4 changes: 4 additions & 0 deletions cpp/src/arrow/compute/kernels/aggregate_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,10 @@ void AddAggKernel(std::shared_ptr<KernelSignature> sig, KernelInit init,
ScalarAggregateFunction* func,
SimdLevel::type simd_level = SimdLevel::NONE);

void AddAggKernel(std::shared_ptr<KernelSignature> sig, KernelInit init,
ScalarAggregateFinalize finalize, ScalarAggregateFunction* func,
SimdLevel::type simd_level = SimdLevel::NONE);

namespace detail {

using arrow::internal::VisitSetBitRunsVoid;
Expand Down
51 changes: 50 additions & 1 deletion cpp/src/arrow/compute/kernels/aggregate_tdigest.cc
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,13 @@ const FunctionDoc tdigest_doc{
{"array"},
"TDigestOptions"};

const FunctionDoc approximate_median_doc{
"Approximate median of a numeric array with T-Digest algorithm",
("Nulls and NaNs are ignored.\n"
"A null scalar is returned if there is no valid data point."),
{"array"},
"ScalarAggregateOptions"};

std::shared_ptr<ScalarAggregateFunction> AddTDigestAggKernels() {
static auto default_tdigest_options = TDigestOptions::Defaults();
auto func = std::make_shared<ScalarAggregateFunction>(
Expand All @@ -175,10 +182,52 @@ std::shared_ptr<ScalarAggregateFunction> AddTDigestAggKernels() {
return func;
}

std::shared_ptr<ScalarAggregateFunction> AddApproximateMedianAggKernels(
const ScalarAggregateFunction* tdigest_func) {
static ScalarAggregateOptions default_scalar_aggregate_options;

auto median = std::make_shared<ScalarAggregateFunction>(
"approximate_median", Arity::Unary(), &approximate_median_doc,
&default_scalar_aggregate_options);

auto sig =
KernelSignature::Make({InputType(ValueDescr::ANY)}, ValueDescr::Scalar(float64()));

auto init = [tdigest_func](
KernelContext* ctx,
const KernelInitArgs& args) -> Result<std::unique_ptr<KernelState>> {
std::vector<ValueDescr> inputs = args.inputs;
ARROW_ASSIGN_OR_RAISE(auto kernel, tdigest_func->DispatchBest(&inputs));
const auto& scalar_options =
checked_cast<const ScalarAggregateOptions&>(*args.options);
TDigestOptions options;
// Default q = 0.5
options.min_count = scalar_options.min_count;
options.skip_nulls = scalar_options.skip_nulls;
KernelInitArgs new_args{kernel, inputs, &options};
return kernel->init(ctx, new_args);
};

auto finalize = [](KernelContext* ctx, Datum* out) -> Status {
Datum temp;
RETURN_NOT_OK(checked_cast<ScalarAggregator*>(ctx->state())->Finalize(ctx, &temp));
const auto arr = temp.make_array();
DCHECK_EQ(arr->length(), 1);
return arr->GetScalar(0).Value(out);
};

AddAggKernel(std::move(sig), std::move(init), std::move(finalize), median.get());
return median;
}

} // namespace

void RegisterScalarAggregateTDigest(FunctionRegistry* registry) {
DCHECK_OK(registry->AddFunction(AddTDigestAggKernels()));
auto tdigest = AddTDigestAggKernels();
DCHECK_OK(registry->AddFunction(tdigest));

auto approx_median = AddApproximateMedianAggKernels(tdigest.get());
DCHECK_OK(registry->AddFunction(approx_median));
}

} // namespace internal
Expand Down
74 changes: 74 additions & 0 deletions cpp/src/arrow/compute/kernels/aggregate_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3447,5 +3447,79 @@ TEST(TestTDigestKernel, Options) {
ResultWith(ArrayFromJSON(ty, "[null]")));
}

TEST(TestTDigestKernel, ApproximateMedian) {
// This is a wrapper for TDigest
for (const auto& ty : {float64(), int64(), uint16()}) {
ScalarAggregateOptions keep_nulls(/*skip_nulls=*/false, /*min_count=*/0);
ScalarAggregateOptions min_count(/*skip_nulls=*/true, /*min_count=*/3);
ScalarAggregateOptions keep_nulls_min_count(/*skip_nulls=*/false, /*min_count=*/3);

EXPECT_THAT(
CallFunction("approximate_median", {ArrayFromJSON(ty, "[1, 2, 3]")}, &keep_nulls),
ResultWith(ScalarFromJSON(float64(), "2.0")));
EXPECT_THAT(CallFunction("approximate_median", {ArrayFromJSON(ty, "[1, 2, 3, null]")},
&keep_nulls),
ResultWith(ScalarFromJSON(float64(), "null")));
EXPECT_THAT(
CallFunction("approximate_median",
{ChunkedArrayFromJSON(ty, {"[1, 2]", "[]", "[3]"})}, &keep_nulls),
ResultWith(ScalarFromJSON(float64(), "2.0")));
EXPECT_THAT(CallFunction("approximate_median",
{ChunkedArrayFromJSON(ty, {"[1, 2]", "[null]", "[3]"})},
&keep_nulls),
ResultWith(ScalarFromJSON(float64(), "null")));
EXPECT_THAT(
CallFunction("approximate_median", {ScalarFromJSON(ty, "1")}, &keep_nulls),
ResultWith(ScalarFromJSON(float64(), "1.0")));
EXPECT_THAT(
CallFunction("approximate_median", {ScalarFromJSON(ty, "null")}, &keep_nulls),
ResultWith(ScalarFromJSON(float64(), "null")));

EXPECT_THAT(CallFunction("approximate_median", {ArrayFromJSON(ty, "[1, 2, 3, null]")},
&min_count),
ResultWith(ScalarFromJSON(float64(), "2.0")));
EXPECT_THAT(CallFunction("approximate_median", {ArrayFromJSON(ty, "[1, 2, null]")},
&min_count),
ResultWith(ScalarFromJSON(float64(), "null")));
EXPECT_THAT(
CallFunction("approximate_median",
{ChunkedArrayFromJSON(ty, {"[1, 2]", "[]", "[3]"})}, &keep_nulls),
ResultWith(ScalarFromJSON(float64(), "2.0")));
EXPECT_THAT(CallFunction("approximate_median",
{ChunkedArrayFromJSON(ty, {"[1, 2]", "[null]", "[3]"})},
&keep_nulls),
ResultWith(ScalarFromJSON(float64(), "null")));
EXPECT_THAT(CallFunction("approximate_median", {ScalarFromJSON(ty, "1")}, &min_count),
ResultWith(ScalarFromJSON(float64(), "null")));
EXPECT_THAT(
CallFunction("approximate_median", {ScalarFromJSON(ty, "null")}, &min_count),
ResultWith(ScalarFromJSON(float64(), "null")));

EXPECT_THAT(CallFunction("approximate_median", {ArrayFromJSON(ty, "[1, 2, 3]")},
&keep_nulls_min_count),
ResultWith(ScalarFromJSON(float64(), "2.0")));
EXPECT_THAT(CallFunction("approximate_median", {ArrayFromJSON(ty, "[1, 2]")},
&keep_nulls_min_count),
ResultWith(ScalarFromJSON(float64(), "null")));
EXPECT_THAT(CallFunction("approximate_median",
{ChunkedArrayFromJSON(ty, {"[1, 2]", "[]", "[3]"})},
&keep_nulls_min_count),
ResultWith(ScalarFromJSON(float64(), "2.0")));
EXPECT_THAT(CallFunction("approximate_median",
{ChunkedArrayFromJSON(ty, {"[1, 2]", "[null]", "[3]"})},
&keep_nulls_min_count),
ResultWith(ScalarFromJSON(float64(), "null")));
EXPECT_THAT(CallFunction("approximate_median", {ArrayFromJSON(ty, "[1, 2, 3, null]")},
&keep_nulls_min_count),
ResultWith(ScalarFromJSON(float64(), "null")));
EXPECT_THAT(CallFunction("approximate_median", {ScalarFromJSON(ty, "1")},
&keep_nulls_min_count),
ResultWith(ScalarFromJSON(float64(), "null")));
EXPECT_THAT(CallFunction("approximate_median", {ScalarFromJSON(ty, "null")},
&keep_nulls_min_count),
ResultWith(ScalarFromJSON(float64(), "null")));
}
}

} // namespace compute
} // namespace arrow
48 changes: 48 additions & 0 deletions cpp/src/arrow/compute/kernels/hash_aggregate.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1734,6 +1734,37 @@ struct GroupedTDigestFactory {
InputType argument_type;
};

HashAggregateKernel MakeApproximateMedianKernel(HashAggregateFunction* tdigest_func) {
HashAggregateKernel kernel;
kernel.init = [tdigest_func](
KernelContext* ctx,
const KernelInitArgs& args) -> Result<std::unique_ptr<KernelState>> {
std::vector<ValueDescr> inputs = args.inputs;
ARROW_ASSIGN_OR_RAISE(auto kernel, tdigest_func->DispatchBest(&inputs));
const auto& scalar_options =
checked_cast<const ScalarAggregateOptions&>(*args.options);
TDigestOptions options;
// Default q = 0.5
options.min_count = scalar_options.min_count;
options.skip_nulls = scalar_options.skip_nulls;
KernelInitArgs new_args{kernel, inputs, &options};
return kernel->init(ctx, new_args);
};
kernel.signature =
KernelSignature::Make({InputType(ValueDescr::ANY), InputType::Array(Type::UINT32)},
ValueDescr::Array(float64()));
kernel.resize = HashAggregateResize;
kernel.consume = HashAggregateConsume;
kernel.merge = HashAggregateMerge;
kernel.finalize = [](KernelContext* ctx, Datum* out) {
ARROW_ASSIGN_OR_RAISE(Datum temp,
checked_cast<GroupedAggregator*>(ctx->state())->Finalize());
*out = temp.array_as<FixedSizeListArray>()->values();
return Status::OK();
};
return kernel;
}

// ----------------------------------------------------------------------
// MinMax implementation

Expand Down Expand Up @@ -2636,6 +2667,13 @@ const FunctionDoc hash_tdigest_doc{
{"array", "group_id_array"},
"TDigestOptions"};

const FunctionDoc hash_approximate_median_doc{
"Calculate approximate medians of a numeric array with the T-Digest algorithm",
("Nulls and NaNs are ignored.\n"
"Null is emitted for a group if there are no valid data points."),
{"array", "group_id_array"},
"ScalarAggregateOptions"};

const FunctionDoc hash_min_max_doc{
"Compute the minimum and maximum values of a numeric array",
("Null values are ignored by default.\n"
Expand Down Expand Up @@ -2760,6 +2798,7 @@ void RegisterHashAggregateBasic(FunctionRegistry* registry) {
DCHECK_OK(registry->AddFunction(std::move(func)));
}

HashAggregateFunction* tdigest_func = nullptr;
{
auto func = std::make_shared<HashAggregateFunction>(
"hash_tdigest", Arity::Binary(), &hash_tdigest_doc, &default_tdigest_options);
Expand All @@ -2769,6 +2808,15 @@ void RegisterHashAggregateBasic(FunctionRegistry* registry) {
AddHashAggKernels(UnsignedIntTypes(), GroupedTDigestFactory::Make, func.get()));
DCHECK_OK(
AddHashAggKernels(FloatingPointTypes(), GroupedTDigestFactory::Make, func.get()));
tdigest_func = func.get();
DCHECK_OK(registry->AddFunction(std::move(func)));
}

{
auto func = std::make_shared<HashAggregateFunction>(
"hash_approximate_median", Arity::Binary(), &hash_approximate_median_doc,
&default_scalar_aggregate_options);
DCHECK_OK(func->AddKernel(MakeApproximateMedianKernel(tdigest_func)));
DCHECK_OK(registry->AddFunction(std::move(func)));
}

Expand Down
64 changes: 64 additions & 0 deletions cpp/src/arrow/compute/kernels/hash_aggregate_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1154,6 +1154,70 @@ TEST(GroupBy, TDigest) {
/*verbose=*/true);
}

TEST(GroupBy, ApproximateMedian) {
for (const auto& type : {float64(), int8()}) {
auto batch =
RecordBatchFromJSON(schema({field("argument", type), field("key", int64())}), R"([
[1, 1],
[null, 1],
[0, 2],
[null, 3],
[1, 4],
[4, null],
[3, 1],
[0, 2],
[-1, 2],
[1, null],
[null, 3],
[1, 4],
[1, 4],
[null, 4]
])");

ScalarAggregateOptions options;
ScalarAggregateOptions keep_nulls(
/*skip_nulls=*/false, /*min_count=*/0);
ScalarAggregateOptions min_count(
/*skip_nulls=*/true, /*min_count=*/3);
ScalarAggregateOptions keep_nulls_min_count(
/*skip_nulls=*/false, /*min_count=*/3);
ASSERT_OK_AND_ASSIGN(Datum aggregated_and_grouped,
internal::GroupBy(
{
batch->GetColumnByName("argument"),
batch->GetColumnByName("argument"),
batch->GetColumnByName("argument"),
batch->GetColumnByName("argument"),
},
{
batch->GetColumnByName("key"),
},
{
{"hash_approximate_median", &options},
{"hash_approximate_median", &keep_nulls},
{"hash_approximate_median", &min_count},
{"hash_approximate_median", &keep_nulls_min_count},
}));

AssertDatumsApproxEqual(ArrayFromJSON(struct_({
field("hash_approximate_median", float64()),
field("hash_approximate_median", float64()),
field("hash_approximate_median", float64()),
field("hash_approximate_median", float64()),
field("key_0", int64()),
}),
R"([
[1.0, null, null, null, 1],
[0.0, 0.0, 0.0, 0.0, 2],
[null, null, null, null, 3],
[1.0, null, 1.0, null, 4],
[1.0, 1.0, null, null, null]
])"),
aggregated_and_grouped,
/*verbose=*/true);
}
}

TEST(GroupBy, StddevVarianceTDigestScalar) {
BatchesWithSchema input;
input.batches = {
Expand Down
Loading