Skip to content

Commit

Permalink
Merge pull request #2719 from tomaarsen/fix/matryoshka_ascending_dims
Browse files Browse the repository at this point in the history
[`fix`] Fix `MatryoshkaLoss` crash if the first dimension is not the biggest
  • Loading branch information
tomaarsen authored Jun 5, 2024
2 parents d6a6347 + d72065f commit 1608eb8
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions sentence_transformers/losses/MatryoshkaLoss.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,10 +121,12 @@ def __init__(
warnings.warn("MatryoshkaLoss is not compatible with CachedMultipleNegativesRankingLoss.", stacklevel=2)
if isinstance(loss, CachedGISTEmbedLoss):
warnings.warn("MatryoshkaLoss is not compatible with CachedGISTEmbedLoss.", stacklevel=2)
self.matryoshka_dims = matryoshka_dims

if matryoshka_weights is None:
matryoshka_weights = [1] * len(matryoshka_dims)
self.matryoshka_weights = matryoshka_weights
# Sort the dimensions and weights in descending order
dims_weights = zip(matryoshka_dims, matryoshka_weights)
self.matryoshka_dims, self.matryoshka_weights = zip(*sorted(dims_weights, key=lambda x: x[0], reverse=True))
self.n_dims_per_step = n_dims_per_step

def forward(self, sentence_features: Iterable[Dict[str, Tensor]], labels: Tensor) -> Tensor:
Expand Down

0 comments on commit 1608eb8

Please sign in to comment.