diff --git a/src/db/index/segment/segment.cc b/src/db/index/segment/segment.cc index 40b8e2c8..7124eff8 100644 --- a/src/db/index/segment/segment.cc +++ b/src/db/index/segment/segment.cc @@ -293,13 +293,11 @@ class SegmentImpl : public Segment, TablePtr fetch_normal(const std::vector &columns, const std::shared_ptr &result_schema, - bool need_local_doc_id, int local_doc_id_col_index, const std::vector &indices) const; // For performance tuning TablePtr fetch_perf(const std::vector &columns, const std::shared_ptr &result_schema, - bool need_local_doc_id, int local_doc_id_col_index, const std::vector &indices) const; void fresh_persist_chunked_array(); @@ -2217,13 +2215,19 @@ bool SegmentImpl::validate(const std::vector &columns) const { TablePtr SegmentImpl::fetch_perf( const std::vector &columns, - const std::shared_ptr &result_schema, bool need_local_doc_id, - int local_doc_id_col_index, const std::vector &indices) const { + const std::shared_ptr &result_schema, + const std::vector &indices) const { std::vector> chunk_arrays; chunk_arrays.resize(columns.size()); + bool need_local_doc_id = false; + size_t local_doc_id_col_index = 0; + for (size_t i = 0; i < columns.size(); ++i) { if (columns[i] == LOCAL_ROW_ID) { + need_local_doc_id = true; + local_doc_id_col_index = i; + chunk_arrays[i] = nullptr; continue; } chunk_arrays[i] = persist_chunk_arrays_[col_idx_map_.at(columns[i])]; @@ -2231,28 +2235,6 @@ TablePtr SegmentImpl::fetch_perf( std::vector> result_arrays(columns.size()); - if (need_local_doc_id) { - std::vector values; - values.reserve(indices.size()); - for (const auto idx : indices) { - values.push_back(idx); - } - - arrow::UInt64Builder builder; - auto s = builder.AppendValues(values); - if (!s.ok()) { - LOG_ERROR("Failed to append values to builder: %s", s.message().c_str()); - return nullptr; - } - std::shared_ptr array; - s = builder.Finish(&array); - if (!s.ok()) { - LOG_ERROR("Failed to finish builder: %s", s.message().c_str()); - return nullptr; - } - result_arrays[local_doc_id_col_index] = array; - } - std::vector> indices_in_table; for (const auto &target_index : indices) { auto it = std::upper_bound(chunk_offsets_.begin(), chunk_offsets_.end(), @@ -2267,12 +2249,11 @@ TablePtr SegmentImpl::fetch_perf( indices_in_table.emplace_back(chunk_index, index_in_chunk); } - size_t result_col_index = 0; for (size_t i = 0; i < columns.size(); ++i) { if (columns[i] == LOCAL_ROW_ID) { continue; } - const auto &source_column = chunk_arrays[result_col_index]; + const auto &source_column = chunk_arrays[i]; std::shared_ptr array; auto status = BuildArrayFromIndicesWithType(source_column, indices_in_table, &array); @@ -2280,8 +2261,29 @@ TablePtr SegmentImpl::fetch_perf( LOG_ERROR("BuildArrayFromIndices failed: %s", status.ToString().c_str()); return nullptr; } - result_arrays[result_col_index] = array; - result_col_index++; + result_arrays[i] = array; + } + + if (need_local_doc_id) { + std::vector values; + values.reserve(indices.size()); + for (const auto idx : indices) { + values.push_back(idx); + } + + arrow::UInt64Builder builder; + auto s = builder.AppendValues(values); + if (!s.ok()) { + LOG_ERROR("Failed to append values to builder: %s", s.message().c_str()); + return nullptr; + } + std::shared_ptr array; + s = builder.Finish(&array); + if (!s.ok()) { + LOG_ERROR("Failed to finish builder: %s", s.message().c_str()); + return nullptr; + } + result_arrays[local_doc_id_col_index] = array; } return arrow::Table::Make(result_schema, result_arrays, @@ -2290,8 +2292,8 @@ TablePtr SegmentImpl::fetch_perf( TablePtr SegmentImpl::fetch_normal( const std::vector &columns, - const std::shared_ptr &result_schema, bool need_local_doc_id, - int local_doc_id_col_index, const std::vector &indices) const { + const std::shared_ptr &result_schema, + const std::vector &indices) const { // Store scalars per column: column_index -> (output_row, scalar) std::vector>>> column_results(columns.size()); @@ -2418,9 +2420,16 @@ TablePtr SegmentImpl::fetch_normal( // Phase 3: Construct result arrays std::vector> result_arrays(columns.size()); + bool need_local_doc_id = false; + size_t local_doc_id_col_index = -1; + for (size_t col_index = 0; col_index < columns.size(); ++col_index) { const std::string &col = columns[col_index]; - if (col == LOCAL_ROW_ID) continue; + if (col == LOCAL_ROW_ID) { + need_local_doc_id = true; + local_doc_id_col_index = col_index; + continue; + } auto &result_vec = column_results[col_index]; std::sort(result_vec.begin(), result_vec.end()); @@ -2493,15 +2502,11 @@ TablePtr SegmentImpl::fetch(const std::vector &columns, // Build result schema std::vector> fields; - bool need_local_doc_id = false; - int local_doc_id_col_index = -1; for (size_t i = 0; i < columns.size(); ++i) { const auto &col = columns[i]; if (col == LOCAL_ROW_ID) { fields.push_back(arrow::field(LOCAL_ROW_ID, arrow::uint64())); - need_local_doc_id = true; - local_doc_id_col_index = static_cast(i); } else if (col == GLOBAL_DOC_ID) { fields.push_back(arrow::field(GLOBAL_DOC_ID, arrow::uint64())); } else if (col == USER_ID) { @@ -2536,11 +2541,9 @@ TablePtr SegmentImpl::fetch(const std::vector &columns, } if (use_fetch_perf_) { - return fetch_perf(columns, result_schema, need_local_doc_id, - local_doc_id_col_index, indices); + return fetch_perf(columns, result_schema, indices); } - return fetch_normal(columns, result_schema, need_local_doc_id, - local_doc_id_col_index, indices); + return fetch_normal(columns, result_schema, indices); } ExecBatchPtr SegmentImpl::fetch(const std::vector &columns, diff --git a/tests/db/index/segment/segment_test.cc b/tests/db/index/segment/segment_test.cc index 39231731..91d6f6b5 100644 --- a/tests/db/index/segment/segment_test.cc +++ b/tests/db/index/segment/segment_test.cc @@ -1473,7 +1473,6 @@ TEST_P(SegmentTest, FetchPerf) { // convert writing segment meta to persisted segment meta Version version = version_manager->get_current_version(); - // auto writing_segment_meta = segment->meta(); writing_segment_meta->remove_writing_forward_block(); auto s = version.add_persisted_segment_meta(writing_segment_meta); ASSERT_TRUE(s.ok()); @@ -1528,29 +1527,42 @@ TEST_P(SegmentTest, FetchPerf) { EXPECT_TRUE(s.ok()); std::vector indices = {0, 3, 6, 1, 0, 501, 999}; - auto combined_table = - segment->fetch({"id", "name", "add_int32", LOCAL_ROW_ID}, indices); - ASSERT_TRUE(combined_table != nullptr); - EXPECT_EQ(combined_table->num_columns(), 4); - EXPECT_EQ(combined_table->num_rows(), indices.size()); - - auto field = combined_table->schema()->field(3); - EXPECT_EQ(field->name(), LOCAL_ROW_ID); - - // Get data from the LOCAL_ROW_ID column for each row - auto id_column = combined_table->column(3); - auto id_array = - std::dynamic_pointer_cast(id_column->chunk(0)); - - std::vector &expected_ids = indices; - std::vector actual_ids; + auto func = [&](const std::vector columns, + int local_row_id_idx) -> void { + auto combined_table = segment->fetch(columns, indices); + ASSERT_TRUE(combined_table != nullptr); + EXPECT_EQ(combined_table->num_columns(), columns.size()); + EXPECT_EQ(combined_table->num_rows(), indices.size()); + + auto field = combined_table->schema()->field(local_row_id_idx); + EXPECT_EQ(field->name(), LOCAL_ROW_ID); + + // Get data from the LOCAL_ROW_ID column for each row + auto id_column = combined_table->column(local_row_id_idx); + auto id_array = + std::dynamic_pointer_cast(id_column->chunk(0)); + + std::vector &expected_ids = indices; + std::vector actual_ids; + + for (int i = 0; i < id_array->length(); ++i) { + actual_ids.push_back(id_array->Value(i)); + } - for (int i = 0; i < id_array->length(); ++i) { - actual_ids.push_back(id_array->Value(i)); - } + EXPECT_EQ(actual_ids, expected_ids) + << "ID column values don't match expected order"; + }; - EXPECT_EQ(actual_ids, expected_ids) - << "ID column values don't match expected order"; + func({LOCAL_ROW_ID, "id", "name", "add_int32"}, 0); + func( + { + "id", + LOCAL_ROW_ID, + "name", + "add_int32", + }, + 1); + func({"id", "name", "add_int32", LOCAL_ROW_ID}, 3); } TEST_P(SegmentTest, AddColumn) {