Skip to content

Commit

Permalink
Truncate in EmbeddingSimilarityEvaluator
Browse files Browse the repository at this point in the history
  • Loading branch information
kddubey committed Apr 5, 2024
1 parent 8125182 commit 9606521
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 17 deletions.
3 changes: 2 additions & 1 deletion sentence_transformers/SentenceTransformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,8 @@ class SentenceTransformer(nn.Sequential):
This option should only be set to True for repositories you trust and in which you have read the code, as it
will execute code present on the Hub on your local machine.
:param token: Hugging Face authentication token to download private models.
:param truncate_dim: The dimension to truncate sentence embeddings to. `None` does no truncation.
:param truncate_dim: The dimension to truncate sentence embeddings to. `None` does no truncation. Truncation is
only applicable during inference when `.encode` is called.
"""

def __init__(
Expand Down
38 changes: 22 additions & 16 deletions sentence_transformers/evaluation/EmbeddingSimilarityEvaluator.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from contextlib import nullcontext
from . import SentenceEvaluator, SimilarityFunction
import logging
import os
Expand Down Expand Up @@ -33,6 +34,7 @@ def __init__(
show_progress_bar: bool = False,
write_csv: bool = True,
precision: Optional[Literal["float32", "int8", "uint8", "binary", "ubinary"]] = None,
truncate_dim: Optional[int] = None,
):
"""
Constructs an evaluator based for the dataset
Expand All @@ -45,12 +47,15 @@ def __init__(
:param write_csv: Write results to a CSV file
:param precision: The precision to use for the embeddings. Can be "float32", "int8", "uint8", "binary", or
"ubinary". Defaults to None.
:param truncate_dim: The dimension to truncate sentence embeddings to. `None` uses the model's current
truncation dimension. Defaults to None.
"""
self.sentences1 = sentences1
self.sentences2 = sentences2
self.scores = scores
self.write_csv = write_csv
self.precision = precision
self.truncate_dim = truncate_dim

assert len(self.sentences1) == len(self.sentences2)
assert len(self.sentences1) == len(self.scores)
Expand Down Expand Up @@ -107,22 +112,23 @@ def __call__(self, model, output_path: str = None, epoch: int = -1, steps: int =

logger.info("EmbeddingSimilarityEvaluator: Evaluating the model on " + self.name + " dataset" + out_txt)

embeddings1 = model.encode(
self.sentences1,
batch_size=self.batch_size,
show_progress_bar=self.show_progress_bar,
convert_to_numpy=True,
precision=self.precision,
normalize_embeddings=bool(self.precision),
)
embeddings2 = model.encode(
self.sentences2,
batch_size=self.batch_size,
show_progress_bar=self.show_progress_bar,
convert_to_numpy=True,
precision=self.precision,
normalize_embeddings=bool(self.precision),
)
with nullcontext() if self.truncate_dim is None else model.truncate_sentence_embeddings(self.truncate_dim):
embeddings1 = model.encode(
self.sentences1,
batch_size=self.batch_size,
show_progress_bar=self.show_progress_bar,
convert_to_numpy=True,
precision=self.precision,
normalize_embeddings=bool(self.precision),
)
embeddings2 = model.encode(
self.sentences2,
batch_size=self.batch_size,
show_progress_bar=self.show_progress_bar,
convert_to_numpy=True,
precision=self.precision,
normalize_embeddings=bool(self.precision),
)
# Binary and ubinary embeddings are packed, so we need to unpack them for the distance metrics
if self.precision == "binary":
embeddings1 = (embeddings1 + 128).astype(np.uint8)
Expand Down

0 comments on commit 9606521

Please sign in to comment.