diff --git a/Cargo.lock b/Cargo.lock index e2b54769..257c79a9 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -8432,6 +8432,7 @@ dependencies = [ "swiftide-test-utils", "temp-dir", "tokio", + "tracing", "tracing-subscriber", ] @@ -8489,6 +8490,7 @@ dependencies = [ "lancedb", "mockall", "ollama-rs", + "once_cell", "parquet", "pgvector", "qdrant-client", diff --git a/Cargo.toml b/Cargo.toml index 3b9ac0e4..afd616ef 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -39,6 +39,7 @@ indoc = { version = "2.0" } regex = { version = "1.11.1" } uuid = { version = "1.10", features = ["v3", "v4", "serde"] } dyn-clone = { version = "1.0" } +once_cell = { version = "1.20.2" } # Integrations spider = { version = "2.13" } diff --git a/examples/Cargo.toml b/examples/Cargo.toml index 6f07a0ad..b0148ea2 100644 --- a/examples/Cargo.toml +++ b/examples/Cargo.toml @@ -31,6 +31,7 @@ fluvio = { workspace = true } temp-dir = { workspace = true } sqlx = { workspace = true } swiftide-test-utils = { path = "../swiftide-test-utils" } +tracing = { workspace = true } [[example]] doc-scrape-examples = true diff --git a/examples/index_md_into_pgvector.rs b/examples/index_md_into_pgvector.rs index 9b298def..a8dfc5c1 100644 --- a/examples/index_md_into_pgvector.rs +++ b/examples/index_md_into_pgvector.rs @@ -11,40 +11,9 @@ use swiftide::{ }, EmbeddedField, }, - integrations::{self, fastembed::FastEmbed, pgvector::PgVector}, - query::{self, answers, query_transformers, response_transformers}, - traits::SimplePrompt, + integrations::{self, pgvector::PgVector}, }; -async fn ask_query( - llm_client: impl SimplePrompt + Clone + 'static, - embed: FastEmbed, - vector_store: PgVector, - question: String, -) -> Result> { - // By default the search strategy is SimilaritySingleEmbedding - // which takes the latest query, embeds it, and does a similarity search - // - // Pgvector will return an error if multiple embeddings are set - // - // The pipeline generates subquestions to increase semantic coverage, embeds these in a single - // embedding, retrieves the default top_k documents, summarizes them and uses that as context - // for the final answer. - let pipeline = query::Pipeline::default() - .then_transform_query(query_transformers::GenerateSubquestions::from_client( - llm_client.clone(), - )) - .then_transform_query(query_transformers::Embed::from_client(embed)) - .then_retrieve(vector_store.clone()) - .then_transform_response(response_transformers::Summary::from_client( - llm_client.clone(), - )) - .then_answer(answers::Simple::from_client(llm_client.clone())); - - let result = pipeline.query(question).await?; - Ok(result.answer().into()) -} - #[tokio::main] async fn main() -> Result<(), Box> { tracing_subscriber::fmt::init(); @@ -72,9 +41,7 @@ async fn main() -> Result<(), Box> { // Configure Pgvector with a default vector size, a single embedding // and in addition to embedding the text metadata, also store it in a field let pgv_storage = PgVector::builder() - .try_connect_to_pool(pgv_db_url, Some(10)) - .await - .expect("Failed to connect to postgres server") + .db_url(pgv_db_url) .vector_size(384) .with_vector(EmbeddedField::Combined) .with_metadata(METADATA_QA_TEXT_NAME) @@ -87,9 +54,9 @@ async fn main() -> Result<(), Box> { let drop_table_sql = "DROP TABLE IF EXISTS swiftide_pgvector_test"; let drop_index_sql = "DROP INDEX IF EXISTS swiftide_pgvector_test_embedding_idx"; - if let Ok(pool) = pgv_storage.get_pool() { - sqlx::query(drop_table_sql).execute(&pool).await?; - sqlx::query(drop_index_sql).execute(&pool).await?; + if let Ok(pool) = pgv_storage.get_pool().await { + sqlx::query(drop_table_sql).execute(pool).await?; + sqlx::query(drop_index_sql).execute(pool).await?; } else { return Err("Failed to get database connection pool".into()); } @@ -103,24 +70,6 @@ async fn main() -> Result<(), Box> { .run() .await?; - for (i, question) in [ - "What is SwiftIDE? Provide a clear, comprehensive summary in under 50 words.", - "How can I use SwiftIDE to connect with the Ethereum blockchain? Please provide a concise, comprehensive summary in less than 50 words.", - ] - .iter() - .enumerate() - { - let result = ask_query( - llm_client.clone(), - fastembed.clone(), - pgv_storage.clone(), - question.to_string(), - ).await?; - tracing::info!("*** Answer Q{} ***", i + 1); - tracing::info!("{}", result); - tracing::info!("===X==="); - } - - tracing::info!("PgVector Indexing & retrieval test completed successfully"); + tracing::info!("PgVector Indexing test completed successfully"); Ok(()) } diff --git a/swiftide-integrations/Cargo.toml b/swiftide-integrations/Cargo.toml index 1e714379..0737bd02 100644 --- a/swiftide-integrations/Cargo.toml +++ b/swiftide-integrations/Cargo.toml @@ -27,7 +27,7 @@ strum = { workspace = true } strum_macros = { workspace = true } regex = { workspace = true } futures-util = { workspace = true } - +once_cell = { workspace = true } # Integrations async-openai = { workspace = true, optional = true } diff --git a/swiftide-integrations/src/pgvector/fixtures.rs b/swiftide-integrations/src/pgvector/fixtures.rs new file mode 100644 index 00000000..9e64d930 --- /dev/null +++ b/swiftide-integrations/src/pgvector/fixtures.rs @@ -0,0 +1,98 @@ +//! This module implements common types and helper utilities for unit tests related to the pgvector +use crate::pgvector::PgVector; +use std::collections::HashSet; +use swiftide_core::{ + indexing::{self, EmbeddedField}, + Persist, +}; +use testcontainers::{ContainerAsync, GenericImage}; + +#[derive(Clone)] +pub(crate) struct PgVectorTestData<'a> { + pub embed_mode: indexing::EmbedMode, + pub chunk: &'a str, + pub metadata: Option, + pub vectors: Vec<(indexing::EmbeddedField, Vec)>, +} + +impl<'a> PgVectorTestData<'a> { + pub(crate) fn to_node(&self) -> indexing::Node { + // Create the initial builder + let mut base_builder = indexing::Node::builder(); + + // Set the required fields + let mut builder = base_builder.chunk(self.chunk).embed_mode(self.embed_mode); + + // Add metadata if it exists + if let Some(metadata) = &self.metadata { + builder = builder.metadata(metadata.clone()); + } + + // Build the node and add vectors + let mut node = builder.build().unwrap(); + node.vectors = Some(self.vectors.clone().into_iter().collect()); + node + } + + pub(crate) fn create_test_vector( + field: EmbeddedField, + base_value: f32, + ) -> (EmbeddedField, Vec) { + (field, vec![base_value; 384]) + } +} + +pub(crate) struct TestContext { + pub(crate) pgv_storage: PgVector, + _pgv_db_container: ContainerAsync, +} + +impl TestContext { + /// Set up the test context, initializing `PostgreSQL` and `PgVector` storage + /// with configurable metadata fields + pub(crate) async fn setup_with_cfg( + metadata_fields: Option>, + vector_fields: HashSet, + ) -> Result> { + // Start `PostgreSQL` container and obtain the connection URL + let (pgv_db_container, pgv_db_url) = swiftide_test_utils::start_postgres().await; + tracing::info!("Postgres database URL: {:#?}", pgv_db_url); + + // Initialize the connection pool outside of the builder chain + let mut connection_pool = PgVector::builder(); + + // Configure PgVector storage + let mut builder = connection_pool + .db_url(pgv_db_url) + .vector_size(384) + .table_name("swiftide_pgvector_test".to_string()); + + // Add all vector fields + for vector_field in vector_fields { + builder = builder.with_vector(vector_field); + } + + // Add all metadata fields + if let Some(metadata_fields_inner) = metadata_fields { + for field in metadata_fields_inner { + builder = builder.with_metadata(field); + } + }; + + let pgv_storage = builder.build().map_err(|err| { + tracing::error!("Failed to build PgVector: {}", err); + err + })?; + + // Set up PgVector storage (create the table if not exists) + pgv_storage.setup().await.map_err(|err| { + tracing::error!("PgVector setup failed: {}", err); + err + })?; + + Ok(Self { + pgv_storage, + _pgv_db_container: pgv_db_container, + }) + } +} diff --git a/swiftide-integrations/src/pgvector/mod.rs b/swiftide-integrations/src/pgvector/mod.rs index e4ca677f..ecf0410e 100644 --- a/swiftide-integrations/src/pgvector/mod.rs +++ b/swiftide-integrations/src/pgvector/mod.rs @@ -2,17 +2,33 @@ //! store data, and optimize indexing for efficient searches. //! //! pgvector is utilized in both the `indexing::Pipeline` and `query::Pipeline` modules. + +#[cfg(test)] +mod fixtures; + mod persist; mod pgv_table_types; -mod retrieve; use anyhow::Result; use derive_builder::Builder; +use once_cell::sync::OnceCell; use sqlx::PgPool; use std::fmt; +use std::sync::Arc; +use tokio::time::Duration; + +use pgv_table_types::{FieldConfig, MetadataConfig, VectorConfig}; -use pgv_table_types::{FieldConfig, MetadataConfig, PgDBConnectionPool, VectorConfig}; +/// Default maximum connections for the database connection pool. +const DB_POOL_CONN_MAX: u32 = 10; -const DEFAULT_BATCH_SIZE: usize = 50; +/// Default maximum retries for database connection attempts. +const DB_POOL_CONN_RETRY_MAX: u32 = 3; + +/// Delay between connection retry attempts, in seconds. +const DB_POOL_CONN_RETRY_DELAY_SECS: u64 = 3; + +/// Default batch size for storing nodes. +const BATCH_SIZE: usize = 50; /// Represents a Pgvector client with configuration options. /// @@ -21,27 +37,41 @@ const DEFAULT_BATCH_SIZE: usize = 50; #[derive(Builder, Clone)] #[builder(setter(into, strip_option), build_fn(error = "anyhow::Error"))] pub struct PgVector { - /// Database connection pool. - #[builder(default = "PgDBConnectionPool::default()")] - connection_pool: PgDBConnectionPool, - - /// Table name to store vectors in. + /// Name of the table to store vectors. #[builder(default = "String::from(\"swiftide_pgv_store\")")] table_name: String, - /// Default sizes of vectors. Vectors can also be of different - /// sizes by specifying the size in the vector configuration. - vector_size: Option, + /// Default vector size; can be customized per configuration. + vector_size: i32, /// Batch size for storing nodes. - #[builder(default = "Some(DEFAULT_BATCH_SIZE)")] - batch_size: Option, + #[builder(default = "BATCH_SIZE")] + batch_size: usize, - /// Field configuration for the Pgvector table, determining the eventual table schema. + /// Field configurations for the `PgVector` table schema. /// - /// Supports multiple field types; see [`FieldConfig`] for details. + /// Supports multiple field types (see [`FieldConfig`]). #[builder(default)] fields: Vec, + + /// Database connection URL. + db_url: String, + + /// Maximum connections allowed in the connection pool. + #[builder(default = "DB_POOL_CONN_MAX")] + db_max_connections: u32, + + /// Maximum retry attempts for establishing a database connection. + #[builder(default = "DB_POOL_CONN_RETRY_MAX")] + db_max_retry: u32, + + /// Delay between retry attempts for database connections. + #[builder(default = "Duration::from_secs(DB_POOL_CONN_RETRY_DELAY_SECS)")] + db_conn_retry_delay: Duration, + + /// Lazy-initialized database connection pool. + #[builder(default = "Arc::new(OnceCell::new())")] + connection_pool: Arc>, } impl fmt::Debug for PgVector { @@ -78,43 +108,12 @@ impl PgVector { /// /// This function will return an error if it fails to retrieve the connection pool, which could occur /// if the underlying connection to `PostgreSQL` has not been properly established. - pub fn get_pool(&self) -> Result { - self.connection_pool.get_pool() + pub async fn get_pool(&self) -> Result<&PgPool> { + self.pool_get_or_initialize().await } } impl PgVectorBuilder { - /// Tries to asynchronously connect to a `Postgres` server and initialize a connection pool. - /// - /// This function attempts to establish a connection to the specified `Postgres` server and - /// sets up a connection pool with an optional maximum number of connections. - /// - /// # Arguments - /// - /// * `url` - A string reference representing the URL of the `Postgres` server to connect to. - /// * `connection_max` - An optional value specifying the maximum number of connections in the pool. - /// - /// # Returns - /// - /// A `Result` that contains an updated `PgVector` instance with the new connection pool on success. - /// On failure, an error is returned. - /// - /// # Errors - /// - /// This function returns an error if the connection to the database fails or if retries are exhausted. - /// Possible reasons include invalid database URLs, unreachable servers, or exceeded retry limits. - pub async fn try_connect_to_pool( - mut self, - url: impl AsRef, - connection_max: Option, - ) -> Result { - let pool = self.connection_pool.clone().unwrap_or_default(); - - self.connection_pool = Some(pool.try_connect_to_url(url, connection_max).await?); - - Ok(self) - } - /// Adds a vector configuration to the builder. /// /// # Arguments @@ -161,221 +160,14 @@ impl PgVectorBuilder { #[cfg(test)] mod tests { - use crate::pgvector::PgVector; + use crate::pgvector::fixtures::{PgVectorTestData, TestContext}; use futures_util::TryStreamExt; - use swiftide_core::{indexing, indexing::EmbeddedField, Persist}; + use std::collections::HashSet; use swiftide_core::{ - indexing::EmbedMode, - querying::{search_strategies::SimilaritySingleEmbedding, states, Query}, - Retrieve, + indexing::{self, EmbedMode, EmbeddedField}, + Persist, }; use test_case::test_case; - use testcontainers::{ContainerAsync, GenericImage}; - - struct TestContext { - pgv_storage: PgVector, - _pgv_db_container: ContainerAsync, - } - - impl TestContext { - /// Set up the test context, initializing `PostgreSQL` and `PgVector` storage - /// with configurable metadata fields - async fn setup_with_cfg( - metadata_fields: Option>, - embedded_field: indexing::EmbeddedField, - ) -> Result> { - // Start `PostgreSQL` container and obtain the connection URL - let (pgv_db_container, pgv_db_url) = swiftide_test_utils::start_postgres().await; - tracing::info!("Postgres database URL: {:#?}", pgv_db_url); - - // Initialize the connection pool outside of the builder chain - let mut connection_pool = PgVector::builder() - .try_connect_to_pool(pgv_db_url, Some(10)) - .await - .map_err(|err| { - tracing::error!("Failed to connect to Postgres server: {}", err); - err - })?; - - // Configure PgVector storage - let mut builder = connection_pool - .vector_size(384) - .with_vector(embedded_field) - .table_name("swiftide_pgvector_test".to_string()); - - // Add all metadata fields - if let Some(metadata_fields_inner) = metadata_fields { - for field in metadata_fields_inner { - builder = builder.with_metadata(field); - } - }; - - let pgv_storage = builder.build().map_err(|err| { - tracing::error!("Failed to build PgVector: {}", err); - err - })?; - - // Set up PgVector storage (create the table if not exists) - pgv_storage.setup().await.map_err(|err| { - tracing::error!("PgVector setup failed: {}", err); - err - })?; - - Ok(Self { - pgv_storage, - _pgv_db_container: pgv_db_container, - }) - } - } - - #[test_log::test(tokio::test)] - async fn test_metadata_filter_with_vector_search() { - let test_context = TestContext::setup_with_cfg( - vec!["category", "priority"].into(), - EmbeddedField::Combined, - ) - .await - .expect("Test setup failed"); - - // Create nodes with different metadata and vectors - let nodes = vec![ - indexing::Node::new("content1") - .with_vectors([(EmbeddedField::Combined, vec![1.0; 384])]) - .with_metadata(vec![("category", "A"), ("priority", "1")]), - indexing::Node::new("content2") - .with_vectors([(EmbeddedField::Combined, vec![1.1; 384])]) - .with_metadata(vec![("category", "A"), ("priority", "2")]), - indexing::Node::new("content3") - .with_vectors([(EmbeddedField::Combined, vec![1.2; 384])]) - .with_metadata(vec![("category", "B"), ("priority", "1")]), - ] - .into_iter() - .map(|node| node.to_owned()) - .collect(); - - // Store all nodes - test_context - .pgv_storage - .batch_store(nodes) - .await - .try_collect::>() - .await - .unwrap(); - - // Test combined metadata and vector search - let mut query = Query::::new("test_query"); - query.embedding = Some(vec![1.0; 384]); - - // Search with category filter - let search_strategy = - SimilaritySingleEmbedding::from_filter("category = \"A\"".to_string()); - let result = test_context - .pgv_storage - .retrieve(&search_strategy, query.clone()) - .await - .unwrap(); - - assert_eq!(result.documents().len(), 2); - assert!(result.documents().contains(&"content1".to_string())); - assert!(result.documents().contains(&"content2".to_string())); - - // Additional test with priority filter - let search_strategy = - SimilaritySingleEmbedding::from_filter("priority = \"1\"".to_string()); - let result = test_context - .pgv_storage - .retrieve(&search_strategy, query) - .await - .unwrap(); - - assert_eq!(result.documents().len(), 2); - assert!(result.documents().contains(&"content1".to_string())); - assert!(result.documents().contains(&"content3".to_string())); - } - - #[test_log::test(tokio::test)] - async fn test_vector_similarity_search_accuracy() { - let test_context = TestContext::setup_with_cfg(None, EmbeddedField::Combined) - .await - .expect("Test setup failed"); - - // Create nodes with known vector relationships - let base_vector = vec![1.0; 384]; - let similar_vector = base_vector.iter().map(|x| x + 0.1).collect::>(); - let dissimilar_vector = vec![-1.0; 384]; - - let nodes = vec![ - indexing::Node::new("base_content") - .with_vectors([(EmbeddedField::Combined, base_vector)]), - indexing::Node::new("similar_content") - .with_vectors([(EmbeddedField::Combined, similar_vector)]), - indexing::Node::new("dissimilar_content") - .with_vectors([(EmbeddedField::Combined, dissimilar_vector)]), - ] - .into_iter() - .map(|node| node.to_owned()) - .collect(); - - // Store all nodes - test_context - .pgv_storage - .batch_store(nodes) - .await - .try_collect::>() - .await - .unwrap(); - - // Search with base vector - let mut query = Query::::new("test_query"); - query.embedding = Some(vec![1.0; 384]); - - let mut search_strategy = SimilaritySingleEmbedding::<()>::default(); - search_strategy.with_top_k(2); - - let result = test_context - .pgv_storage - .retrieve(&search_strategy, query) - .await - .unwrap(); - - // Verify that similar vectors are retrieved first - assert_eq!(result.documents().len(), 2); - assert!(result.documents().contains(&"base_content".to_string())); - assert!(result.documents().contains(&"similar_content".to_string())); - } - - #[derive(Clone)] - struct PgVectorTestData<'a> { - pub embed_mode: indexing::EmbedMode, - pub chunk: &'a str, - pub metadata: Option, - pub vectors: Vec<(indexing::EmbeddedField, Vec)>, - pub expected_in_results: bool, - } - - impl<'a> PgVectorTestData<'a> { - fn to_node(&self) -> indexing::Node { - // Create the initial builder - let mut base_builder = indexing::Node::builder(); - - // Set the required fields - let mut builder = base_builder.chunk(self.chunk).embed_mode(self.embed_mode); - - // Add metadata if it exists - if let Some(metadata) = &self.metadata { - builder = builder.metadata(metadata.clone()); - } - - // Build the node and add vectors - let mut node = builder.build().unwrap(); - node.vectors = Some(self.vectors.clone().into_iter().collect()); - node - } - } - - fn create_test_vector(field: EmbeddedField, base_value: f32) -> (EmbeddedField, Vec) { - (field, vec![base_value; 384]) - } #[test_case( // SingleWithMetadata - No Metadata @@ -384,17 +176,16 @@ mod tests { embed_mode: EmbedMode::SingleWithMetadata, chunk: "single_no_meta_1", metadata: None, - vectors: vec![create_test_vector(EmbeddedField::Combined, 1.0)], - expected_in_results: true, + vectors: vec![PgVectorTestData::create_test_vector(EmbeddedField::Combined, 1.0)], }, PgVectorTestData { embed_mode: EmbedMode::SingleWithMetadata, chunk: "single_no_meta_2", metadata: None, - vectors: vec![create_test_vector(EmbeddedField::Combined, 1.1)], - expected_in_results: true, + vectors: vec![PgVectorTestData::create_test_vector(EmbeddedField::Combined, 1.1)], } - ] + ], + HashSet::from([EmbeddedField::Combined]) ; "SingleWithMetadata mode without metadata")] #[test_case( // SingleWithMetadata - With Metadata @@ -406,8 +197,7 @@ mod tests { ("category", "A"), ("priority", "high") ].into()), - vectors: vec![create_test_vector(EmbeddedField::Combined, 1.2)], - expected_in_results: true, + vectors: vec![PgVectorTestData::create_test_vector(EmbeddedField::Combined, 1.2)], }, PgVectorTestData { embed_mode: EmbedMode::SingleWithMetadata, @@ -416,11 +206,77 @@ mod tests { ("category", "B"), ("priority", "low") ].into()), - vectors: vec![create_test_vector(EmbeddedField::Combined, 1.3)], - expected_in_results: true, + vectors: vec![PgVectorTestData::create_test_vector(EmbeddedField::Combined, 1.3)], } - ] + ], + HashSet::from([EmbeddedField::Combined]) ; "SingleWithMetadata mode with metadata")] + #[test_case( + // PerField - No Metadata + vec![ + PgVectorTestData { + embed_mode: EmbedMode::PerField, + chunk: "per_field_no_meta_1", + metadata: None, + vectors: vec![ + PgVectorTestData::create_test_vector(EmbeddedField::Chunk, 1.2), + PgVectorTestData::create_test_vector(EmbeddedField::Metadata("category".into()), 2.2), + PgVectorTestData::create_test_vector(EmbeddedField::Metadata("priority".into()), 3.2), + ], + }, + PgVectorTestData { + embed_mode: EmbedMode::PerField, + chunk: "per_field_no_meta_2", + metadata: None, + vectors: vec![ + PgVectorTestData::create_test_vector(EmbeddedField::Chunk, 1.3), + PgVectorTestData::create_test_vector(EmbeddedField::Metadata("category".into()), 2.3), + PgVectorTestData::create_test_vector(EmbeddedField::Metadata("priority".into()), 3.3), + ], + } + ], + HashSet::from([ + EmbeddedField::Chunk, + EmbeddedField::Metadata("category".into()), + EmbeddedField::Metadata("priority".into()), + ]) + ; "PerField mode without metadata")] + #[test_case( + // PerField - With Metadata + vec![ + PgVectorTestData { + embed_mode: EmbedMode::PerField, + chunk: "single_with_meta_1", + metadata: Some(vec![ + ("category", "A"), + ("priority", "high") + ].into()), + vectors: vec![ + PgVectorTestData::create_test_vector(EmbeddedField::Chunk, 1.2), + PgVectorTestData::create_test_vector(EmbeddedField::Metadata("category".into()), 2.2), + PgVectorTestData::create_test_vector(EmbeddedField::Metadata("priority".into()), 3.2), + ], + }, + PgVectorTestData { + embed_mode: EmbedMode::PerField, + chunk: "single_with_meta_2", + metadata: Some(vec![ + ("category", "B"), + ("priority", "low") + ].into()), + vectors: vec![ + PgVectorTestData::create_test_vector(EmbeddedField::Chunk, 1.3), + PgVectorTestData::create_test_vector(EmbeddedField::Metadata("category".into()), 2.3), + PgVectorTestData::create_test_vector(EmbeddedField::Metadata("priority".into()), 3.3), + ], + } + ], + HashSet::from([ + EmbeddedField::Chunk, + EmbeddedField::Metadata("category".into()), + EmbeddedField::Metadata("priority".into()), + ]) + ; "PerField mode with metadata")] #[test_case( // Both - No Metadata vec![ @@ -429,22 +285,21 @@ mod tests { chunk: "both_no_meta_1", metadata: None, vectors: vec![ - create_test_vector(EmbeddedField::Combined, 3.0), - create_test_vector(EmbeddedField::Chunk, 3.1) + PgVectorTestData::create_test_vector(EmbeddedField::Combined, 3.0), + PgVectorTestData::create_test_vector(EmbeddedField::Chunk, 3.1) ], - expected_in_results: true, }, PgVectorTestData { embed_mode: EmbedMode::Both, chunk: "both_no_meta_2", metadata: None, vectors: vec![ - create_test_vector(EmbeddedField::Combined, 3.2), - create_test_vector(EmbeddedField::Chunk, 3.3) + PgVectorTestData::create_test_vector(EmbeddedField::Combined, 3.2), + PgVectorTestData::create_test_vector(EmbeddedField::Chunk, 3.3) ], - expected_in_results: true, } - ] + ], + HashSet::from([EmbeddedField::Combined, EmbeddedField::Chunk]) ; "Both mode without metadata")] #[test_case( // Both - With Metadata @@ -458,13 +313,12 @@ mod tests { ("tag", "test1") ].into()), vectors: vec![ - create_test_vector(EmbeddedField::Combined, 3.4), - create_test_vector(EmbeddedField::Chunk, 3.5), - create_test_vector(EmbeddedField::Metadata("category".into()), 3.6), - create_test_vector(EmbeddedField::Metadata("priority".into()), 3.7), - create_test_vector(EmbeddedField::Metadata("tag".into()), 3.8) + PgVectorTestData::create_test_vector(EmbeddedField::Combined, 3.4), + PgVectorTestData::create_test_vector(EmbeddedField::Chunk, 3.5), + PgVectorTestData::create_test_vector(EmbeddedField::Metadata("category".into()), 3.6), + PgVectorTestData::create_test_vector(EmbeddedField::Metadata("priority".into()), 3.7), + PgVectorTestData::create_test_vector(EmbeddedField::Metadata("tag".into()), 3.8) ], - expected_in_results: true, }, PgVectorTestData { embed_mode: EmbedMode::Both, @@ -475,18 +329,27 @@ mod tests { ("tag", "test2") ].into()), vectors: vec![ - create_test_vector(EmbeddedField::Combined, 3.9), - create_test_vector(EmbeddedField::Chunk, 4.0), - create_test_vector(EmbeddedField::Metadata("category".into()), 4.1), - create_test_vector(EmbeddedField::Metadata("priority".into()), 4.2), - create_test_vector(EmbeddedField::Metadata("tag".into()), 4.3) + PgVectorTestData::create_test_vector(EmbeddedField::Combined, 3.9), + PgVectorTestData::create_test_vector(EmbeddedField::Chunk, 4.0), + PgVectorTestData::create_test_vector(EmbeddedField::Metadata("category".into()), 4.1), + PgVectorTestData::create_test_vector(EmbeddedField::Metadata("priority".into()), 4.2), + PgVectorTestData::create_test_vector(EmbeddedField::Metadata("tag".into()), 4.3) ], - expected_in_results: true, } - ] + ], + HashSet::from([ + EmbeddedField::Combined, + EmbeddedField::Chunk, + EmbeddedField::Metadata("category".into()), + EmbeddedField::Metadata("priority".into()), + EmbeddedField::Metadata("tag".into()), + ]) ; "Both mode with metadata")] #[test_log::test(tokio::test)] - async fn test_persist_and_retrieve_nodes(test_cases: Vec>) { + async fn test_persist_nodes( + test_cases: Vec>, + vector_fields: HashSet, + ) { // Extract all possible metadata fields from test cases let metadata_fields: Vec<&str> = test_cases .iter() @@ -497,10 +360,9 @@ mod tests { .collect(); // Initialize test context with all required metadata fields - let test_context = - TestContext::setup_with_cfg(Some(metadata_fields), EmbeddedField::Combined) - .await - .expect("Test setup failed"); + let test_context = TestContext::setup_with_cfg(Some(metadata_fields), vector_fields) + .await + .expect("Test setup failed"); // Convert test cases to nodes and store them let nodes: Vec = test_cases.iter().map(PgVectorTestData::to_node).collect(); @@ -520,7 +382,7 @@ mod tests { "All nodes should be stored" ); - // Verify storage and retrieval for each test case + // Verify storage for each test case for (test_case, stored_node) in test_cases.iter().zip(stored_nodes.iter()) { // 1. Verify basic node properties assert_eq!( @@ -542,52 +404,6 @@ mod tests { test_case.vectors.len(), "Vector count should match" ); - - // 3. Test vector similarity search - for (field, vector) in &test_case.vectors { - let mut query = Query::::new("test_query"); - query.embedding = Some(vector.clone()); - - let mut search_strategy = SimilaritySingleEmbedding::<()>::default(); - search_strategy.with_top_k(nodes.len() as u64); - - let result = test_context - .pgv_storage - .retrieve(&search_strategy, query.clone()) - .await - .expect("Retrieval should succeed"); - - if test_case.expected_in_results { - assert!( - result.documents().contains(&test_case.chunk.to_string()), - "Document should be found in results for field {field}", - ); - } - } - - // 4. Test metadata filtering if present - if let Some(metadata) = &test_case.metadata { - for (key, value) in metadata { - let filter_query = format!("{key} = \"{value}\""); - let search_strategy = SimilaritySingleEmbedding::from_filter(filter_query); - - let mut query = Query::::new("test_query"); - query.embedding = Some(test_case.vectors[0].1.clone()); - - let result = test_context - .pgv_storage - .retrieve(&search_strategy, query) - .await - .expect("Filtered retrieval should succeed"); - - if test_case.expected_in_results { - assert!( - result.documents().contains(&test_case.chunk.to_string()), - "Document should be found when filtering by metadata {key}={value}" - ); - } - } - } } } } diff --git a/swiftide-integrations/src/pgvector/persist.rs b/swiftide-integrations/src/pgvector/persist.rs index c0706576..43bb723d 100644 --- a/swiftide-integrations/src/pgvector/persist.rs +++ b/swiftide-integrations/src/pgvector/persist.rs @@ -13,7 +13,10 @@ use swiftide_core::{ impl Persist for PgVector { #[tracing::instrument(skip_all)] async fn setup(&self) -> Result<()> { - let mut tx = self.connection_pool.get_pool()?.begin().await?; + // Get or initialize the connection pool + let pool = self.pool_get_or_initialize().await?; + + let mut tx = pool.begin().await?; // Create extension let sql = "CREATE EXTENSION IF NOT EXISTS vector"; @@ -48,50 +51,24 @@ impl Persist for PgVector { } fn batch_size(&self) -> Option { - self.batch_size + Some(self.batch_size) } } #[cfg(test)] mod tests { - use crate::pgvector::PgVector; + use crate::pgvector::fixtures::TestContext; + use std::collections::HashSet; use swiftide_core::{indexing::EmbeddedField, Persist}; - use testcontainers::{ContainerAsync, GenericImage}; - - struct TestContext { - pgv_storage: PgVector, - _pgv_db_container: ContainerAsync, - } - - impl TestContext { - /// Set up the test context, initializing `PostgreSQL` and `PgVector` storage - async fn setup() -> Result> { - // Start PostgreSQL container and obtain the connection URL - let (pgv_db_container, pgv_db_url) = swiftide_test_utils::start_postgres().await; - - // Configure and build PgVector storage - let pgv_storage = PgVector::builder() - .try_connect_to_pool(pgv_db_url, Some(10)) - .await? - .vector_size(384) - .with_vector(EmbeddedField::Combined) - .with_metadata("filter") - .table_name("swiftide_pgvector_test".to_string()) - .build()?; - - // Set up PgVector storage (create the table if not exists) - pgv_storage.setup().await?; - - Ok(Self { - pgv_storage, - _pgv_db_container: pgv_db_container, - }) - } - } #[test_log::test(tokio::test)] async fn test_persist_setup_no_error_when_table_exists() { - let test_context = TestContext::setup().await.expect("Test setup failed"); + let test_context = TestContext::setup_with_cfg( + vec!["filter"].into(), + HashSet::from([EmbeddedField::Combined]), + ) + .await + .expect("Test setup failed"); test_context .pgv_storage diff --git a/swiftide-integrations/src/pgvector/pgv_table_types.rs b/swiftide-integrations/src/pgvector/pgv_table_types.rs index 3598d873..f133c63d 100644 --- a/swiftide-integrations/src/pgvector/pgv_table_types.rs +++ b/swiftide-integrations/src/pgvector/pgv_table_types.rs @@ -4,81 +4,17 @@ //! with `PostgreSQL`'s required data format. use crate::pgvector::PgVector; -use anyhow::{anyhow, Context, Result}; +use anyhow::{anyhow, Result}; +// use once_cell::sync::OnceCell; use pgvector as ExtPgVector; use regex::Regex; use sqlx::postgres::PgArguments; use sqlx::postgres::PgPoolOptions; use sqlx::PgPool; use std::collections::BTreeMap; -use std::sync::Arc; +// use std::sync::Arc; use swiftide_core::indexing::{EmbeddedField, Node}; -use tokio::time::{sleep, Duration}; - -#[derive(Clone)] -pub struct PgDBConnectionPool(Arc>); - -impl Default for PgDBConnectionPool { - fn default() -> Self { - Self(Arc::new(None)) - } -} - -impl PgDBConnectionPool { - /// Attempts to connect to the database with retries. - async fn connect_with_retry( - database_url: impl AsRef, - max_retries: u32, - pool_options: &PgPoolOptions, - ) -> Result { - for attempt in 1..=max_retries { - match pool_options.clone().connect(database_url.as_ref()).await { - Ok(pool) => { - return Ok(pool); - } - Err(_err) if attempt < max_retries => { - sleep(Duration::from_secs(2)).await; - } - Err(err) => return Err(err), - } - } - unreachable!() - } - - /// Connects to the database using the provided URL and sets the connection pool. - pub async fn try_connect_to_url( - mut self, - database_url: impl AsRef, - connection_max: Option, - ) -> Result { - let pool_options = PgPoolOptions::new().max_connections(connection_max.unwrap_or(10)); - - let pool = Self::connect_with_retry(database_url, 10, &pool_options) - .await - .context("Failed to connect to the database")?; - - self.0 = Arc::new(Some(pool)); - - Ok(self) - } - - /// Retrieves the connection pool, returning an error if the pool is not initialized. - pub fn get_pool(&self) -> Result { - self.0 - .as_ref() - .clone() - .ok_or_else(|| anyhow!("Database connection pool is not initialized")) - } - - /// Returns the connection status of the pool. - pub fn connection_status(&self) -> &'static str { - match self.0.as_ref() { - Some(pool) if !pool.is_closed() => "Open", - Some(_) => "Closed", - None => "Not initialized", - } - } -} +use tokio::time::sleep; #[derive(Clone, Debug)] pub struct VectorConfig { @@ -172,10 +108,6 @@ impl PgVector { return Err(anyhow::anyhow!("Invalid table name")); } - let vector_size = self - .vector_size - .ok_or_else(|| anyhow!("vector_size must be configured"))?; - let columns: Vec = self .fields .iter() @@ -183,7 +115,9 @@ impl PgVector { FieldConfig::ID => "id UUID NOT NULL".to_string(), FieldConfig::Chunk => format!("{} TEXT NOT NULL", field.field_name()), FieldConfig::Metadata(_) => format!("{} JSONB", field.field_name()), - FieldConfig::Vector(_) => format!("{} VECTOR({})", field.field_name(), vector_size), + FieldConfig::Vector(_) => { + format!("{} VECTOR({})", field.field_name(), self.vector_size) + } }) .chain(std::iter::once("PRIMARY KEY (id)".to_string())) .collect(); @@ -244,7 +178,7 @@ impl PgVector { /// - Any of the SQL queries fail to execute due to schema mismatch, constraint violations, or connectivity issues. /// - Committing the transaction fails. pub async fn store_nodes(&self, nodes: &[Node]) -> Result<()> { - let pool = self.connection_pool.get_pool()?; + let pool = self.pool_get_or_initialize().await?; let mut tx = pool.begin().await?; let bulk_data = self.prepare_bulk_data(nodes)?; @@ -471,6 +405,59 @@ impl PgVector { } } +impl PgVector { + async fn create_pool(&self) -> Result { + let pool_options = PgPoolOptions::new().max_connections(self.db_max_connections); + + for attempt in 1..=self.db_max_retry { + match pool_options.clone().connect(self.db_url.as_ref()).await { + Ok(pool) => { + tracing::info!("Successfully established database connection"); + return Ok(pool); + } + Err(err) if attempt < self.db_max_retry => { + tracing::warn!( + error = %err, + attempt = attempt, + max_retries = self.db_max_retry, + "Database connection attempt failed, retrying..." + ); + sleep(self.db_conn_retry_delay).await; + } + Err(err) => { + return Err(anyhow!(err).context("Failed to establish database connection")); + } + } + } + + Err(anyhow!( + "Max connection retries ({}) exceeded", + self.db_max_retry + )) + } + + /// Returns a reference to the `PgPool` if it is already initialized, + /// or creates and initializes it if it is not. + /// + /// # Errors + /// This function will return an error if pool creation fails. + pub async fn pool_get_or_initialize(&self) -> Result<&PgPool> { + if let Some(pool) = self.connection_pool.get() { + return Ok(pool); + } + + let pool = self.create_pool().await?; + self.connection_pool + .set(pool) + .map_err(|_| anyhow!("Pool already initialized"))?; + + // Re-check if the pool was set successfully, otherwise return an error + self.connection_pool + .get() + .ok_or_else(|| anyhow!("Failed to retrieve connection pool after setting it")) + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/swiftide-integrations/src/pgvector/retrieve.rs b/swiftide-integrations/src/pgvector/retrieve.rs deleted file mode 100644 index 7987650d..00000000 --- a/swiftide-integrations/src/pgvector/retrieve.rs +++ /dev/null @@ -1,223 +0,0 @@ -use crate::pgvector::{PgVector, PgVectorBuilder}; -use anyhow::{anyhow, Result}; -use async_trait::async_trait; -use pgvector::Vector; -use sqlx::{prelude::FromRow, types::Uuid}; -use swiftide_core::{ - querying::{search_strategies::SimilaritySingleEmbedding, states, Query}, - Retrieve, -}; - -#[allow(dead_code)] -#[derive(Debug, Clone, FromRow)] -struct VectorSearchResult { - id: Uuid, - chunk: String, -} - -#[allow(clippy::redundant_closure_for_method_calls)] -#[async_trait] -impl Retrieve> for PgVector { - #[tracing::instrument] - async fn retrieve( - &self, - search_strategy: &SimilaritySingleEmbedding, - query_state: Query, - ) -> Result> { - let embedding = query_state - .embedding - .as_ref() - .ok_or_else(|| anyhow!("No embedding for query"))?; - let embedding = Vector::from(embedding.clone()); - - // let pool = self.connection_pool.get_pool().await?; - let pool = self.connection_pool.get_pool()?; - - let default_columns: Vec<_> = PgVectorBuilder::default_fields() - .iter() - .map(|f| f.field_name().to_string()) - .collect(); - let vector_column_name = self.get_vector_column_name()?; - - // Start building the SQL query - let mut sql = format!( - "SELECT {} FROM {}", - default_columns.join(", "), - self.table_name - ); - - if let Some(filter) = search_strategy.filter() { - let filter_parts: Vec<&str> = filter.split('=').collect(); - if filter_parts.len() == 2 { - let key = filter_parts[0].trim(); - let value = filter_parts[1].trim().trim_matches('"'); - tracing::debug!( - "Filter being applied: key = {:#?}, value = {:#?}", - key, - value - ); - - let sql_filter = format!( - " WHERE meta_{}->>'{}' = '{}'", - PgVector::normalize_field_name(key), - key, - value - ); - sql.push_str(&sql_filter); - } else { - return Err(anyhow!("Invalid filter format")); - } - } - - // Add the ORDER BY clause for vector similarity search - sql.push_str(&format!( - " ORDER BY {} <=> $1 LIMIT $2", - &vector_column_name - )); - - tracing::debug!("Running retrieve with SQL: {}", sql); - - let top_k = i32::try_from(search_strategy.top_k()) - .map_err(|_| anyhow!("Failed to convert top_k to i32"))?; - - let data: Vec = sqlx::query_as(&sql) - .bind(embedding) - .bind(top_k) - .fetch_all(&pool) - .await?; - - let docs = data.into_iter().map(|r| r.chunk).collect(); - - Ok(query_state.retrieved_documents(docs)) - } -} - -#[async_trait] -impl Retrieve for PgVector { - async fn retrieve( - &self, - search_strategy: &SimilaritySingleEmbedding, - query: Query, - ) -> Result> { - Retrieve::>::retrieve( - self, - &search_strategy.into_concrete_filter::(), - query, - ) - .await - } -} - -#[cfg(test)] -mod tests { - use crate::pgvector::PgVector; - use futures_util::TryStreamExt; - use swiftide_core::{indexing, indexing::EmbeddedField, Persist}; - use swiftide_core::{ - querying::{search_strategies::SimilaritySingleEmbedding, states, Query}, - Retrieve, - }; - use testcontainers::{ContainerAsync, GenericImage}; - - struct TestContext { - pgv_storage: PgVector, - _pgv_db_container: ContainerAsync, - } - - impl TestContext { - /// Set up the test context, initializing `PostgreSQL` and `PgVector` storage - async fn setup() -> Result> { - // Start PostgreSQL container and obtain the connection URL - let (pgv_db_container, pgv_db_url) = swiftide_test_utils::start_postgres().await; - - tracing::info!("Postgres database URL: {:#?}", pgv_db_url); - - // Configure and build PgVector storage - let pgv_storage = PgVector::builder() - .try_connect_to_pool(pgv_db_url, Some(10)) - .await - .map_err(|err| { - tracing::error!("Failed to connect to Postgres server: {}", err); - err - })? - .vector_size(384) - .with_vector(EmbeddedField::Combined) - .with_metadata("filter") - .table_name("swiftide_pgvector_test".to_string()) - .build() - .map_err(|err| { - tracing::error!("Failed to build PgVector: {}", err); - err - })?; - - // Set up PgVector storage (create the table if not exists) - pgv_storage.setup().await.map_err(|err| { - tracing::error!("PgVector setup failed: {}", err); - err - })?; - - Ok(Self { - pgv_storage, - _pgv_db_container: pgv_db_container, - }) - } - } - - #[test_log::test(tokio::test)] - async fn test_retrieve_multiple_docs_and_filter() { - let test_context = TestContext::setup().await.expect("Test setup failed"); - - let nodes = vec![ - indexing::Node::new("test_query1").with_metadata(("filter", "true")), - indexing::Node::new("test_query2").with_metadata(("filter", "true")), - indexing::Node::new("test_query3").with_metadata(("filter", "false")), - ] - .into_iter() - .map(|node| { - node.with_vectors([(EmbeddedField::Combined, vec![1.0; 384])]); - node.to_owned() - }) - .collect(); - - test_context - .pgv_storage - .batch_store(nodes) - .await - .try_collect::>() - .await - .unwrap(); - - let mut query = Query::::new("test_query"); - query.embedding = Some(vec![1.0; 384]); - - let search_strategy = SimilaritySingleEmbedding::<()>::default(); - let result = test_context - .pgv_storage - .retrieve(&search_strategy, query.clone()) - .await - .unwrap(); - - assert_eq!(result.documents().len(), 3); - - let search_strategy = - SimilaritySingleEmbedding::from_filter("filter = \"true\"".to_string()); - - let result = test_context - .pgv_storage - .retrieve(&search_strategy, query.clone()) - .await - .unwrap(); - - assert_eq!(result.documents().len(), 2); - - let search_strategy = - SimilaritySingleEmbedding::from_filter("filter = \"banana\"".to_string()); - - let result = test_context - .pgv_storage - .retrieve(&search_strategy, query.clone()) - .await - .unwrap(); - assert_eq!(result.documents().len(), 0); - } -} diff --git a/swiftide/tests/lancedb.rs b/swiftide/tests/lancedb.rs index 5464b5fd..5202c5d0 100644 --- a/swiftide/tests/lancedb.rs +++ b/swiftide/tests/lancedb.rs @@ -1,16 +1,9 @@ -use arrow_array::{cast::AsArray, Array, RecordBatch, StringArray}; -use lancedb::query::ExecutableQuery; use swiftide::indexing; -use swiftide::query::{self, states, Query, TransformationEvent}; -use swiftide::{ - indexing::{ - transformers::{ - metadata_qa_code::NAME as METADATA_QA_CODE_NAME, ChunkCode, MetadataQACode, - }, - EmbeddedField, - }, - query::TryStreamExt as _, +use swiftide::indexing::{ + transformers::{metadata_qa_code::NAME as METADATA_QA_CODE_NAME, ChunkCode, MetadataQACode}, + EmbeddedField, }; +use swiftide::query::{self, states, Query, TransformationEvent}; use swiftide_indexing::{loaders, transformers, Pipeline}; use swiftide_integrations::{fastembed::FastEmbed, lancedb::LanceDB}; use swiftide_query::{answers, query_transformers, response_transformers}; @@ -95,55 +88,4 @@ async fn test_lancedb() { documents.first().unwrap(), "fn main() { println!(\"Hello, World!\"); }" ); - - // Manually assert everything was stored as expected - let conn = lancedb.get_connection().await.unwrap(); - let table = conn.open_table("swiftide_test").execute().await.unwrap(); - - let result: RecordBatch = table - .query() - .execute() - .await - .unwrap() - .try_collect::>() - .await - .unwrap() - .first() - .unwrap() - .clone(); - - assert_eq!(result.num_rows(), 1); - assert_eq!(result.num_columns(), 5); - dbg!(result.columns()); - assert!(result.column_by_name("id").is_some()); - assert_eq!( - result - .column_by_name("chunk") - .unwrap() - .as_any() - .downcast_ref::() // as_string() doesn't work, wtf - .unwrap() - .value(0), - code - ); - assert_eq!( - result - .column_by_name("questions_and_answers__code_") - .unwrap() - .as_any() - .downcast_ref::() // as_string() doesn't work, wtf - .unwrap() - .value(0), - "\n\nHello there, how may I assist you today?" - ); - - assert_eq!( - result - .column_by_name("vector_combined") - .unwrap() - .as_fixed_size_list() - .value(0) - .len(), - 384 - ); }