From 08b340b118348f262cfc11d28ec97c92393adc2c Mon Sep 17 00:00:00 2001 From: Tom Aarsen Date: Fri, 7 Jun 2024 10:11:33 +0200 Subject: [PATCH] Set can_return_loss=True globally, instead of via the data collator --- sentence_transformers/data_collator.py | 2 +- sentence_transformers/trainer.py | 6 ++++++ 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/sentence_transformers/data_collator.py b/sentence_transformers/data_collator.py index afdd86bfb..5350a8a71 100644 --- a/sentence_transformers/data_collator.py +++ b/sentence_transformers/data_collator.py @@ -19,7 +19,7 @@ def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, torch.Tensor]: columns = list(features[0].keys()) # We should always be able to return a loss, label or not: - batch = {"return_loss": True} + batch = {} if "dataset_name" in columns: columns.remove("dataset_name") diff --git a/sentence_transformers/trainer.py b/sentence_transformers/trainer.py index 2800571a6..de78494c7 100644 --- a/sentence_transformers/trainer.py +++ b/sentence_transformers/trainer.py @@ -197,7 +197,13 @@ def __init__( optimizers=optimizers, preprocess_logits_for_metrics=preprocess_logits_for_metrics, ) + # Every Sentence Transformer model can always return a loss, so we set this to True + # to avoid having to specify it in the data collator or model's forward + self.can_return_loss = True + self.model: SentenceTransformer + self.args: SentenceTransformerTrainingArguments + self.data_collator: SentenceTransformerDataCollator # Set the W&B project via environment variables if it's not already set if any([isinstance(callback, WandbCallback) for callback in self.callback_handler.callbacks]): os.environ.setdefault("WANDB_PROJECT", "sentence-transformers")