From 7bef38b135569dba6564af07276853b853def00b Mon Sep 17 00:00:00 2001 From: light-city <455954986@qq.com> Date: Thu, 28 Sep 2023 09:30:24 +0800 Subject: [PATCH] fix rb --- cpp/src/arrow/record_batch.cc | 15 +++++++++------ cpp/src/arrow/record_batch.h | 6 +++--- cpp/src/arrow/record_batch_test.cc | 15 +++------------ 3 files changed, 15 insertions(+), 21 deletions(-) diff --git a/cpp/src/arrow/record_batch.cc b/cpp/src/arrow/record_batch.cc index 3522549b53b4a..5f1f192133beb 100644 --- a/cpp/src/arrow/record_batch.cc +++ b/cpp/src/arrow/record_batch.cc @@ -440,9 +440,6 @@ Result> ConcatenateRecordBatches( if (n == 0) { return Status::Invalid("Must pass at least one recordbatch"); } - if (n == 1) { - return batches[0]; - } int cols = batches[0]->num_columns(); auto schema = batches[0]->schema(); std::vector> columns; @@ -464,12 +461,18 @@ Result> ConcatenateRecordBatches( auto column_data = batches[i]->column(col); data.push_back(column_data); } - auto array = Concatenate(data, pool).ValueOrDie(); - length = array->length(); + ARROW_ASSIGN_OR_RAISE(auto array, Concatenate(data, pool)); + if (col == 0) { + length = array->length(); + } else if (length != array->length()) { + return Status::Invalid( + "column index ", i, " length is ", array->length(), + ", does not match the length of previous columns: ", length); + } columns.push_back(array); } } - return RecordBatch::Make(std::move(schema), length, columns); + return RecordBatch::Make(std::move(schema), length, std::move(columns)); } } // namespace arrow diff --git a/cpp/src/arrow/record_batch.h b/cpp/src/arrow/record_batch.h index 9514617d06193..dac48d83c3ccc 100644 --- a/cpp/src/arrow/record_batch.h +++ b/cpp/src/arrow/record_batch.h @@ -350,11 +350,11 @@ class ARROW_EXPORT RecordBatchReader { Iterator> batches, std::shared_ptr schema); }; -/// \brief Concatenate recordbatches +/// \brief Concatenate record batches /// -/// \param[in] batches a vector of recordbatches to be concatenated +/// \param[in] batches a vector of record batches to be concatenated /// \param[in] pool memory to store the result will be allocated from this memory pool -/// \return the concatenated recordbatch +/// \return the concatenated record batch Result> ConcatenateRecordBatches( const RecordBatchVector& batches, MemoryPool* pool = default_memory_pool()); diff --git a/cpp/src/arrow/record_batch_test.cc b/cpp/src/arrow/record_batch_test.cc index 8e04f1d218b18..9bef567d0ca94 100644 --- a/cpp/src/arrow/record_batch_test.cc +++ b/cpp/src/arrow/record_batch_test.cc @@ -565,17 +565,11 @@ TEST_F(TestRecordBatch, ConcatenateRecordBatches) { random::RandomArrayGenerator gen(42); - auto a0 = gen.ArrayOf(int32(), length); - auto a1 = gen.ArrayOf(uint8(), length); - - auto b1 = RecordBatch::Make(schema, length, {a0, a1}); + auto b1 = gen.BatchOf(schema->fields(), length); length = 5; - a0 = gen.ArrayOf(int32(), length); - a1 = gen.ArrayOf(uint8(), length); - - auto b2 = RecordBatch::Make(schema, length, {a0, a1}); + auto b2 = gen.BatchOf(schema->fields(), length); ASSERT_OK_AND_ASSIGN(auto batch, ConcatenateRecordBatches({b1, b2})); ASSERT_EQ(batch->num_rows(), b1->num_rows() + b2->num_rows()); @@ -585,10 +579,7 @@ TEST_F(TestRecordBatch, ConcatenateRecordBatches) { schema = ::arrow::schema({f0, f1}); - a0 = gen.ArrayOf(int32(), length); - a1 = gen.ArrayOf(uint8(), length); - - auto b3 = RecordBatch::Make(schema, length, {a0, a1}); + auto b3 = gen.BatchOf(schema->fields(), length); ASSERT_RAISES(Invalid, ConcatenateRecordBatches({b1, b3})); }