From 2977ad44437ed0ddf75327c9d302f2409cc8f28c Mon Sep 17 00:00:00 2001 From: shamb0 Date: Thu, 21 Nov 2024 22:16:47 +0530 Subject: [PATCH] test bringup Signed-off-by: shamb0 --- Cargo.lock | 212 ------------------ Cargo.toml | 5 +- examples/index_md_into_pgvector.rs | 5 +- swiftide-core/src/query.rs | 7 +- swiftide-core/src/type_aliases.rs | 18 ++ .../src/pgvector/fixtures.rs | 2 + swiftide-integrations/src/pgvector/mod.rs | 187 ++++++++++++++- swiftide-integrations/src/pgvector/persist.rs | 66 ------ .../src/pgvector/pgv_table_types.rs | 25 ++- .../src/pgvector/retrieve.rs | 95 +++----- swiftide-test-utils/Cargo.toml | 2 - swiftide-test-utils/src/test_utils.rs | 1 - 12 files changed, 264 insertions(+), 361 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index dda904f0..93fc6abf 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -6486,15 +6486,6 @@ dependencies = [ "portable-atomic", ] -[[package]] -name = "portpicker" -version = "0.1.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "be97d76faf1bfab666e1375477b23fde79eccf0276e9b63b92a39d676a889ba9" -dependencies = [ - "rand", -] - [[package]] name = "powerfmt" version = "0.2.0" @@ -8244,208 +8235,6 @@ dependencies = [ "uuid", ] -[[package]] -name = "sqlx" -version = "0.8.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "93334716a037193fac19df402f8571269c84a00852f6a7066b5d2616dcd64d3e" -dependencies = [ - "sqlx-core", - "sqlx-macros", - "sqlx-mysql", - "sqlx-postgres", - "sqlx-sqlite", -] - -[[package]] -name = "sqlx-core" -version = "0.8.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d4d8060b456358185f7d50c55d9b5066ad956956fddec42ee2e8567134a8936e" -dependencies = [ - "atoi", - "byteorder", - "bytes", - "chrono", - "crc", - "crossbeam-queue", - "either", - "event-listener 5.3.1", - "futures-channel", - "futures-core", - "futures-intrusive", - "futures-io", - "futures-util", - "hashbrown 0.14.5", - "hashlink", - "hex", - "indexmap 2.6.0", - "log", - "memchr", - "once_cell", - "paste", - "percent-encoding", - "serde", - "serde_json", - "sha2", - "smallvec", - "sqlformat", - "thiserror", - "tokio", - "tokio-stream", - "tracing", - "url", - "uuid", -] - -[[package]] -name = "sqlx-macros" -version = "0.8.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cac0692bcc9de3b073e8d747391827297e075c7710ff6276d9f7a1f3d58c6657" -dependencies = [ - "proc-macro2", - "quote", - "sqlx-core", - "sqlx-macros-core", - "syn 2.0.79", -] - -[[package]] -name = "sqlx-macros-core" -version = "0.8.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1804e8a7c7865599c9c79be146dc8a9fd8cc86935fa641d3ea58e5f0688abaa5" -dependencies = [ - "dotenvy", - "either", - "heck 0.5.0", - "hex", - "once_cell", - "proc-macro2", - "quote", - "serde", - "serde_json", - "sha2", - "sqlx-core", - "sqlx-mysql", - "sqlx-postgres", - "sqlx-sqlite", - "syn 2.0.79", - "tempfile", - "tokio", - "url", -] - -[[package]] -name = "sqlx-mysql" -version = "0.8.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "64bb4714269afa44aef2755150a0fc19d756fb580a67db8885608cf02f47d06a" -dependencies = [ - "atoi", - "base64 0.22.1", - "bitflags 2.6.0", - "byteorder", - "bytes", - "chrono", - "crc", - "digest", - "dotenvy", - "either", - "futures-channel", - "futures-core", - "futures-io", - "futures-util", - "generic-array", - "hex", - "hkdf", - "hmac", - "itoa", - "log", - "md-5", - "memchr", - "once_cell", - "percent-encoding", - "rand", - "rsa", - "serde", - "sha1", - "sha2", - "smallvec", - "sqlx-core", - "stringprep", - "thiserror", - "tracing", - "uuid", - "whoami", -] - -[[package]] -name = "sqlx-postgres" -version = "0.8.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6fa91a732d854c5d7726349bb4bb879bb9478993ceb764247660aee25f67c2f8" -dependencies = [ - "atoi", - "base64 0.22.1", - "bitflags 2.6.0", - "byteorder", - "chrono", - "crc", - "dotenvy", - "etcetera", - "futures-channel", - "futures-core", - "futures-io", - "futures-util", - "hex", - "hkdf", - "hmac", - "home", - "itoa", - "log", - "md-5", - "memchr", - "once_cell", - "rand", - "serde", - "serde_json", - "sha2", - "smallvec", - "sqlx-core", - "stringprep", - "thiserror", - "tracing", - "uuid", - "whoami", -] - -[[package]] -name = "sqlx-sqlite" -version = "0.8.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d5b2cf34a45953bfd3daaf3db0f7a7878ab9b7a6b91b422d24a7a9e4c857b680" -dependencies = [ - "atoi", - "chrono", - "flume", - "futures-channel", - "futures-core", - "futures-executor", - "futures-intrusive", - "futures-util", - "libsqlite3-sys", - "log", - "percent-encoding", - "serde", - "serde_urlencoded", - "sqlx-core", - "tracing", - "url", - "uuid", -] - [[package]] name = "stable_deref_trait" version = "1.2.0" @@ -8771,7 +8560,6 @@ dependencies = [ "anyhow", "async-openai", "mockall", - "portpicker", "qdrant-client", "serde", "serde_json", diff --git a/Cargo.toml b/Cargo.toml index 46051a73..0086e332 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -53,10 +53,7 @@ arrow-array = { version = "52.2", default-features = false } arrow = { version = "52.2", default-features = false } parquet = { version = "52.2", default-features = false, features = ["async"] } redb = { version = "2.2" } -sqlx = { version = "0.8.2", features = [ - "postgres", - "uuid", -], default-features = false } +sqlx = { version = "0.8.2", features = ["postgres", "uuid"] } aws-config = "1.5" pgvector = { version = "0.4.0", features = ["sqlx"], default-features = false } aws-credential-types = "1.2" diff --git a/examples/index_md_into_pgvector.rs b/examples/index_md_into_pgvector.rs index 26e1ff0e..9c538d0e 100644 --- a/examples/index_md_into_pgvector.rs +++ b/examples/index_md_into_pgvector.rs @@ -93,6 +93,7 @@ async fn main() -> Result<(), Box> { } tracing::info!("Starting indexing pipeline"); + indexing::Pipeline::from_loader(FileLoader::new(test_dataset_path).with_extensions(&["md"])) .then_chunk(ChunkMarkdown::from_chunk_range(10..2048)) .then(MetadataQAText::new(llm_client.clone())) @@ -101,7 +102,8 @@ async fn main() -> Result<(), Box> { .run() .await?; - tracing::info!("PgVector Indexing test completed successfully"); + tracing::info!("PgVector Indexing completed successfully"); + for (i, question) in [ "What is SwiftIDE? Provide a clear, comprehensive summary in under 50 words.", "How can I use SwiftIDE to connect with the Ethereum blockchain? Please provide a concise, comprehensive summary in less than 50 words.", @@ -121,5 +123,6 @@ async fn main() -> Result<(), Box> { } tracing::info!("PgVector Indexing & retrieval test completed successfully"); + Ok(()) } diff --git a/swiftide-core/src/query.rs b/swiftide-core/src/query.rs index 5e2e0b9a..d9c8fce2 100644 --- a/swiftide-core/src/query.rs +++ b/swiftide-core/src/query.rs @@ -7,7 +7,7 @@ //! `states::Answered`: The query has been answered use derive_builder::Builder; -use crate::{util::debug_long_utf8, Embedding, SparseEmbedding}; +use crate::{util::debug_long_utf8, AdvanceEmbedding, Embedding, SparseEmbedding}; type Document = String; @@ -34,6 +34,9 @@ pub struct Query { #[builder(default)] pub sparse_embedding: Option, + + #[builder(default)] + pub adv_embedding: Option, } impl std::fmt::Debug for Query { @@ -44,6 +47,7 @@ impl std::fmt::Debug for Query { .field("state", &self.state) .field("transformation_history", &self.transformation_history) .field("embedding", &self.embedding.is_some()) + .field("adv_embedding", &self.adv_embedding.is_some()) .finish() } } @@ -71,6 +75,7 @@ impl Query { transformation_history: self.transformation_history, embedding: self.embedding, sparse_embedding: self.sparse_embedding, + adv_embedding: self.adv_embedding, } } diff --git a/swiftide-core/src/type_aliases.rs b/swiftide-core/src/type_aliases.rs index 197c56b3..6c97bc6e 100644 --- a/swiftide-core/src/type_aliases.rs +++ b/swiftide-core/src/type_aliases.rs @@ -1,5 +1,6 @@ #![cfg_attr(coverage_nightly, coverage(off))] +use crate::indexing::EmbeddedField; use serde::{Deserialize, Serialize}; pub type Embedding = Vec; @@ -20,3 +21,20 @@ impl std::fmt::Debug for SparseEmbedding { .finish() } } + +#[derive(Serialize, Deserialize, Clone, PartialEq)] +pub struct AdvanceEmbedding { + pub embedded_field: EmbeddedField, + pub field_value: Vec, +} +pub type AdvanceEmbeddings = Vec; + +impl std::fmt::Debug for AdvanceEmbedding { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + // Start the debug struct formatting + f.debug_struct("AdvanceEmbedding") + .field("embedded_field", &self.embedded_field) + .field("field_value", &self.field_value) + .finish() + } +} diff --git a/swiftide-integrations/src/pgvector/fixtures.rs b/swiftide-integrations/src/pgvector/fixtures.rs index 6508893a..9307c14d 100644 --- a/swiftide-integrations/src/pgvector/fixtures.rs +++ b/swiftide-integrations/src/pgvector/fixtures.rs @@ -77,6 +77,8 @@ pub(crate) struct PgVectorTestData<'a> { pub metadata: Option, /// Vector embeddings with their corresponding fields pub vectors: Vec<(indexing::EmbeddedField, Vec)>, + pub expected_in_results: bool, + pub use_adv_embedding_query: bool, } impl<'a> PgVectorTestData<'a> { diff --git a/swiftide-integrations/src/pgvector/mod.rs b/swiftide-integrations/src/pgvector/mod.rs index d9849293..f0eba79b 100644 --- a/swiftide-integrations/src/pgvector/mod.rs +++ b/swiftide-integrations/src/pgvector/mod.rs @@ -27,6 +27,7 @@ mod fixtures; mod persist; mod pgv_table_types; +mod retrieve; use anyhow::Result; use derive_builder::Builder; use sqlx::PgPool; @@ -35,7 +36,7 @@ use std::sync::Arc; use std::sync::OnceLock; use tokio::time::Duration; -use pgv_table_types::{FieldConfig, MetadataConfig, PgDBConnectionPool, VectorConfig}; +use pgv_table_types::{FieldConfig, MetadataConfig, VectorConfig}; /// Default maximum connections for the database connection pool. const DB_POOL_CONN_MAX: u32 = 10; @@ -188,10 +189,134 @@ mod tests { use std::collections::HashSet; use swiftide_core::{ indexing::{self, EmbedMode, EmbeddedField}, - Persist, + querying::{search_strategies::SimilaritySingleEmbedding, states, Query}, + AdvanceEmbedding, Persist, Retrieve, }; use test_case::test_case; + #[test_log::test(tokio::test)] + async fn test_metadata_filter_with_vector_search() { + let test_context = TestContext::setup_with_cfg( + vec!["category", "priority"].into(), + HashSet::from([EmbeddedField::Combined]), + ) + .await + .expect("Test setup failed"); + + // Create nodes with different metadata and vectors + let nodes = vec![ + indexing::Node::new("content1") + .with_vectors([(EmbeddedField::Combined, vec![1.0; 384])]) + .with_metadata(vec![("category", "A"), ("priority", "1")]), + indexing::Node::new("content2") + .with_vectors([(EmbeddedField::Combined, vec![1.1; 384])]) + .with_metadata(vec![("category", "A"), ("priority", "2")]), + indexing::Node::new("content3") + .with_vectors([(EmbeddedField::Combined, vec![1.2; 384])]) + .with_metadata(vec![("category", "B"), ("priority", "1")]), + ] + .into_iter() + .map(|node| node.to_owned()) + .collect(); + + // Store all nodes + test_context + .pgv_storage + .batch_store(nodes) + .await + .try_collect::>() + .await + .unwrap(); + + // Test combined metadata and vector search + let mut query = Query::::new("test_query"); + query.embedding = Some(vec![1.0; 384]); + + let search_strategy = + SimilaritySingleEmbedding::from_filter("category = \"A\"".to_string()); + + let result = test_context + .pgv_storage + .retrieve(&search_strategy, query.clone()) + .await + .unwrap(); + + assert_eq!(result.documents().len(), 2); + + assert!(result.documents().contains(&"content1".to_string())); + assert!(result.documents().contains(&"content2".to_string())); + + // Additional test with priority filter + let search_strategy = + SimilaritySingleEmbedding::from_filter("priority = \"1\"".to_string()); + let result = test_context + .pgv_storage + .retrieve(&search_strategy, query) + .await + .unwrap(); + + assert_eq!(result.documents().len(), 2); + assert!(result.documents().contains(&"content1".to_string())); + assert!(result.documents().contains(&"content3".to_string())); + } + + #[test_log::test(tokio::test)] + async fn test_vector_similarity_search_accuracy() { + let test_context = TestContext::setup_with_cfg( + vec!["category", "priority"].into(), + HashSet::from([EmbeddedField::Combined]), + ) + .await + .expect("Test setup failed"); + + // Create nodes with known vector relationships + let base_vector = vec![1.0; 384]; + let similar_vector = base_vector.iter().map(|x| x + 0.1).collect::>(); + let dissimilar_vector = vec![-1.0; 384]; + + let nodes = vec![ + indexing::Node::new("base_content") + .with_vectors([(EmbeddedField::Combined, base_vector)]) + .with_metadata(vec![("category", "A"), ("priority", "1")]), + indexing::Node::new("similar_content") + .with_vectors([(EmbeddedField::Combined, similar_vector)]) + .with_metadata(vec![("category", "A"), ("priority", "2")]), + indexing::Node::new("dissimilar_content") + .with_vectors([(EmbeddedField::Combined, dissimilar_vector)]) + .with_metadata(vec![("category", "B"), ("priority", "1")]), + ] + .into_iter() + .map(|node| node.to_owned()) + .collect(); + + // Store all nodes + test_context + .pgv_storage + .batch_store(nodes) + .await + .try_collect::>() + .await + .unwrap(); + + // Search with base vector + let mut query = Query::::new("test_query"); + query.embedding = Some(vec![1.0; 384]); + + let mut search_strategy = SimilaritySingleEmbedding::<()>::default(); + search_strategy.with_top_k(2); + + let result = test_context + .pgv_storage + .retrieve(&search_strategy, query) + .await + .unwrap(); + + // Verify that similar vectors are retrieved first + assert_eq!(result.documents().len(), 2); + assert!(result.documents().contains(&"base_content".to_string())); + assert!(result.documents().contains(&"similar_content".to_string())); + } + #[test_case( // SingleWithMetadata - No Metadata vec![ @@ -200,12 +325,16 @@ mod tests { chunk: "single_no_meta_1", metadata: None, vectors: vec![PgVectorTestData::create_test_vector(EmbeddedField::Combined, 1.0)], + expected_in_results: true, + use_adv_embedding_query: false, }, PgVectorTestData { embed_mode: EmbedMode::SingleWithMetadata, chunk: "single_no_meta_2", metadata: None, vectors: vec![PgVectorTestData::create_test_vector(EmbeddedField::Combined, 1.1)], + expected_in_results: true, + use_adv_embedding_query: false, } ], HashSet::from([EmbeddedField::Combined]) @@ -221,6 +350,8 @@ mod tests { ("priority", "high") ].into()), vectors: vec![PgVectorTestData::create_test_vector(EmbeddedField::Combined, 1.2)], + expected_in_results: true, + use_adv_embedding_query: false, }, PgVectorTestData { embed_mode: EmbedMode::SingleWithMetadata, @@ -230,6 +361,8 @@ mod tests { ("priority", "low") ].into()), vectors: vec![PgVectorTestData::create_test_vector(EmbeddedField::Combined, 1.3)], + expected_in_results: true, + use_adv_embedding_query: false, } ], HashSet::from([EmbeddedField::Combined]) @@ -246,6 +379,8 @@ mod tests { PgVectorTestData::create_test_vector(EmbeddedField::Metadata("category".into()), 2.2), PgVectorTestData::create_test_vector(EmbeddedField::Metadata("priority".into()), 3.2), ], + expected_in_results: true, + use_adv_embedding_query: true, }, PgVectorTestData { embed_mode: EmbedMode::PerField, @@ -256,6 +391,8 @@ mod tests { PgVectorTestData::create_test_vector(EmbeddedField::Metadata("category".into()), 2.3), PgVectorTestData::create_test_vector(EmbeddedField::Metadata("priority".into()), 3.3), ], + expected_in_results: true, + use_adv_embedding_query: true, } ], HashSet::from([ @@ -279,6 +416,8 @@ mod tests { PgVectorTestData::create_test_vector(EmbeddedField::Metadata("category".into()), 2.2), PgVectorTestData::create_test_vector(EmbeddedField::Metadata("priority".into()), 3.2), ], + expected_in_results: true, + use_adv_embedding_query: true, }, PgVectorTestData { embed_mode: EmbedMode::PerField, @@ -292,6 +431,8 @@ mod tests { PgVectorTestData::create_test_vector(EmbeddedField::Metadata("category".into()), 2.3), PgVectorTestData::create_test_vector(EmbeddedField::Metadata("priority".into()), 3.3), ], + expected_in_results: true, + use_adv_embedding_query: true, } ], HashSet::from([ @@ -311,6 +452,8 @@ mod tests { PgVectorTestData::create_test_vector(EmbeddedField::Combined, 3.0), PgVectorTestData::create_test_vector(EmbeddedField::Chunk, 3.1) ], + expected_in_results: true, + use_adv_embedding_query: true, }, PgVectorTestData { embed_mode: EmbedMode::Both, @@ -320,6 +463,8 @@ mod tests { PgVectorTestData::create_test_vector(EmbeddedField::Combined, 3.2), PgVectorTestData::create_test_vector(EmbeddedField::Chunk, 3.3) ], + expected_in_results: true, + use_adv_embedding_query: true, } ], HashSet::from([EmbeddedField::Combined, EmbeddedField::Chunk]) @@ -342,6 +487,8 @@ mod tests { PgVectorTestData::create_test_vector(EmbeddedField::Metadata("priority".into()), 3.7), PgVectorTestData::create_test_vector(EmbeddedField::Metadata("tag".into()), 3.8) ], + expected_in_results: true, + use_adv_embedding_query: true, }, PgVectorTestData { embed_mode: EmbedMode::Both, @@ -358,6 +505,8 @@ mod tests { PgVectorTestData::create_test_vector(EmbeddedField::Metadata("priority".into()), 4.2), PgVectorTestData::create_test_vector(EmbeddedField::Metadata("tag".into()), 4.3) ], + expected_in_results: true, + use_adv_embedding_query: true, } ], HashSet::from([ @@ -405,7 +554,7 @@ mod tests { "All nodes should be stored" ); - // Verify storage for each test case + // Verify storage and retrieval for each test case for (test_case, stored_node) in test_cases.iter().zip(stored_nodes.iter()) { // 1. Verify basic node properties assert_eq!( @@ -427,6 +576,38 @@ mod tests { test_case.vectors.len(), "Vector count should match" ); + + // 3. Test vector similarity search + for (index, (field, vector)) in test_case.vectors.iter().enumerate() { + tracing::warn!("Enter :: {:#?}!", index); + let mut query = Query::::new("test_query"); + + if test_case.use_adv_embedding_query { + query.adv_embedding = Some(AdvanceEmbedding { + embedded_field: field.clone(), + field_value: vector.clone(), + }); + } else { + query.embedding = Some(vector.clone()); + } + + let mut search_strategy = SimilaritySingleEmbedding::<()>::default(); + search_strategy.with_top_k(nodes.len() as u64); + + let result = test_context + .pgv_storage + .retrieve(&search_strategy, query) + .await + .expect("Retrieval should succeed"); + + if test_case.expected_in_results { + assert!( + result.documents().contains(&test_case.chunk.to_string()), + "Document should be found in results for field {field}", + ); + } + tracing::warn!("Exit :: {:#?}!", index); + } } } } diff --git a/swiftide-integrations/src/pgvector/persist.rs b/swiftide-integrations/src/pgvector/persist.rs index 2f517fe4..6b9973ae 100644 --- a/swiftide-integrations/src/pgvector/persist.rs +++ b/swiftide-integrations/src/pgvector/persist.rs @@ -91,69 +91,3 @@ mod tests { .expect("PgVector setup should not fail when the table already exists"); } } - -#[cfg(test)] -mod tests { - use crate::pgvector::PgVector; - use swiftide_core::{indexing::EmbeddedField, Persist}; - use temp_dir::TempDir; - use testcontainers::{ContainerAsync, GenericImage}; - - struct TestContext { - pgv_storage: PgVector, - _temp_dir: TempDir, - _pgv_db_container: ContainerAsync, - } - - impl TestContext { - /// Set up the test context, initializing `PostgreSQL` and `PgVector` storage - async fn setup() -> Result> { - // Start PostgreSQL container and obtain the connection URL - let (pgv_db_container, pgv_db_url, temp_dir) = - swiftide_test_utils::start_postgres().await; - - tracing::info!("Postgres database URL: {:#?}", pgv_db_url); - - // Configure and build PgVector storage - let pgv_storage = PgVector::builder() - .try_connect_to_pool(pgv_db_url, Some(10)) - .await - .map_err(|err| { - tracing::error!("Failed to connect to Postgres server: {}", err); - err - })? - .vector_size(384) - .with_vector(EmbeddedField::Combined) - .with_metadata("filter") - .table_name("swiftide_pgvector_test".to_string()) - .build() - .map_err(|err| { - tracing::error!("Failed to build PgVector: {}", err); - err - })?; - - // Set up PgVector storage (create the table if not exists) - pgv_storage.setup().await.map_err(|err| { - tracing::error!("PgVector setup failed: {}", err); - err - })?; - - Ok(Self { - pgv_storage, - _temp_dir: temp_dir, - _pgv_db_container: pgv_db_container, - }) - } - } - - #[test_log::test(tokio::test)] - async fn test_persist_setup_no_error_when_table_exists() { - let test_context = TestContext::setup().await.expect("Test setup failed"); - - test_context - .pgv_storage - .setup() - .await - .expect("PgVector setup should not fail when the table already exists"); - } -} diff --git a/swiftide-integrations/src/pgvector/pgv_table_types.rs b/swiftide-integrations/src/pgvector/pgv_table_types.rs index 92da4dee..a66acb5b 100644 --- a/swiftide-integrations/src/pgvector/pgv_table_types.rs +++ b/swiftide-integrations/src/pgvector/pgv_table_types.rs @@ -7,12 +7,13 @@ //! - Bulk data preparation and SQL query generation //! use crate::pgvector::PgVector; -use anyhow::{anyhow, Context, Result}; +use anyhow::{anyhow, Result}; use pgvector as ExtPgVector; use regex::Regex; use sqlx::postgres::PgArguments; use sqlx::postgres::PgPoolOptions; use sqlx::PgPool; +use std::collections::BTreeMap; use swiftide_core::indexing::{EmbeddedField, Node}; use tokio::time::sleep; @@ -23,7 +24,7 @@ use tokio::time::sleep; #[derive(Clone, Debug)] pub struct VectorConfig { embedded_field: EmbeddedField, - field: String, + pub(crate) field: String, } impl VectorConfig { @@ -75,7 +76,7 @@ impl> From for MetadataConfig { /// Represents different field types that can be configured in the table schema, /// including vector embeddings, metadata, and system fields. #[derive(Clone, Debug)] -pub enum FieldConfig { +pub(crate) enum FieldConfig { /// `Vector` - Vector embedding field configuration Vector(VectorConfig), /// `Metadata` - Metadata field configuration @@ -185,7 +186,9 @@ impl PgVector { FieldConfig::ID => "id UUID NOT NULL".to_string(), FieldConfig::Chunk => format!("{} TEXT NOT NULL", field.field_name()), FieldConfig::Metadata(_) => format!("{} JSONB", field.field_name()), - FieldConfig::Vector(_) => format!("{} VECTOR({})", field.field_name(), vector_size), + FieldConfig::Vector(_) => { + format!("{} VECTOR({})", field.field_name(), self.vector_size) + } }) .chain(std::iter::once("PRIMARY KEY (id)".to_string())) .collect(); @@ -196,6 +199,8 @@ impl PgVector { columns.join(",\n ") ); + tracing::info!("Sql statement :: {:#?}", sql); + Ok(sql) } @@ -265,11 +270,6 @@ impl PgVector { .await .map_err(|e| anyhow!("Failed to store nodes: {:?}", e))?; - query.execute(&mut *tx).await.map_err(|e| { - tracing::error!("Failed to store nodes: {:?}", e); - anyhow!("Failed to store nodes: {:?}", e) - })?; - tx.commit() .await .map_err(|e| anyhow!("Failed to commit transaction: {:?}", e)) @@ -296,7 +296,10 @@ impl PgVector { .get(&config.original_field) .ok_or_else(|| anyhow!("Missing metadata field"))?; - bulk_data.metadata_fields[idx].push(value.clone()); + let mut metadata_map = BTreeMap::new(); + metadata_map.insert(config.original_field.clone(), value.clone()); + + bulk_data.metadata_fields[idx].push(serde_json::to_value(metadata_map)?); } FieldConfig::Vector(config) => { let idx = bulk_data @@ -560,6 +563,7 @@ mod tests { use super::*; #[test] + #[ignore] fn test_valid_identifiers() { assert!(PgVector::is_valid_identifier("valid_name")); assert!(PgVector::is_valid_identifier("_valid_name")); @@ -568,6 +572,7 @@ mod tests { } #[test] + #[ignore] fn test_invalid_identifiers() { assert!(!PgVector::is_valid_identifier("")); // Empty string assert!(!PgVector::is_valid_identifier(&"a".repeat(64))); // Too long diff --git a/swiftide-integrations/src/pgvector/retrieve.rs b/swiftide-integrations/src/pgvector/retrieve.rs index 811988cb..166085cc 100644 --- a/swiftide-integrations/src/pgvector/retrieve.rs +++ b/swiftide-integrations/src/pgvector/retrieve.rs @@ -1,4 +1,4 @@ -use crate::pgvector::{PgVector, PgVectorBuilder}; +use crate::pgvector::{pgv_table_types::VectorConfig, PgVector, PgVectorBuilder}; use anyhow::{anyhow, Result}; use async_trait::async_trait; use pgvector::Vector; @@ -24,20 +24,36 @@ impl Retrieve> for PgVector { search_strategy: &SimilaritySingleEmbedding, query_state: Query, ) -> Result> { - let embedding = query_state - .embedding - .as_ref() - .ok_or_else(|| anyhow!("No embedding for query"))?; - let embedding = Vector::from(embedding.clone()); + let (vector_column_name, embedding) = match ( + query_state.embedding.as_ref(), + query_state.adv_embedding.as_ref(), + ) { + (Some(embed), None) => { + let vector_column_name = self.get_vector_column_name()?; + let embedding = Vector::from(embed.clone()); + (vector_column_name, embedding) + } + (None, Some(adv_embed)) => { + let vector_column_name = VectorConfig::from(adv_embed.embedded_field.clone()).field; + let embedding = Vector::from(adv_embed.field_value.clone()); + (vector_column_name, embedding) + } + (None, None) => { + return Err(anyhow!("No embedding found in query state")); + } + (Some(_), Some(_)) => { + return Err(anyhow!( + "Both regular and advanced embeddings found. Please provide only one type." + )); + } + }; - // let pool = self.connection_pool.get_pool().await?; - let pool = self.connection_pool.get_pool()?; + let pool = self.pool_get_or_initialize().await?; let default_columns: Vec<_> = PgVectorBuilder::default_fields() .iter() .map(|f| f.field_name().to_string()) .collect(); - let vector_column_name = self.get_vector_column_name()?; // Start building the SQL query let mut sql = format!( @@ -83,7 +99,7 @@ impl Retrieve> for PgVector { let data: Vec = sqlx::query_as(&sql) .bind(embedding) .bind(top_k) - .fetch_all(&pool) + .fetch_all(pool) .await?; let docs = data.into_iter().map(|r| r.chunk).collect(); @@ -110,66 +126,23 @@ impl Retrieve for PgVector { #[cfg(test)] mod tests { - use crate::pgvector::PgVector; + use crate::pgvector::fixtures::TestContext; use futures_util::TryStreamExt; + use std::collections::HashSet; use swiftide_core::{indexing, indexing::EmbeddedField, Persist}; use swiftide_core::{ querying::{search_strategies::SimilaritySingleEmbedding, states, Query}, Retrieve, }; - use temp_dir::TempDir; - use testcontainers::{ContainerAsync, GenericImage}; - - struct TestContext { - pgv_storage: PgVector, - _temp_dir: TempDir, - _pgv_db_container: ContainerAsync, - } - - impl TestContext { - /// Set up the test context, initializing `PostgreSQL` and `PgVector` storage - async fn setup() -> Result> { - // Start PostgreSQL container and obtain the connection URL - let (pgv_db_container, pgv_db_url, temp_dir) = - swiftide_test_utils::start_postgres().await; - - tracing::info!("Postgres database URL: {:#?}", pgv_db_url); - - // Configure and build PgVector storage - let pgv_storage = PgVector::builder() - .try_connect_to_pool(pgv_db_url, Some(10)) - .await - .map_err(|err| { - tracing::error!("Failed to connect to Postgres server: {}", err); - err - })? - .vector_size(384) - .with_vector(EmbeddedField::Combined) - .with_metadata("filter") - .table_name("swiftide_pgvector_test".to_string()) - .build() - .map_err(|err| { - tracing::error!("Failed to build PgVector: {}", err); - err - })?; - - // Set up PgVector storage (create the table if not exists) - pgv_storage.setup().await.map_err(|err| { - tracing::error!("PgVector setup failed: {}", err); - err - })?; - - Ok(Self { - pgv_storage, - _temp_dir: temp_dir, - _pgv_db_container: pgv_db_container, - }) - } - } #[test_log::test(tokio::test)] async fn test_retrieve_multiple_docs_and_filter() { - let test_context = TestContext::setup().await.expect("Test setup failed"); + let test_context = TestContext::setup_with_cfg( + vec!["filter"].into(), + HashSet::from([EmbeddedField::Combined]), + ) + .await + .expect("Test setup failed"); let nodes = vec![ indexing::Node::new("test_query1").with_metadata(("filter", "true")), diff --git a/swiftide-test-utils/Cargo.toml b/swiftide-test-utils/Cargo.toml index b0a344db..840729f7 100644 --- a/swiftide-test-utils/Cargo.toml +++ b/swiftide-test-utils/Cargo.toml @@ -27,8 +27,6 @@ wiremock = { workspace = true } serde = { workspace = true } serde_json = { workspace = true } tokio = { workspace = true } -tempfile = { workspace = true } -portpicker = { workspace = true } [features] default = ["test-utils"] diff --git a/swiftide-test-utils/src/test_utils.rs b/swiftide-test-utils/src/test_utils.rs index 4b7069c2..3ce54df6 100644 --- a/swiftide-test-utils/src/test_utils.rs +++ b/swiftide-test-utils/src/test_utils.rs @@ -11,7 +11,6 @@ use wiremock::matchers::{method, path}; use wiremock::{Mock, MockServer, ResponseTemplate}; use swiftide_integrations as integrations; -use temp_dir::TempDir; pub fn openai_client( mock_server_uri: &str,