Skip to content

Commit

Permalink
Set can_return_loss=True globally, instead of via the data collator
Browse files Browse the repository at this point in the history
  • Loading branch information
tomaarsen committed Jun 7, 2024
1 parent 1608eb8 commit 08b340b
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 1 deletion.
2 changes: 1 addition & 1 deletion sentence_transformers/data_collator.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
columns = list(features[0].keys())

# We should always be able to return a loss, label or not:
batch = {"return_loss": True}
batch = {}

if "dataset_name" in columns:
columns.remove("dataset_name")
Expand Down
6 changes: 6 additions & 0 deletions sentence_transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,13 @@ def __init__(
optimizers=optimizers,
preprocess_logits_for_metrics=preprocess_logits_for_metrics,
)
# Every Sentence Transformer model can always return a loss, so we set this to True
# to avoid having to specify it in the data collator or model's forward
self.can_return_loss = True

self.model: SentenceTransformer
self.args: SentenceTransformerTrainingArguments
self.data_collator: SentenceTransformerDataCollator
# Set the W&B project via environment variables if it's not already set
if any([isinstance(callback, WandbCallback) for callback in self.callback_handler.callbacks]):
os.environ.setdefault("WANDB_PROJECT", "sentence-transformers")
Expand Down

0 comments on commit 08b340b

Please sign in to comment.