diff --git a/cpp/src/arrow/compute/kernels/hash_aggregate_test.cc b/cpp/src/arrow/compute/kernels/hash_aggregate_test.cc index c14381262366e..2f26577f5b291 100644 --- a/cpp/src/arrow/compute/kernels/hash_aggregate_test.cc +++ b/cpp/src/arrow/compute/kernels/hash_aggregate_test.cc @@ -55,6 +55,7 @@ #include "arrow/util/int_util_overflow.h" #include "arrow/util/key_value_metadata.h" #include "arrow/util/logging.h" +#include "arrow/util/string.h" #include "arrow/util/thread_pool.h" #include "arrow/util/vector.h" @@ -66,6 +67,7 @@ namespace arrow { using internal::BitmapReader; using internal::checked_cast; using internal::checked_pointer_cast; +using internal::ToChars; namespace compute { namespace { @@ -112,12 +114,25 @@ Result NaiveGroupBy(std::vector arguments, std::vector keys int i = 0; ARROW_ASSIGN_OR_RAISE(auto uniques, grouper->GetUniques()); + std::vector sort_keys; + std::vector> sort_table_fields; for (const Datum& key : uniques.values) { out_columns.push_back(key.make_array()); - out_names.push_back("key_" + std::to_string(i++)); + sort_keys.emplace_back(FieldRef(i)); + sort_table_fields.push_back(field("key_" + ToChars(i), key.type())); + out_names.push_back("key_" + ToChars(i++)); } - return StructArray::Make(std::move(out_columns), std::move(out_names)); + // Return a struct array sorted by the keys + SortOptions sort_options(std::move(sort_keys)); + ARROW_ASSIGN_OR_RAISE(std::shared_ptr sort_batch, + uniques.ToRecordBatch(schema(std::move(sort_table_fields)))); + ARROW_ASSIGN_OR_RAISE(std::shared_ptr sorted_indices, + SortIndices(sort_batch, sort_options)); + + ARROW_ASSIGN_OR_RAISE(auto struct_arr, + StructArray::Make(std::move(out_columns), std::move(out_names))); + return Take(struct_arr, sorted_indices); } Result RunGroupBy(const BatchesWithSchema& input, @@ -185,8 +200,18 @@ Result RunGroupBy(const BatchesWithSchema& input, std::vector sort_keys; for (std::size_t i = 0; i < key_names.size(); i++) { const std::shared_ptr& arr = out_arrays[i + aggregates.size()]; - key_columns.push_back(arr); - key_fields.push_back(field("name_does_not_matter", arr->type())); + if (arr->type_id() == Type::DICTIONARY) { + // Can't sort dictionary columns so need to decode + auto dict_arr = checked_pointer_cast(arr); + ARROW_ASSIGN_OR_RAISE(auto decoded_arr, + Take(*dict_arr->dictionary(), *dict_arr->indices())); + key_columns.push_back(decoded_arr); + key_fields.push_back( + field("name_does_not_matter", dict_arr->dict_type()->value_type())); + } else { + key_columns.push_back(arr); + key_fields.push_back(field("name_does_not_matter", arr->type())); + } sort_keys.emplace_back(static_cast(i)); } std::shared_ptr key_schema = schema(std::move(key_fields)); @@ -212,11 +237,11 @@ Result RunGroupBy(const std::vector& arguments, FieldVector scan_fields(arguments.size() + keys.size()); std::vector key_names(keys.size()); for (size_t i = 0; i < arguments.size(); ++i) { - auto name = std::string("agg_") + std::to_string(i); + auto name = std::string("agg_") + ToChars(i); scan_fields[i] = field(name, arguments[i].type()); } for (size_t i = 0; i < keys.size(); ++i) { - auto name = std::string("key_") + std::to_string(i); + auto name = std::string("key_") + ToChars(i); scan_fields[arguments.size() + i] = field(name, keys[i].type()); key_names[i] = std::move(name); } @@ -273,7 +298,7 @@ Result GroupByTest(const std::vector& arguments, int idx = 0; for (auto t_agg : aggregates) { internal_aggregates.push_back( - {t_agg.function, t_agg.options, "agg_" + std::to_string(idx), t_agg.function}); + {t_agg.function, t_agg.options, "agg_" + ToChars(idx), t_agg.function}); idx = idx + 1; } return RunGroupBy(arguments, keys, internal_aggregates, use_threads); @@ -439,7 +464,7 @@ struct TestGrouper { ValidateOutput(*ids); for (int i = 0; i < key_batch.num_values(); ++i) { - SCOPED_TRACE(std::to_string(i) + "th key array"); + SCOPED_TRACE(ToChars(i) + "th key array"); auto original = key_batch[i].is_array() ? key_batch[i].make_array() @@ -618,7 +643,7 @@ TEST(Grouper, DoubleStringInt64Key) { TEST(Grouper, RandomInt64Keys) { TestGrouper g({int64()}); for (int i = 0; i < 4; ++i) { - SCOPED_TRACE(std::to_string(i) + "th key batch"); + SCOPED_TRACE(ToChars(i) + "th key batch"); ExecBatch key_batch{ *random::GenerateBatch(g.key_schema_->fields(), 1 << 12, 0xDEADBEEF)}; @@ -629,7 +654,7 @@ TEST(Grouper, RandomInt64Keys) { TEST(Grouper, RandomStringInt64Keys) { TestGrouper g({utf8(), int64()}); for (int i = 0; i < 4; ++i) { - SCOPED_TRACE(std::to_string(i) + "th key batch"); + SCOPED_TRACE(ToChars(i) + "th key batch"); ExecBatch key_batch{ *random::GenerateBatch(g.key_schema_->fields(), 1 << 12, 0xDEADBEEF)}; @@ -640,7 +665,7 @@ TEST(Grouper, RandomStringInt64Keys) { TEST(Grouper, RandomStringInt64DoubleInt32Keys) { TestGrouper g({utf8(), int64(), float64(), int32()}); for (int i = 0; i < 4; ++i) { - SCOPED_TRACE(std::to_string(i) + "th key batch"); + SCOPED_TRACE(ToChars(i) + "th key batch"); ExecBatch key_batch{ *random::GenerateBatch(g.key_schema_->fields(), 1 << 12, 0xDEADBEEF)}; @@ -3411,7 +3436,7 @@ TEST(GroupBy, SumOnlyStringAndDictKeys) { SCOPED_TRACE("key type: " + key_type->ToString()); auto batch = RecordBatchFromJSON( - schema({field("argument", float64()), field("key", key_type)}), R"([ + schema({field("agg_0", float64()), field("key", key_type)}), R"([ [1.0, "alfa"], [null, "alfa"], [0.0, "beta"], @@ -3426,7 +3451,7 @@ TEST(GroupBy, SumOnlyStringAndDictKeys) { ASSERT_OK_AND_ASSIGN( Datum aggregated_and_grouped, - RunGroupBy({batch->GetColumnByName("argument")}, {batch->GetColumnByName("key")}, + RunGroupBy({batch->GetColumnByName("agg_0")}, {batch->GetColumnByName("key")}, { {"hash_sum", nullptr, "agg_0", "hash_sum"}, })); @@ -3513,9 +3538,9 @@ TEST(GroupBy, RandomArraySum) { for (auto null_probability : {0.0, 0.01, 0.5, 1.0}) { auto batch = random::GenerateBatch( { - field("argument", float32(), - key_value_metadata( - {{"null_probability", std::to_string(null_probability)}})), + field( + "agg_0", float32(), + key_value_metadata({{"null_probability", ToChars(null_probability)}})), field("key", int64(), key_value_metadata({{"min", "0"}, {"max", "100"}})), }, length, 0xDEADBEEF); @@ -3524,7 +3549,7 @@ TEST(GroupBy, RandomArraySum) { { {"hash_sum", options, "agg_0", "hash_sum"}, }, - {batch->GetColumnByName("argument")}, {batch->GetColumnByName("key")}); + {batch->GetColumnByName("agg_0")}, {batch->GetColumnByName("key")}); } } }