diff --git a/Cargo.lock b/Cargo.lock index b8b2b397..1c26e505 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -8285,6 +8285,7 @@ dependencies = [ "tracing", "uuid", "zeph-llm", + "zeph-memory", "zeph-tools", ] @@ -8321,6 +8322,7 @@ dependencies = [ "tracing", "uuid", "zeph-llm", + "zeph-memory", ] [[package]] diff --git a/crates/zeph-index/src/store.rs b/crates/zeph-index/src/store.rs index 27d39c33..400246eb 100644 --- a/crates/zeph-index/src/store.rs +++ b/crates/zeph-index/src/store.rs @@ -1,11 +1,10 @@ //! `Qdrant` collection + `SQLite` metadata for code chunks. -use qdrant_client::Qdrant; use qdrant_client::qdrant::{ - CreateCollectionBuilder, CreateFieldIndexCollectionBuilder, DeletePointsBuilder, Distance, - FieldType, Filter, PointStruct, PointsIdsList, ScalarQuantizationBuilder, ScoredPoint, - SearchPointsBuilder, UpsertPointsBuilder, VectorParamsBuilder, + CreateCollectionBuilder, CreateFieldIndexCollectionBuilder, Distance, FieldType, Filter, + PointStruct, ScalarQuantizationBuilder, ScoredPoint, VectorParamsBuilder, }; +use zeph_memory::QdrantOps; use crate::error::Result; @@ -14,7 +13,7 @@ const CODE_COLLECTION: &str = "zeph_code_chunks"; /// `Qdrant` + `SQLite` dual-write store for code chunks. #[derive(Clone)] pub struct CodeStore { - qdrant: Qdrant, + ops: QdrantOps, collection: String, pool: sqlx::SqlitePool, } @@ -49,9 +48,9 @@ impl CodeStore { /// /// Returns an error if the `Qdrant` client fails to connect. pub fn new(qdrant_url: &str, pool: sqlx::SqlitePool) -> Result { - let qdrant = Qdrant::from_url(qdrant_url).build().map_err(Box::new)?; + let ops = QdrantOps::new(qdrant_url).map_err(crate::error::IndexError::Qdrant)?; Ok(Self { - qdrant, + ops, collection: CODE_COLLECTION.into(), pool, }) @@ -73,16 +72,12 @@ impl CodeStore { /// /// Returns an error if `Qdrant` operations fail. pub async fn ensure_collection(&self, vector_size: u64) -> Result<()> { - if self - .qdrant - .collection_exists(&self.collection) - .await - .map_err(Box::new)? - { + if self.ops.collection_exists(&self.collection).await? { return Ok(()); } - self.qdrant + self.ops + .client() .create_collection( CreateCollectionBuilder::new(&self.collection) .vectors_config(VectorParamsBuilder::new(vector_size, Distance::Cosine)) @@ -91,30 +86,17 @@ impl CodeStore { .await .map_err(Box::new)?; - self.qdrant - .create_field_index(CreateFieldIndexCollectionBuilder::new( - &self.collection, - "language", - FieldType::Keyword, - )) - .await - .map_err(Box::new)?; - self.qdrant - .create_field_index(CreateFieldIndexCollectionBuilder::new( - &self.collection, - "file_path", - FieldType::Keyword, - )) - .await - .map_err(Box::new)?; - self.qdrant - .create_field_index(CreateFieldIndexCollectionBuilder::new( - &self.collection, - "node_type", - FieldType::Keyword, - )) - .await - .map_err(Box::new)?; + for field in ["language", "file_path", "node_type"] { + self.ops + .client() + .create_field_index(CreateFieldIndexCollectionBuilder::new( + &self.collection, + field, + FieldType::Keyword, + )) + .await + .map_err(Box::new)?; + } Ok(()) } @@ -127,26 +109,24 @@ impl CodeStore { pub async fn upsert_chunk(&self, chunk: &ChunkInsert<'_>, vector: Vec) -> Result { let point_id = uuid::Uuid::new_v4().to_string(); - let payload: std::collections::HashMap = - serde_json::from_value(serde_json::json!({ - "file_path": chunk.file_path, - "language": chunk.language, - "node_type": chunk.node_type, - "entity_name": chunk.entity_name, - "line_start": chunk.line_start, - "line_end": chunk.line_end, - "code": chunk.code, - "scope_chain": chunk.scope_chain, - "content_hash": chunk.content_hash, - }))?; - - self.qdrant - .upsert_points(UpsertPointsBuilder::new( + let payload = QdrantOps::json_to_payload(serde_json::json!({ + "file_path": chunk.file_path, + "language": chunk.language, + "node_type": chunk.node_type, + "entity_name": chunk.entity_name, + "line_start": chunk.line_start, + "line_end": chunk.line_end, + "code": chunk.code, + "scope_chain": chunk.scope_chain, + "content_hash": chunk.content_hash, + }))?; + + self.ops + .upsert( &self.collection, vec![PointStruct::new(point_id.clone(), vector, payload)], - )) - .await - .map_err(Box::new)?; + ) + .await?; let line_start = i64::try_from(chunk.line_start)?; let line_end = i64::try_from(chunk.line_end)?; @@ -205,12 +185,7 @@ impl CodeStore { .map(|(id,)| id.clone().into()) .collect::>(); - self.qdrant - .delete_points( - DeletePointsBuilder::new(&self.collection).points(PointsIdsList { ids: point_ids }), - ) - .await - .map_err(Box::new)?; + self.ops.delete_by_ids(&self.collection, point_ids).await?; let count = ids.len(); sqlx::query("DELETE FROM chunk_metadata WHERE file_path = ?") @@ -232,17 +207,13 @@ impl CodeStore { limit: usize, filter: Option, ) -> Result> { - let mut builder = SearchPointsBuilder::new(&self.collection, query_vector, limit as u64) - .with_payload(true); - - if let Some(f) = filter { - builder = builder.filter(f); - } - - let results = self.qdrant.search_points(builder).await.map_err(Box::new)?; + let limit_u64 = u64::try_from(limit)?; + let results = self + .ops + .search(&self.collection, query_vector, limit_u64, filter) + .await?; Ok(results - .result .iter() .filter_map(SearchHit::from_scored_point) .collect()) diff --git a/crates/zeph-mcp/Cargo.toml b/crates/zeph-mcp/Cargo.toml index 0a2ac662..d3c034b9 100644 --- a/crates/zeph-mcp/Cargo.toml +++ b/crates/zeph-mcp/Cargo.toml @@ -8,7 +8,7 @@ repository.workspace = true [features] default = [] -qdrant = ["dep:blake3", "dep:qdrant-client", "dep:uuid"] +qdrant = ["dep:blake3", "dep:qdrant-client", "dep:uuid", "dep:zeph-memory"] [dependencies] blake3 = { workspace = true, optional = true } @@ -21,6 +21,7 @@ tokio = { workspace = true, features = ["process", "sync", "time", "rt"] } tracing.workspace = true uuid = { workspace = true, optional = true, features = ["v5"] } zeph-llm.workspace = true +zeph-memory = { workspace = true, optional = true } zeph-tools.workspace = true [dev-dependencies] diff --git a/crates/zeph-mcp/src/registry.rs b/crates/zeph-mcp/src/registry.rs index 332706a9..f7751963 100644 --- a/crates/zeph-mcp/src/registry.rs +++ b/crates/zeph-mcp/src/registry.rs @@ -1,11 +1,7 @@ use std::collections::HashMap; -use qdrant_client::Qdrant; -use qdrant_client::qdrant::{ - CreateCollectionBuilder, DeletePointsBuilder, Distance, PointStruct, PointsIdsList, - ScrollPointsBuilder, SearchPointsBuilder, UpsertPointsBuilder, VectorParamsBuilder, - value::Kind, -}; +use qdrant_client::qdrant::{PointStruct, value::Kind}; +use zeph_memory::QdrantOps; use crate::error::McpError; use crate::tool::McpTool; @@ -30,7 +26,7 @@ pub struct SyncStats { } pub struct McpToolRegistry { - client: Qdrant, + ops: QdrantOps, collection: String, hashes: HashMap, } @@ -61,10 +57,10 @@ impl McpToolRegistry { /// /// Returns an error if the Qdrant client cannot be created. pub fn new(qdrant_url: &str) -> Result { - let client = Qdrant::from_url(qdrant_url).build().map_err(Box::new)?; + let ops = QdrantOps::new(qdrant_url)?; Ok(Self { - client, + ops, collection: COLLECTION_NAME.into(), hashes: HashMap::new(), }) @@ -88,7 +84,7 @@ impl McpToolRegistry { self.ensure_collection(&embed_fn).await?; - let existing = self.scroll_all().await?; + let existing = self.ops.scroll_all(&self.collection, "tool_key").await?; let mut current: HashMap = HashMap::with_capacity(tools.len()); for tool in tools { @@ -130,7 +126,7 @@ impl McpToolRegistry { }; let point_id = tool_point_id(key); - let payload: serde_json::Value = serde_json::json!({ + let payload = serde_json::json!({ "tool_key": key, "server_id": tool.server_id, "tool_name": tool.name, @@ -138,8 +134,7 @@ impl McpToolRegistry { "content_hash": hash, "embedding_model": embedding_model, }); - let payload_map: HashMap = - serde_json::from_value(payload)?; + let payload_map = QdrantOps::json_to_payload(payload)?; points_to_upsert.push(PointStruct::new(point_id, vector, payload_map)); @@ -152,31 +147,18 @@ impl McpToolRegistry { } if !points_to_upsert.is_empty() { - self.client - .upsert_points(UpsertPointsBuilder::new(&self.collection, points_to_upsert)) - .await - .map_err(Box::new)?; + self.ops.upsert(&self.collection, points_to_upsert).await?; } - let orphan_ids: Vec = existing + let orphan_ids: Vec = existing .keys() .filter(|key| !current.contains_key(*key)) - .map(|key| tool_point_id(key)) + .map(|key| qdrant_client::qdrant::PointId::from(tool_point_id(key).as_str())) .collect(); if !orphan_ids.is_empty() { stats.removed = orphan_ids.len(); - let point_ids: Vec = orphan_ids - .into_iter() - .map(|id| qdrant_client::qdrant::PointId::from(id.as_str())) - .collect(); - self.client - .delete_points( - DeletePointsBuilder::new(&self.collection) - .points(PointsIdsList { ids: point_ids }), - ) - .await - .map_err(Box::new)?; + self.ops.delete_by_ids(&self.collection, orphan_ids).await?; } tracing::info!( @@ -208,10 +190,8 @@ impl McpToolRegistry { }; let results = match self - .client - .search_points( - SearchPointsBuilder::new(&self.collection, query_vec, limit_u64).with_payload(true), - ) + .ops + .search(&self.collection, query_vec, limit_u64, None) .await { Ok(r) => r, @@ -222,7 +202,6 @@ impl McpToolRegistry { }; results - .result .into_iter() .filter_map(|point| { let server_id = extract_string(&point.payload, "server_id")?; @@ -242,16 +221,8 @@ impl McpToolRegistry { where F: Fn(&str) -> EmbedFuture, { - if self - .client - .collection_exists(&self.collection) - .await - .map_err(Box::new)? - { - self.client - .delete_collection(&self.collection) - .await - .map_err(Box::new)?; + if self.ops.collection_exists(&self.collection).await? { + self.ops.delete_collection(&self.collection).await?; tracing::info!( collection = &self.collection, "deleted MCP tools collection for recreation" @@ -264,12 +235,7 @@ impl McpToolRegistry { where F: Fn(&str) -> EmbedFuture, { - if self - .client - .collection_exists(&self.collection) - .await - .map_err(Box::new)? - { + if self.ops.collection_exists(&self.collection).await? { return Ok(()); } @@ -278,13 +244,9 @@ impl McpToolRegistry { .map_err(|e| McpError::Embedding(e.to_string()))?; let vector_size = u64::try_from(probe.len())?; - self.client - .create_collection( - CreateCollectionBuilder::new(&self.collection) - .vectors_config(VectorParamsBuilder::new(vector_size, Distance::Cosine)), - ) - .await - .map_err(Box::new)?; + self.ops + .ensure_collection(&self.collection, vector_size) + .await?; tracing::info!( collection = &self.collection, @@ -294,48 +256,6 @@ impl McpToolRegistry { Ok(()) } - - async fn scroll_all(&self) -> Result>, McpError> { - let mut result = HashMap::new(); - let mut offset: Option = None; - - loop { - let mut builder = ScrollPointsBuilder::new(&self.collection) - .with_payload(true) - .with_vectors(false) - .limit(100); - - if let Some(ref off) = offset { - builder = builder.offset(off.clone()); - } - - let response = self.client.scroll(builder).await.map_err(Box::new)?; - - for point in &response.result { - let Some(key_val) = point.payload.get("tool_key") else { - continue; - }; - let Some(Kind::StringValue(key)) = &key_val.kind else { - continue; - }; - - let mut fields = HashMap::new(); - for (k, val) in &point.payload { - if let Some(Kind::StringValue(s)) = &val.kind { - fields.insert(k.clone(), s.clone()); - } - } - result.insert(key.clone(), fields); - } - - match response.next_page_offset { - Some(next) => offset = Some(next), - None => break, - } - } - - Ok(result) - } } fn extract_string( diff --git a/crates/zeph-memory/src/lib.rs b/crates/zeph-memory/src/lib.rs index f3231d31..764efcdb 100644 --- a/crates/zeph-memory/src/lib.rs +++ b/crates/zeph-memory/src/lib.rs @@ -2,11 +2,13 @@ pub mod error; pub mod qdrant; +pub mod qdrant_ops; pub mod semantic; pub mod sqlite; pub mod types; pub use error::MemoryError; pub use qdrant::ensure_qdrant_collection; +pub use qdrant_ops::QdrantOps; pub use semantic::estimate_tokens; pub use types::{ConversationId, MessageId}; diff --git a/crates/zeph-memory/src/qdrant.rs b/crates/zeph-memory/src/qdrant.rs index 14c6b96c..95289f50 100644 --- a/crates/zeph-memory/src/qdrant.rs +++ b/crates/zeph-memory/src/qdrant.rs @@ -1,12 +1,9 @@ -use qdrant_client::Qdrant; pub use qdrant_client::qdrant::Filter; -use qdrant_client::qdrant::{ - Condition, CreateCollectionBuilder, Distance, PointStruct, SearchPointsBuilder, - UpsertPointsBuilder, VectorParamsBuilder, -}; +use qdrant_client::qdrant::{Condition, PointStruct}; use sqlx::SqlitePool; use crate::error::MemoryError; +use crate::qdrant_ops::QdrantOps; use crate::types::{ConversationId, MessageId}; /// Distinguishes regular messages from summaries when storing embeddings. @@ -33,29 +30,15 @@ const COLLECTION_NAME: &str = "zeph_conversations"; /// /// Returns an error if Qdrant cannot be reached or collection creation fails. pub async fn ensure_qdrant_collection( - client: &Qdrant, + ops: &QdrantOps, collection: &str, vector_size: u64, ) -> Result<(), Box> { - if client - .collection_exists(collection) - .await - .map_err(Box::new)? - { - return Ok(()); - } - client - .create_collection( - CreateCollectionBuilder::new(collection) - .vectors_config(VectorParamsBuilder::new(vector_size, Distance::Cosine)), - ) - .await - .map_err(Box::new)?; - Ok(()) + ops.ensure_collection(collection, vector_size).await } pub struct QdrantStore { - client: Qdrant, + ops: QdrantOps, collection: String, pool: SqlitePool, } @@ -91,15 +74,21 @@ impl QdrantStore { /// /// Returns an error if the Qdrant client cannot be created. pub fn new(url: &str, pool: SqlitePool) -> Result { - let client = Qdrant::from_url(url).build().map_err(Box::new)?; + let ops = QdrantOps::new(url).map_err(MemoryError::Qdrant)?; Ok(Self { - client, + ops, collection: COLLECTION_NAME.into(), pool, }) } + /// Access the underlying `QdrantOps`. + #[must_use] + pub fn ops(&self) -> &QdrantOps { + &self.ops + } + /// Ensure the collection exists in Qdrant with the given vector size. /// /// Idempotent: no-op if the collection already exists. @@ -108,7 +97,9 @@ impl QdrantStore { /// /// Returns an error if Qdrant cannot be reached or collection creation fails. pub async fn ensure_collection(&self, vector_size: u64) -> Result<(), MemoryError> { - ensure_qdrant_collection(&self.client, &self.collection, vector_size).await?; + self.ops + .ensure_collection(&self.collection, vector_size) + .await?; Ok(()) } @@ -131,21 +122,17 @@ impl QdrantStore { let point_id = uuid::Uuid::new_v4().to_string(); let dimensions = i64::try_from(vector.len())?; - let payload: serde_json::Value = serde_json::json!({ + let payload = serde_json::json!({ "message_id": message_id.0, "conversation_id": conversation_id.0, "role": role, "is_summary": kind.is_summary(), }); - let payload_map: std::collections::HashMap = - serde_json::from_value(payload)?; + let payload_map = QdrantOps::json_to_payload(payload)?; let point = PointStruct::new(point_id.clone(), vector, payload_map); - self.client - .upsert_points(UpsertPointsBuilder::new(&self.collection, vec![point])) - .await - .map_err(Box::new)?; + self.ops.upsert(&self.collection, vec![point]).await?; sqlx::query( "INSERT INTO embeddings_metadata (message_id, qdrant_point_id, dimensions, model) \ @@ -176,36 +163,37 @@ impl QdrantStore { ) -> Result, MemoryError> { let limit_u64 = u64::try_from(limit)?; - let mut builder = - SearchPointsBuilder::new(&self.collection, query_vector.to_vec(), limit_u64) - .with_payload(true); - - if let Some(ref f) = filter { + let qdrant_filter = filter.as_ref().and_then(|f| { let mut conditions = Vec::new(); - if let Some(cid) = f.conversation_id { conditions.push(Condition::matches("conversation_id", cid.0)); } if let Some(ref role) = f.role { conditions.push(Condition::matches("role", role.clone())); } - - if !conditions.is_empty() { - builder = builder.filter(Filter::must(conditions)); + if conditions.is_empty() { + None + } else { + Some(Filter::must(conditions)) } - } + }); - let results = self.client.search_points(builder).await.map_err(Box::new)?; + let results = self + .ops + .search( + &self.collection, + query_vector.to_vec(), + limit_u64, + qdrant_filter, + ) + .await?; let search_results = results - .result .into_iter() .filter_map(|point| { let payload = &point.payload; - let message_id = MessageId(payload.get("message_id")?.as_integer()?); let conversation_id = ConversationId(payload.get("conversation_id")?.as_integer()?); - Some(SearchResult { message_id, conversation_id, @@ -227,7 +215,7 @@ impl QdrantStore { name: &str, vector_size: u64, ) -> Result<(), MemoryError> { - ensure_qdrant_collection(&self.client, name, vector_size).await?; + self.ops.ensure_collection(name, vector_size).await?; Ok(()) } @@ -245,17 +233,9 @@ impl QdrantStore { vector: Vec, ) -> Result { let point_id = uuid::Uuid::new_v4().to_string(); - - let payload_map: std::collections::HashMap = - serde_json::from_value(payload)?; - + let payload_map = QdrantOps::json_to_payload(payload)?; let point = PointStruct::new(point_id.clone(), vector, payload_map); - - self.client - .upsert_points(UpsertPointsBuilder::new(collection, vec![point])) - .await - .map_err(Box::new)?; - + self.ops.upsert(collection, vec![point]).await?; Ok(point_id) } @@ -272,17 +252,11 @@ impl QdrantStore { filter: Option, ) -> Result, MemoryError> { let limit_u64 = u64::try_from(limit)?; - - let mut builder = SearchPointsBuilder::new(collection, query_vector.to_vec(), limit_u64) - .with_payload(true); - - if let Some(f) = filter { - builder = builder.filter(f); - } - - let results = self.client.search_points(builder).await.map_err(Box::new)?; - - Ok(results.result) + let results = self + .ops + .search(collection, query_vector.to_vec(), limit_u64, filter) + .await?; + Ok(results) } /// Check whether an embedding already exists for the given message ID. diff --git a/crates/zeph-memory/src/qdrant_ops.rs b/crates/zeph-memory/src/qdrant_ops.rs new file mode 100644 index 00000000..0f16983b --- /dev/null +++ b/crates/zeph-memory/src/qdrant_ops.rs @@ -0,0 +1,244 @@ +//! Low-level Qdrant operations shared across crates. + +use std::collections::HashMap; + +use qdrant_client::Qdrant; +use qdrant_client::qdrant::{ + CreateCollectionBuilder, DeletePointsBuilder, Distance, Filter, PointId, PointStruct, + PointsIdsList, ScoredPoint, ScrollPointsBuilder, SearchPointsBuilder, UpsertPointsBuilder, + VectorParamsBuilder, value::Kind, +}; + +type QdrantResult = Result>; + +/// Thin wrapper over [`Qdrant`] client encapsulating common collection operations. +#[derive(Clone)] +pub struct QdrantOps { + client: Qdrant, +} + +impl std::fmt::Debug for QdrantOps { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("QdrantOps").finish_non_exhaustive() + } +} + +impl QdrantOps { + /// Create a new `QdrantOps` connected to the given URL. + /// + /// # Errors + /// + /// Returns an error if the Qdrant client cannot be created. + pub fn new(url: &str) -> QdrantResult { + let client = Qdrant::from_url(url).build().map_err(Box::new)?; + Ok(Self { client }) + } + + /// Access the underlying Qdrant client for advanced operations. + #[must_use] + pub fn client(&self) -> &Qdrant { + &self.client + } + + /// Ensure a collection exists with cosine distance vectors. + /// + /// Idempotent: no-op if the collection already exists. + /// + /// # Errors + /// + /// Returns an error if Qdrant cannot be reached or collection creation fails. + pub async fn ensure_collection(&self, collection: &str, vector_size: u64) -> QdrantResult<()> { + if self + .client + .collection_exists(collection) + .await + .map_err(Box::new)? + { + return Ok(()); + } + self.client + .create_collection( + CreateCollectionBuilder::new(collection) + .vectors_config(VectorParamsBuilder::new(vector_size, Distance::Cosine)), + ) + .await + .map_err(Box::new)?; + Ok(()) + } + + /// Check whether a collection exists. + /// + /// # Errors + /// + /// Returns an error if Qdrant cannot be reached. + pub async fn collection_exists(&self, collection: &str) -> QdrantResult { + self.client + .collection_exists(collection) + .await + .map_err(Box::new) + } + + /// Delete a collection. + /// + /// # Errors + /// + /// Returns an error if the collection cannot be deleted. + pub async fn delete_collection(&self, collection: &str) -> QdrantResult<()> { + self.client + .delete_collection(collection) + .await + .map_err(Box::new)?; + Ok(()) + } + + /// Upsert points into a collection. + /// + /// # Errors + /// + /// Returns an error if the upsert fails. + pub async fn upsert(&self, collection: &str, points: Vec) -> QdrantResult<()> { + self.client + .upsert_points(UpsertPointsBuilder::new(collection, points)) + .await + .map_err(Box::new)?; + Ok(()) + } + + /// Search for similar vectors, returning scored points with payloads. + /// + /// # Errors + /// + /// Returns an error if the search fails. + pub async fn search( + &self, + collection: &str, + vector: Vec, + limit: u64, + filter: Option, + ) -> QdrantResult> { + let mut builder = SearchPointsBuilder::new(collection, vector, limit).with_payload(true); + if let Some(f) = filter { + builder = builder.filter(f); + } + let results = self.client.search_points(builder).await.map_err(Box::new)?; + Ok(results.result) + } + + /// Delete points by their IDs. + /// + /// # Errors + /// + /// Returns an error if the deletion fails. + pub async fn delete_by_ids(&self, collection: &str, ids: Vec) -> QdrantResult<()> { + if ids.is_empty() { + return Ok(()); + } + self.client + .delete_points(DeletePointsBuilder::new(collection).points(PointsIdsList { ids })) + .await + .map_err(Box::new)?; + Ok(()) + } + + /// Scroll all points in a collection, extracting string payload fields. + /// + /// Returns a map of `key_field` value -> { `field_name` -> `field_value` }. + /// + /// # Errors + /// + /// Returns an error if the scroll operation fails. + pub async fn scroll_all( + &self, + collection: &str, + key_field: &str, + ) -> QdrantResult>> { + let mut result = HashMap::new(); + let mut offset: Option = None; + + loop { + let mut builder = ScrollPointsBuilder::new(collection) + .with_payload(true) + .with_vectors(false) + .limit(100); + + if let Some(ref off) = offset { + builder = builder.offset(off.clone()); + } + + let response = self.client.scroll(builder).await.map_err(Box::new)?; + + for point in &response.result { + let Some(key_val) = point.payload.get(key_field) else { + continue; + }; + let Some(Kind::StringValue(key)) = &key_val.kind else { + continue; + }; + + let mut fields = HashMap::new(); + for (k, val) in &point.payload { + if let Some(Kind::StringValue(s)) = &val.kind { + fields.insert(k.clone(), s.clone()); + } + } + result.insert(key.clone(), fields); + } + + match response.next_page_offset { + Some(next) => offset = Some(next), + None => break, + } + } + + Ok(result) + } + + /// Convert a JSON value to a Qdrant payload map. + /// + /// # Errors + /// + /// Returns a JSON error if deserialization fails. + pub fn json_to_payload( + value: serde_json::Value, + ) -> Result, serde_json::Error> { + serde_json::from_value(value) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn new_valid_url() { + let ops = QdrantOps::new("http://localhost:6334"); + assert!(ops.is_ok()); + } + + #[test] + fn new_invalid_url() { + let ops = QdrantOps::new("not a valid url"); + assert!(ops.is_err()); + } + + #[test] + fn debug_format() { + let ops = QdrantOps::new("http://localhost:6334").unwrap(); + let dbg = format!("{ops:?}"); + assert!(dbg.contains("QdrantOps")); + } + + #[test] + fn json_to_payload_valid() { + let value = serde_json::json!({"key": "value", "num": 42}); + let result = QdrantOps::json_to_payload(value); + assert!(result.is_ok()); + } + + #[test] + fn json_to_payload_empty() { + let result = QdrantOps::json_to_payload(serde_json::json!({})); + assert!(result.is_ok()); + assert!(result.unwrap().is_empty()); + } +} diff --git a/crates/zeph-skills/Cargo.toml b/crates/zeph-skills/Cargo.toml index 6f5fe913..8917dc13 100644 --- a/crates/zeph-skills/Cargo.toml +++ b/crates/zeph-skills/Cargo.toml @@ -8,7 +8,7 @@ repository.workspace = true [features] default = [] -qdrant = ["dep:blake3", "dep:qdrant-client", "dep:serde_json", "dep:uuid"] +qdrant = ["dep:blake3", "dep:qdrant-client", "dep:serde_json", "dep:uuid", "dep:zeph-memory"] self-learning = [] [dependencies] @@ -22,6 +22,7 @@ tokio = { workspace = true, features = ["sync", "rt"] } tracing.workspace = true uuid = { workspace = true, optional = true, features = ["v5"] } zeph-llm.workspace = true +zeph-memory = { workspace = true, optional = true } [dev-dependencies] anyhow.workspace = true diff --git a/crates/zeph-skills/src/qdrant_matcher.rs b/crates/zeph-skills/src/qdrant_matcher.rs index aa19eaa1..2d36ec28 100644 --- a/crates/zeph-skills/src/qdrant_matcher.rs +++ b/crates/zeph-skills/src/qdrant_matcher.rs @@ -1,11 +1,7 @@ use std::collections::HashMap; -use qdrant_client::Qdrant; -use qdrant_client::qdrant::{ - CreateCollectionBuilder, DeletePointsBuilder, Distance, PointStruct, PointsIdsList, - ScrollPointsBuilder, SearchPointsBuilder, UpsertPointsBuilder, VectorParamsBuilder, - value::Kind, -}; +use qdrant_client::qdrant::{PointStruct, value::Kind}; +use zeph_memory::QdrantOps; use crate::error::SkillError; use crate::loader::SkillMeta; @@ -29,7 +25,7 @@ pub struct SyncStats { } pub struct QdrantSkillMatcher { - client: Qdrant, + ops: QdrantOps, collection: String, hashes: HashMap, } @@ -58,10 +54,10 @@ impl QdrantSkillMatcher { /// /// Returns an error if the Qdrant client cannot be created. pub fn new(qdrant_url: &str) -> Result { - let client = Qdrant::from_url(qdrant_url).build().map_err(Box::new)?; + let ops = QdrantOps::new(qdrant_url)?; Ok(Self { - client, + ops, collection: COLLECTION_NAME.into(), hashes: HashMap::new(), }) @@ -85,7 +81,7 @@ impl QdrantSkillMatcher { self.ensure_collection(&embed_fn).await?; - let existing = self.scroll_all().await?; + let existing = self.ops.scroll_all(&self.collection, "skill_name").await?; let mut current: HashMap = HashMap::with_capacity(meta.len()); for m in meta { @@ -126,14 +122,13 @@ impl QdrantSkillMatcher { }; let point_id = skill_point_id(name); - let payload: serde_json::Value = serde_json::json!({ + let payload = serde_json::json!({ "skill_name": name, "description": m.description, "content_hash": hash, "embedding_model": embedding_model, }); - let payload_map: HashMap = - serde_json::from_value(payload)?; + let payload_map = QdrantOps::json_to_payload(payload)?; points_to_upsert.push(PointStruct::new(point_id, vector, payload_map)); @@ -146,31 +141,18 @@ impl QdrantSkillMatcher { } if !points_to_upsert.is_empty() { - self.client - .upsert_points(UpsertPointsBuilder::new(&self.collection, points_to_upsert)) - .await - .map_err(Box::new)?; + self.ops.upsert(&self.collection, points_to_upsert).await?; } - let orphan_ids: Vec = existing + let orphan_ids: Vec = existing .keys() .filter(|name| !current.contains_key(*name)) - .map(|name| skill_point_id(name)) + .map(|name| qdrant_client::qdrant::PointId::from(skill_point_id(name).as_str())) .collect(); if !orphan_ids.is_empty() { stats.removed = orphan_ids.len(); - let point_ids: Vec = orphan_ids - .into_iter() - .map(|id| qdrant_client::qdrant::PointId::from(id.as_str())) - .collect(); - self.client - .delete_points( - DeletePointsBuilder::new(&self.collection) - .points(PointsIdsList { ids: point_ids }), - ) - .await - .map_err(Box::new)?; + self.ops.delete_by_ids(&self.collection, orphan_ids).await?; } tracing::info!( @@ -209,10 +191,8 @@ impl QdrantSkillMatcher { }; let results = match self - .client - .search_points( - SearchPointsBuilder::new(&self.collection, query_vec, limit_u64).with_payload(true), - ) + .ops + .search(&self.collection, query_vec, limit_u64, None) .await { Ok(r) => r, @@ -223,7 +203,6 @@ impl QdrantSkillMatcher { }; results - .result .into_iter() .filter_map(|point| { let name = point.payload.get("skill_name")?; @@ -240,16 +219,8 @@ impl QdrantSkillMatcher { where F: Fn(&str) -> EmbedFuture, { - if self - .client - .collection_exists(&self.collection) - .await - .map_err(Box::new)? - { - self.client - .delete_collection(&self.collection) - .await - .map_err(Box::new)?; + if self.ops.collection_exists(&self.collection).await? { + self.ops.delete_collection(&self.collection).await?; tracing::info!( collection = &self.collection, "deleted collection for recreation" @@ -262,12 +233,7 @@ impl QdrantSkillMatcher { where F: Fn(&str) -> EmbedFuture, { - if self - .client - .collection_exists(&self.collection) - .await - .map_err(Box::new)? - { + if self.ops.collection_exists(&self.collection).await? { return Ok(()); } @@ -276,13 +242,9 @@ impl QdrantSkillMatcher { .map_err(|e| SkillError::Other(format!("failed to probe embedding dimensions: {e}")))?; let vector_size = u64::try_from(probe.len())?; - self.client - .create_collection( - CreateCollectionBuilder::new(&self.collection) - .vectors_config(VectorParamsBuilder::new(vector_size, Distance::Cosine)), - ) - .await - .map_err(Box::new)?; + self.ops + .ensure_collection(&self.collection, vector_size) + .await?; tracing::info!( collection = &self.collection, @@ -292,48 +254,6 @@ impl QdrantSkillMatcher { Ok(()) } - - async fn scroll_all(&self) -> Result>, SkillError> { - let mut result = HashMap::new(); - let mut offset: Option = None; - - loop { - let mut builder = ScrollPointsBuilder::new(&self.collection) - .with_payload(true) - .with_vectors(false) - .limit(100); - - if let Some(ref off) = offset { - builder = builder.offset(off.clone()); - } - - let response = self.client.scroll(builder).await.map_err(Box::new)?; - - for point in &response.result { - let Some(name_val) = point.payload.get("skill_name") else { - continue; - }; - let Some(Kind::StringValue(name)) = &name_val.kind else { - continue; - }; - - let mut fields = HashMap::new(); - for (key, val) in &point.payload { - if let Some(Kind::StringValue(s)) = &val.kind { - fields.insert(key.clone(), s.clone()); - } - } - result.insert(name.clone(), fields); - } - - match response.next_page_offset { - Some(next) => offset = Some(next), - None => break, - } - } - - Ok(result) - } } #[cfg(test)]