Skip to content

Commit

Permalink
Addressed review comments:
Browse files Browse the repository at this point in the history
- added Postgres test_util,
- completed unit tests for persist and retrieval

Signed-off-by: shamb0 <r.raajey@gmail.com>
  • Loading branch information
shamb0 committed Nov 21, 2024
1 parent 7a6b2bc commit 4f0da81
Show file tree
Hide file tree
Showing 11 changed files with 366 additions and 94 deletions.
10 changes: 10 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

52 changes: 51 additions & 1 deletion examples/index_md_into_pgvector.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,40 @@ use swiftide::{
},
EmbeddedField,
},
integrations::{self, pgvector::PgVector},
integrations::{self, fastembed::FastEmbed, pgvector::PgVector},
query::{self, answers, query_transformers, response_transformers},
traits::SimplePrompt,
};

async fn ask_query(
llm_client: impl SimplePrompt + Clone + 'static,
embed: FastEmbed,
vector_store: PgVector,
question: String,
) -> Result<String, Box<dyn std::error::Error>> {
// By default the search strategy is SimilaritySingleEmbedding
// which takes the latest query, embeds it, and does a similarity search
//
// Pgvector will return an error if multiple embeddings are set
//
// The pipeline generates subquestions to increase semantic coverage, embeds these in a single
// embedding, retrieves the default top_k documents, summarizes them and uses that as context
// for the final answer.
let pipeline = query::Pipeline::default()
.then_transform_query(query_transformers::GenerateSubquestions::from_client(
llm_client.clone(),
))
.then_transform_query(query_transformers::Embed::from_client(embed))
.then_retrieve(vector_store.clone())
.then_transform_response(response_transformers::Summary::from_client(
llm_client.clone(),
))
.then_answer(answers::Simple::from_client(llm_client.clone()));

let result = pipeline.query(question).await?;
Ok(result.answer().into())
}

#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
tracing_subscriber::fmt::init();
Expand Down Expand Up @@ -71,5 +102,24 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
.await?;

tracing::info!("PgVector Indexing test completed successfully");
for (i, question) in [
"What is SwiftIDE? Provide a clear, comprehensive summary in under 50 words.",
"How can I use SwiftIDE to connect with the Ethereum blockchain? Please provide a concise, comprehensive summary in less than 50 words.",
]
.iter()
.enumerate()
{
let result = ask_query(
llm_client.clone(),
fastembed.clone(),
pgv_storage.clone(),
question.to_string(),
).await?;
tracing::info!("*** Answer Q{} ***", i + 1);
tracing::info!("{}", result);
tracing::info!("===X===");
}

tracing::info!("PgVector Indexing & retrieval test completed successfully");
Ok(())
}
41 changes: 0 additions & 41 deletions examples/test_dataset/README.md

This file was deleted.

46 changes: 0 additions & 46 deletions scripts/docker/docker-compose-db-pg.yml

This file was deleted.

2 changes: 1 addition & 1 deletion swiftide-indexing/src/loaders/file_loader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ impl Loader for FileLoader {
.filter(|entry| entry.file_type().is_some_and(|ft| ft.is_file()))
.filter(move |entry| self.file_has_extension(entry.path()))
.map(|entry| {
tracing::debug!("Reading file: {:?}", entry);
tracing::info!("Reading file: {:?}", entry);
let content =
std::fs::read_to_string(entry.path()).context("Failed to read file")?;
let original_size = content.len();
Expand Down
2 changes: 1 addition & 1 deletion swiftide-integrations/src/pgvector/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ use std::sync::Arc;
use std::sync::OnceLock;
use tokio::time::Duration;

use pgv_table_types::{FieldConfig, MetadataConfig, VectorConfig};
use pgv_table_types::{FieldConfig, MetadataConfig, PgDBConnectionPool, VectorConfig};

/// Default maximum connections for the database connection pool.
const DB_POOL_CONN_MAX: u32 = 10;
Expand Down
66 changes: 66 additions & 0 deletions swiftide-integrations/src/pgvector/persist.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,3 +91,69 @@ mod tests {
.expect("PgVector setup should not fail when the table already exists");
}
}

#[cfg(test)]
mod tests {
use crate::pgvector::PgVector;
use swiftide_core::{indexing::EmbeddedField, Persist};
use temp_dir::TempDir;
use testcontainers::{ContainerAsync, GenericImage};

struct TestContext {
pgv_storage: PgVector,
_temp_dir: TempDir,
_pgv_db_container: ContainerAsync<GenericImage>,
}

impl TestContext {
/// Set up the test context, initializing `PostgreSQL` and `PgVector` storage
async fn setup() -> Result<Self, Box<dyn std::error::Error>> {
// Start PostgreSQL container and obtain the connection URL
let (pgv_db_container, pgv_db_url, temp_dir) =
swiftide_test_utils::start_postgres().await;

tracing::info!("Postgres database URL: {:#?}", pgv_db_url);

// Configure and build PgVector storage
let pgv_storage = PgVector::builder()
.try_connect_to_pool(pgv_db_url, Some(10))
.await
.map_err(|err| {
tracing::error!("Failed to connect to Postgres server: {}", err);
err
})?
.vector_size(384)
.with_vector(EmbeddedField::Combined)
.with_metadata("filter")
.table_name("swiftide_pgvector_test".to_string())
.build()
.map_err(|err| {
tracing::error!("Failed to build PgVector: {}", err);
err
})?;

// Set up PgVector storage (create the table if not exists)
pgv_storage.setup().await.map_err(|err| {
tracing::error!("PgVector setup failed: {}", err);
err
})?;

Ok(Self {
pgv_storage,
_temp_dir: temp_dir,
_pgv_db_container: pgv_db_container,
})
}
}

#[test_log::test(tokio::test)]
async fn test_persist_setup_no_error_when_table_exists() {
let test_context = TestContext::setup().await.expect("Test setup failed");

test_context
.pgv_storage
.setup()
.await
.expect("PgVector setup should not fail when the table already exists");
}
}
11 changes: 7 additions & 4 deletions swiftide-integrations/src/pgvector/pgv_table_types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
//! - Bulk data preparation and SQL query generation
//!
use crate::pgvector::PgVector;
use anyhow::{anyhow, Result};
use anyhow::{anyhow, Context, Result};
use pgvector as ExtPgVector;
use regex::Regex;
use sqlx::postgres::PgArguments;
Expand Down Expand Up @@ -185,9 +185,7 @@ impl PgVector {
FieldConfig::ID => "id UUID NOT NULL".to_string(),
FieldConfig::Chunk => format!("{} TEXT NOT NULL", field.field_name()),
FieldConfig::Metadata(_) => format!("{} JSONB", field.field_name()),
FieldConfig::Vector(_) => {
format!("{} VECTOR({})", field.field_name(), self.vector_size)
}
FieldConfig::Vector(_) => format!("{} VECTOR({})", field.field_name(), vector_size),
})
.chain(std::iter::once("PRIMARY KEY (id)".to_string()))
.collect();
Expand Down Expand Up @@ -267,6 +265,11 @@ impl PgVector {
.await
.map_err(|e| anyhow!("Failed to store nodes: {:?}", e))?;

query.execute(&mut *tx).await.map_err(|e| {
tracing::error!("Failed to store nodes: {:?}", e);
anyhow!("Failed to store nodes: {:?}", e)
})?;

tx.commit()
.await
.map_err(|e| anyhow!("Failed to commit transaction: {:?}", e))
Expand Down
Loading

0 comments on commit 4f0da81

Please sign in to comment.