diff --git a/cpp/src/arrow/compute/exec/groupby.cc b/cpp/src/arrow/compute/exec/groupby.cc index 25890a2f09baa..b10e567db7d6e 100644 --- a/cpp/src/arrow/compute/exec/groupby.cc +++ b/cpp/src/arrow/compute/exec/groupby.cc @@ -39,20 +39,20 @@ namespace compute { namespace { -std::shared_ptr SimpleSchemaForBatch(const ExecBatch& batch) { +std::shared_ptr SimpleSchemaForColumns( + const std::vector>& columns) { std::vector> fields; - for (int i = 0; i < batch.num_values(); i++) { - fields.push_back( - field("key_" + ::arrow::internal::ToChars(i), batch.values[i].type())); + for (int i = 0; i < static_cast(columns.size()); i++) { + fields.push_back(field("key_" + ::arrow::internal::ToChars(i), columns[i]->type())); } return schema(std::move(fields)); } } // namespace -Result> GroupBy( - const std::vector>& arguments, - const std::vector>& keys, +Result> GroupByChunked( + const std::vector>& arguments, + const std::vector>& keys, const std::vector& aggregates, bool use_threads, ExecContext* ctx) { if (arguments.size() != aggregates.size()) { return Status::Invalid("arguments and aggregates must be the same size"); @@ -62,7 +62,7 @@ Result> GroupBy( return Table::MakeEmpty(schema({})); } - std::vector all_columns; + std::vector> all_columns; int64_t length = 0; for (const auto& key : keys) { if (length == 0) { @@ -84,11 +84,9 @@ Result> GroupBy( } all_columns.emplace_back(argument); } - ExecBatch input_batch(std::move(all_columns), length); - std::shared_ptr batch_schema = SimpleSchemaForBatch(input_batch); - ARROW_ASSIGN_OR_RAISE(std::shared_ptr rb, - input_batch.ToRecordBatch(std::move(batch_schema))); - ARROW_ASSIGN_OR_RAISE(std::shared_ptr table, Table::FromRecordBatches({rb})); + std::shared_ptr batch_schema = SimpleSchemaForColumns(all_columns); + std::shared_ptr
table = + Table::Make(std::move(batch_schema), std::move(all_columns)); std::vector key_refs; for (int i = 0; i < static_cast(keys.size()); i++) { @@ -127,5 +125,28 @@ Result> GroupBy( return DeclarationToTable(plan); } +namespace { + +std::vector> ToChunked( + const std::vector>& arrays) { + std::vector> chunked; + chunked.reserve(arrays.size()); + for (const auto& array : arrays) { + chunked.push_back(std::make_shared(array)); + } + return chunked; +} + +} // namespace + +Result> GroupBy( + const std::vector>& arguments, + const std::vector>& keys, + const std::vector& aggregates, bool use_threads, ExecContext* ctx) { + std::vector> chunked_args = ToChunked(arguments); + std::vector> chunked_keys = ToChunked(keys); + return GroupByChunked(chunked_args, chunked_keys, aggregates, use_threads, ctx); +} + } // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/compute/exec/groupby.h b/cpp/src/arrow/compute/exec/groupby.h index f01d0635be995..68eb634811a0f 100644 --- a/cpp/src/arrow/compute/exec/groupby.h +++ b/cpp/src/arrow/compute/exec/groupby.h @@ -65,5 +65,13 @@ Result> GroupBy( const std::vector& aggregates, bool use_threads = false, ExecContext* ctx = default_exec_context()); +/// \see GroupBy +ARROW_EXPORT +Result> GroupByChunked( + const std::vector>& arguments, + const std::vector>& keys, + const std::vector& aggregates, bool use_threads = false, + ExecContext* ctx = default_exec_context()); + } // namespace compute } // namespace arrow diff --git a/python/pyarrow/_compute.pyx b/python/pyarrow/_compute.pyx index b407bbac92b08..416cb38ce77ad 100644 --- a/python/pyarrow/_compute.pyx +++ b/python/pyarrow/_compute.pyx @@ -2191,23 +2191,20 @@ class RankOptions(_RankOptions): self._set_options(sort_keys, null_placement, tiebreaker) -cdef _pack_groupby_args(object values, vector[shared_ptr[CArray]]* out): +cdef _pack_groupby_args(object values, vector[shared_ptr[CChunkedArray]]* out): for val in values: - if isinstance(val, (list, np.ndarray)): - val = lib.asarray(val) - - if isinstance(val, Array): - out.push_back(( val).sp_array) + if isinstance(val, ChunkedArray): + out.push_back(( val).sp_chunked_array) continue raise TypeError(f"Got unexpected argument type {type(val)} " - "for group_by function, expected Array") + "for group_by function, expected ChunkedArray") def _group_by(args, keys, aggregations): cdef: - vector[shared_ptr[CArray]] c_args - vector[shared_ptr[CArray]] c_keys + vector[shared_ptr[CChunkedArray]] c_args + vector[shared_ptr[CChunkedArray]] c_keys vector[CSimpleAggregate] c_aggregations CSimpleAggregate c_aggr shared_ptr[CTable] sp_table @@ -2225,7 +2222,7 @@ def _group_by(args, keys, aggregations): with nogil: sp_table = GetResultValue( - GroupBy(c_args, c_keys, c_aggregations) + GroupByChunked(c_args, c_keys, c_aggregations) ) return pyarrow_wrap_table(sp_table) diff --git a/python/pyarrow/includes/libarrow.pxd b/python/pyarrow/includes/libarrow.pxd index 64dda1bce69d6..2096d4e261ad3 100644 --- a/python/pyarrow/includes/libarrow.pxd +++ b/python/pyarrow/includes/libarrow.pxd @@ -2490,9 +2490,10 @@ cdef extern from "arrow/compute/exec/groupby.h" namespace \ c_string function shared_ptr[CFunctionOptions] options - CResult[shared_ptr[CTable]] GroupBy(const vector[shared_ptr[CArray]]& arguments, - const vector[shared_ptr[CArray]]& keys, - const vector[CSimpleAggregate]& aggregates) + CResult[shared_ptr[CTable]] GroupByChunked( + const vector[shared_ptr[CChunkedArray]]& arguments, + const vector[shared_ptr[CChunkedArray]]& keys, + const vector[CSimpleAggregate]& aggregates) cdef extern from * namespace "arrow::compute": diff --git a/python/pyarrow/table.pxi b/python/pyarrow/table.pxi index 758a641dcd146..30350109ee436 100644 --- a/python/pyarrow/table.pxi +++ b/python/pyarrow/table.pxi @@ -5384,11 +5384,8 @@ list[tuple(str, str, FunctionOptions)] ] + self.keys agg_tables = [] - for batch in self._table.to_batches(): - agg_tables.append(_pc()._group_by( - [batch[c] for c in columns], - [batch[k] for k in self.keys], - group_by_aggrs - )) - - return concat_tables(agg_tables).rename_columns(column_names) + return _pc()._group_by( + [self._table[c] for c in columns], + [self._table[k] for k in self.keys], + group_by_aggrs + ).rename_columns(column_names) diff --git a/python/pyarrow/tests/test_table.py b/python/pyarrow/tests/test_table.py index 04e2dacc48144..ef8a507581d96 100644 --- a/python/pyarrow/tests/test_table.py +++ b/python/pyarrow/tests/test_table.py @@ -2024,6 +2024,14 @@ def sorted_by_keys(d): "values_count": [1] } + table = pa.table({'keys': ['a', 'b', 'a', 'b', 'a', 'b'], 'values': range(6)}) + table_with_chunks = pa.Table.from_batches(table.to_batches(max_chunksize=3)) + r = table_with_chunks.group_by('keys').aggregate([('values', 'sum')]) + print(r) + assert sorted_by_keys(r.to_pydict()) == { + "keys": ["a", "b"], + "values_sum": [6, 9] + } def test_table_to_recordbatchreader(): table = pa.Table.from_pydict({'x': [1, 2, 3]})