Skip to content

Commit

Permalink
Find similar embeddings.
Browse files Browse the repository at this point in the history
  • Loading branch information
kirkbyers committed May 16, 2024
1 parent cbea77e commit d1b0c3b
Show file tree
Hide file tree
Showing 4 changed files with 111 additions and 1 deletion.
4 changes: 4 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@ path = "src/bin/fast_embed_scrapes.rs"
name = "migrate"
path = "src/bin/migrate.rs"

[[bin]]
name = "fast_embed_similars"
path = "src/bin/fast_embed_similars.rs"

[dependencies]
actix-web = "4"
bytes = "1.5.0"
Expand Down
2 changes: 1 addition & 1 deletion src/bin/fast_embed_scrapes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
Ok(v) => Ok(v),
Err(e) => {
println!("{:?}", e);
Err(e.into())
Err(e)
}
}
}
74 changes: 74 additions & 0 deletions src/bin/fast_embed_similars.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
use zero2prod::{db::start_db, models::fast_embeds};

#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
let db = start_db().await.unwrap();
let conn = db.connect().unwrap();

let limit = 10;
let mut page: u32 = 0;

let mut fast_embed_recs = fast_embeds::get_page(&conn, &limit, &(limit * page))
.await
.unwrap();
let mut all_fast_embeds: Vec<fast_embeds::FastEmbed> = vec![];
while !fast_embed_recs.is_empty() {
all_fast_embeds.append(&mut fast_embed_recs);
page += 1;
fast_embed_recs = fast_embeds::get_page(&conn, &limit, &(limit * page))
.await
.unwrap();
}

for i in 0..all_fast_embeds.len() {
let mut max_similarity: f64 = 0.0;
let mut similar_index = i;
for j in 0..all_fast_embeds.len() {
if i == j {
continue;
}

let similarity =
cosine_similarity(&all_fast_embeds[i].embedding, &all_fast_embeds[j].embedding);

if similarity > max_similarity {
max_similarity = similarity;
similar_index = j;
}

// If similarity is 1 then there is a dup in the set and its pointless to continue checking the record
if max_similarity >= 1.0 {
break;
}
}
println!(
"{:?} - {max_similarity:} - {:?}",
all_fast_embeds[i].doc_id, all_fast_embeds[similar_index].doc_id
);
}

Ok(())
}

fn cosine_similarity(vec1: &[u8], vec2: &[u8]) -> f64 {
let dot_product = dot_product(vec1, vec2);
let magnitude1 = magnitude(vec1);
let magnitude2 = magnitude(vec2);
dot_product / (magnitude1 * magnitude2)
}

fn dot_product(vec1: &[u8], vec2: &[u8]) -> f64 {
let mut result = 0.0;
for i in 0..vec1.len() {
result += (vec1[i] as f64) * (vec2[i] as f64);
}
result
}

fn magnitude(vec: &[u8]) -> f64 {
let mut result = 0.0;
for i in vec {
result += (*i as f64) * (*i as f64);
}
result.sqrt()
}
32 changes: 32 additions & 0 deletions src/models/fast_embeds.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,3 +46,35 @@ impl Default for FastEmbed {
Self::new()
}
}

pub async fn get_page(
conn: &Connection,
limit: &u32,
offset: &u32,
) -> Result<Vec<FastEmbed>, Box<dyn std::error::Error>> {
let mut stmt = conn
.prepare(
r#"
SELECT id, doc_type, doc_id, embedding
FROM fast_embeds
LIMIT ? OFFSET ?
"#,
)
.await?;
let mut rows = stmt.query((*limit, *offset)).await?;
let mut fast_embeds = Vec::new();
while let Some(row) = rows.next().await? {
let id: String = row.get(0)?;
let doc_type: String = row.get(1)?;
let doc_id: String = row.get(2)?;
let embedding: Vec<u8> = row.get(3)?;
let fast_embed = FastEmbed {
id,
doc_type,
doc_id,
embedding,
};
fast_embeds.push(fast_embed);
}
Ok(fast_embeds)
}

0 comments on commit d1b0c3b

Please sign in to comment.