-
Notifications
You must be signed in to change notification settings - Fork 1.4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[ENH] Get vectors orchestrator (#2348)
This PR adds the Get Vectors() RPC for query nodes. 1. Adds the GetVectors instrumented rpc call and associated types 2. Adds a GetVectorsOrchestrator for managing these queries 3. Adds the GetVectorsOperator for reading from the log and record segment to respond to get_vectors 4. Moves common orchestrator functions into a /common and refactors hnsw orchestrator to use this instead (DRY) 5. Adds the `ChromaError` Trait for ChannelError in Sender - needed to use that as Box<dyn ChromaError> elsewhere. I explicitly make the choice to not have to materialize in this operator since it doesn't need to for any reason. Tasks - [x] Stub out - [x] Types - [x] Implement return - [x] Server input parse - [x] Server output parse - [x] Init segments - [x] Pull log + filter - [x] Operator for reading - [x] Add test (Will handle in separate PR) - [x] Validate that the get() semantics for missing ids is correct (I have validated that the semantics are the same - we just ignore invalid entries in the output)
- Loading branch information
Showing
9 changed files
with
833 additions
and
126 deletions.
There are no files selected for viewing
231 changes: 231 additions & 0 deletions
231
rust/worker/src/execution/operators/get_vectors_operator.rs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,231 @@ | ||
use crate::{ | ||
blockstore::provider::BlockfileProvider, | ||
errors::{ChromaError, ErrorCodes}, | ||
execution::{data::data_chunk::Chunk, operator::Operator}, | ||
segment::record_segment::{self, RecordSegmentReader}, | ||
types::{LogRecord, Segment}, | ||
}; | ||
use async_trait::async_trait; | ||
use std::collections::{HashMap, HashSet}; | ||
use thiserror::Error; | ||
|
||
#[derive(Debug)] | ||
pub struct GetVectorsOperator {} | ||
|
||
impl GetVectorsOperator { | ||
pub fn new() -> Box<Self> { | ||
return Box::new(GetVectorsOperator {}); | ||
} | ||
} | ||
|
||
/// The input to the get vectors operator. | ||
/// # Parameters | ||
/// * `record_segment_definition` - The segment definition for the record segment. | ||
/// * `blockfile_provider` - The blockfile provider. | ||
/// * `log_records` - The log records. | ||
/// * `search_user_ids` - The user ids to search for. | ||
#[derive(Debug)] | ||
pub struct GetVectorsOperatorInput { | ||
record_segment_definition: Segment, | ||
blockfile_provider: BlockfileProvider, | ||
log_records: Chunk<LogRecord>, | ||
search_user_ids: Vec<String>, | ||
} | ||
|
||
impl GetVectorsOperatorInput { | ||
pub fn new( | ||
record_segment_definition: Segment, | ||
blockfile_provider: BlockfileProvider, | ||
log_records: Chunk<LogRecord>, | ||
search_user_ids: Vec<String>, | ||
) -> Self { | ||
return GetVectorsOperatorInput { | ||
record_segment_definition, | ||
blockfile_provider, | ||
log_records, | ||
search_user_ids, | ||
}; | ||
} | ||
} | ||
|
||
/// The output of the get vectors operator. | ||
/// # Parameters | ||
/// * `ids` - The ids of the vectors. | ||
/// * `vectors` - The vectors. | ||
/// # Notes | ||
/// The vectors are in the same order as the ids. | ||
#[derive(Debug)] | ||
pub struct GetVectorsOperatorOutput { | ||
pub(crate) ids: Vec<String>, | ||
pub(crate) vectors: Vec<Vec<f32>>, | ||
} | ||
|
||
#[derive(Debug, Error)] | ||
pub enum GetVectorsOperatorError { | ||
#[error("Error creating record segment reader {0}")] | ||
RecordSegmentReaderCreationError( | ||
#[from] crate::segment::record_segment::RecordSegmentReaderCreationError, | ||
), | ||
#[error(transparent)] | ||
RecordSegmentReaderError(#[from] Box<dyn ChromaError>), | ||
} | ||
|
||
impl ChromaError for GetVectorsOperatorError { | ||
fn code(&self) -> ErrorCodes { | ||
ErrorCodes::Internal | ||
} | ||
} | ||
|
||
#[async_trait] | ||
impl Operator<GetVectorsOperatorInput, GetVectorsOperatorOutput> for GetVectorsOperator { | ||
type Error = GetVectorsOperatorError; | ||
|
||
async fn run( | ||
&self, | ||
input: &GetVectorsOperatorInput, | ||
) -> Result<GetVectorsOperatorOutput, Self::Error> { | ||
let mut output_vectors = HashMap::new(); | ||
|
||
let record_segment_reader = match RecordSegmentReader::from_segment( | ||
&input.record_segment_definition, | ||
&input.blockfile_provider, | ||
) | ||
.await | ||
{ | ||
Ok(reader) => Some(reader), | ||
Err(e) => match *e { | ||
record_segment::RecordSegmentReaderCreationError::UninitializedSegment => None, | ||
record_segment::RecordSegmentReaderCreationError::BlockfileOpenError(_) => { | ||
return Err(GetVectorsOperatorError::RecordSegmentReaderCreationError( | ||
*e, | ||
)) | ||
} | ||
record_segment::RecordSegmentReaderCreationError::InvalidNumberOfFiles => { | ||
return Err(GetVectorsOperatorError::RecordSegmentReaderCreationError( | ||
*e, | ||
)) | ||
} | ||
}, | ||
}; | ||
|
||
// Search the log records for the user ids | ||
let logs = input.log_records.clone(); | ||
let mut remaining_search_user_ids: HashSet<String> = | ||
HashSet::from_iter(input.search_user_ids.iter().cloned()); | ||
for (log_record, _) in logs.iter() { | ||
if remaining_search_user_ids.contains(&log_record.record.id) { | ||
match log_record.record.operation { | ||
crate::types::Operation::Add => { | ||
// If there is a record segment, validate the add | ||
if let Some(ref reader) = record_segment_reader { | ||
match reader.data_exists_for_user_id(&log_record.record.id).await { | ||
Ok(true) => { | ||
// The record exists in the record segment, so this add is faulty | ||
// and we should skip it | ||
continue; | ||
} | ||
Ok(false) => { | ||
// The record does not exist in the record segment, | ||
// the add is valid | ||
|
||
// If the user id is already present in the log set, skip it | ||
// We use the first add log entry for a user id if a | ||
// user has multiple log entries | ||
if output_vectors.contains_key(&log_record.record.id) { | ||
continue; | ||
} | ||
let vector = log_record.record.embedding.as_ref().expect("Invariant violation. The log record for an add does not have an embedding."); | ||
output_vectors | ||
.insert(log_record.record.id.clone(), vector.clone()); | ||
remaining_search_user_ids.remove(&log_record.record.id); | ||
} | ||
Err(e) => { | ||
// If there is an error, we skip the add | ||
return Err(GetVectorsOperatorError::RecordSegmentReaderError( | ||
e.into(), | ||
)); | ||
} | ||
} | ||
} | ||
} | ||
crate::types::Operation::Update => { | ||
// If there is a record segment, validate the update | ||
if let Some(ref reader) = record_segment_reader { | ||
match reader.data_exists_for_user_id(&log_record.record.id).await { | ||
Ok(true) => { | ||
// The record exists in the record segment, so this update is valid | ||
// and we should include it in the output | ||
|
||
// If the update mutates the vector, we need to update the output | ||
match &log_record.record.embedding { | ||
Some(vector) => { | ||
// This will overwrite the vector if it already exists | ||
// (e.g if it was added previously in the log) | ||
output_vectors.insert( | ||
log_record.record.id.clone(), | ||
vector.clone(), | ||
); | ||
remaining_search_user_ids.remove(&log_record.record.id); | ||
} | ||
None => { | ||
// Nothing to do with this as the vector was not updated | ||
} | ||
} | ||
} | ||
Ok(false) => { | ||
// The record does not exist in the record segment, | ||
// the update is faulty. | ||
// We skip the update | ||
continue; | ||
} | ||
Err(e) => { | ||
// If there is an error, we skip the update | ||
return Err(GetVectorsOperatorError::RecordSegmentReaderError( | ||
e.into(), | ||
)); | ||
} | ||
} | ||
} | ||
} | ||
crate::types::Operation::Upsert => { | ||
// The upsert operation does not allow embeddings to be None | ||
// So the final value is always present in the log | ||
let vector = log_record.record.embedding.as_ref().expect("Invariant violation. The log record for an upsert does not have an embedding."); | ||
output_vectors.insert(log_record.record.id.clone(), vector.clone()); | ||
remaining_search_user_ids.remove(&log_record.record.id); | ||
} | ||
crate::types::Operation::Delete => { | ||
// If the user id is present in the output, remove it | ||
output_vectors.remove(&log_record.record.id); | ||
remaining_search_user_ids.remove(&log_record.record.id); | ||
} | ||
} | ||
} | ||
} | ||
|
||
// Search the record segment for the remaining user ids | ||
if !remaining_search_user_ids.is_empty() { | ||
if let Some(reader) = record_segment_reader { | ||
for user_id in remaining_search_user_ids.iter() { | ||
let read_data = reader.get_data_and_offset_id_for_user_id(user_id).await; | ||
match read_data { | ||
Ok((record, _)) => { | ||
output_vectors.insert(record.id.to_string(), record.embedding.to_vec()); | ||
} | ||
Err(_) => { | ||
// If the user id is not found in the record segment, we do not add it to the output | ||
} | ||
} | ||
} | ||
} | ||
} | ||
|
||
let mut ids = Vec::new(); | ||
let mut vectors = Vec::new(); | ||
for (id, vector) in output_vectors.drain() { | ||
ids.push(id); | ||
vectors.push(vector); | ||
} | ||
return Ok(GetVectorsOperatorOutput { ids, vectors }); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,150 @@ | ||
use crate::{ | ||
errors::{ChromaError, ErrorCodes}, | ||
sysdb::sysdb::{GetCollectionsError, GetSegmentsError, SysDb}, | ||
types::{Collection, Segment, SegmentType}, | ||
}; | ||
use thiserror::Error; | ||
use tracing::{trace_span, Instrument, Span}; | ||
use uuid::Uuid; | ||
|
||
#[derive(Debug, Error)] | ||
pub(super) enum GetHnswSegmentByIdError { | ||
#[error("Hnsw segment with id: {0} not found")] | ||
HnswSegmentNotFound(Uuid), | ||
#[error("Get segments error")] | ||
GetSegmentsError(#[from] GetSegmentsError), | ||
} | ||
|
||
impl ChromaError for GetHnswSegmentByIdError { | ||
fn code(&self) -> ErrorCodes { | ||
match self { | ||
GetHnswSegmentByIdError::HnswSegmentNotFound(_) => ErrorCodes::NotFound, | ||
GetHnswSegmentByIdError::GetSegmentsError(e) => e.code(), | ||
} | ||
} | ||
} | ||
|
||
pub(super) async fn get_hnsw_segment_by_id( | ||
mut sysdb: Box<SysDb>, | ||
hnsw_segment_id: &Uuid, | ||
) -> Result<Segment, Box<GetHnswSegmentByIdError>> { | ||
let segments = sysdb | ||
.get_segments(Some(*hnsw_segment_id), None, None, None) | ||
.await; | ||
let segment = match segments { | ||
Ok(segments) => { | ||
if segments.is_empty() { | ||
return Err(Box::new(GetHnswSegmentByIdError::HnswSegmentNotFound( | ||
*hnsw_segment_id, | ||
))); | ||
} | ||
segments[0].clone() | ||
} | ||
Err(e) => { | ||
return Err(Box::new(GetHnswSegmentByIdError::GetSegmentsError(e))); | ||
} | ||
}; | ||
|
||
if segment.r#type != SegmentType::HnswDistributed { | ||
return Err(Box::new(GetHnswSegmentByIdError::HnswSegmentNotFound( | ||
*hnsw_segment_id, | ||
))); | ||
} | ||
Ok(segment) | ||
} | ||
|
||
#[derive(Debug, Error)] | ||
pub(super) enum GetCollectionByIdError { | ||
#[error("Collection with id: {0} not found")] | ||
CollectionNotFound(Uuid), | ||
#[error("Get collection error")] | ||
GetCollectionError(#[from] GetCollectionsError), | ||
} | ||
|
||
impl ChromaError for GetCollectionByIdError { | ||
fn code(&self) -> ErrorCodes { | ||
match self { | ||
GetCollectionByIdError::CollectionNotFound(_) => ErrorCodes::NotFound, | ||
GetCollectionByIdError::GetCollectionError(e) => e.code(), | ||
} | ||
} | ||
} | ||
|
||
pub(super) async fn get_collection_by_id( | ||
mut sysdb: Box<SysDb>, | ||
collection_id: &Uuid, | ||
) -> Result<Collection, Box<GetCollectionByIdError>> { | ||
let child_span: tracing::Span = | ||
trace_span!(parent: Span::current(), "get collection for collection id"); | ||
let collections = sysdb | ||
.get_collections(Some(*collection_id), None, None, None) | ||
.instrument(child_span.clone()) | ||
.await; | ||
match collections { | ||
Ok(mut collections) => { | ||
if collections.is_empty() { | ||
return Err(Box::new(GetCollectionByIdError::CollectionNotFound( | ||
*collection_id, | ||
))); | ||
} | ||
Ok(collections.drain(..).next().unwrap()) | ||
} | ||
Err(e) => { | ||
return Err(Box::new(GetCollectionByIdError::GetCollectionError(e))); | ||
} | ||
} | ||
} | ||
|
||
#[derive(Debug, Error)] | ||
pub(super) enum GetRecordSegmentByCollectionIdError { | ||
#[error("Record segment for collection with id: {0} not found")] | ||
RecordSegmentNotFound(Uuid), | ||
#[error("Get segments error")] | ||
GetSegmentsError(#[from] GetSegmentsError), | ||
} | ||
|
||
impl ChromaError for GetRecordSegmentByCollectionIdError { | ||
fn code(&self) -> ErrorCodes { | ||
match self { | ||
GetRecordSegmentByCollectionIdError::RecordSegmentNotFound(_) => ErrorCodes::NotFound, | ||
GetRecordSegmentByCollectionIdError::GetSegmentsError(e) => e.code(), | ||
} | ||
} | ||
} | ||
|
||
pub(super) async fn get_record_segment_by_collection_id( | ||
mut sysdb: Box<SysDb>, | ||
collection_id: &Uuid, | ||
) -> Result<Segment, Box<GetRecordSegmentByCollectionIdError>> { | ||
let segments = sysdb | ||
.get_segments( | ||
None, | ||
Some(SegmentType::BlockfileRecord.into()), | ||
None, | ||
Some(*collection_id), | ||
) | ||
.await; | ||
|
||
let segment = match segments { | ||
Ok(mut segments) => { | ||
if segments.is_empty() { | ||
return Err(Box::new( | ||
GetRecordSegmentByCollectionIdError::RecordSegmentNotFound(*collection_id), | ||
)); | ||
} | ||
segments.drain(..).next().unwrap() | ||
} | ||
Err(e) => { | ||
return Err(Box::new( | ||
GetRecordSegmentByCollectionIdError::GetSegmentsError(e), | ||
)); | ||
} | ||
}; | ||
|
||
if segment.r#type != SegmentType::BlockfileRecord { | ||
return Err(Box::new( | ||
GetRecordSegmentByCollectionIdError::RecordSegmentNotFound(*collection_id), | ||
)); | ||
} | ||
Ok(segment) | ||
} |
Oops, something went wrong.