Skip to content

Commit

Permalink
[ENH] Query merging (#2066)
Browse files Browse the repository at this point in the history
## Description of changes

*Summarize the changes made by this PR.*
 - Improvements & Bug fixes
	 - /
 - New functionality
- Completes the query merging flow by adding an HSNW operator and merge
results operator that will rehydrate from the record segment blockfile.
	 
## Test plan
*How are these changes tested?*
- [x] Tests pass locally with `pytest` for python, `yarn test` for js,
`cargo test` for rust

## Documentation Changes
None
  • Loading branch information
HammadB committed May 3, 2024
1 parent 325211c commit c95edd1
Show file tree
Hide file tree
Showing 13 changed files with 1,044 additions and 196 deletions.
2 changes: 1 addition & 1 deletion rust/worker/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ FROM debian:bookworm-slim as query_service

COPY --from=query_service_builder /chroma/query_service .
COPY --from=query_service_builder /chroma/rust/worker/chroma_config.yaml .
RUN apt-get update && apt-get install -y libssl-dev
RUN apt-get update && apt-get install -y libssl-dev ca-certificates

ENTRYPOINT [ "./query_service" ]

Expand Down
5 changes: 5 additions & 0 deletions rust/worker/bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -200,4 +200,9 @@ extern "C"
{
index->set_ef(ef);
}

int len(Index<float> *index)
{
return index->appr_alg->getCurrentElementCount() - index->appr_alg->getDeletedCount();
}
}
36 changes: 36 additions & 0 deletions rust/worker/src/execution/operators/hnsw_knn.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
use crate::{
errors::ChromaError, execution::operator::Operator,
segment::distributed_hnsw_segment::DistributedHNSWSegment,
};
use async_trait::async_trait;

#[derive(Debug)]
pub struct HnswKnnOperator {}

#[derive(Debug)]
pub struct HnswKnnOperatorInput {
pub segment: Box<DistributedHNSWSegment>,
pub query: Vec<f32>,
pub k: usize,
}

#[derive(Debug)]
pub struct HnswKnnOperatorOutput {
pub offset_ids: Vec<usize>,
pub distances: Vec<f32>,
}

pub type HnswKnnOperatorResult = Result<HnswKnnOperatorOutput, Box<dyn ChromaError>>;

#[async_trait]
impl Operator<HnswKnnOperatorInput, HnswKnnOperatorOutput> for HnswKnnOperator {
type Error = Box<dyn ChromaError>;

async fn run(&self, input: &HnswKnnOperatorInput) -> HnswKnnOperatorResult {
let (offset_ids, distances) = input.segment.query(&input.query, input.k);
Ok(HnswKnnOperatorOutput {
offset_ids,
distances,
})
}
}
171 changes: 171 additions & 0 deletions rust/worker/src/execution/operators/merge_knn_results.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
use std::f64::consts::E;

use crate::{
blockstore::provider::BlockfileProvider,
errors::ChromaError,
execution::operator::Operator,
segment::record_segment::{RecordSegmentReader, RecordSegmentReaderCreationError},
types::Segment,
};
use async_trait::async_trait;
use thiserror::Error;

#[derive(Debug)]
pub struct MergeKnnResultsOperator {}

#[derive(Debug)]
pub struct MergeKnnResultsOperatorInput {
hnsw_result_offset_ids: Vec<usize>,
hnsw_result_distances: Vec<f32>,
brute_force_result_user_ids: Vec<String>,
brute_force_result_distances: Vec<f32>,
k: usize,
record_segment_definition: Segment,
blockfile_provider: BlockfileProvider,
}

impl MergeKnnResultsOperatorInput {
pub fn new(
hnsw_result_offset_ids: Vec<usize>,
hnsw_result_distances: Vec<f32>,
brute_force_result_user_ids: Vec<String>,
brute_force_result_distances: Vec<f32>,
k: usize,
record_segment_definition: Segment,
blockfile_provider: BlockfileProvider,
) -> Self {
Self {
hnsw_result_offset_ids,
hnsw_result_distances,
brute_force_result_user_ids,
brute_force_result_distances,
k,
record_segment_definition,
blockfile_provider: blockfile_provider,
}
}
}

#[derive(Debug)]
pub struct MergeKnnResultsOperatorOutput {
pub user_ids: Vec<String>,
pub distances: Vec<f32>,
}

#[derive(Error, Debug)]
pub enum MergeKnnResultsOperatorError {}

impl ChromaError for MergeKnnResultsOperatorError {
fn code(&self) -> crate::errors::ErrorCodes {
return crate::errors::ErrorCodes::UNKNOWN;
}
}

pub type MergeKnnResultsOperatorResult =
Result<MergeKnnResultsOperatorOutput, Box<dyn ChromaError>>;

#[async_trait]
impl Operator<MergeKnnResultsOperatorInput, MergeKnnResultsOperatorOutput>
for MergeKnnResultsOperator
{
type Error = Box<dyn ChromaError>;

async fn run(&self, input: &MergeKnnResultsOperatorInput) -> MergeKnnResultsOperatorResult {
let (result_user_ids, result_distances) = match RecordSegmentReader::from_segment(
&input.record_segment_definition,
&input.blockfile_provider,
)
.await
{
Ok(reader) => {
println!("Record Segment Reader created successfully");
// Convert the HNSW result offset IDs to user IDs
let mut hnsw_result_user_ids = Vec::new();
for offset_id in &input.hnsw_result_offset_ids {
let user_id = reader.get_user_id_for_offset_id(*offset_id as u32).await;
match user_id {
Ok(user_id) => hnsw_result_user_ids.push(user_id),
Err(e) => return Err(e),
}
}
merge_results(
&hnsw_result_user_ids,
&input.hnsw_result_distances,
&input.brute_force_result_user_ids,
&input.brute_force_result_distances,
input.k,
)
}
Err(e) => match *e {
RecordSegmentReaderCreationError::BlockfileOpenError(e) => {
return Err(e);
}
RecordSegmentReaderCreationError::InvalidNumberOfFiles => {
return Err(e);
}
RecordSegmentReaderCreationError::UninitializedSegment => {
// The record segment doesn't exist - which implies no HNSW results
let hnsw_result_user_ids = Vec::new();
let hnsw_result_distances = Vec::new();
merge_results(
&hnsw_result_user_ids,
&hnsw_result_distances,
&input.brute_force_result_user_ids,
&input.brute_force_result_distances,
input.k,
)
}
},
};

Ok(MergeKnnResultsOperatorOutput {
user_ids: result_user_ids,
distances: result_distances,
})
}
}

fn merge_results(
hnsw_result_user_ids: &Vec<&str>,
hnsw_result_distances: &Vec<f32>,
brute_force_result_user_ids: &Vec<String>,
brute_force_result_distances: &Vec<f32>,
k: usize,
) -> (Vec<String>, Vec<f32>) {
let mut result_user_ids = Vec::with_capacity(k);
let mut result_distances = Vec::with_capacity(k);

// Merge the HNSW and brute force results together by the minimum distance top k
let mut hnsw_index = 0;
let mut brute_force_index = 0;

// TODO: This doesn't have to clone the user IDs, but it's easier for now
while (result_user_ids.len() <= k)
&& (hnsw_index < hnsw_result_user_ids.len()
|| brute_force_index < brute_force_result_user_ids.len())
{
if hnsw_index < hnsw_result_user_ids.len()
&& brute_force_index < brute_force_result_user_ids.len()
{
if hnsw_result_distances[hnsw_index] < brute_force_result_distances[brute_force_index] {
result_user_ids.push(hnsw_result_user_ids[hnsw_index].to_string());
result_distances.push(hnsw_result_distances[hnsw_index]);
hnsw_index += 1;
} else {
result_user_ids.push(brute_force_result_user_ids[brute_force_index].to_string());
result_distances.push(brute_force_result_distances[brute_force_index]);
brute_force_index += 1;
}
} else if hnsw_index < hnsw_result_user_ids.len() {
result_user_ids.push(hnsw_result_user_ids[hnsw_index].to_string());
result_distances.push(hnsw_result_distances[hnsw_index]);
hnsw_index += 1;
} else if brute_force_index < brute_force_result_user_ids.len() {
result_user_ids.push(brute_force_result_user_ids[brute_force_index].to_string());
result_distances.push(brute_force_result_distances[brute_force_index]);
brute_force_index += 1;
}
}

(result_user_ids, result_distances)
}
2 changes: 2 additions & 0 deletions rust/worker/src/execution/operators/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
pub(super) mod brute_force_knn;
pub(super) mod flush_s3;
pub(super) mod hnsw_knn;
pub(super) mod merge_knn_results;
pub(super) mod normalize_vectors;
pub(super) mod partition;
pub(super) mod pull_log;
Expand Down
2 changes: 2 additions & 0 deletions rust/worker/src/execution/orchestration/compact.rs
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,8 @@ impl CompactOrchestrator {
};
let input = PullLogsInput::new(
collection_id,
// Here we do not need to be inclusive since the compaction job
// offset is the one after the last compaction offset
self.compaction_job.offset,
100,
None,
Expand Down
Loading

0 comments on commit c95edd1

Please sign in to comment.