diff --git a/Cargo.toml b/Cargo.toml index ff0b6199..a4612579 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -118,7 +118,7 @@ gateway = ["dep:zeph-gateway"] daemon = ["zeph-core/daemon"] scheduler = ["dep:zeph-scheduler"] otel = ["dep:opentelemetry", "dep:opentelemetry_sdk", "dep:opentelemetry-otlp", "dep:tracing-opentelemetry"] -mock = ["zeph-llm/mock"] +mock = ["zeph-llm/mock", "zeph-memory/mock"] [dependencies] anyhow.workspace = true diff --git a/crates/zeph-memory/Cargo.toml b/crates/zeph-memory/Cargo.toml index 63552481..abb798ea 100644 --- a/crates/zeph-memory/Cargo.toml +++ b/crates/zeph-memory/Cargo.toml @@ -20,6 +20,10 @@ zeph-llm.workspace = true name = "token_estimation" harness = false +[features] +default = [] +mock = [] + [dev-dependencies] anyhow.workspace = true criterion.workspace = true diff --git a/crates/zeph-memory/src/qdrant.rs b/crates/zeph-memory/src/embedding_store.rs similarity index 80% rename from crates/zeph-memory/src/qdrant.rs rename to crates/zeph-memory/src/embedding_store.rs index 95289f50..0cd7a7cf 100644 --- a/crates/zeph-memory/src/qdrant.rs +++ b/crates/zeph-memory/src/embedding_store.rs @@ -1,10 +1,10 @@ pub use qdrant_client::qdrant::Filter; -use qdrant_client::qdrant::{Condition, PointStruct}; use sqlx::SqlitePool; use crate::error::MemoryError; use crate::qdrant_ops::QdrantOps; use crate::types::{ConversationId, MessageId}; +use crate::vector_store::{FieldCondition, FieldValue, VectorFilter, VectorPoint, VectorStore}; /// Distinguishes regular messages from summaries when storing embeddings. #[derive(Debug, Clone, Copy, PartialEq, Eq)] @@ -37,15 +37,15 @@ pub async fn ensure_qdrant_collection( ops.ensure_collection(collection, vector_size).await } -pub struct QdrantStore { - ops: QdrantOps, +pub struct EmbeddingStore { + ops: Box, collection: String, pool: SqlitePool, } -impl std::fmt::Debug for QdrantStore { +impl std::fmt::Debug for EmbeddingStore { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("QdrantStore") + f.debug_struct("EmbeddingStore") .field("collection", &self.collection) .finish_non_exhaustive() } @@ -64,8 +64,8 @@ pub struct SearchResult { pub score: f32, } -impl QdrantStore { - /// Create a new `QdrantStore` connected to the given Qdrant URL. +impl EmbeddingStore { + /// Create a new `EmbeddingStore` connected to the given Qdrant URL. /// /// The `pool` is used for `SQLite` metadata operations on the `embeddings_metadata` /// table (which must already exist via sqlx migrations). @@ -77,16 +77,19 @@ impl QdrantStore { let ops = QdrantOps::new(url).map_err(MemoryError::Qdrant)?; Ok(Self { - ops, + ops: Box::new(ops), collection: COLLECTION_NAME.into(), pool, }) } - /// Access the underlying `QdrantOps`. #[must_use] - pub fn ops(&self) -> &QdrantOps { - &self.ops + pub fn with_store(store: Box, pool: SqlitePool) -> Self { + Self { + ops: store, + collection: COLLECTION_NAME.into(), + pool, + } } /// Ensure the collection exists in Qdrant with the given vector size. @@ -122,15 +125,24 @@ impl QdrantStore { let point_id = uuid::Uuid::new_v4().to_string(); let dimensions = i64::try_from(vector.len())?; - let payload = serde_json::json!({ - "message_id": message_id.0, - "conversation_id": conversation_id.0, - "role": role, - "is_summary": kind.is_summary(), - }); - let payload_map = QdrantOps::json_to_payload(payload)?; - - let point = PointStruct::new(point_id.clone(), vector, payload_map); + let payload = std::collections::HashMap::from([ + ("message_id".to_owned(), serde_json::json!(message_id.0)), + ( + "conversation_id".to_owned(), + serde_json::json!(conversation_id.0), + ), + ("role".to_owned(), serde_json::json!(role)), + ( + "is_summary".to_owned(), + serde_json::json!(kind.is_summary()), + ), + ]); + + let point = VectorPoint { + id: point_id.clone(), + vector, + payload, + }; self.ops.upsert(&self.collection, vec![point]).await?; @@ -163,18 +175,27 @@ impl QdrantStore { ) -> Result, MemoryError> { let limit_u64 = u64::try_from(limit)?; - let qdrant_filter = filter.as_ref().and_then(|f| { - let mut conditions = Vec::new(); + let vector_filter = filter.as_ref().and_then(|f| { + let mut must = Vec::new(); if let Some(cid) = f.conversation_id { - conditions.push(Condition::matches("conversation_id", cid.0)); + must.push(FieldCondition { + field: "conversation_id".into(), + value: FieldValue::Integer(cid.0), + }); } if let Some(ref role) = f.role { - conditions.push(Condition::matches("role", role.clone())); + must.push(FieldCondition { + field: "role".into(), + value: FieldValue::Text(role.clone()), + }); } - if conditions.is_empty() { + if must.is_empty() { None } else { - Some(Filter::must(conditions)) + Some(VectorFilter { + must, + must_not: vec![], + }) } }); @@ -184,16 +205,16 @@ impl QdrantStore { &self.collection, query_vector.to_vec(), limit_u64, - qdrant_filter, + vector_filter, ) .await?; let search_results = results .into_iter() .filter_map(|point| { - let payload = &point.payload; - let message_id = MessageId(payload.get("message_id")?.as_integer()?); - let conversation_id = ConversationId(payload.get("conversation_id")?.as_integer()?); + let message_id = MessageId(point.payload.get("message_id")?.as_i64()?); + let conversation_id = + ConversationId(point.payload.get("conversation_id")?.as_i64()?); Some(SearchResult { message_id, conversation_id, @@ -233,8 +254,13 @@ impl QdrantStore { vector: Vec, ) -> Result { let point_id = uuid::Uuid::new_v4().to_string(); - let payload_map = QdrantOps::json_to_payload(payload)?; - let point = PointStruct::new(point_id.clone(), vector, payload_map); + let payload_map: std::collections::HashMap = + serde_json::from_value(payload)?; + let point = VectorPoint { + id: point_id.clone(), + vector, + payload: payload_map, + }; self.ops.upsert(collection, vec![point]).await?; Ok(point_id) } @@ -249,8 +275,8 @@ impl QdrantStore { collection: &str, query_vector: &[f32], limit: usize, - filter: Option, - ) -> Result, MemoryError> { + filter: Option, + ) -> Result, MemoryError> { let limit_u64 = u64::try_from(limit)?; let results = self .ops diff --git a/crates/zeph-memory/src/error.rs b/crates/zeph-memory/src/error.rs index 24331eef..62495dc7 100644 --- a/crates/zeph-memory/src/error.rs +++ b/crates/zeph-memory/src/error.rs @@ -6,6 +6,9 @@ pub enum MemoryError { #[error("Qdrant error: {0}")] Qdrant(#[from] Box), + #[error("vector store error: {0}")] + VectorStore(#[from] crate::vector_store::VectorStoreError), + #[error("migration failed: {0}")] Migration(#[from] sqlx::migrate::MigrateError), diff --git a/crates/zeph-memory/src/in_memory_store.rs b/crates/zeph-memory/src/in_memory_store.rs new file mode 100644 index 00000000..c8b41b74 --- /dev/null +++ b/crates/zeph-memory/src/in_memory_store.rs @@ -0,0 +1,405 @@ +use std::collections::HashMap; +use std::future::Future; +use std::pin::Pin; +use std::sync::RwLock; + +use crate::vector_store::{ + FieldValue, ScoredVectorPoint, VectorFilter, VectorPoint, VectorStore, VectorStoreError, +}; + +type BoxFuture<'a, T> = Pin + Send + 'a>>; + +struct StoredPoint { + vector: Vec, + payload: HashMap, +} + +struct InMemoryCollection { + points: HashMap, +} + +pub struct InMemoryVectorStore { + collections: RwLock>, +} + +impl InMemoryVectorStore { + #[must_use] + pub fn new() -> Self { + Self { + collections: RwLock::new(HashMap::new()), + } + } +} + +impl Default for InMemoryVectorStore { + fn default() -> Self { + Self::new() + } +} + +impl std::fmt::Debug for InMemoryVectorStore { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("InMemoryVectorStore") + .finish_non_exhaustive() + } +} + +fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 { + let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum(); + let norm_a: f32 = a.iter().map(|x| x * x).sum::().sqrt(); + let norm_b: f32 = b.iter().map(|x| x * x).sum::().sqrt(); + if norm_a == 0.0 || norm_b == 0.0 { + return 0.0; + } + dot / (norm_a * norm_b) +} + +fn matches_filter(payload: &HashMap, filter: &VectorFilter) -> bool { + for cond in &filter.must { + let Some(val) = payload.get(&cond.field) else { + return false; + }; + if !field_matches(val, &cond.value) { + return false; + } + } + for cond in &filter.must_not { + if let Some(val) = payload.get(&cond.field) + && field_matches(val, &cond.value) + { + return false; + } + } + true +} + +fn field_matches(val: &serde_json::Value, expected: &FieldValue) -> bool { + match expected { + FieldValue::Integer(i) => val.as_i64() == Some(*i), + FieldValue::Text(s) => val.as_str() == Some(s.as_str()), + } +} + +impl VectorStore for InMemoryVectorStore { + fn ensure_collection( + &self, + collection: &str, + _vector_size: u64, + ) -> BoxFuture<'_, Result<(), VectorStoreError>> { + let collection = collection.to_owned(); + Box::pin(async move { + let mut cols = self + .collections + .write() + .map_err(|e| VectorStoreError::Collection(e.to_string()))?; + cols.entry(collection) + .or_insert_with(|| InMemoryCollection { + points: HashMap::new(), + }); + Ok(()) + }) + } + + fn collection_exists(&self, collection: &str) -> BoxFuture<'_, Result> { + let collection = collection.to_owned(); + Box::pin(async move { + let cols = self + .collections + .read() + .map_err(|e| VectorStoreError::Collection(e.to_string()))?; + Ok(cols.contains_key(&collection)) + }) + } + + fn delete_collection(&self, collection: &str) -> BoxFuture<'_, Result<(), VectorStoreError>> { + let collection = collection.to_owned(); + Box::pin(async move { + let mut cols = self + .collections + .write() + .map_err(|e| VectorStoreError::Collection(e.to_string()))?; + cols.remove(&collection); + Ok(()) + }) + } + + fn upsert( + &self, + collection: &str, + points: Vec, + ) -> BoxFuture<'_, Result<(), VectorStoreError>> { + let collection = collection.to_owned(); + Box::pin(async move { + let mut cols = self + .collections + .write() + .map_err(|e| VectorStoreError::Upsert(e.to_string()))?; + let col = cols.get_mut(&collection).ok_or_else(|| { + VectorStoreError::Upsert(format!("collection {collection} not found")) + })?; + for p in points { + col.points.insert( + p.id, + StoredPoint { + vector: p.vector, + payload: p.payload, + }, + ); + } + Ok(()) + }) + } + + fn search( + &self, + collection: &str, + vector: Vec, + limit: u64, + filter: Option, + ) -> BoxFuture<'_, Result, VectorStoreError>> { + let collection = collection.to_owned(); + Box::pin(async move { + let cols = self + .collections + .read() + .map_err(|e| VectorStoreError::Search(e.to_string()))?; + let col = cols.get(&collection).ok_or_else(|| { + VectorStoreError::Search(format!("collection {collection} not found")) + })?; + + let empty_filter = VectorFilter::default(); + let f = filter.as_ref().unwrap_or(&empty_filter); + + let mut scored: Vec = col + .points + .iter() + .filter(|(_, sp)| matches_filter(&sp.payload, f)) + .map(|(id, sp)| ScoredVectorPoint { + id: id.clone(), + score: cosine_similarity(&vector, &sp.vector), + payload: sp.payload.clone(), + }) + .collect(); + + scored.sort_by(|a, b| { + b.score + .partial_cmp(&a.score) + .unwrap_or(std::cmp::Ordering::Equal) + }); + #[expect(clippy::cast_possible_truncation)] + scored.truncate(limit as usize); + Ok(scored) + }) + } + + fn delete_by_ids( + &self, + collection: &str, + ids: Vec, + ) -> BoxFuture<'_, Result<(), VectorStoreError>> { + let collection = collection.to_owned(); + Box::pin(async move { + if ids.is_empty() { + return Ok(()); + } + let mut cols = self + .collections + .write() + .map_err(|e| VectorStoreError::Delete(e.to_string()))?; + let col = cols.get_mut(&collection).ok_or_else(|| { + VectorStoreError::Delete(format!("collection {collection} not found")) + })?; + for id in &ids { + col.points.remove(id); + } + Ok(()) + }) + } + + fn scroll_all( + &self, + collection: &str, + key_field: &str, + ) -> BoxFuture<'_, Result>, VectorStoreError>> { + let collection = collection.to_owned(); + let key_field = key_field.to_owned(); + Box::pin(async move { + let cols = self + .collections + .read() + .map_err(|e| VectorStoreError::Scroll(e.to_string()))?; + let col = cols.get(&collection).ok_or_else(|| { + VectorStoreError::Scroll(format!("collection {collection} not found")) + })?; + + let mut result = HashMap::new(); + for sp in col.points.values() { + let Some(key_val) = sp.payload.get(&key_field).and_then(|v| v.as_str()) else { + continue; + }; + let mut fields = HashMap::new(); + for (k, v) in &sp.payload { + if let Some(s) = v.as_str() { + fields.insert(k.clone(), s.to_owned()); + } + } + result.insert(key_val.to_owned(), fields); + } + Ok(result) + }) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[tokio::test] + async fn ensure_collection_and_exists() { + let store = InMemoryVectorStore::new(); + assert!(!store.collection_exists("test").await.unwrap()); + store.ensure_collection("test", 3).await.unwrap(); + assert!(store.collection_exists("test").await.unwrap()); + } + + #[tokio::test] + async fn ensure_collection_idempotent() { + let store = InMemoryVectorStore::new(); + store.ensure_collection("test", 3).await.unwrap(); + store.ensure_collection("test", 3).await.unwrap(); + assert!(store.collection_exists("test").await.unwrap()); + } + + #[tokio::test] + async fn delete_collection_removes() { + let store = InMemoryVectorStore::new(); + store.ensure_collection("test", 3).await.unwrap(); + store.delete_collection("test").await.unwrap(); + assert!(!store.collection_exists("test").await.unwrap()); + } + + #[tokio::test] + async fn upsert_and_search() { + let store = InMemoryVectorStore::new(); + store.ensure_collection("test", 3).await.unwrap(); + + let points = vec![ + VectorPoint { + id: "a".into(), + vector: vec![1.0, 0.0, 0.0], + payload: HashMap::from([("name".into(), serde_json::json!("alpha"))]), + }, + VectorPoint { + id: "b".into(), + vector: vec![0.0, 1.0, 0.0], + payload: HashMap::from([("name".into(), serde_json::json!("beta"))]), + }, + ]; + store.upsert("test", points).await.unwrap(); + + let results = store + .search("test", vec![1.0, 0.0, 0.0], 2, None) + .await + .unwrap(); + assert_eq!(results.len(), 2); + assert_eq!(results[0].id, "a"); + assert!((results[0].score - 1.0).abs() < f32::EPSILON); + } + + #[tokio::test] + async fn search_with_filter() { + let store = InMemoryVectorStore::new(); + store.ensure_collection("test", 3).await.unwrap(); + + let points = vec![ + VectorPoint { + id: "a".into(), + vector: vec![1.0, 0.0, 0.0], + payload: HashMap::from([("role".into(), serde_json::json!("user"))]), + }, + VectorPoint { + id: "b".into(), + vector: vec![0.9, 0.1, 0.0], + payload: HashMap::from([("role".into(), serde_json::json!("assistant"))]), + }, + ]; + store.upsert("test", points).await.unwrap(); + + let filter = VectorFilter { + must: vec![crate::vector_store::FieldCondition { + field: "role".into(), + value: FieldValue::Text("user".into()), + }], + must_not: vec![], + }; + let results = store + .search("test", vec![1.0, 0.0, 0.0], 10, Some(filter)) + .await + .unwrap(); + assert_eq!(results.len(), 1); + assert_eq!(results[0].id, "a"); + } + + #[tokio::test] + async fn delete_by_ids_removes_points() { + let store = InMemoryVectorStore::new(); + store.ensure_collection("test", 3).await.unwrap(); + + let points = vec![VectorPoint { + id: "a".into(), + vector: vec![1.0, 0.0, 0.0], + payload: HashMap::new(), + }]; + store.upsert("test", points).await.unwrap(); + store.delete_by_ids("test", vec!["a".into()]).await.unwrap(); + + let results = store + .search("test", vec![1.0, 0.0, 0.0], 10, None) + .await + .unwrap(); + assert!(results.is_empty()); + } + + #[tokio::test] + async fn scroll_all_extracts_strings() { + let store = InMemoryVectorStore::new(); + store.ensure_collection("test", 3).await.unwrap(); + + let points = vec![VectorPoint { + id: "a".into(), + vector: vec![1.0, 0.0, 0.0], + payload: HashMap::from([ + ("name".into(), serde_json::json!("alpha")), + ("desc".into(), serde_json::json!("first")), + ("num".into(), serde_json::json!(42)), + ]), + }]; + store.upsert("test", points).await.unwrap(); + + let result = store.scroll_all("test", "name").await.unwrap(); + assert_eq!(result.len(), 1); + let fields = result.get("alpha").unwrap(); + assert_eq!(fields.get("desc").unwrap(), "first"); + assert!(!fields.contains_key("num")); + } + + #[test] + fn cosine_similarity_orthogonal() { + let a = vec![1.0, 0.0, 0.0]; + let b = vec![0.0, 1.0, 0.0]; + assert!((cosine_similarity(&a, &b)).abs() < f32::EPSILON); + } + + #[tokio::test] + async fn default_impl() { + let store = InMemoryVectorStore::default(); + assert!(!store.collection_exists("any").await.unwrap()); + } + + #[test] + fn debug_format() { + let store = InMemoryVectorStore::new(); + let dbg = format!("{store:?}"); + assert!(dbg.contains("InMemoryVectorStore")); + } +} diff --git a/crates/zeph-memory/src/lib.rs b/crates/zeph-memory/src/lib.rs index 764efcdb..2d7f9960 100644 --- a/crates/zeph-memory/src/lib.rs +++ b/crates/zeph-memory/src/lib.rs @@ -1,14 +1,21 @@ //! SQLite-backed conversation persistence with Qdrant vector search. +pub mod embedding_store; pub mod error; -pub mod qdrant; +#[cfg(feature = "mock")] +pub mod in_memory_store; pub mod qdrant_ops; pub mod semantic; pub mod sqlite; pub mod types; +pub mod vector_store; +pub use embedding_store::ensure_qdrant_collection; pub use error::MemoryError; -pub use qdrant::ensure_qdrant_collection; pub use qdrant_ops::QdrantOps; pub use semantic::estimate_tokens; pub use types::{ConversationId, MessageId}; +pub use vector_store::{ + FieldCondition, FieldValue, ScoredVectorPoint, VectorFilter, VectorPoint, VectorStore, + VectorStoreError, +}; diff --git a/crates/zeph-memory/src/qdrant_ops.rs b/crates/zeph-memory/src/qdrant_ops.rs index 0f16983b..f10bcb9c 100644 --- a/crates/zeph-memory/src/qdrant_ops.rs +++ b/crates/zeph-memory/src/qdrant_ops.rs @@ -205,6 +205,202 @@ impl QdrantOps { } } +impl crate::vector_store::VectorStore for QdrantOps { + fn ensure_collection( + &self, + collection: &str, + vector_size: u64, + ) -> std::pin::Pin< + Box> + Send + '_>, + > { + let collection = collection.to_owned(); + Box::pin(async move { + self.ensure_collection(&collection, vector_size) + .await + .map_err(|e| crate::VectorStoreError::Collection(e.to_string())) + }) + } + + fn collection_exists( + &self, + collection: &str, + ) -> std::pin::Pin< + Box> + Send + '_>, + > { + let collection = collection.to_owned(); + Box::pin(async move { + self.collection_exists(&collection) + .await + .map_err(|e| crate::VectorStoreError::Collection(e.to_string())) + }) + } + + fn delete_collection( + &self, + collection: &str, + ) -> std::pin::Pin< + Box> + Send + '_>, + > { + let collection = collection.to_owned(); + Box::pin(async move { + self.delete_collection(&collection) + .await + .map_err(|e| crate::VectorStoreError::Collection(e.to_string())) + }) + } + + fn upsert( + &self, + collection: &str, + points: Vec, + ) -> std::pin::Pin< + Box> + Send + '_>, + > { + let collection = collection.to_owned(); + Box::pin(async move { + let qdrant_points: Vec = points + .into_iter() + .map(|p| { + let payload: HashMap = + serde_json::from_value(serde_json::Value::Object( + p.payload.into_iter().collect(), + )) + .unwrap_or_default(); + PointStruct::new(p.id, p.vector, payload) + }) + .collect(); + self.upsert(&collection, qdrant_points) + .await + .map_err(|e| crate::VectorStoreError::Upsert(e.to_string())) + }) + } + + fn search( + &self, + collection: &str, + vector: Vec, + limit: u64, + filter: Option, + ) -> std::pin::Pin< + Box< + dyn std::future::Future< + Output = Result, crate::VectorStoreError>, + > + Send + + '_, + >, + > { + let collection = collection.to_owned(); + Box::pin(async move { + let qdrant_filter = filter.map(vector_filter_to_qdrant); + let results = self + .search(&collection, vector, limit, qdrant_filter) + .await + .map_err(|e| crate::VectorStoreError::Search(e.to_string()))?; + Ok(results.into_iter().map(scored_point_to_vector).collect()) + }) + } + + fn delete_by_ids( + &self, + collection: &str, + ids: Vec, + ) -> std::pin::Pin< + Box> + Send + '_>, + > { + let collection = collection.to_owned(); + Box::pin(async move { + let point_ids: Vec = ids.into_iter().map(PointId::from).collect(); + self.delete_by_ids(&collection, point_ids) + .await + .map_err(|e| crate::VectorStoreError::Delete(e.to_string())) + }) + } + + fn scroll_all( + &self, + collection: &str, + key_field: &str, + ) -> std::pin::Pin< + Box< + dyn std::future::Future< + Output = Result< + HashMap>, + crate::VectorStoreError, + >, + > + Send + + '_, + >, + > { + let collection = collection.to_owned(); + let key_field = key_field.to_owned(); + Box::pin(async move { + self.scroll_all(&collection, &key_field) + .await + .map_err(|e| crate::VectorStoreError::Scroll(e.to_string())) + }) + } +} + +fn vector_filter_to_qdrant(filter: crate::VectorFilter) -> Filter { + let must: Vec<_> = filter + .must + .into_iter() + .map(field_condition_to_qdrant) + .collect(); + let must_not: Vec<_> = filter + .must_not + .into_iter() + .map(field_condition_to_qdrant) + .collect(); + + let mut f = Filter::default(); + if !must.is_empty() { + f.must = must; + } + if !must_not.is_empty() { + f.must_not = must_not; + } + f +} + +fn field_condition_to_qdrant(cond: crate::FieldCondition) -> qdrant_client::qdrant::Condition { + match cond.value { + crate::FieldValue::Integer(v) => qdrant_client::qdrant::Condition::matches(cond.field, v), + crate::FieldValue::Text(v) => qdrant_client::qdrant::Condition::matches(cond.field, v), + } +} + +fn scored_point_to_vector(point: ScoredPoint) -> crate::ScoredVectorPoint { + let payload: HashMap = point + .payload + .into_iter() + .filter_map(|(k, v)| { + let json_val = match v.kind? { + Kind::StringValue(s) => serde_json::Value::String(s), + Kind::IntegerValue(i) => serde_json::Value::Number(i.into()), + Kind::DoubleValue(d) => { + serde_json::Number::from_f64(d).map(serde_json::Value::Number)? + } + Kind::BoolValue(b) => serde_json::Value::Bool(b), + _ => return None, + }; + Some((k, json_val)) + }) + .collect(); + + let id = match point.id.and_then(|pid| pid.point_id_options) { + Some(qdrant_client::qdrant::point_id::PointIdOptions::Uuid(u)) => u, + Some(qdrant_client::qdrant::point_id::PointIdOptions::Num(n)) => n.to_string(), + None => String::new(), + }; + + crate::ScoredVectorPoint { + id, + score: point.score, + payload, + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/crates/zeph-memory/src/semantic.rs b/crates/zeph-memory/src/semantic.rs index d846be64..40284509 100644 --- a/crates/zeph-memory/src/semantic.rs +++ b/crates/zeph-memory/src/semantic.rs @@ -1,11 +1,11 @@ -use qdrant_client::qdrant::Condition; use zeph_llm::any::AnyProvider; use zeph_llm::provider::{LlmProvider, Message, Role}; +use crate::embedding_store::{EmbeddingStore, MessageKind, SearchFilter}; use crate::error::MemoryError; -use crate::qdrant::{Filter, MessageKind, QdrantStore, SearchFilter}; use crate::sqlite::SqliteStore; use crate::types::{ConversationId, MessageId}; +use crate::vector_store::{FieldCondition, FieldValue, VectorFilter}; const SESSION_SUMMARIES_COLLECTION: &str = "zeph_session_summaries"; @@ -58,7 +58,7 @@ fn build_summarization_prompt(messages: &[(MessageId, String, String)]) -> Strin pub struct SemanticMemory { sqlite: SqliteStore, - qdrant: Option, + qdrant: Option, provider: AnyProvider, embedding_model: String, vector_weight: f64, @@ -98,7 +98,7 @@ impl SemanticMemory { let sqlite = SqliteStore::new(sqlite_path).await?; let pool = sqlite.pool().clone(); - let qdrant = match QdrantStore::new(qdrant_url, pool) { + let qdrant = match EmbeddingStore::new(qdrant_url, pool) { Ok(store) => Some(store), Err(e) => { tracing::warn!("Qdrant unavailable, semantic search disabled: {e:#}"); @@ -450,8 +450,13 @@ impl SemanticMemory { .ensure_named_collection(SESSION_SUMMARIES_COLLECTION, vector_size) .await?; - let filter = exclude_conversation_id - .map(|cid| Filter::must_not(vec![Condition::matches("conversation_id", cid.0)])); + let filter = exclude_conversation_id.map(|cid| VectorFilter { + must: vec![], + must_not: vec![FieldCondition { + field: "conversation_id".into(), + value: FieldValue::Integer(cid.0), + }], + }); let points = qdrant .search_collection(SESSION_SUMMARIES_COLLECTION, &vector, limit, filter) @@ -460,9 +465,9 @@ impl SemanticMemory { let results = points .into_iter() .filter_map(|point| { - let payload = &point.payload; - let summary_text = payload.get("summary_text")?.as_str()?.to_owned(); - let conversation_id = ConversationId(payload.get("conversation_id")?.as_integer()?); + let summary_text = point.payload.get("summary_text")?.as_str()?.to_owned(); + let conversation_id = + ConversationId(point.payload.get("conversation_id")?.as_i64()?); Some(SessionSummaryResult { summary_text, score: point.score, diff --git a/crates/zeph-memory/src/vector_store.rs b/crates/zeph-memory/src/vector_store.rs new file mode 100644 index 00000000..6964f68e --- /dev/null +++ b/crates/zeph-memory/src/vector_store.rs @@ -0,0 +1,95 @@ +use std::collections::HashMap; +use std::future::Future; +use std::pin::Pin; + +#[derive(Debug, thiserror::Error)] +pub enum VectorStoreError { + #[error("connection error: {0}")] + Connection(String), + #[error("collection error: {0}")] + Collection(String), + #[error("upsert error: {0}")] + Upsert(String), + #[error("search error: {0}")] + Search(String), + #[error("delete error: {0}")] + Delete(String), + #[error("scroll error: {0}")] + Scroll(String), + #[error("serialization error: {0}")] + Serialization(String), +} + +#[derive(Debug, Clone)] +pub struct VectorPoint { + pub id: String, + pub vector: Vec, + pub payload: HashMap, +} + +#[derive(Debug, Clone, Default)] +pub struct VectorFilter { + pub must: Vec, + pub must_not: Vec, +} + +#[derive(Debug, Clone)] +pub struct FieldCondition { + pub field: String, + pub value: FieldValue, +} + +#[derive(Debug, Clone)] +pub enum FieldValue { + Integer(i64), + Text(String), +} + +#[derive(Debug, Clone)] +pub struct ScoredVectorPoint { + pub id: String, + pub score: f32, + pub payload: HashMap, +} + +type BoxFuture<'a, T> = Pin + Send + 'a>>; + +pub type ScrollResult = HashMap>; + +pub trait VectorStore: Send + Sync { + fn ensure_collection( + &self, + collection: &str, + vector_size: u64, + ) -> BoxFuture<'_, Result<(), VectorStoreError>>; + + fn collection_exists(&self, collection: &str) -> BoxFuture<'_, Result>; + + fn delete_collection(&self, collection: &str) -> BoxFuture<'_, Result<(), VectorStoreError>>; + + fn upsert( + &self, + collection: &str, + points: Vec, + ) -> BoxFuture<'_, Result<(), VectorStoreError>>; + + fn search( + &self, + collection: &str, + vector: Vec, + limit: u64, + filter: Option, + ) -> BoxFuture<'_, Result, VectorStoreError>>; + + fn delete_by_ids( + &self, + collection: &str, + ids: Vec, + ) -> BoxFuture<'_, Result<(), VectorStoreError>>; + + fn scroll_all( + &self, + collection: &str, + key_field: &str, + ) -> BoxFuture<'_, Result>; +} diff --git a/crates/zeph-memory/tests/qdrant_integration.rs b/crates/zeph-memory/tests/qdrant_integration.rs index 7a9f03c8..294e35d6 100644 --- a/crates/zeph-memory/tests/qdrant_integration.rs +++ b/crates/zeph-memory/tests/qdrant_integration.rs @@ -2,7 +2,7 @@ use testcontainers::ContainerAsync; use testcontainers::GenericImage; use testcontainers::core::{ContainerPort, WaitFor}; use testcontainers::runners::AsyncRunner; -use zeph_memory::qdrant::{MessageKind, QdrantStore}; +use zeph_memory::embedding_store::{EmbeddingStore, MessageKind}; use zeph_memory::sqlite::SqliteStore; const QDRANT_GRPC_PORT: ContainerPort = ContainerPort::Tcp(6334); @@ -13,14 +13,14 @@ fn qdrant_image() -> GenericImage { .with_exposed_port(QDRANT_GRPC_PORT) } -async fn setup_with_qdrant() -> (SqliteStore, QdrantStore, ContainerAsync) { +async fn setup_with_qdrant() -> (SqliteStore, EmbeddingStore, ContainerAsync) { let container = qdrant_image().start().await.unwrap(); let grpc_port = container.get_host_port_ipv4(6334).await.unwrap(); let url = format!("http://127.0.0.1:{grpc_port}"); let sqlite = SqliteStore::new(":memory:").await.unwrap(); let pool = sqlite.pool().clone(); - let store = QdrantStore::new(&url, pool).unwrap(); + let store = EmbeddingStore::new(&url, pool).unwrap(); (sqlite, store, container) } @@ -105,7 +105,7 @@ async fn search_with_conversation_filter() { .unwrap(); let query = vec![0.1, 0.2, 0.3, 0.4]; - let filter = zeph_memory::qdrant::SearchFilter { + let filter = zeph_memory::embedding_store::SearchFilter { conversation_id: Some(cid1), role: None, };