Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[fix] Matryoshka training always patch original forward, and check matryoshka_dims #2593

Merged
merged 4 commits into from
Apr 15, 2024
Merged

[fix] Matryoshka training always patch original forward, and check matryoshka_dims #2593

merged 4 commits into from
Apr 15, 2024

Conversation

kddubey
Copy link
Contributor

@kddubey kddubey commented Apr 14, 2024

Hello,

TLDR: patch back the original forward method, even after an error, to avoid silent problems during training.

In case there's an error in, e.g., self.loss(sentence_features, labels), the closing self.model.forward = original_forward line won't get hit. So if a user is training a model in an interactive session (e.g., a notebook) and doesn't re-create the self.model object (maybe b/c they know the error came before the optimizer ever stepped, so the model's parameters didn't change), then on a second .fit run, I think the forward method will start by effectively doing:

original_forward = self.model.forward  # is already a ForwardDecorator b/c of the first (failed) run
decorated_forward = ForwardDecorator(original_forward)  # is now doubly-decorated
self.model.forward = decorated_forward

Whether or not this is a problem depends on the user's input to MatryoshkaLoss. If they don't set n_dims_per_step, they set matryoshka_dims to a list such that matryoshka_dims[0] == max(matryoshka_dims), and the error is raised at the first dimension in the list, then there's no problem. Otherwise, there is a problem and it's silent.

The problem occurs if and only if the self.model.forward.dim attribute ends up being set to something less than max(matryoshka_dims). The result will be that the model doesn't actually Matryoshka-train; it will only train up to the last dimension that was set before it errored out in the first run. Call this dimension err_dim.

Reasoning: we'll always have—

output["sentence_embedding"].shape[-1] == err_dim

—b/c self.fn is the last ForwardDecorator whose dimension was last set before erroring out. Next—

tensor = tensor[..., : self.dim]

—will give back the exact same tensor if self.dim >= err_dim, as this slicing style doesn't raise an error if self.dim > tensor.shape[-1]. In other words, self.shrink gives back tensors truncated at err_dim when self.dim > err_dim.

The downstream result is that, for example, if err_dim = 32, matryoshka_dims = [16, 32, 64, 128], and matryoshka_weights = [1, 1, 1, 1], then the user's second attempt at training effectively makes matryoshka_weights look something like [1, 3, 0, 0].

This is also making me wonder if some input checking should be done on matryoshka_dims. If self.model.get_sentence_embedding_dimension() == d but the user sets matryoshka_weights=[d/2, d, 2*d, 4*d], they should know that they're up-weighing the loss at dimension d, which might result in not-so-Matryoshka like properties at inference time.

@tomaarsen
Copy link
Collaborator

Hello!

This strikes me as a good idea to prevent some very unexpected issues, even if they only occur very rarely.

As for your last paragraph, perhaps we can simplify this by warning the user if one of their provided Matryoshka embedding dimensions is larger than the model's original embedding dimension. After all, in that case the truncation won't do anything.

  • Tom Aarsen

@kddubey
Copy link
Contributor Author

kddubey commented Apr 14, 2024

I verified that the bug happens in certain conditions—a simple one is where matryoshka_dims[0] < max(matryoshka_dims), e.g., matryoshka_dims=[64, 128, 256, 512, 768], and .fit has to be re-run after a failed initial run.

Here's a CPU-friendly script, where I modified ForwardDecorator.__call__ to print out the shape of self.shrink(output["sentence_embedding"])—the tensor that gets sent to self.loss:

script
from typing import NoReturn

import torch
from torch.utils.data import DataLoader
from sentence_transformers import SentenceTransformer, losses, InputExample
from sentence_transformers.evaluation import (
    EmbeddingSimilarityEvaluator,
    SimilarityFunction,
)


model = SentenceTransformer("paraphrase-albert-small-v2", device="cpu")
print(model.get_sentence_embedding_dimension())
# 768

# matryoshka_dims = [768, 10, 9, 8, 7, 6, 5, 4, 3, 2]
n_dims_per_step = -1
matryoshka_dims = [2, 3, 4, 5, 6, 7, 8, 9, 10, 768]

# Dummy data
train_examples = [
    InputExample(texts=["Anchor 1", "Positive 1"]),
    InputExample(texts=["somethin", "something else"]),
]
train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=32)
dev_evaluator = EmbeddingSimilarityEvaluator(
    ["aljfad", "a;lkjdfasl;jf"],
    ["sentence3", "sentence4"],
    [0.9, 0.9],
    main_similarity=SimilarityFunction.COSINE,
    write_csv=False,
    show_progress_bar=True,
)

# Bad loss that will immediately raise an error
class MultipleNegativesRankingLossBad(torch.nn.Module):
    def __init__(self, model: SentenceTransformer) -> None:
        super().__init__()
        self.model = model

    def forward(*args, **kwargs) -> NoReturn:
        raise ValueError("Faaaaill")


train_loss = MultipleNegativesRankingLossBad(model)
train_loss = losses.MatryoshkaLoss(
    model, train_loss, matryoshka_dims=matryoshka_dims, n_dims_per_step=n_dims_per_step
)

# First attempt at training
model.fit(
    train_objectives=[(train_dataloader, train_loss)],
    evaluator=dev_evaluator,
    epochs=2,
)
# raises ValueError: Faaaaill

# model.forward has been modified and will always truncate to 2
print(type(model.forward))
# <class 'sentence_transformers.losses.MatryoshkaLoss.ForwardDecorator'>
print(model.forward.dim)
# 2

# Correct the loss and run it again
train_loss = losses.MultipleNegativesRankingLoss(model)
train_loss = losses.MatryoshkaLoss(model, train_loss, matryoshka_dims=matryoshka_dims)

# Silently wrong training
model.fit(
    train_objectives=[(train_dataloader, train_loss)],
    evaluator=dev_evaluator,
    epochs=2,
)
# torch.Size([2, 2])
# torch.Size([2, 2])
# torch.Size([2, 2])
# torch.Size([2, 2])
# torch.Size([2, 2])
# torch.Size([2, 2])
# torch.Size([2, 2])
# torch.Size([2, 2])
# torch.Size([2, 2])
# torch.Size([2, 2])
# torch.Size([2, 2])
# torch.Size([2, 2])
# torch.Size([2, 2])
# torch.Size([2, 2])
# torch.Size([2, 2])
# torch.Size([2, 2])
# torch.Size([2, 2])
# torch.Size([2, 2])
# torch.Size([2, 2])
# torch.Size([2, 2])
# torch.Size([2, 2])
# torch.Size([2, 2])

As for your last paragraph, perhaps we can simplify this by warning the user if one of their provided Matryoshka embedding dimensions is larger than the model's original embedding dimension. After all, in that case the truncation won't do anything.

Good idea, I'll add this to the PR. Do you think this should be a warning, a ValueError, or something else? I lean towards ValueError b/c it seems quite unintentional to include dimensions past the model's output dimension, and it will cause unresearched inference behavior.

@kddubey kddubey changed the title Matryoshka training always patch original forward [fix] Matryoshka training always patch original forward, and check matryoshka_dims Apr 14, 2024
Copy link
Collaborator

@tomaarsen tomaarsen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the error text should be a bit more explicit with notifying the users what they're doing wrong. Other than that, this is looking good to go.

sentence_transformers/losses/MatryoshkaLoss.py Outdated Show resolved Hide resolved
Co-authored-by: Tom Aarsen <37621491+tomaarsen@users.noreply.github.com>
Copy link
Collaborator

@tomaarsen tomaarsen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Much appreciated! Looks good :)

@tomaarsen tomaarsen merged commit d3d767c into UKPLab:master Apr 15, 2024
9 checks passed
@tomaarsen
Copy link
Collaborator

This is somewhat unrelated, but it may interest you nonetheless. I also don't remember if I have written about this here before, so apologies if you've seen this already:
As you're well aware at this point, Matryoshka models are trained by repeating the loss function at different embedding dimensions and summing them up. It works well in practice, but it does require users to specify exactly which dimensions they want to optimize for. Beyond that, having to apply the same loss function many times is a bit odd to me. It strikes me as suboptimal or potentially inconsistent.

Instead, a potential advancement is to consider a "continuous Matryoshka loss" or "continuous MRL" simply by creating a similarity function that prioritizes information towards the start of the embedding. This similarity score function can be applied directly in other losses that accept such functions, e.g. MultipleNegativesRankingLoss.

I experimented with a naive version, e.g.:

def mrl_cos_sim(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
    # Normalize the input embeddings such that matrix multiplication is cosine similarity
    a_norm = torch.nn.functional.normalize(a, p=2, dim=1)
    b_norm = torch.nn.functional.normalize(b, p=2, dim=1)

    # Multiply the normalize embeddings with a decreasing multiplier to give more importance to the first dimensions
    multiplier = torch.arange(b_norm.shape[-1], 0, -1, device=a_norm.device) / b_norm.shape[-1]
    a_norm *= multiplier
    b_norm *= multiplier

    # Return the cosine similarity
    return a_norm @ b_norm.T
loss = losses.MultipleNegativesRankingLoss(model, similarity_fct=mrl_cos_sim)

I evaluated this with semantic textual similarity using STSB and the triplet evaluator using the AllNLI validation dataset, while training on the AllNLI train dataset. I used the training refactor PR which integrates with Weights and Biases to easily compare the performance with a baseline:
(w&b link)
image
image

This figure is quite interesting. With the "matryoshka cosine similarity", the Spearman Correlation reduces very gradually when reducing the dimensionality: 768 > 512 > 256 > 128 > 64 > 32 > 16 > 8, while the baseline is very jumpy: 256 > 768 > 128 > 512 > 64 > 32 > 16 > 8. As a result, the matryoshka cosine similarity model is sometimes much better, and sometimes much worse.

(w&b link)
image

This figure is a lot more straightforward. Perhaps that is to be expected, as this validation dataset originates from the same distribution as the training set, so this "gradual increase over time during training" is pretty normal. As could be somewhat expected, the model performs worse than the baseline at 768 (because the MRL cosine similarity cant use the last dimensions to store information as well as the baseline, so in essence it can store "less information"). This difference shrinks and eventually the continuous MRL model handily outperforms the baseline.
Interestingly, the continuous MRL model keeps its performance very well even down to a dimensionality of 64. It only degraded 0.4%, despite being 12 times smaller. This is what you could expect from "normal" Matryoshka models.

The full training script
from collections import defaultdict
import datasets
from datasets import Dataset
import torch
from sentence_transformers import (
    SentenceTransformer,
    SentenceTransformerTrainer,
    losses,
    evaluation,
    SentenceTransformerTrainingArguments
)
from sentence_transformers.models import Transformer, Pooling

"""
def mrl_cos_sim(a, b):
    a_norm = torch.nn.functional.normalize(a, p=2, dim=1)
    b_norm = torch.nn.functional.normalize(b, p=2, dim=1)
    # a x d, b x d -> a x b x d

    similarity_per_dim = a_norm * b_norm
    multiplier = torch.arange(similarity_per_dim.shape[-1], 0, -1, device=similarity_per_dim.device) / 768
    similarity_per_dim *= multiplier
    return similarity_per_dim.sum(-1, keepdim=True)
"""
def mrl_cos_sim(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
    # Normalize the input embeddings such that matrix multiplication is cosine similarity
    a_norm = torch.nn.functional.normalize(a, p=2, dim=1)
    b_norm = torch.nn.functional.normalize(b, p=2, dim=1)

    # Multiply the normalize embeddings with a decreasing multiplier to give more importance to the first dimensions
    multiplier = torch.arange(b_norm.shape[-1], 0, -1, device=a_norm.device) / b_norm.shape[-1]
    a_norm *= multiplier
    b_norm *= multiplier

    # Return the cosine similarity
    return a_norm @ b_norm.T

def to_triplets(dataset):
    premises = defaultdict(dict)
    for sample in dataset:
        premises[sample["premise"]][sample["label"]] = sample["hypothesis"]
    queries = []
    positives = []
    negatives = []
    for premise, sentences in premises.items():
        if 0 in sentences and 2 in sentences:
            queries.append(premise)
            positives.append(sentences[0]) # <- entailment
            negatives.append(sentences[2]) # <- contradiction
    return Dataset.from_dict({
        "anchor": queries,
        "positive": positives,
        "negative": negatives,
    })

if __name__ == "__main__":
    snli_ds = datasets.load_dataset("snli")
    snli_ds = datasets.DatasetDict({
        "train": to_triplets(snli_ds["train"]),
        "validation": to_triplets(snli_ds["validation"]),
        "test": to_triplets(snli_ds["test"]),
    })
    multi_nli_ds = datasets.load_dataset("multi_nli")
    multi_nli_ds = datasets.DatasetDict({
        "train": to_triplets(multi_nli_ds["train"]),
        "validation_matched": to_triplets(multi_nli_ds["validation_matched"]),
    })

    all_nli_ds = datasets.DatasetDict({
        "train": datasets.concatenate_datasets([snli_ds["train"], snli_ds["train"]]),
        "validation": datasets.concatenate_datasets([snli_ds["validation"], multi_nli_ds["validation_matched"]]),
        "test": snli_ds["test"]
    })

    stsb_dev = datasets.load_dataset("mteb/stsbenchmark-sts", split="validation")
    stsb_test = datasets.load_dataset("mteb/stsbenchmark-sts", split="test")

    training_args = SentenceTransformerTrainingArguments(
        output_dir="checkpoints",
        run_name="mpnet-base-allnli-baseline",
        # report_to="none",
        num_train_epochs=1,
        seed=33,
        per_device_train_batch_size=64,
        per_device_eval_batch_size=64,
        learning_rate=2e-5,
        warmup_ratio=0.1,
        bf16=True,
        logging_steps=100,
        evaluation_strategy="steps",
        eval_steps=500,
        save_steps=500,
        save_total_limit=2,
        metric_for_best_model="eval_sts-dev-768_spearman_cosine",
        greater_is_better=True,
    )

    transformer = Transformer("microsoft/mpnet-base", max_seq_length=384)
    pooling = Pooling(transformer.get_word_embedding_dimension(), pooling_mode="mean")
    model = SentenceTransformer(modules=[transformer, pooling])

    loss = losses.MultipleNegativesRankingLoss(model)#, similarity_fct=mrl_cos_sim)
    dev_evaluators = []
    for matryoshka_dim in [768, 512, 256, 128, 64, 32, 16, 8]:
        dev_evaluators.append(evaluation.EmbeddingSimilarityEvaluator(
            stsb_dev["sentence1"],
            stsb_dev["sentence2"],
            [score / 5 for score in stsb_dev["score"]],
            main_similarity=evaluation.SimilarityFunction.COSINE,
            name=f"sts-dev-{matryoshka_dim}",
            truncate_dim=matryoshka_dim,
        ))
        dev_evaluators.append(evaluation.TripletEvaluator(
            anchors=all_nli_ds["validation"]["anchor"],
            positives=all_nli_ds["validation"]["positive"],
            negatives=all_nli_ds["validation"]["negative"],
            name=f"allnli-validation-{matryoshka_dim}",
            main_distance_function=evaluation.SimilarityFunction.COSINE,
            truncate_dim=matryoshka_dim,
        ))
    dev_evaluator = evaluation.SequentialEvaluator(dev_evaluators)
    # dev_evaluator(model)

    trainer = SentenceTransformerTrainer(
        model=model,
        args=training_args,
        train_dataset=all_nli_ds["train"],
        eval_dataset=all_nli_ds["validation"],
        evaluator=dev_evaluator,
        loss=loss,
    )
    trainer.train()

    test_evaluator = evaluation.EmbeddingSimilarityEvaluator(
        stsb_test["sentence1"],
        stsb_test["sentence2"],
        [score / 5 for score in stsb_test["score"]],
        main_similarity=evaluation.SimilarityFunction.COSINE,
        name="sts-test",
    )
    results = test_evaluator(model)
    print(results)
    model.save("mpnet-base-allnli-baseline")

# Continuous-mrl-linear:
# {'sts-test_pearson_cosine': 0.8175790710661428, 'sts-test_spearman_cosine': 0.830859176653514, 'sts-test_pearson_manhattan': 0.8362810843054216, 'sts-test_spearman_manhattan': 0.8284792813481868, 'sts-test_pearson_euclidean': 0.8261216518822675, 'sts-test_spearman_euclidean': 0.8181971965432933, 'sts-test_pearson_dot': 0.7944856099219624, 'sts-test_spearman_dot': 0.7834936544677008, 'sts-test_pearson_max': 0.8362810843054216, 'sts-test_spearman_max': 0.830859176653514}

# Base:
# {'sts-test_pearson_cosine': 0.8101063845818969, 'sts-test_spearman_cosine': 0.8319477736976867, 'sts-test_pearson_manhattan': 0.8438869778745631, 'sts-test_spearman_manhattan': 0.8365807837093596, 'sts-test_pearson_euclidean': 0.8418259869573095, 'sts-test_spearman_euclidean': 0.8352941277766923, 'sts-test_pearson_dot': 0.6568623097846887, 'sts-test_spearman_dot': 0.6654033827828304, 'sts-test_pearson_max': 0.8438869778745631, 'sts-test_spearman_max': 0.8365807837093596}

Anyways, I thought these experiments might interest you! I think I'll leave these experiments as-is for now, but perhaps a great solution is possible here (e.g. like how the original MRL is often on par with baselines, but it ALSO does pretty well at lower dimensions).
cc @aamir-s18 as you've also been working on MRL stuff.

  • Tom Aarsen

@aamir-s18
Copy link
Contributor

Hey Tom, this is an extremely cool idea! We will look more closely look into that the coming days.

@kddubey kddubey deleted the matryoshka-try-finally branch April 15, 2024 19:38
@kddubey
Copy link
Contributor Author

kddubey commented Apr 15, 2024

lol @tomaarsen we thought almost identically about Matryoshka training. I called your "continuous MRL" loss "diagonaloss", and had recently spent some time playing around w/ different versions of the "multiply each vector by a decaying list of weights and compute a distance metric on them" idea.

It's interesting to see that level of jumpiness in STSB within each dimension and across dimensions, I wouldn't have guessed that. The AllNLI charts are super clear and cool to see. Though good to know that for mpnet, the clarity of the effect could be due to continued training on NLI.

I think I'll leave these experiments as-is for now, but perhaps a great solution is possible here (e.g. like how the original MRL is often on par with baselines, but it ALSO does pretty well at lower dimensions).

One thing I'm hoping for is that a clearer geometric interpretation will help in getting to a Matryoshka-like loss.

Here's some dirty work I recently did on a few variants of continuous MRL / diagonaloss:

  • mini training script—I ended up thinking that Euclidean distance made more sense after re-scaling vectors, but I previously experimented w/ a distance metric that was only subtly different to your mrl_cos_sim function.
  • full training script—adapted from your matryoshka_nli.py script
  • gradient analysis on a toy problem—I initially hoped that the weights to multiply vectors by—AKA multiplier in mrl_cos_sim—could be mathematically derived from the gradient of the original Matryoshka loss function.

The Matryoshka eval plot (from this notebook)—

image

—shows that MRL is significantly better at 64 dimensions, and slightly better at the rest. (I realize that I should've trained the Matryoshka model down to 8 dimensions and compared all the models at 8, 16, and 32 as well.) And wish I knew about your refactor PR earlier so I'd have those nice w&b charts! :-)

Beyond that, having to apply the same loss function many times is a bit odd to me. It strikes me as suboptimal or potentially inconsistent.

I was reflecting and thinking that on one hand, MRL does work in an odd way. But on the other hand, it is pretty well-motivated—if we want a model to do well in many lower-dim spaces, then just directly do that / penalize in lower-dim spaces. In light of the results—especially your AllNLI results—I would be slightly surprised to see that the Matryoshka effect can be reproduced by re-scaling alone. The gradients of plain MRL vs continuous MRL / diagonaloss are different-enough that I couldn't come up w/ a way to reproduce the gradient via re-scaling alone.

@tomaarsen
Copy link
Collaborator

Interesting results! It indeed seems like MRL is rather challenging to beat.
Also, as a heads up, your old training scripts (with model.fit, DataLoaders and InputExample instances) will (or rather, should) still work with the new training refactor, and it'll still give you the new features like the weights and biases. So you won't have to completely change everything up.

  • Tom Aarsen

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants