Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ gateway = ["dep:zeph-gateway"]
daemon = ["zeph-core/daemon"]
scheduler = ["dep:zeph-scheduler"]
otel = ["dep:opentelemetry", "dep:opentelemetry_sdk", "dep:opentelemetry-otlp", "dep:tracing-opentelemetry"]
mock = ["zeph-llm/mock"]
mock = ["zeph-llm/mock", "zeph-memory/mock"]

[dependencies]
anyhow.workspace = true
Expand Down
4 changes: 4 additions & 0 deletions crates/zeph-memory/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@ zeph-llm.workspace = true
name = "token_estimation"
harness = false

[features]
default = []
mock = []

[dev-dependencies]
anyhow.workspace = true
criterion.workspace = true
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
pub use qdrant_client::qdrant::Filter;
use qdrant_client::qdrant::{Condition, PointStruct};
use sqlx::SqlitePool;

use crate::error::MemoryError;
use crate::qdrant_ops::QdrantOps;
use crate::types::{ConversationId, MessageId};
use crate::vector_store::{FieldCondition, FieldValue, VectorFilter, VectorPoint, VectorStore};

/// Distinguishes regular messages from summaries when storing embeddings.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
Expand Down Expand Up @@ -37,15 +37,15 @@ pub async fn ensure_qdrant_collection(
ops.ensure_collection(collection, vector_size).await
}

pub struct QdrantStore {
ops: QdrantOps,
pub struct EmbeddingStore {
ops: Box<dyn VectorStore>,
collection: String,
pool: SqlitePool,
}

impl std::fmt::Debug for QdrantStore {
impl std::fmt::Debug for EmbeddingStore {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("QdrantStore")
f.debug_struct("EmbeddingStore")
.field("collection", &self.collection)
.finish_non_exhaustive()
}
Expand All @@ -64,8 +64,8 @@ pub struct SearchResult {
pub score: f32,
}

impl QdrantStore {
/// Create a new `QdrantStore` connected to the given Qdrant URL.
impl EmbeddingStore {
/// Create a new `EmbeddingStore` connected to the given Qdrant URL.
///
/// The `pool` is used for `SQLite` metadata operations on the `embeddings_metadata`
/// table (which must already exist via sqlx migrations).
Expand All @@ -77,16 +77,19 @@ impl QdrantStore {
let ops = QdrantOps::new(url).map_err(MemoryError::Qdrant)?;

Ok(Self {
ops,
ops: Box::new(ops),
collection: COLLECTION_NAME.into(),
pool,
})
}

/// Access the underlying `QdrantOps`.
#[must_use]
pub fn ops(&self) -> &QdrantOps {
&self.ops
pub fn with_store(store: Box<dyn VectorStore>, pool: SqlitePool) -> Self {
Self {
ops: store,
collection: COLLECTION_NAME.into(),
pool,
}
}

/// Ensure the collection exists in Qdrant with the given vector size.
Expand Down Expand Up @@ -122,15 +125,24 @@ impl QdrantStore {
let point_id = uuid::Uuid::new_v4().to_string();
let dimensions = i64::try_from(vector.len())?;

let payload = serde_json::json!({
"message_id": message_id.0,
"conversation_id": conversation_id.0,
"role": role,
"is_summary": kind.is_summary(),
});
let payload_map = QdrantOps::json_to_payload(payload)?;

let point = PointStruct::new(point_id.clone(), vector, payload_map);
let payload = std::collections::HashMap::from([
("message_id".to_owned(), serde_json::json!(message_id.0)),
(
"conversation_id".to_owned(),
serde_json::json!(conversation_id.0),
),
("role".to_owned(), serde_json::json!(role)),
(
"is_summary".to_owned(),
serde_json::json!(kind.is_summary()),
),
]);

let point = VectorPoint {
id: point_id.clone(),
vector,
payload,
};

self.ops.upsert(&self.collection, vec![point]).await?;

Expand Down Expand Up @@ -163,18 +175,27 @@ impl QdrantStore {
) -> Result<Vec<SearchResult>, MemoryError> {
let limit_u64 = u64::try_from(limit)?;

let qdrant_filter = filter.as_ref().and_then(|f| {
let mut conditions = Vec::new();
let vector_filter = filter.as_ref().and_then(|f| {
let mut must = Vec::new();
if let Some(cid) = f.conversation_id {
conditions.push(Condition::matches("conversation_id", cid.0));
must.push(FieldCondition {
field: "conversation_id".into(),
value: FieldValue::Integer(cid.0),
});
}
if let Some(ref role) = f.role {
conditions.push(Condition::matches("role", role.clone()));
must.push(FieldCondition {
field: "role".into(),
value: FieldValue::Text(role.clone()),
});
}
if conditions.is_empty() {
if must.is_empty() {
None
} else {
Some(Filter::must(conditions))
Some(VectorFilter {
must,
must_not: vec![],
})
}
});

Expand All @@ -184,16 +205,16 @@ impl QdrantStore {
&self.collection,
query_vector.to_vec(),
limit_u64,
qdrant_filter,
vector_filter,
)
.await?;

let search_results = results
.into_iter()
.filter_map(|point| {
let payload = &point.payload;
let message_id = MessageId(payload.get("message_id")?.as_integer()?);
let conversation_id = ConversationId(payload.get("conversation_id")?.as_integer()?);
let message_id = MessageId(point.payload.get("message_id")?.as_i64()?);
let conversation_id =
ConversationId(point.payload.get("conversation_id")?.as_i64()?);
Some(SearchResult {
message_id,
conversation_id,
Expand Down Expand Up @@ -233,8 +254,13 @@ impl QdrantStore {
vector: Vec<f32>,
) -> Result<String, MemoryError> {
let point_id = uuid::Uuid::new_v4().to_string();
let payload_map = QdrantOps::json_to_payload(payload)?;
let point = PointStruct::new(point_id.clone(), vector, payload_map);
let payload_map: std::collections::HashMap<String, serde_json::Value> =
serde_json::from_value(payload)?;
let point = VectorPoint {
id: point_id.clone(),
vector,
payload: payload_map,
};
self.ops.upsert(collection, vec![point]).await?;
Ok(point_id)
}
Expand All @@ -249,8 +275,8 @@ impl QdrantStore {
collection: &str,
query_vector: &[f32],
limit: usize,
filter: Option<Filter>,
) -> Result<Vec<qdrant_client::qdrant::ScoredPoint>, MemoryError> {
filter: Option<VectorFilter>,
) -> Result<Vec<crate::ScoredVectorPoint>, MemoryError> {
let limit_u64 = u64::try_from(limit)?;
let results = self
.ops
Expand Down
3 changes: 3 additions & 0 deletions crates/zeph-memory/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@ pub enum MemoryError {
#[error("Qdrant error: {0}")]
Qdrant(#[from] Box<qdrant_client::QdrantError>),

#[error("vector store error: {0}")]
VectorStore(#[from] crate::vector_store::VectorStoreError),

#[error("migration failed: {0}")]
Migration(#[from] sqlx::migrate::MigrateError),

Expand Down
Loading
Loading