Skip to content

Commit

Permalink
fix rb
Browse files Browse the repository at this point in the history
  • Loading branch information
Light-City committed Sep 28, 2023
1 parent b346799 commit 7bef38b
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 21 deletions.
15 changes: 9 additions & 6 deletions cpp/src/arrow/record_batch.cc
Original file line number Diff line number Diff line change
Expand Up @@ -440,9 +440,6 @@ Result<std::shared_ptr<RecordBatch>> ConcatenateRecordBatches(
if (n == 0) {
return Status::Invalid("Must pass at least one recordbatch");
}
if (n == 1) {
return batches[0];
}
int cols = batches[0]->num_columns();
auto schema = batches[0]->schema();
std::vector<std::shared_ptr<Array>> columns;
Expand All @@ -464,12 +461,18 @@ Result<std::shared_ptr<RecordBatch>> ConcatenateRecordBatches(
auto column_data = batches[i]->column(col);
data.push_back(column_data);
}
auto array = Concatenate(data, pool).ValueOrDie();
length = array->length();
ARROW_ASSIGN_OR_RAISE(auto array, Concatenate(data, pool));
if (col == 0) {
length = array->length();
} else if (length != array->length()) {
return Status::Invalid(
"column index ", i, " length is ", array->length(),
", does not match the length of previous columns: ", length);
}
columns.push_back(array);
}
}
return RecordBatch::Make(std::move(schema), length, columns);
return RecordBatch::Make(std::move(schema), length, std::move(columns));
}

} // namespace arrow
6 changes: 3 additions & 3 deletions cpp/src/arrow/record_batch.h
Original file line number Diff line number Diff line change
Expand Up @@ -350,11 +350,11 @@ class ARROW_EXPORT RecordBatchReader {
Iterator<std::shared_ptr<RecordBatch>> batches, std::shared_ptr<Schema> schema);
};

/// \brief Concatenate recordbatches
/// \brief Concatenate record batches
///
/// \param[in] batches a vector of recordbatches to be concatenated
/// \param[in] batches a vector of record batches to be concatenated
/// \param[in] pool memory to store the result will be allocated from this memory pool
/// \return the concatenated recordbatch
/// \return the concatenated record batch
Result<std::shared_ptr<RecordBatch>> ConcatenateRecordBatches(
const RecordBatchVector& batches, MemoryPool* pool = default_memory_pool());

Expand Down
15 changes: 3 additions & 12 deletions cpp/src/arrow/record_batch_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -565,17 +565,11 @@ TEST_F(TestRecordBatch, ConcatenateRecordBatches) {

random::RandomArrayGenerator gen(42);

auto a0 = gen.ArrayOf(int32(), length);
auto a1 = gen.ArrayOf(uint8(), length);

auto b1 = RecordBatch::Make(schema, length, {a0, a1});
auto b1 = gen.BatchOf(schema->fields(), length);

length = 5;

a0 = gen.ArrayOf(int32(), length);
a1 = gen.ArrayOf(uint8(), length);

auto b2 = RecordBatch::Make(schema, length, {a0, a1});
auto b2 = gen.BatchOf(schema->fields(), length);

ASSERT_OK_AND_ASSIGN(auto batch, ConcatenateRecordBatches({b1, b2}));
ASSERT_EQ(batch->num_rows(), b1->num_rows() + b2->num_rows());
Expand All @@ -585,10 +579,7 @@ TEST_F(TestRecordBatch, ConcatenateRecordBatches) {

schema = ::arrow::schema({f0, f1});

a0 = gen.ArrayOf(int32(), length);
a1 = gen.ArrayOf(uint8(), length);

auto b3 = RecordBatch::Make(schema, length, {a0, a1});
auto b3 = gen.BatchOf(schema->fields(), length);

ASSERT_RAISES(Invalid, ConcatenateRecordBatches({b1, b3}));
}
Expand Down

0 comments on commit 7bef38b

Please sign in to comment.