diff --git a/cpp/src/arrow/record_batch.cc b/cpp/src/arrow/record_batch.cc index f0ee295c6347d..3522549b53b4a 100644 --- a/cpp/src/arrow/record_batch.cc +++ b/cpp/src/arrow/record_batch.cc @@ -25,6 +25,7 @@ #include #include "arrow/array.h" +#include "arrow/array/concatenate.h" #include "arrow/array/validate.h" #include "arrow/pretty_print.h" #include "arrow/status.h" @@ -432,4 +433,43 @@ RecordBatchReader::~RecordBatchReader() { ARROW_WARN_NOT_OK(this->Close(), "Implicitly called RecordBatchReader::Close failed"); } +Result> ConcatenateRecordBatches( + const RecordBatchVector& batches, MemoryPool* pool) { + int64_t length = 0; + size_t n = batches.size(); + if (n == 0) { + return Status::Invalid("Must pass at least one recordbatch"); + } + if (n == 1) { + return batches[0]; + } + int cols = batches[0]->num_columns(); + auto schema = batches[0]->schema(); + std::vector> columns; + if (cols == 0) { + // special case: null batch, no data, just length + for (size_t i = 0; i < batches.size(); ++i) { + length += batches[i]->num_rows(); + } + } else { + for (int col = 0; col < cols; ++col) { + ArrayVector data; + for (size_t i = 0; i < batches.size(); ++i) { + auto cur_schema = batches[i]->schema(); + if (!schema->Equals(cur_schema)) { + return Status::Invalid( + "RecordBatch index ", i, " schema is ", cur_schema->ToString(), + ", did not match index 0 recordbatch schema: ", schema->ToString()); + } + auto column_data = batches[i]->column(col); + data.push_back(column_data); + } + auto array = Concatenate(data, pool).ValueOrDie(); + length = array->length(); + columns.push_back(array); + } + } + return RecordBatch::Make(std::move(schema), length, columns); +} + } // namespace arrow diff --git a/cpp/src/arrow/record_batch.h b/cpp/src/arrow/record_batch.h index cb1f6d54f7cff..9514617d06193 100644 --- a/cpp/src/arrow/record_batch.h +++ b/cpp/src/arrow/record_batch.h @@ -350,4 +350,12 @@ class ARROW_EXPORT RecordBatchReader { Iterator> batches, std::shared_ptr schema); }; +/// \brief Concatenate recordbatches +/// +/// \param[in] batches a vector of recordbatches to be concatenated +/// \param[in] pool memory to store the result will be allocated from this memory pool +/// \return the concatenated recordbatch +Result> ConcatenateRecordBatches( + const RecordBatchVector& batches, MemoryPool* pool = default_memory_pool()); + } // namespace arrow diff --git a/cpp/src/arrow/record_batch_test.cc b/cpp/src/arrow/record_batch_test.cc index bc923a1444160..8e04f1d218b18 100644 --- a/cpp/src/arrow/record_batch_test.cc +++ b/cpp/src/arrow/record_batch_test.cc @@ -555,4 +555,42 @@ TEST_F(TestRecordBatch, ReplaceSchema) { ASSERT_RAISES(Invalid, b1->ReplaceSchema(schema)); } +TEST_F(TestRecordBatch, ConcatenateRecordBatches) { + int length = 10; + + auto f0 = field("f0", int32()); + auto f1 = field("f1", uint8()); + + auto schema = ::arrow::schema({f0, f1}); + + random::RandomArrayGenerator gen(42); + + auto a0 = gen.ArrayOf(int32(), length); + auto a1 = gen.ArrayOf(uint8(), length); + + auto b1 = RecordBatch::Make(schema, length, {a0, a1}); + + length = 5; + + a0 = gen.ArrayOf(int32(), length); + a1 = gen.ArrayOf(uint8(), length); + + auto b2 = RecordBatch::Make(schema, length, {a0, a1}); + + ASSERT_OK_AND_ASSIGN(auto batch, ConcatenateRecordBatches({b1, b2})); + ASSERT_EQ(batch->num_rows(), b1->num_rows() + b2->num_rows()); + + f0 = field("fd0", int32()); + f1 = field("fd1", uint8()); + + schema = ::arrow::schema({f0, f1}); + + a0 = gen.ArrayOf(int32(), length); + a1 = gen.ArrayOf(uint8(), length); + + auto b3 = RecordBatch::Make(schema, length, {a0, a1}); + + ASSERT_RAISES(Invalid, ConcatenateRecordBatches({b1, b3})); +} + } // namespace arrow