Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[feat] Add lightning-fast StaticEmbedding module based on model2vec #2961

Merged
merged 2 commits into from
Oct 8, 2024

Conversation

tomaarsen
Copy link
Collaborator

Hello!

Pull Request overview

  • Add StaticEmbedding module that wraps the torch EmbeddingBag
  • Update MatryoshkaLoss to work with modules that skip the "token_embeddings" values

Details

This new StaticEmbedding module can be initialized:

  1. With random embeddings (requires training afterwards)
  2. With distillation from model2vec via from_distillation
  3. With a pre-distilled model2vec model via from_model2vec

Example

This script distills from BAAI/bge-base-en-v1.5 into static embeddings, and then (very efficiently) embeds questions and answers from the natural-questions dataset. It then prints the mean similarity between positive pairs and all negative pairs. You can also update it to do the same for the all-nli dataset.

The distillation process takes me ~1 second on CUDA, and encoding 10k queries and documents takes 0 seconds on CPU or CUDA.

from datasets import load_dataset
import torch
from sentence_transformers import SentenceTransformer
from sentence_transformers.models import StaticEmbedding

# static_embedding = StaticEmbedding.from_model2vec("minishlab/M2V_base_output")
static_embedding = StaticEmbedding.from_distillation("BAAI/bge-base-en-v1.5", device="cuda", pca_dims=256)
# 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 30/30 [00:01<00:00, 18.04it/s]
model = SentenceTransformer(modules=[static_embedding])

natural_questions_dataset = load_dataset("sentence-transformers/natural-questions", split="train")
corpus = natural_questions_dataset["answer"]
queries = natural_questions_dataset["query"]
"""
all_nli_dataset = load_dataset("sentence-transformers/all-nli", "pair", split="train[:10000]")
corpus = all_nli_dataset["anchor"]
queries = all_nli_dataset["positive"]
"""

corpus_embeddings = model.encode(corpus, show_progress_bar=True, batch_size=2048)
query_embeddings = model.encode(queries, show_progress_bar=True, batch_size=2048)

similarities = model.similarity(query_embeddings, corpus_embeddings)
positive_similarity = similarities.diag().mean().item()
negative_similarity = similarities[~torch.eye(similarities.shape[0], dtype=bool)].mean().item()
print("Average similarity of question-answer in Natural Questions: ")
print("Positive Similarity:", positive_similarity)
print("Negative Similarity:", negative_similarity)
"""
Average similarity of question-answer in Natural Questions:
Positive Similarity: 0.616450309753418
Negative Similarity: 0.3292446732521057
"""

cc @stephantul @Pringled

  • Tom Aarsen

@tomaarsen tomaarsen changed the title Add lightning-fast StaticEmbedding module based on model2vec [feat] Add lightning-fast StaticEmbedding module based on model2vec Sep 26, 2024
@tomaarsen tomaarsen merged commit 7855327 into UKPLab:master Oct 8, 2024
11 checks passed
@tomaarsen tomaarsen deleted the module/static_embedding branch October 8, 2024 14:32
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant