Skip to content

Commit

Permalink
feat(query): Add support for single embedding retrieval with PGVector (
Browse files Browse the repository at this point in the history
  • Loading branch information
shamb0 authored Dec 4, 2024
1 parent 5ce4d21 commit 3751f49
Show file tree
Hide file tree
Showing 9 changed files with 647 additions and 148 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 2 additions & 4 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,7 @@ arrow-array = { version = "52.2", default-features = false }
arrow = { version = "52.2", default-features = false }
parquet = { version = "52.2", default-features = false, features = ["async"] }
redb = { version = "2.2" }
sqlx = { version = "0.8.2", features = [
"postgres",
"uuid",
], default-features = false }
sqlx = { version = "0.8.2", features = ["postgres", "uuid"] }
aws-config = "1.5"
pgvector = { version = "0.4.0", features = ["sqlx"], default-features = false }
aws-credential-types = "1.2"
Expand Down Expand Up @@ -87,6 +84,7 @@ tree-sitter-ruby = "0.23"
tree-sitter-rust = "0.23"
tree-sitter-typescript = "0.23"


# Testing
test-log = "0.2.16"
testcontainers = { version = "0.23.0", features = ["http_wait"] }
Expand Down
65 changes: 63 additions & 2 deletions examples/index_md_into_pgvector.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,46 @@ use swiftide::{
},
EmbeddedField,
},
integrations::{self, pgvector::PgVector},
integrations::{self, fastembed::FastEmbed, pgvector::PgVector},
query::{self, answers, query_transformers, response_transformers},
traits::SimplePrompt,
};

async fn ask_query(
llm_client: impl SimplePrompt + Clone + 'static,
embed: FastEmbed,
vector_store: PgVector,
questions: Vec<String>,
) -> Result<Vec<String>, Box<dyn std::error::Error>> {
// By default the search strategy is SimilaritySingleEmbedding
// which takes the latest query, embeds it, and does a similarity search
//
// Pgvector will return an error if multiple embeddings are set
//
// The pipeline generates subquestions to increase semantic coverage, embeds these in a single
// embedding, retrieves the default top_k documents, summarizes them and uses that as context
// for the final answer.
let pipeline = query::Pipeline::default()
.then_transform_query(query_transformers::GenerateSubquestions::from_client(
llm_client.clone(),
))
.then_transform_query(query_transformers::Embed::from_client(embed))
.then_retrieve(vector_store.clone())
.then_transform_response(response_transformers::Summary::from_client(
llm_client.clone(),
))
.then_answer(answers::Simple::from_client(llm_client.clone()));

let results: Vec<String> = pipeline
.query_all(questions)
.await?
.iter()
.map(|result| result.answer().to_string())
.collect();

Ok(results)
}

#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
tracing_subscriber::fmt::init();
Expand Down Expand Up @@ -62,6 +99,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
}

tracing::info!("Starting indexing pipeline");

indexing::Pipeline::from_loader(FileLoader::new(test_dataset_path).with_extensions(&["md"]))
.then_chunk(ChunkMarkdown::from_chunk_range(10..2048))
.then(MetadataQAText::new(llm_client.clone()))
Expand All @@ -70,6 +108,29 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
.run()
.await?;

tracing::info!("PgVector Indexing test completed successfully");
tracing::info!("PgVector Indexing completed successfully");

let questions: Vec<String> = vec![
"What is SwiftIDE? Provide a clear, comprehensive summary in under 50 words.".into(),
"How can I use SwiftIDE to connect with the Ethereum blockchain? Please provide a concise, comprehensive summary in less than 50 words.".into(),
];

ask_query(
llm_client.clone(),
fastembed.clone(),
pgv_storage.clone(),
questions,
)
.await?
.iter()
.enumerate()
.for_each(|(i, result)| {
tracing::info!("*** Answer Q{} ***", i + 1);
tracing::info!("{}", result);
tracing::info!("===X===");
});

tracing::info!("PgVector Indexing & retrieval test completed successfully");

Ok(())
}
1 change: 1 addition & 0 deletions swiftide-integrations/src/pgvector/fixtures.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ pub(crate) struct PgVectorTestData<'a> {
pub metadata: Option<indexing::Metadata>,
/// Vector embeddings with their corresponding fields
pub vectors: Vec<(indexing::EmbeddedField, Vec<f32>)>,
pub expected_in_results: bool,
}

impl PgVectorTestData<'_> {
Expand Down
Loading

0 comments on commit 3751f49

Please sign in to comment.