Skip to content

Commit

Permalink
Now that we are sorting by keys I needed to fix to also sort by keys …
Browse files Browse the repository at this point in the history
…in NaiveGroupBy. Also, we need to decode dictionary columns which cannot be sorted
  • Loading branch information
westonpace committed Jan 4, 2023
1 parent 9e71dc2 commit 412a32d
Showing 1 changed file with 42 additions and 17 deletions.
59 changes: 42 additions & 17 deletions cpp/src/arrow/compute/kernels/hash_aggregate_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
#include "arrow/util/int_util_overflow.h"
#include "arrow/util/key_value_metadata.h"
#include "arrow/util/logging.h"
#include "arrow/util/string.h"
#include "arrow/util/thread_pool.h"
#include "arrow/util/vector.h"

Expand All @@ -66,6 +67,7 @@ namespace arrow {
using internal::BitmapReader;
using internal::checked_cast;
using internal::checked_pointer_cast;
using internal::ToChars;

namespace compute {
namespace {
Expand Down Expand Up @@ -112,12 +114,25 @@ Result<Datum> NaiveGroupBy(std::vector<Datum> arguments, std::vector<Datum> keys

int i = 0;
ARROW_ASSIGN_OR_RAISE(auto uniques, grouper->GetUniques());
std::vector<SortKey> sort_keys;
std::vector<std::shared_ptr<Field>> sort_table_fields;
for (const Datum& key : uniques.values) {
out_columns.push_back(key.make_array());
out_names.push_back("key_" + std::to_string(i++));
sort_keys.emplace_back(FieldRef(i));
sort_table_fields.push_back(field("key_" + ToChars(i), key.type()));
out_names.push_back("key_" + ToChars(i++));
}

return StructArray::Make(std::move(out_columns), std::move(out_names));
// Return a struct array sorted by the keys
SortOptions sort_options(std::move(sort_keys));
ARROW_ASSIGN_OR_RAISE(std::shared_ptr<RecordBatch> sort_batch,
uniques.ToRecordBatch(schema(std::move(sort_table_fields))));
ARROW_ASSIGN_OR_RAISE(std::shared_ptr<Array> sorted_indices,
SortIndices(sort_batch, sort_options));

ARROW_ASSIGN_OR_RAISE(auto struct_arr,
StructArray::Make(std::move(out_columns), std::move(out_names)));
return Take(struct_arr, sorted_indices);
}

Result<Datum> RunGroupBy(const BatchesWithSchema& input,
Expand Down Expand Up @@ -185,8 +200,18 @@ Result<Datum> RunGroupBy(const BatchesWithSchema& input,
std::vector<SortKey> sort_keys;
for (std::size_t i = 0; i < key_names.size(); i++) {
const std::shared_ptr<Array>& arr = out_arrays[i + aggregates.size()];
key_columns.push_back(arr);
key_fields.push_back(field("name_does_not_matter", arr->type()));
if (arr->type_id() == Type::DICTIONARY) {
// Can't sort dictionary columns so need to decode
auto dict_arr = checked_pointer_cast<DictionaryArray>(arr);
ARROW_ASSIGN_OR_RAISE(auto decoded_arr,
Take(*dict_arr->dictionary(), *dict_arr->indices()));
key_columns.push_back(decoded_arr);
key_fields.push_back(
field("name_does_not_matter", dict_arr->dict_type()->value_type()));
} else {
key_columns.push_back(arr);
key_fields.push_back(field("name_does_not_matter", arr->type()));
}
sort_keys.emplace_back(static_cast<int>(i));
}
std::shared_ptr<Schema> key_schema = schema(std::move(key_fields));
Expand All @@ -212,11 +237,11 @@ Result<Datum> RunGroupBy(const std::vector<Datum>& arguments,
FieldVector scan_fields(arguments.size() + keys.size());
std::vector<std::string> key_names(keys.size());
for (size_t i = 0; i < arguments.size(); ++i) {
auto name = std::string("agg_") + std::to_string(i);
auto name = std::string("agg_") + ToChars(i);
scan_fields[i] = field(name, arguments[i].type());
}
for (size_t i = 0; i < keys.size(); ++i) {
auto name = std::string("key_") + std::to_string(i);
auto name = std::string("key_") + ToChars(i);
scan_fields[arguments.size() + i] = field(name, keys[i].type());
key_names[i] = std::move(name);
}
Expand Down Expand Up @@ -273,7 +298,7 @@ Result<Datum> GroupByTest(const std::vector<Datum>& arguments,
int idx = 0;
for (auto t_agg : aggregates) {
internal_aggregates.push_back(
{t_agg.function, t_agg.options, "agg_" + std::to_string(idx), t_agg.function});
{t_agg.function, t_agg.options, "agg_" + ToChars(idx), t_agg.function});
idx = idx + 1;
}
return RunGroupBy(arguments, keys, internal_aggregates, use_threads);
Expand Down Expand Up @@ -439,7 +464,7 @@ struct TestGrouper {
ValidateOutput(*ids);

for (int i = 0; i < key_batch.num_values(); ++i) {
SCOPED_TRACE(std::to_string(i) + "th key array");
SCOPED_TRACE(ToChars(i) + "th key array");
auto original =
key_batch[i].is_array()
? key_batch[i].make_array()
Expand Down Expand Up @@ -618,7 +643,7 @@ TEST(Grouper, DoubleStringInt64Key) {
TEST(Grouper, RandomInt64Keys) {
TestGrouper g({int64()});
for (int i = 0; i < 4; ++i) {
SCOPED_TRACE(std::to_string(i) + "th key batch");
SCOPED_TRACE(ToChars(i) + "th key batch");

ExecBatch key_batch{
*random::GenerateBatch(g.key_schema_->fields(), 1 << 12, 0xDEADBEEF)};
Expand All @@ -629,7 +654,7 @@ TEST(Grouper, RandomInt64Keys) {
TEST(Grouper, RandomStringInt64Keys) {
TestGrouper g({utf8(), int64()});
for (int i = 0; i < 4; ++i) {
SCOPED_TRACE(std::to_string(i) + "th key batch");
SCOPED_TRACE(ToChars(i) + "th key batch");

ExecBatch key_batch{
*random::GenerateBatch(g.key_schema_->fields(), 1 << 12, 0xDEADBEEF)};
Expand All @@ -640,7 +665,7 @@ TEST(Grouper, RandomStringInt64Keys) {
TEST(Grouper, RandomStringInt64DoubleInt32Keys) {
TestGrouper g({utf8(), int64(), float64(), int32()});
for (int i = 0; i < 4; ++i) {
SCOPED_TRACE(std::to_string(i) + "th key batch");
SCOPED_TRACE(ToChars(i) + "th key batch");

ExecBatch key_batch{
*random::GenerateBatch(g.key_schema_->fields(), 1 << 12, 0xDEADBEEF)};
Expand Down Expand Up @@ -3411,7 +3436,7 @@ TEST(GroupBy, SumOnlyStringAndDictKeys) {
SCOPED_TRACE("key type: " + key_type->ToString());

auto batch = RecordBatchFromJSON(
schema({field("argument", float64()), field("key", key_type)}), R"([
schema({field("agg_0", float64()), field("key", key_type)}), R"([
[1.0, "alfa"],
[null, "alfa"],
[0.0, "beta"],
Expand All @@ -3426,7 +3451,7 @@ TEST(GroupBy, SumOnlyStringAndDictKeys) {

ASSERT_OK_AND_ASSIGN(
Datum aggregated_and_grouped,
RunGroupBy({batch->GetColumnByName("argument")}, {batch->GetColumnByName("key")},
RunGroupBy({batch->GetColumnByName("agg_0")}, {batch->GetColumnByName("key")},
{
{"hash_sum", nullptr, "agg_0", "hash_sum"},
}));
Expand Down Expand Up @@ -3513,9 +3538,9 @@ TEST(GroupBy, RandomArraySum) {
for (auto null_probability : {0.0, 0.01, 0.5, 1.0}) {
auto batch = random::GenerateBatch(
{
field("argument", float32(),
key_value_metadata(
{{"null_probability", std::to_string(null_probability)}})),
field(
"agg_0", float32(),
key_value_metadata({{"null_probability", ToChars(null_probability)}})),
field("key", int64(), key_value_metadata({{"min", "0"}, {"max", "100"}})),
},
length, 0xDEADBEEF);
Expand All @@ -3524,7 +3549,7 @@ TEST(GroupBy, RandomArraySum) {
{
{"hash_sum", options, "agg_0", "hash_sum"},
},
{batch->GetColumnByName("argument")}, {batch->GetColumnByName("key")});
{batch->GetColumnByName("agg_0")}, {batch->GetColumnByName("key")});
}
}
}
Expand Down

0 comments on commit 412a32d

Please sign in to comment.