Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[v3] Add similarity and similarity_pairwise methods to Sentence Transformers #2615

Merged
merged 8 commits into from
Apr 25, 2024
Merged
8 changes: 4 additions & 4 deletions examples/applications/image-search/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ Ensure that you have [transformers](https://pypi.org/project/transformers/) inst
SentenceTransformers provides a wrapper for the [OpenAI CLIP Model](https://github.com/openai/CLIP), which was trained on a variety of (image, text)-pairs.

```python
from sentence_transformers import SentenceTransformer, util
from sentence_transformers import SentenceTransformer
from PIL import Image

# Load CLIP model
Expand All @@ -26,9 +26,9 @@ text_emb = model.encode(
["Two dogs in the snow", "A cat on a table", "A picture of London at night"]
)

# Compute cosine similarities
cos_scores = util.cos_sim(img_emb, text_emb)
print(cos_scores)
# Compute similarities
similarity_scores = model.similarity(img_emb, text_emb)
print(similarity_scores)
```

You can use the CLIP model for:
Expand Down
6 changes: 3 additions & 3 deletions examples/applications/semantic-search/semantic_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
This script outputs for various queries the top 5 most similar sentences in the corpus.
"""

from sentence_transformers import SentenceTransformer, util
from sentence_transformers import SentenceTransformer
import torch

embedder = SentenceTransformer("all-MiniLM-L6-v2")
Expand Down Expand Up @@ -40,8 +40,8 @@
query_embedding = embedder.encode(query, convert_to_tensor=True)

# We use cosine-similarity and torch.topk to find the highest 5 scores
cos_scores = util.cos_sim(query_embedding, corpus_embeddings)[0]
top_results = torch.topk(cos_scores, k=top_k)
similarity_scores = embedder.similarity(query_embedding, corpus_embeddings)[0]
top_results = torch.topk(similarity_scores, k=top_k)

print("\n\n======================\n\n")
print("Query:", query)
Expand Down
10 changes: 5 additions & 5 deletions examples/applications/text-summarization/text-summarization.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
"""

import nltk
from sentence_transformers import SentenceTransformer, util
from sentence_transformers import SentenceTransformer
import numpy as np
from LexRank import degree_centrality_scores

Expand All @@ -43,13 +43,13 @@
print("Num sentences:", len(sentences))

# Compute the sentence embeddings
embeddings = model.encode(sentences, convert_to_tensor=True)
embeddings = model.encode(sentences)

# Compute the pair-wise cosine similarities
cos_scores = util.cos_sim(embeddings, embeddings).numpy()
# Compute the similarity scores
similarity_scores = model.similarity(embeddings, embeddings).numpy()

# Compute the centrality for each sentence
centrality_scores = degree_centrality_scores(cos_scores, threshold=None)
centrality_scores = degree_centrality_scores(similarity_scores, threshold=None)

# We argsort so that the first element is the sentence with the highest score
most_central_sentence_indices = np.argsort(-centrality_scores)
Expand Down
3 changes: 1 addition & 2 deletions examples/training/adaptive_layer/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,6 @@ Then we can run inference with it using <a href="../../../docs/package_reference

```python
from sentence_transformers import SentenceTransformer
from sentence_transformers.util import cos_sim

model = SentenceTransformer("tomaarsen/mpnet-base-nli-adaptive-layer")
new_num_layers = 3
Expand All @@ -134,7 +133,7 @@ embeddings = model.encode(
]
)
# Similarity of the first sentence with the other two
similarities = cos_sim(embeddings[0], embeddings[1:])
similarities = model.similarity(embeddings[0], embeddings[1:])
# => tensor([[0.7761, 0.1655]])
# compared to tensor([[ 0.7547, -0.0162]]) for the full model
```
Expand Down
3 changes: 1 addition & 2 deletions examples/training/matryoshka/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,6 @@ After a model has been trained using a Matryoshka loss, you can then run inferen

```python
from sentence_transformers import SentenceTransformer
from sentence_transformers.util import cos_sim
import torch.nn.functional as F

matryoshka_dim = 64
Expand All @@ -77,7 +76,7 @@ embeddings = model.encode(
)
assert embeddings.shape[-1] == matryoshka_dim

similarities = cos_sim(embeddings[0], embeddings[1:])
similarities = model.similarity(embeddings[0], embeddings[1:])
# => tensor([[0.7839, 0.4933]])
```
As you can see, the similarity between the search query and the correct document is much higher than that of an unrelated document, despite the very small matryoshka dimension applied. Feel free to copy this script locally, modify the `matryoshka_dim`, and observe the difference in similarities.
Expand Down
142 changes: 138 additions & 4 deletions sentence_transformers/SentenceTransformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from collections import OrderedDict
from pathlib import Path
import warnings
from typing import List, Dict, Literal, Tuple, Iterable, Union, Optional
from typing import Callable, List, Dict, Literal, Tuple, Iterable, Union, Optional, overload
import numpy as np
from numpy import ndarray
import transformers
Expand All @@ -20,7 +20,7 @@
import tempfile

from sentence_transformers.model_card import SentenceTransformerModelCardData, generate_model_card

from sentence_transformers.similarity_functions import SimilarityFunction

from . import __MODEL_HUB_ORGANIZATION__
from .evaluation import SentenceEvaluator
Expand Down Expand Up @@ -59,6 +59,9 @@ class SentenceTransformer(nn.Sequential, FitMixin):
titles in "}`.
:param default_prompt_name: The name of the prompt that should be used by default. If not set,
no prompt will be applied.
:param similarity_fn_name: The name of the similarity function to use. Valid options are "cosine", "dot",
"euclidean", and "manhattan". If not set, it is automatically to "cosine" if `similarity` or
`similarity_pairwise` are called while `model.similarity_fn_name` is still `None`.
:param cache_folder: Path to store models. Can also be set by the SENTENCE_TRANSFORMERS_HOME environment variable.
:param trust_remote_code: Whether or not to allow for custom models defined on the Hub in their own modeling files.
This option should only be set to True for repositories you trust and in which you have read the code, as it
Expand All @@ -78,6 +81,7 @@ def __init__(
device: Optional[str] = None,
prompts: Optional[Dict[str, str]] = None,
default_prompt_name: Optional[str] = None,
similarity_fn_name: Optional[Union[str, SimilarityFunction]] = None,
cache_folder: Optional[str] = None,
trust_remote_code: bool = False,
revision: Optional[str] = None,
Expand All @@ -90,6 +94,7 @@ def __init__(
# Note: self._load_sbert_model can also update `self.prompts` and `self.default_prompt_name`
self.prompts = prompts or {}
self.default_prompt_name = default_prompt_name
self.similarity_fn_name = similarity_fn_name
self.truncate_dim = truncate_dim
self.model_card_data = model_card_data or SentenceTransformerModelCardData()
self._model_card_vars = {}
Expand Down Expand Up @@ -436,6 +441,105 @@ def encode(

return all_embeddings

@property
def similarity_fn_name(self) -> Optional[str]:
return self._similarity_fn_name

@similarity_fn_name.setter
def similarity_fn_name(self, value: Union[str, SimilarityFunction]) -> None:
if isinstance(value, SimilarityFunction):
value = value.value
self._similarity_fn_name = value

if value is not None:
self._similarity = SimilarityFunction.to_similarity_fn(value)
self._similarity_pairwise = SimilarityFunction.to_similarity_pairwise_fn(value)

@overload
def similarity(self, embeddings1: Tensor, embeddings2: Tensor) -> Tensor: ...

@overload
def similarity(self, embeddings1: ndarray, embeddings2: ndarray) -> Tensor: ...

@property
def similarity(self) -> Callable[[Union[Tensor, ndarray], Union[Tensor, ndarray]], Tensor]:
"""
Compute the similarity between two collections of embeddings. The output will be a matrix with the similarity
scores between all embeddings from the first parameter and all embeddings from the second parameter. This
differs from `similarity_pairwise` which computes the similarity between each pair of embeddings.

Example
::

>>> model = SentenceTransformer("all-mpnet-base-v2")
>>> sentences = [
... "The weather is so nice!",
... "It's so sunny outside.",
... "He's driving to the movie theater.",
... "She's going to the cinema.",
... ]
>>> embeddings = model.encode(sentences, normalize_embeddings=True)
>>> model.similarity(embeddings, embeddings)
tensor([[1.0000, 0.7235, 0.0290, 0.1309],
[0.7235, 1.0000, 0.0613, 0.1129],
[0.0290, 0.0613, 1.0000, 0.5027],
[0.1309, 0.1129, 0.5027, 1.0000]])
>>> model.similarity_fn_name
"cosine"
>>> model.similarity_fn_name = "euclidean"
>>> model.similarity(embeddings, embeddings)
tensor([[-0.0000, -0.7437, -1.3935, -1.3184],
[-0.7437, -0.0000, -1.3702, -1.3320],
[-1.3935, -1.3702, -0.0000, -0.9973],
[-1.3184, -1.3320, -0.9973, -0.0000]])

:param embeddings1: [num_embeddings_1, embedding_dim] or [embedding_dim]-shaped numpy array or torch tensor.
:param embeddings2: [num_embeddings_2, embedding_dim] or [embedding_dim]-shaped numpy array or torch tensor.
:return: A [num_embeddings_1, num_embeddings_2]-shaped torch tensor with similarity scores.
"""
if self.similarity_fn_name is None:
self.similarity_fn_name = SimilarityFunction.COSINE
return self._similarity

@overload
def similarity_pairwise(self, embeddings1: Tensor, embeddings2: Tensor) -> Tensor: ...

@overload
def similarity_pairwise(self, embeddings1: ndarray, embeddings2: ndarray) -> Tensor: ...

@property
def similarity_pairwise(self) -> Callable[[Union[Tensor, ndarray], Union[Tensor, ndarray]], Tensor]:
"""
Compute the similarity between two collections of embeddings. The output will be a vector with the similarity
scores between each pair of embeddings.

Example
::

>>> model = SentenceTransformer("all-mpnet-base-v2")
>>> sentences = [
... "The weather is so nice!",
... "It's so sunny outside.",
... "He's driving to the movie theater.",
... "She's going to the cinema.",
... ]
>>> embeddings = model.encode(sentences, normalize_embeddings=True)
>>> model.similarity_pairwise(embeddings[::2], embeddings[1::2])
tensor([0.7235, 0.5027])
>>> model.similarity_fn_name
"cosine"
>>> model.similarity_fn_name = "euclidean"
>>> model.similarity_pairwise(embeddings[::2], embeddings[1::2])
tensor([-0.7437, -0.9973])

:param embeddings1: [num_embeddings, embedding_dim] or [embedding_dim]-shaped numpy array or torch tensor.
:param embeddings2: [num_embeddings, embedding_dim] or [embedding_dim]-shaped numpy array or torch tensor.
:return: A [num_embeddings]-shaped torch tensor with pairwise similarity scores.
"""
if self.similarity_fn_name is None:
self.similarity_fn_name = SimilarityFunction.COSINE
return self._similarity_pairwise

def start_multi_process_pool(self, target_devices: List[str] = None):
"""
Starts multi process to process the encoding with several, independent processes.
Expand Down Expand Up @@ -672,7 +776,8 @@ def save(
safe_serialization: bool = True,
):
"""
Saves all elements for this seq. sentence embedder into different sub-folders
Saves a model and its configuration files to a directory, so that it can be loaded
with `SentenceTransformer(path)` again.

:param path: Path on disc
:param model_name: Optional model name
Expand Down Expand Up @@ -700,6 +805,7 @@ def save(
config = self._model_config.copy()
config["prompts"] = self.prompts
config["default_prompt_name"] = self.default_prompt_name
config["similarity_fn_name"] = self.similarity_fn_name
json.dump(config, fOut, indent=2)

# Save modules
Expand Down Expand Up @@ -727,6 +833,32 @@ def save(
if create_model_card:
self._create_model_card(path, model_name, train_datasets)

def save_pretrained(
self,
path: str,
model_name: Optional[str] = None,
create_model_card: bool = True,
train_datasets: Optional[List[str]] = None,
safe_serialization: bool = True,
):
"""
Saves a model and its configuration files to a directory, so that it can be loaded
with `SentenceTransformer(path)` again. Alias of `SentenceTransformer.save`.

:param path: Path on disc
:param model_name: Optional model name
:param create_model_card: If True, create a README.md with basic information about this model
:param train_datasets: Optional list with the names of the datasets used to to train the model
:param safe_serialization: If true, save the model using safetensors. If false, save the model the traditional PyTorch way
"""
self.save(
path,
model_name=model_name,
create_model_card=create_model_card,
train_datasets=train_datasets,
safe_serialization=safe_serialization,
)

def _create_model_card(
self, path: str, model_name: Optional[str] = None, train_datasets: Optional[List[str]] = "deprecated"
):
Expand Down Expand Up @@ -982,7 +1114,9 @@ def _load_sbert_model(
)
)

# Set prompts if not already overridden by the __init__ calls
# Set score functions & prompts if not already overridden by the __init__ calls
if self.similarity_fn_name is None:
self.similarity_fn_name = self._model_config.get("similarity_fn_name", None)
if not self.prompts:
self.prompts = self._model_config.get("prompts", {})
if not self.default_prompt_name:
Expand Down
Loading