Skip to content

Commit

Permalink
Address review feedback:
Browse files Browse the repository at this point in the history
  - Removed retrieval functionality.
  - Increased unit test coverage for Persist by 30%.

Signed-off-by: shamb0 <r.raajey@gmail.com>
  • Loading branch information
shamb0 committed Nov 12, 2024
1 parent 744b799 commit c4fdcbe
Show file tree
Hide file tree
Showing 11 changed files with 344 additions and 794 deletions.
2 changes: 2 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ indoc = { version = "2.0" }
regex = { version = "1.11.1" }
uuid = { version = "1.10", features = ["v3", "v4", "serde"] }
dyn-clone = { version = "1.0" }
once_cell = { version = "1.20.2" }

# Integrations
spider = { version = "2.13" }
Expand Down
1 change: 1 addition & 0 deletions examples/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ fluvio = { workspace = true }
temp-dir = { workspace = true }
sqlx = { workspace = true }
swiftide-test-utils = { path = "../swiftide-test-utils" }
tracing = { workspace = true }

[[example]]
doc-scrape-examples = true
Expand Down
63 changes: 6 additions & 57 deletions examples/index_md_into_pgvector.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,40 +11,9 @@ use swiftide::{
},
EmbeddedField,
},
integrations::{self, fastembed::FastEmbed, pgvector::PgVector},
query::{self, answers, query_transformers, response_transformers},
traits::SimplePrompt,
integrations::{self, pgvector::PgVector},
};

async fn ask_query(
llm_client: impl SimplePrompt + Clone + 'static,
embed: FastEmbed,
vector_store: PgVector,
question: String,
) -> Result<String, Box<dyn std::error::Error>> {
// By default the search strategy is SimilaritySingleEmbedding
// which takes the latest query, embeds it, and does a similarity search
//
// Pgvector will return an error if multiple embeddings are set
//
// The pipeline generates subquestions to increase semantic coverage, embeds these in a single
// embedding, retrieves the default top_k documents, summarizes them and uses that as context
// for the final answer.
let pipeline = query::Pipeline::default()
.then_transform_query(query_transformers::GenerateSubquestions::from_client(
llm_client.clone(),
))
.then_transform_query(query_transformers::Embed::from_client(embed))
.then_retrieve(vector_store.clone())
.then_transform_response(response_transformers::Summary::from_client(
llm_client.clone(),
))
.then_answer(answers::Simple::from_client(llm_client.clone()));

let result = pipeline.query(question).await?;
Ok(result.answer().into())
}

#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
tracing_subscriber::fmt::init();
Expand Down Expand Up @@ -72,9 +41,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
// Configure Pgvector with a default vector size, a single embedding
// and in addition to embedding the text metadata, also store it in a field
let pgv_storage = PgVector::builder()
.try_connect_to_pool(pgv_db_url, Some(10))
.await
.expect("Failed to connect to postgres server")
.db_url(pgv_db_url)
.vector_size(384)
.with_vector(EmbeddedField::Combined)
.with_metadata(METADATA_QA_TEXT_NAME)
Expand All @@ -87,9 +54,9 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
let drop_table_sql = "DROP TABLE IF EXISTS swiftide_pgvector_test";
let drop_index_sql = "DROP INDEX IF EXISTS swiftide_pgvector_test_embedding_idx";

if let Ok(pool) = pgv_storage.get_pool() {
sqlx::query(drop_table_sql).execute(&pool).await?;
sqlx::query(drop_index_sql).execute(&pool).await?;
if let Ok(pool) = pgv_storage.get_pool().await {
sqlx::query(drop_table_sql).execute(pool).await?;
sqlx::query(drop_index_sql).execute(pool).await?;
} else {
return Err("Failed to get database connection pool".into());
}
Expand All @@ -103,24 +70,6 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
.run()
.await?;

for (i, question) in [
"What is SwiftIDE? Provide a clear, comprehensive summary in under 50 words.",
"How can I use SwiftIDE to connect with the Ethereum blockchain? Please provide a concise, comprehensive summary in less than 50 words.",
]
.iter()
.enumerate()
{
let result = ask_query(
llm_client.clone(),
fastembed.clone(),
pgv_storage.clone(),
question.to_string(),
).await?;
tracing::info!("*** Answer Q{} ***", i + 1);
tracing::info!("{}", result);
tracing::info!("===X===");
}

tracing::info!("PgVector Indexing & retrieval test completed successfully");
tracing::info!("PgVector Indexing test completed successfully");
Ok(())
}
2 changes: 1 addition & 1 deletion swiftide-integrations/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ strum = { workspace = true }
strum_macros = { workspace = true }
regex = { workspace = true }
futures-util = { workspace = true }

once_cell = { workspace = true }

# Integrations
async-openai = { workspace = true, optional = true }
Expand Down
98 changes: 98 additions & 0 deletions swiftide-integrations/src/pgvector/fixtures.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
//! This module implements common types and helper utilities for unit tests related to the pgvector
use crate::pgvector::PgVector;
use std::collections::HashSet;
use swiftide_core::{
indexing::{self, EmbeddedField},
Persist,
};
use testcontainers::{ContainerAsync, GenericImage};

#[derive(Clone)]
pub(crate) struct PgVectorTestData<'a> {
pub embed_mode: indexing::EmbedMode,
pub chunk: &'a str,
pub metadata: Option<indexing::Metadata>,
pub vectors: Vec<(indexing::EmbeddedField, Vec<f32>)>,
}

impl<'a> PgVectorTestData<'a> {
pub(crate) fn to_node(&self) -> indexing::Node {
// Create the initial builder
let mut base_builder = indexing::Node::builder();

// Set the required fields
let mut builder = base_builder.chunk(self.chunk).embed_mode(self.embed_mode);

// Add metadata if it exists
if let Some(metadata) = &self.metadata {
builder = builder.metadata(metadata.clone());
}

// Build the node and add vectors
let mut node = builder.build().unwrap();
node.vectors = Some(self.vectors.clone().into_iter().collect());
node
}

pub(crate) fn create_test_vector(
field: EmbeddedField,
base_value: f32,
) -> (EmbeddedField, Vec<f32>) {
(field, vec![base_value; 384])
}
}

pub(crate) struct TestContext {
pub(crate) pgv_storage: PgVector,
_pgv_db_container: ContainerAsync<GenericImage>,
}

impl TestContext {
/// Set up the test context, initializing `PostgreSQL` and `PgVector` storage
/// with configurable metadata fields
pub(crate) async fn setup_with_cfg(
metadata_fields: Option<Vec<&str>>,
vector_fields: HashSet<EmbeddedField>,
) -> Result<Self, Box<dyn std::error::Error>> {
// Start `PostgreSQL` container and obtain the connection URL
let (pgv_db_container, pgv_db_url) = swiftide_test_utils::start_postgres().await;
tracing::info!("Postgres database URL: {:#?}", pgv_db_url);

// Initialize the connection pool outside of the builder chain
let mut connection_pool = PgVector::builder();

// Configure PgVector storage
let mut builder = connection_pool
.db_url(pgv_db_url)
.vector_size(384)
.table_name("swiftide_pgvector_test".to_string());

// Add all vector fields
for vector_field in vector_fields {
builder = builder.with_vector(vector_field);
}

// Add all metadata fields
if let Some(metadata_fields_inner) = metadata_fields {
for field in metadata_fields_inner {
builder = builder.with_metadata(field);
}
};

let pgv_storage = builder.build().map_err(|err| {
tracing::error!("Failed to build PgVector: {}", err);
err
})?;

// Set up PgVector storage (create the table if not exists)
pgv_storage.setup().await.map_err(|err| {
tracing::error!("PgVector setup failed: {}", err);
err
})?;

Ok(Self {
pgv_storage,
_pgv_db_container: pgv_db_container,
})
}
}
Loading

0 comments on commit c4fdcbe

Please sign in to comment.