Skip to content

Commit

Permalink
[ENH] Get vectors orchestrator (#2348)
Browse files Browse the repository at this point in the history
This PR adds the Get Vectors() RPC for query nodes.

1. Adds the GetVectors instrumented rpc call and associated types
2. Adds a GetVectorsOrchestrator for managing these queries
3. Adds the GetVectorsOperator for reading from the log and record
segment to respond to get_vectors
4. Moves common orchestrator functions into a /common and refactors hnsw
orchestrator to use this instead (DRY)
5. Adds the `ChromaError` Trait for ChannelError in Sender - needed to
use that as Box<dyn ChromaError> elsewhere.

I explicitly make the choice to not have to materialize in this operator
since it doesn't need to for any reason.

Tasks
- [x] Stub out
- [x] Types
- [x] Implement return
- [x] Server input parse
- [x] Server output parse
- [x] Init segments
- [x] Pull log + filter
- [x] Operator for reading
- [x] Add test (Will handle in separate PR) 
- [x] Validate that the get() semantics for missing ids is correct (I
have validated that the semantics are the same - we just ignore invalid
entries in the output)
  • Loading branch information
HammadB authored Jun 17, 2024
1 parent d78df1e commit 09df9ce
Show file tree
Hide file tree
Showing 9 changed files with 833 additions and 126 deletions.
231 changes: 231 additions & 0 deletions rust/worker/src/execution/operators/get_vectors_operator.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,231 @@
use crate::{
blockstore::provider::BlockfileProvider,
errors::{ChromaError, ErrorCodes},
execution::{data::data_chunk::Chunk, operator::Operator},
segment::record_segment::{self, RecordSegmentReader},
types::{LogRecord, Segment},
};
use async_trait::async_trait;
use std::collections::{HashMap, HashSet};
use thiserror::Error;

#[derive(Debug)]
pub struct GetVectorsOperator {}

impl GetVectorsOperator {
pub fn new() -> Box<Self> {
return Box::new(GetVectorsOperator {});
}
}

/// The input to the get vectors operator.
/// # Parameters
/// * `record_segment_definition` - The segment definition for the record segment.
/// * `blockfile_provider` - The blockfile provider.
/// * `log_records` - The log records.
/// * `search_user_ids` - The user ids to search for.
#[derive(Debug)]
pub struct GetVectorsOperatorInput {
record_segment_definition: Segment,
blockfile_provider: BlockfileProvider,
log_records: Chunk<LogRecord>,
search_user_ids: Vec<String>,
}

impl GetVectorsOperatorInput {
pub fn new(
record_segment_definition: Segment,
blockfile_provider: BlockfileProvider,
log_records: Chunk<LogRecord>,
search_user_ids: Vec<String>,
) -> Self {
return GetVectorsOperatorInput {
record_segment_definition,
blockfile_provider,
log_records,
search_user_ids,
};
}
}

/// The output of the get vectors operator.
/// # Parameters
/// * `ids` - The ids of the vectors.
/// * `vectors` - The vectors.
/// # Notes
/// The vectors are in the same order as the ids.
#[derive(Debug)]
pub struct GetVectorsOperatorOutput {
pub(crate) ids: Vec<String>,
pub(crate) vectors: Vec<Vec<f32>>,
}

#[derive(Debug, Error)]
pub enum GetVectorsOperatorError {
#[error("Error creating record segment reader {0}")]
RecordSegmentReaderCreationError(
#[from] crate::segment::record_segment::RecordSegmentReaderCreationError,
),
#[error(transparent)]
RecordSegmentReaderError(#[from] Box<dyn ChromaError>),
}

impl ChromaError for GetVectorsOperatorError {
fn code(&self) -> ErrorCodes {
ErrorCodes::Internal
}
}

#[async_trait]
impl Operator<GetVectorsOperatorInput, GetVectorsOperatorOutput> for GetVectorsOperator {
type Error = GetVectorsOperatorError;

async fn run(
&self,
input: &GetVectorsOperatorInput,
) -> Result<GetVectorsOperatorOutput, Self::Error> {
let mut output_vectors = HashMap::new();

let record_segment_reader = match RecordSegmentReader::from_segment(
&input.record_segment_definition,
&input.blockfile_provider,
)
.await
{
Ok(reader) => Some(reader),
Err(e) => match *e {
record_segment::RecordSegmentReaderCreationError::UninitializedSegment => None,
record_segment::RecordSegmentReaderCreationError::BlockfileOpenError(_) => {
return Err(GetVectorsOperatorError::RecordSegmentReaderCreationError(
*e,
))
}
record_segment::RecordSegmentReaderCreationError::InvalidNumberOfFiles => {
return Err(GetVectorsOperatorError::RecordSegmentReaderCreationError(
*e,
))
}
},
};

// Search the log records for the user ids
let logs = input.log_records.clone();
let mut remaining_search_user_ids: HashSet<String> =
HashSet::from_iter(input.search_user_ids.iter().cloned());
for (log_record, _) in logs.iter() {
if remaining_search_user_ids.contains(&log_record.record.id) {
match log_record.record.operation {
crate::types::Operation::Add => {
// If there is a record segment, validate the add
if let Some(ref reader) = record_segment_reader {
match reader.data_exists_for_user_id(&log_record.record.id).await {
Ok(true) => {
// The record exists in the record segment, so this add is faulty
// and we should skip it
continue;
}
Ok(false) => {
// The record does not exist in the record segment,
// the add is valid

// If the user id is already present in the log set, skip it
// We use the first add log entry for a user id if a
// user has multiple log entries
if output_vectors.contains_key(&log_record.record.id) {
continue;
}
let vector = log_record.record.embedding.as_ref().expect("Invariant violation. The log record for an add does not have an embedding.");
output_vectors
.insert(log_record.record.id.clone(), vector.clone());
remaining_search_user_ids.remove(&log_record.record.id);
}
Err(e) => {
// If there is an error, we skip the add
return Err(GetVectorsOperatorError::RecordSegmentReaderError(
e.into(),
));
}
}
}
}
crate::types::Operation::Update => {
// If there is a record segment, validate the update
if let Some(ref reader) = record_segment_reader {
match reader.data_exists_for_user_id(&log_record.record.id).await {
Ok(true) => {
// The record exists in the record segment, so this update is valid
// and we should include it in the output

// If the update mutates the vector, we need to update the output
match &log_record.record.embedding {
Some(vector) => {
// This will overwrite the vector if it already exists
// (e.g if it was added previously in the log)
output_vectors.insert(
log_record.record.id.clone(),
vector.clone(),
);
remaining_search_user_ids.remove(&log_record.record.id);
}
None => {
// Nothing to do with this as the vector was not updated
}
}
}
Ok(false) => {
// The record does not exist in the record segment,
// the update is faulty.
// We skip the update
continue;
}
Err(e) => {
// If there is an error, we skip the update
return Err(GetVectorsOperatorError::RecordSegmentReaderError(
e.into(),
));
}
}
}
}
crate::types::Operation::Upsert => {
// The upsert operation does not allow embeddings to be None
// So the final value is always present in the log
let vector = log_record.record.embedding.as_ref().expect("Invariant violation. The log record for an upsert does not have an embedding.");
output_vectors.insert(log_record.record.id.clone(), vector.clone());
remaining_search_user_ids.remove(&log_record.record.id);
}
crate::types::Operation::Delete => {
// If the user id is present in the output, remove it
output_vectors.remove(&log_record.record.id);
remaining_search_user_ids.remove(&log_record.record.id);
}
}
}
}

// Search the record segment for the remaining user ids
if !remaining_search_user_ids.is_empty() {
if let Some(reader) = record_segment_reader {
for user_id in remaining_search_user_ids.iter() {
let read_data = reader.get_data_and_offset_id_for_user_id(user_id).await;
match read_data {
Ok((record, _)) => {
output_vectors.insert(record.id.to_string(), record.embedding.to_vec());
}
Err(_) => {
// If the user id is not found in the record segment, we do not add it to the output
}
}
}
}
}

let mut ids = Vec::new();
let mut vectors = Vec::new();
for (id, vector) in output_vectors.drain() {
ids.push(id);
vectors.push(vector);
}
return Ok(GetVectorsOperatorOutput { ids, vectors });
}
}
1 change: 1 addition & 0 deletions rust/worker/src/execution/operators/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
pub(super) mod brute_force_knn;
pub(super) mod count_records;
pub(super) mod flush_s3;
pub(super) mod get_vectors_operator;
pub(super) mod hnsw_knn;
pub(super) mod merge_knn_results;
pub(super) mod merge_metadata_results;
Expand Down
150 changes: 150 additions & 0 deletions rust/worker/src/execution/orchestration/common.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
use crate::{
errors::{ChromaError, ErrorCodes},
sysdb::sysdb::{GetCollectionsError, GetSegmentsError, SysDb},
types::{Collection, Segment, SegmentType},
};
use thiserror::Error;
use tracing::{trace_span, Instrument, Span};
use uuid::Uuid;

#[derive(Debug, Error)]
pub(super) enum GetHnswSegmentByIdError {
#[error("Hnsw segment with id: {0} not found")]
HnswSegmentNotFound(Uuid),
#[error("Get segments error")]
GetSegmentsError(#[from] GetSegmentsError),
}

impl ChromaError for GetHnswSegmentByIdError {
fn code(&self) -> ErrorCodes {
match self {
GetHnswSegmentByIdError::HnswSegmentNotFound(_) => ErrorCodes::NotFound,
GetHnswSegmentByIdError::GetSegmentsError(e) => e.code(),
}
}
}

pub(super) async fn get_hnsw_segment_by_id(
mut sysdb: Box<SysDb>,
hnsw_segment_id: &Uuid,
) -> Result<Segment, Box<GetHnswSegmentByIdError>> {
let segments = sysdb
.get_segments(Some(*hnsw_segment_id), None, None, None)
.await;
let segment = match segments {
Ok(segments) => {
if segments.is_empty() {
return Err(Box::new(GetHnswSegmentByIdError::HnswSegmentNotFound(
*hnsw_segment_id,
)));
}
segments[0].clone()
}
Err(e) => {
return Err(Box::new(GetHnswSegmentByIdError::GetSegmentsError(e)));
}
};

if segment.r#type != SegmentType::HnswDistributed {
return Err(Box::new(GetHnswSegmentByIdError::HnswSegmentNotFound(
*hnsw_segment_id,
)));
}
Ok(segment)
}

#[derive(Debug, Error)]
pub(super) enum GetCollectionByIdError {
#[error("Collection with id: {0} not found")]
CollectionNotFound(Uuid),
#[error("Get collection error")]
GetCollectionError(#[from] GetCollectionsError),
}

impl ChromaError for GetCollectionByIdError {
fn code(&self) -> ErrorCodes {
match self {
GetCollectionByIdError::CollectionNotFound(_) => ErrorCodes::NotFound,
GetCollectionByIdError::GetCollectionError(e) => e.code(),
}
}
}

pub(super) async fn get_collection_by_id(
mut sysdb: Box<SysDb>,
collection_id: &Uuid,
) -> Result<Collection, Box<GetCollectionByIdError>> {
let child_span: tracing::Span =
trace_span!(parent: Span::current(), "get collection for collection id");
let collections = sysdb
.get_collections(Some(*collection_id), None, None, None)
.instrument(child_span.clone())
.await;
match collections {
Ok(mut collections) => {
if collections.is_empty() {
return Err(Box::new(GetCollectionByIdError::CollectionNotFound(
*collection_id,
)));
}
Ok(collections.drain(..).next().unwrap())
}
Err(e) => {
return Err(Box::new(GetCollectionByIdError::GetCollectionError(e)));
}
}
}

#[derive(Debug, Error)]
pub(super) enum GetRecordSegmentByCollectionIdError {
#[error("Record segment for collection with id: {0} not found")]
RecordSegmentNotFound(Uuid),
#[error("Get segments error")]
GetSegmentsError(#[from] GetSegmentsError),
}

impl ChromaError for GetRecordSegmentByCollectionIdError {
fn code(&self) -> ErrorCodes {
match self {
GetRecordSegmentByCollectionIdError::RecordSegmentNotFound(_) => ErrorCodes::NotFound,
GetRecordSegmentByCollectionIdError::GetSegmentsError(e) => e.code(),
}
}
}

pub(super) async fn get_record_segment_by_collection_id(
mut sysdb: Box<SysDb>,
collection_id: &Uuid,
) -> Result<Segment, Box<GetRecordSegmentByCollectionIdError>> {
let segments = sysdb
.get_segments(
None,
Some(SegmentType::BlockfileRecord.into()),
None,
Some(*collection_id),
)
.await;

let segment = match segments {
Ok(mut segments) => {
if segments.is_empty() {
return Err(Box::new(
GetRecordSegmentByCollectionIdError::RecordSegmentNotFound(*collection_id),
));
}
segments.drain(..).next().unwrap()
}
Err(e) => {
return Err(Box::new(
GetRecordSegmentByCollectionIdError::GetSegmentsError(e),
));
}
};

if segment.r#type != SegmentType::BlockfileRecord {
return Err(Box::new(
GetRecordSegmentByCollectionIdError::RecordSegmentNotFound(*collection_id),
));
}
Ok(segment)
}
Loading

0 comments on commit 09df9ce

Please sign in to comment.