-
Notifications
You must be signed in to change notification settings - Fork 2.5k
/
trainer.py
923 lines (803 loc) · 44.6 KB
/
trainer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
from __future__ import annotations
import logging
import os
import warnings
from collections import OrderedDict
from contextlib import nullcontext
from typing import TYPE_CHECKING, Any, Callable
import torch
from torch import nn
from torch.utils.data import BatchSampler, ConcatDataset, DataLoader, SubsetRandomSampler
from transformers import EvalPrediction, PreTrainedTokenizerBase, Trainer, TrainerCallback
from transformers.data.data_collator import DataCollator
from transformers.integrations import WandbCallback
from transformers.trainer import TRAINING_ARGS_NAME
from transformers.trainer_utils import EvalLoopOutput
from sentence_transformers.data_collator import SentenceTransformerDataCollator
from sentence_transformers.evaluation import SentenceEvaluator, SequentialEvaluator
from sentence_transformers.losses.CoSENTLoss import CoSENTLoss
from sentence_transformers.model_card import ModelCardCallback
from sentence_transformers.models.Transformer import Transformer
from sentence_transformers.sampler import (
DefaultBatchSampler,
GroupByLabelBatchSampler,
NoDuplicatesBatchSampler,
ProportionalBatchSampler,
RoundRobinBatchSampler,
)
from sentence_transformers.training_args import (
BatchSamplers,
MultiDatasetBatchSamplers,
SentenceTransformerTrainingArguments,
)
from sentence_transformers.util import disable_logging, is_datasets_available, is_training_available
if is_datasets_available():
from datasets import Dataset, DatasetDict, IterableDataset, IterableDatasetDict, Value
logger = logging.getLogger(__name__)
if TYPE_CHECKING:
from sentence_transformers.SentenceTransformer import SentenceTransformer
class SentenceTransformerTrainer(Trainer):
"""
SentenceTransformerTrainer is a simple but feature-complete training and eval loop for PyTorch
based on the 🤗 Transformers :class:`~transformers.Trainer`.
This trainer integrates support for various :class:`transformers.TrainerCallback` subclasses, such as:
- :class:`~transformers.integrations.WandbCallback` to automatically log training metrics to W&B if `wandb` is installed
- :class:`~transformers.integrations.TensorBoardCallback` to log training metrics to TensorBoard if `tensorboard` is accessible.
- :class:`~transformers.integrations.CodeCarbonCallback` to track the carbon emissions of your model during training if `codecarbon` is installed.
- Note: These carbon emissions will be included in your automatically generated model card.
See the Transformers `Callbacks <https://huggingface.co/docs/transformers/main/en/main_classes/callback>`_
documentation for more information on the integrated callbacks and how to write your own callbacks.
Args:
model (:class:`~sentence_transformers.SentenceTransformer`, *optional*):
The model to train, evaluate or use for predictions. If not provided, a `model_init` must be passed.
args (:class:`~sentence_transformers.training_args.SentenceTransformerTrainingArguments`, *optional*):
The arguments to tweak for training. Will default to a basic instance of
:class:`~sentence_transformers.training_args.SentenceTransformerTrainingArguments` with the
`output_dir` set to a directory named *tmp_trainer* in the current directory if not provided.
train_dataset (Union[:class:`datasets.Dataset`, :class:`datasets.DatasetDict`, :class:`datasets.IterableDataset`, Dict[str, :class:`datasets.Dataset`]], *optional*):
The dataset to use for training. Must have a format accepted by your loss function, see
`Training Overview > Dataset Format <../../../docs/sentence_transformer/training_overview.html#dataset-format>`_.
eval_dataset (Union[:class:`datasets.Dataset`, :class:`datasets.DatasetDict`, :class:`datasets.IterableDataset`, Dict[str, :class:`datasets.Dataset`]], *optional*):
The dataset to use for evaluation. Must have a format accepted by your loss function, see
`Training Overview > Dataset Format <../../../docs/sentence_transformer/training_overview.html#dataset-format>`_.
loss (Optional[Union[:class:`torch.nn.Module`, Dict[str, :class:`torch.nn.Module`],\
Callable[[:class:`~sentence_transformers.SentenceTransformer`], :class:`torch.nn.Module`],\
Dict[str, Callable[[:class:`~sentence_transformers.SentenceTransformer`]]]], *optional*):
The loss function to use for training. Can either be a loss class instance, a dictionary mapping dataset names to
loss class instances, a function that returns a loss class instance given a model, or a dictionary mapping
dataset names to functions that return a loss class instance given a model. In practice, the latter two
are primarily used for hyper-parameter optimization. Will default to
:class:`~sentence_transformers.losses.CoSENTLoss` if no ``loss`` is provided.
evaluator (Union[:class:`~sentence_transformers.evaluation.SentenceEvaluator`,\
List[:class:`~sentence_transformers.evaluation.SentenceEvaluator`]], *optional*):
The evaluator instance for useful evaluation metrics during training. You can use an ``evaluator`` with
or without an ``eval_dataset``, and vice versa. Generally, the metrics that an ``evaluator`` returns
are more useful than the loss value returned from the ``eval_dataset``. A list of evaluators will be
wrapped in a :class:`~sentence_transformers.evaluation.SequentialEvaluator` to run them sequentially.
callbacks (List of [:class:`transformers.TrainerCallback`], *optional*):
A list of callbacks to customize the training loop. Will add those to the list of default callbacks
detailed in [here](callback).
If you want to remove one of the default callbacks used, use the [`Trainer.remove_callback`] method.
optimizers (`Tuple[:class:`torch.optim.Optimizer`, :class:`torch.optim.lr_scheduler.LambdaLR`]`, *optional*, defaults to `(None, None)`):
A tuple containing the optimizer and the scheduler to use. Will default to an instance of :class:`torch.optim.AdamW`
on your model and a scheduler given by :func:`transformers.get_linear_schedule_with_warmup` controlled by `args`.
Important attributes:
- **model** -- Always points to the core model. If using a transformers model, it will be a [`PreTrainedModel`]
subclass.
- **model_wrapped** -- Always points to the most external model in case one or more other modules wrap the
original model. This is the model that should be used for the forward pass. For example, under `DeepSpeed`,
the inner model is wrapped in `DeepSpeed` and then again in `torch.nn.DistributedDataParallel`. If the inner
model hasn't been wrapped, then `self.model_wrapped` is the same as `self.model`.
- **is_model_parallel** -- Whether or not a model has been switched to a model parallel mode (different from
data parallelism, this means some of the model layers are split on different GPUs).
- **place_model_on_device** -- Whether or not to automatically place the model on the device - it will be set
to `False` if model parallel or deepspeed is used, or if the default
`TrainingArguments.place_model_on_device` is overridden to return `False` .
- **is_in_train** -- Whether or not a model is currently running `train` (e.g. when `evaluate` is called while
in `train`)
"""
def __init__(
self,
model: SentenceTransformer | None = None,
args: SentenceTransformerTrainingArguments = None,
train_dataset: Dataset | DatasetDict | IterableDataset | dict[str, Dataset] | None = None,
eval_dataset: Dataset | DatasetDict | IterableDataset | dict[str, Dataset] | None = None,
loss: nn.Module
| dict[str, nn.Module]
| Callable[[SentenceTransformer], torch.nn.Module]
| dict[str, Callable[[SentenceTransformer], torch.nn.Module]]
| None = None,
evaluator: SentenceEvaluator | list[SentenceEvaluator] | None = None,
data_collator: DataCollator | None = None,
tokenizer: PreTrainedTokenizerBase | Callable | None = None,
model_init: Callable[[], SentenceTransformer] | None = None,
compute_metrics: Callable[[EvalPrediction], dict] | None = None,
callbacks: list[TrainerCallback] | None = None,
optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
preprocess_logits_for_metrics: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] | None = None,
) -> None:
if not is_training_available():
raise RuntimeError(
"To train a SentenceTransformer model, you need to install the `accelerate` and `datasets` modules. "
"You can do so with the `train` extra:\n"
'pip install -U "sentence-transformers[train]"'
)
if args is None:
output_dir = "tmp_trainer"
logger.info(f"No `TrainingArguments` passed, using `output_dir={output_dir}`.")
args = SentenceTransformerTrainingArguments(output_dir=output_dir)
elif not isinstance(args, SentenceTransformerTrainingArguments):
raise ValueError("Please use `TrainingArguments` imported from `sentence_transformers`.")
if model is None:
if model_init is not None:
self.model_init = model_init
model = self.call_model_init()
else:
raise RuntimeError("`Trainer` requires either a `model` or `model_init` argument")
else:
if model_init is not None:
warnings.warn(
"`Trainer` requires either a `model` or `model_init` argument, but not both. `model_init` will"
" overwrite your model when calling the `train` method. This will become a fatal error in the next"
" release.",
FutureWarning,
)
self.model_init = model_init
# Get a dictionary of the default training arguments, so we can determine which arguments have been changed
# for the model card
default_args_dict = SentenceTransformerTrainingArguments(output_dir="unused").to_dict()
# If the model ID is set via the SentenceTransformerTrainingArguments, but not via the SentenceTransformerModelCardData,
# then we can set it here for the model card regardless
if args.hub_model_id and not model.model_card_data.model_id:
model.model_card_data.set_model_id(args.hub_model_id)
if tokenizer is None and isinstance(model.tokenizer, PreTrainedTokenizerBase):
tokenizer = model.tokenizer
if data_collator is None:
data_collator = SentenceTransformerDataCollator(tokenize_fn=model.tokenize)
for dataset_name, dataset in zip(["train", "eval"], [train_dataset, eval_dataset]):
if isinstance(dataset, IterableDataset) and dataset.column_names is None:
sample = next(iter(dataset))
naive_type_mapping = {str: "string", int: "int64", float: "float32", bool: "bool"}
example_features = {
key: Value(naive_type_mapping.get(type(value), "null")) for key, value in sample.items()
}
raise ValueError(
f"The provided `{dataset_name}_dataset` must have Features. Specify them with e.g.:\n"
f"{dataset_name}_dataset = {dataset_name}_dataset.cast(Features({example_features}))\n"
"or by providing the Features to the IterableDataset initialization method. See the Datasets "
"documentation for more information on dataset Features: "
"https://huggingface.co/docs/datasets/en/about_dataset_features"
)
if isinstance(train_dataset, dict) and not isinstance(train_dataset, DatasetDict):
train_dataset = DatasetDict(train_dataset)
if isinstance(eval_dataset, dict) and not isinstance(eval_dataset, DatasetDict):
eval_dataset = DatasetDict(eval_dataset)
super().__init__(
model=None if self.model_init else model,
args=args,
data_collator=data_collator,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
tokenizer=tokenizer,
model_init=model_init,
compute_metrics=compute_metrics,
callbacks=callbacks,
optimizers=optimizers,
preprocess_logits_for_metrics=preprocess_logits_for_metrics,
)
# Every Sentence Transformer model can always return a loss, so we set this to True
# to avoid having to specify it in the data collator or model's forward
self.can_return_loss = True
self.model: SentenceTransformer
self.args: SentenceTransformerTrainingArguments
self.data_collator: SentenceTransformerDataCollator
# Set the W&B project via environment variables if it's not already set
if any([isinstance(callback, WandbCallback) for callback in self.callback_handler.callbacks]):
os.environ.setdefault("WANDB_PROJECT", "sentence-transformers")
if loss is None:
logger.info("No `loss` passed, using `losses.CoSENTLoss` as a default option.")
loss = CoSENTLoss(self.model)
if isinstance(loss, dict):
self.loss = {dataset_name: self.prepare_loss(loss_fn, model) for dataset_name, loss_fn in loss.items()}
for dataset_name, dataset in zip(["train", "eval"], [train_dataset, eval_dataset]):
if dataset is None:
continue
if not isinstance(dataset, dict):
raise ValueError(
f"If the provided `loss` is a dict, then the `{dataset_name}_dataset` must be a `DatasetDict`."
)
if missing := set(dataset.keys()) - set(loss.keys()):
raise ValueError(
f"If the provided `loss` is a dict, then all keys from the `{dataset_name}_dataset` dictionary must occur in `loss` also. "
f"Currently, {sorted(missing)} occur{'s' if len(missing) == 1 else ''} in `{dataset_name}_dataset` but not in `loss`."
)
else:
self.loss = self.prepare_loss(loss, model)
# If evaluator is a list, we wrap it in a SequentialEvaluator
if evaluator is not None and not isinstance(evaluator, SentenceEvaluator):
evaluator = SequentialEvaluator(evaluator)
self.evaluator = evaluator
# Add a callback responsible for automatically tracking data required for the automatic model card generation
model_card_callback = ModelCardCallback(self, default_args_dict)
self.add_callback(model_card_callback)
model_card_callback.on_init_end(self.args, self.state, self.control, self.model)
def call_model_init(self, trial=None) -> SentenceTransformer:
model = super().call_model_init(trial=trial)
# If the Trainer already has a loss, then we'll want to override the model in the loss function
if not hasattr(self, "loss"):
return model
# Multi-loss training:
if isinstance(self.loss, dict):
for key, loss_fn in self.loss.items():
# If a loss function is not yet initialized, we initialize it here
if not isinstance(loss_fn, torch.nn.Module):
self.loss[key] = loss_fn(model)
# Otherwise, we override the original model with the updated model in the loss function
elif hasattr(loss_fn, "model"):
self.loss = self.override_model_in_loss(self.loss, model)
# Loss is a function accepting a model as an argument
elif not isinstance(self.loss, torch.nn.Module):
self.loss = self.loss(model)
# Loss is an initialized torch.nn.Module
elif hasattr(self.loss, "model"):
self.loss = self.override_model_in_loss(self.loss, model)
return model
def override_model_in_loss(self, loss: torch.nn.Module, model: SentenceTransformer) -> torch.nn.Module:
from sentence_transformers import SentenceTransformer
for name, child in loss.named_children():
if name == "model" and isinstance(child, SentenceTransformer):
loss.model = model
elif isinstance(child, torch.nn.Module):
setattr(loss, name, self.override_model_in_loss(child, model))
return loss
def prepare_loss(
self,
loss: Callable[[SentenceTransformer], torch.nn.Module] | torch.nn.Module,
model: SentenceTransformer,
) -> torch.nn.Module:
if isinstance(loss, torch.nn.Module):
return loss.to(model.device)
return loss(model).to(model.device)
def add_dataset_name_column(self, dataset_dict: DatasetDict) -> DatasetDict:
for key, dataset in dataset_dict.items():
if "dataset_name" not in dataset.column_names:
dataset_dict[key] = dataset.add_column("dataset_name", [key] * len(dataset))
return dataset_dict
def compute_loss(
self,
model: SentenceTransformer,
inputs: dict[str, torch.Tensor | Any],
return_outputs: bool = False,
) -> torch.Tensor | tuple[torch.Tensor, dict[str, Any]]:
"""
Computes the loss for the SentenceTransformer model.
It uses ``self.loss`` to compute the loss, which can be a single loss function or a dictionary of loss functions
for different datasets. If the loss is a dictionary, the dataset name is expected to be passed in the inputs
under the key "dataset_name". This is done automatically in the ``add_dataset_name_column`` method.
Note that even if ``return_outputs = True``, the outputs will be empty, as the SentenceTransformers losses do not
return outputs.
Args:
model (SentenceTransformer): The SentenceTransformer model.
inputs (Dict[str, Union[torch.Tensor, Any]]): The input data for the model.
return_outputs (bool, optional): Whether to return the outputs along with the loss. Defaults to False.
Returns:
Union[torch.Tensor, Tuple[torch.Tensor, Dict[str, Any]]]: The computed loss. If `return_outputs` is True, returns a tuple of loss and outputs. Otherwise, returns only the loss.
"""
dataset_name = inputs.pop("dataset_name", None)
features, labels = self.collect_features(inputs)
loss_fn = self.loss
if isinstance(loss_fn, dict) and dataset_name:
loss_fn = loss_fn[dataset_name]
# Insert the wrapped (e.g. distributed or compiled) model into the loss function,
# if the loss stores the model. Only called once per process
if (
model == self.model_wrapped
and model != self.model # Only if the model is wrapped
and hasattr(loss_fn, "model") # Only if the loss stores the model
and loss_fn.model != model # Only if the wrapped model is not already stored
):
loss_fn = self.override_model_in_loss(loss_fn, model)
loss = loss_fn(features, labels)
if return_outputs:
# During prediction/evaluation, `compute_loss` will be called with `return_outputs=True`.
# However, Sentence Transformer losses do not return outputs, so we return an empty dictionary.
# This does not result in any problems, as the SentenceTransformerTrainingArguments sets
# `prediction_loss_only=True` which means that the output is not used.
return loss, {}
return loss
def collect_features(
self, inputs: dict[str, torch.Tensor | Any]
) -> tuple[list[dict[str, torch.Tensor]], torch.Tensor | None]:
"""Turn the inputs from the dataloader into the separate model inputs & the labels.
Example::
>>> list(inputs.keys())
['return_loss', 'label', 'sentence_0_input_ids', 'sentence_0_token_type_ids', 'sentence_0_attention_mask', 'sentence_1_input_ids', 'sentence_1_token_type_ids', 'sentence_1_attention_mask']
>>> features, labels = self.collect_features(inputs)
>>> len(features)
2
>>> list(features[0].keys())
['input_ids', 'token_type_ids', 'attention_mask']
>>> list(features[1].keys())
['input_ids', 'token_type_ids', 'attention_mask']
>>> torch.equal(labels, inputs["label"])
True
"""
# All inputs ending with `_input_ids` (Transformers), `_sentence_embedding` (BoW), `_pixel_values` (CLIPModel)
# are considered to correspond to a feature
features = []
for column in inputs:
if column.endswith("_input_ids"):
prefix = column[: -len("input_ids")]
elif column.endswith("_sentence_embedding"):
prefix = column[: -len("sentence_embedding")]
elif column.endswith("_pixel_values"):
prefix = column[: -len("pixel_values")]
else:
continue
features.append({key[len(prefix) :]: value for key, value in inputs.items() if key.startswith(prefix)})
labels = inputs.get("label", None)
return features, labels
def evaluate(
self,
eval_dataset: Dataset | dict[str, Dataset] | None = None,
ignore_keys: list[str] | None = None,
metric_key_prefix: str = "eval",
) -> dict[str, float]:
eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
if isinstance(eval_dataset, DatasetDict) and isinstance(self.loss, dict):
eval_dataset = self.add_dataset_name_column(eval_dataset)
return super().evaluate(eval_dataset, ignore_keys, metric_key_prefix)
def evaluation_loop(
self,
dataloader: DataLoader,
description: str,
prediction_loss_only: bool | None = None,
ignore_keys: list[str] | None = None,
metric_key_prefix: str = "eval",
) -> EvalLoopOutput:
output = super().evaluation_loop(
dataloader=dataloader,
description=description,
prediction_loss_only=prediction_loss_only,
ignore_keys=ignore_keys,
metric_key_prefix=metric_key_prefix,
)
# If the evaluator is not defined, we can just return the output
if self.evaluator is None:
return output
# If we are training and eval_dataset is a DatasetDict, then we should
# 1) only run the evaluator for the first dataset
# 2) prefix that only run as "eval", rather than e.g. "eval_multi_nli"
if self.is_in_train and isinstance(self.eval_dataset, dict) and metric_key_prefix.startswith("eval_"):
if metric_key_prefix[5:] == list(self.eval_dataset.keys())[0]:
metric_key_prefix = "eval"
else:
return output
with nullcontext() if self.is_local_process_zero() else disable_logging(logging.INFO):
evaluator_metrics = self.evaluator(self.model)
if not isinstance(evaluator_metrics, dict):
evaluator_metrics = {"evaluator": evaluator_metrics}
# Prefix all keys with metric_key_prefix + '_'
for key in list(evaluator_metrics.keys()):
if not key.startswith(f"{metric_key_prefix}_"):
evaluator_metrics[f"{metric_key_prefix}_{key}"] = evaluator_metrics.pop(key)
output.metrics.update(evaluator_metrics)
return output
def _load_best_model(self) -> None:
# We want to ensure that this does not fail, and it may change if transformers updates how checkpoints are saved
# Loading the best model is only supported for `transformers`-based models
if not isinstance(self.model[0], Transformer):
logger.info("Could not load best model, as the model is not a `transformers`-based model.")
return
try:
if checkpoint := self.state.best_model_checkpoint:
step = checkpoint.rsplit("-", 1)[-1]
self.model.model_card_data.set_best_model_step(int(step))
except Exception:
pass
# Override the model with the `transformers`-based auto_model, and restore the original SentenceTransformers
# model with the loaded `transformers` model
full_model = self.model
self.model = self.model[0].auto_model
try:
return super()._load_best_model()
finally:
loaded_auto_model = self.model
self.model = full_model
self.model[0].auto_model = loaded_auto_model
def validate_column_names(self, dataset: Dataset, dataset_name: str | None = None) -> bool:
if overlap := set(dataset.column_names) & {"return_loss", "dataset_name"}:
raise ValueError(
f"The following column names are invalid in your {dataset_name + ' ' if dataset_name else ''}dataset: {list(overlap)}."
" Avoid using these column names, as they are reserved for internal use."
)
def get_batch_sampler(
self,
dataset: Dataset,
batch_size: int,
drop_last: bool,
valid_label_columns: list[str] | None = None,
generator: torch.Generator | None = None,
) -> BatchSampler | None:
"""
Returns the appropriate batch sampler based on the ``batch_sampler`` argument in ``self.args``.
This batch sampler class supports ``__len__`` and ``__iter__`` methods, and is used as the ``batch_sampler``
to create the :class:`torch.utils.data.DataLoader`.
.. note::
Override this method to provide a custom batch sampler.
Args:
dataset (Dataset): The dataset to sample from.
batch_size (int): Number of samples per batch.
drop_last (bool): If True, drop the last incomplete batch if the dataset size
is not divisible by the batch size.
valid_label_columns (List[str]): List of column names to check for labels.
The first column name from ``valid_label_columns`` found in the dataset will
be used as the label column.
generator (torch.Generator, optional): Optional random number generator for shuffling
the indices.
"""
if isinstance(dataset, IterableDataset):
if self.args.batch_sampler != BatchSamplers.BATCH_SAMPLER:
logger.warning("When using an IterableDataset, you cannot specify a batch sampler.")
return None
if self.args.batch_sampler == BatchSamplers.NO_DUPLICATES:
return NoDuplicatesBatchSampler(
dataset=dataset,
batch_size=batch_size,
drop_last=drop_last,
valid_label_columns=valid_label_columns,
generator=generator,
)
if self.args.batch_sampler == BatchSamplers.GROUP_BY_LABEL:
return GroupByLabelBatchSampler(
dataset=dataset,
batch_size=batch_size,
drop_last=drop_last,
valid_label_columns=valid_label_columns,
)
if self.args.batch_sampler == BatchSamplers.BATCH_SAMPLER:
return DefaultBatchSampler(
SubsetRandomSampler(range(len(dataset)), generator=generator),
batch_size=batch_size,
drop_last=drop_last,
)
def get_multi_dataset_batch_sampler(
self,
dataset: ConcatDataset,
batch_samplers: list[BatchSampler],
generator: torch.Generator | None = None,
seed: int | None = 0,
) -> BatchSampler:
"""
Returns the appropriate multi-dataset batch sampler based on the ``multi_dataset_batch_sampler`` argument
in ``self.args``. This batch sampler class supports ``__len__`` and ``__iter__`` methods, and is used as the
``batch_sampler`` to create the :class:`torch.utils.data.DataLoader`.
.. note::
Override this method to provide a custom multi-dataset batch sampler.
Args:
dataset (ConcatDataset): The concatenation of all datasets.
batch_samplers (List[BatchSampler]): List of batch samplers for each dataset in the concatenated dataset.
generator (torch.Generator, optional): Optional random number generator for shuffling the indices.
seed (int, optional): Optional seed for the random number generator
"""
if self.args.multi_dataset_batch_sampler == MultiDatasetBatchSamplers.ROUND_ROBIN:
return RoundRobinBatchSampler(
dataset=dataset,
batch_samplers=batch_samplers,
generator=generator,
seed=seed,
)
if self.args.multi_dataset_batch_sampler == MultiDatasetBatchSamplers.PROPORTIONAL:
return ProportionalBatchSampler(
dataset=dataset,
batch_samplers=batch_samplers,
generator=generator,
seed=seed,
)
def get_train_dataloader(self) -> DataLoader:
"""
Returns the training [`~torch.utils.data.DataLoader`].
Will use no sampler if `train_dataset` does not implement `__len__`, a random sampler (adapted to distributed
training if necessary) otherwise.
Subclass and override this method if you want to inject some custom behavior.
"""
if self.train_dataset is None:
raise ValueError("Trainer: training requires a train_dataset.")
train_dataset = self.train_dataset
data_collator = self.data_collator
generator = torch.Generator()
if self.args.seed:
generator.manual_seed(self.args.seed)
dataloader_params = {
"collate_fn": data_collator,
"num_workers": self.args.dataloader_num_workers,
"pin_memory": self.args.dataloader_pin_memory,
"persistent_workers": self.args.dataloader_persistent_workers,
"prefetch_factor": self.args.dataloader_prefetch_factor,
}
if isinstance(train_dataset, IterableDataset):
dataloader_params.update(
{
"batch_size": self.args.train_batch_size,
"drop_last": self.args.dataloader_drop_last,
}
)
elif isinstance(train_dataset, IterableDatasetDict):
raise ValueError(
"Sentence Transformers is not compatible with IterableDatasetDict. Please use a DatasetDict instead."
)
elif isinstance(train_dataset, DatasetDict):
for dataset_name, dataset in train_dataset.items():
self.validate_column_names(dataset, dataset_name=dataset_name)
if isinstance(dataset, IterableDataset):
raise ValueError(
"Sentence Transformers is not compatible with a DatasetDict containing an IterableDataset."
)
if isinstance(self.loss, dict):
train_dataset = self.add_dataset_name_column(train_dataset)
batch_samplers = [
self.get_batch_sampler(
dataset,
batch_size=self.args.train_batch_size,
drop_last=self.args.dataloader_drop_last,
valid_label_columns=data_collator.valid_label_columns,
generator=generator,
)
for dataset in train_dataset.values()
]
train_dataset = ConcatDataset(train_dataset.values())
batch_sampler = self.get_multi_dataset_batch_sampler(
dataset=train_dataset,
batch_samplers=batch_samplers,
generator=generator,
seed=self.args.seed,
)
dataloader_params["batch_sampler"] = batch_sampler
elif isinstance(train_dataset, Dataset):
self.validate_column_names(train_dataset)
batch_sampler = self.get_batch_sampler(
train_dataset,
batch_size=self.args.train_batch_size,
drop_last=self.args.dataloader_drop_last,
valid_label_columns=data_collator.valid_label_columns,
generator=generator,
)
dataloader_params["batch_sampler"] = batch_sampler
else:
raise ValueError(
"Unsupported `train_dataset` type. Use a Dataset, DatasetDict, or IterableDataset for training."
)
# If 'even_batches' is True, it will use the initial few samples to pad out the last sample. This can
# cause issues with multi-dataset training, so we want to set this to False.
# For evaluation, setting 'even_batches' to False results in hanging, so we keep it as True there.
self.accelerator.even_batches = False
self._train_dataloader = self.accelerator.prepare(DataLoader(train_dataset, **dataloader_params))
return self._train_dataloader
def get_eval_dataloader(self, eval_dataset: Dataset | None = None) -> DataLoader:
"""
Returns the evaluation [`~torch.utils.data.DataLoader`].
Subclass and override this method if you want to inject some custom behavior.
Args:
eval_dataset (`torch.utils.data.Dataset`, *optional*):
If provided, will override `self.eval_dataset`. If it is a [`~datasets.Dataset`], columns not accepted
by the `model.forward()` method are automatically removed. It must implement `__len__`.
"""
if eval_dataset is None and self.eval_dataset is None:
# Prevent errors if the evaluator is set but no eval_dataset is provided
if self.evaluator is not None:
return DataLoader([])
raise ValueError("Trainer: evaluation requires an eval_dataset.")
eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
data_collator = self.data_collator
generator = torch.Generator()
if self.args.seed:
generator.manual_seed(self.args.seed)
dataloader_params = {
"collate_fn": data_collator,
"num_workers": self.args.dataloader_num_workers,
"pin_memory": self.args.dataloader_pin_memory,
"persistent_workers": self.args.dataloader_persistent_workers,
"prefetch_factor": self.args.dataloader_prefetch_factor,
}
if isinstance(eval_dataset, IterableDataset):
dataloader_params.update(
{
"batch_size": self.args.eval_batch_size,
"drop_last": self.args.dataloader_drop_last,
}
)
elif isinstance(eval_dataset, IterableDatasetDict):
raise ValueError(
"Sentence Transformers is not compatible with IterableDatasetDict. Please use a DatasetDict instead."
)
elif isinstance(eval_dataset, DatasetDict):
for dataset in eval_dataset.values():
if isinstance(dataset, IterableDataset):
raise ValueError(
"Sentence Transformers is not compatible with a DatasetDict containing an IterableDataset."
)
if isinstance(self.loss, dict):
eval_dataset = self.add_dataset_name_column(eval_dataset)
batch_samplers = [
self.get_batch_sampler(
dataset,
batch_size=self.args.eval_batch_size,
drop_last=self.args.dataloader_drop_last,
valid_label_columns=data_collator.valid_label_columns,
generator=generator,
)
for dataset in eval_dataset.values()
]
eval_dataset = ConcatDataset(eval_dataset.values())
batch_sampler = self.get_multi_dataset_batch_sampler(
dataset=eval_dataset,
batch_samplers=batch_samplers,
generator=generator,
seed=self.args.seed,
)
dataloader_params["batch_sampler"] = batch_sampler
elif isinstance(eval_dataset, Dataset):
batch_sampler = self.get_batch_sampler(
eval_dataset,
batch_size=self.args.eval_batch_size,
drop_last=self.args.dataloader_drop_last,
valid_label_columns=data_collator.valid_label_columns,
generator=generator,
)
dataloader_params["batch_sampler"] = batch_sampler
else:
raise ValueError(
"Unsupported `eval_dataset` type. Use a Dataset, DatasetDict, or IterableDataset for evaluation."
)
# If 'even_batches' is True, it will use the initial few samples to pad out the last sample. This can
# cause issues with multi-dataset training, so we want to set this to False during training.
# For evaluation, setting 'even_batches' to False results in hanging, so we keep it as True here.
self.accelerator.even_batches = True
return self.accelerator.prepare(DataLoader(eval_dataset, **dataloader_params))
def get_test_dataloader(self, test_dataset: Dataset) -> DataLoader:
"""
Returns the training [`~torch.utils.data.DataLoader`].
Subclass and override this method if you want to inject some custom behavior.
Args:
test_dataset (`torch.utils.data.Dataset`, *optional*):
The test dataset to use. If it is a [`~datasets.Dataset`], columns not accepted by the
`model.forward()` method are automatically removed. It must implement `__len__`.
"""
data_collator = self.data_collator
generator = torch.Generator()
if self.args.seed:
generator.manual_seed(self.args.seed)
dataloader_params = {
"collate_fn": data_collator,
"num_workers": self.args.dataloader_num_workers,
"pin_memory": self.args.dataloader_pin_memory,
"persistent_workers": self.args.dataloader_persistent_workers,
"prefetch_factor": self.args.dataloader_prefetch_factor,
}
if isinstance(test_dataset, IterableDataset):
dataloader_params.update(
{
"batch_size": self.args.eval_batch_size,
"drop_last": self.args.dataloader_drop_last,
}
)
elif isinstance(test_dataset, IterableDatasetDict):
raise ValueError(
"Sentence Transformers is not compatible with IterableDatasetDict. Please use a DatasetDict instead."
)
elif isinstance(test_dataset, DatasetDict):
for dataset_name, dataset in test_dataset.items():
self.validate_column_names(dataset, dataset_name=dataset_name)
if isinstance(dataset, IterableDataset):
raise ValueError(
"Sentence Transformers is not compatible with a DatasetDict containing an IterableDataset."
)
if isinstance(self.loss, dict):
test_dataset = self.add_dataset_name_column(test_dataset)
batch_samplers = [
self.get_batch_sampler(
dataset,
batch_size=self.args.eval_batch_size,
drop_last=self.args.dataloader_drop_last,
valid_label_columns=data_collator.valid_label_columns,
generator=generator,
)
for dataset in test_dataset.values()
]
test_dataset = ConcatDataset(test_dataset.values())
batch_sampler = self.get_multi_dataset_batch_sampler(
dataset=test_dataset,
batch_samplers=batch_samplers,
generator=generator,
seed=self.args.seed,
)
dataloader_params["batch_sampler"] = batch_sampler
elif isinstance(test_dataset, Dataset):
self.validate_column_names(test_dataset)
batch_sampler = self.get_batch_sampler(
test_dataset,
batch_size=self.args.eval_batch_size,
drop_last=self.args.dataloader_drop_last,
valid_label_columns=data_collator.valid_label_columns,
generator=generator,
)
dataloader_params["batch_sampler"] = batch_sampler
else:
raise ValueError(
"Unsupported `test_dataset` type. Use a Dataset, DatasetDict, or IterableDataset for testing."
)
# If 'even_batches' is True, it will use the initial few samples to pad out the last sample. This can
# cause issues with multi-dataset training, so we want to set this to False.
# For evaluation, setting 'even_batches' to False results in hanging, so we keep it as True there.
self.accelerator.even_batches = False
self._train_dataloader = self.accelerator.prepare(DataLoader(test_dataset, **dataloader_params))
return self._train_dataloader
def _save(self, output_dir: str | None = None, state_dict=None) -> None:
# If we are executing this function, we are the process zero, so we don't check for that.
output_dir = output_dir if output_dir is not None else self.args.output_dir
os.makedirs(output_dir, exist_ok=True)
logger.info(f"Saving model checkpoint to {output_dir}")
self.model.save_pretrained(output_dir, safe_serialization=self.args.save_safetensors)
if self.tokenizer is not None:
self.tokenizer.save_pretrained(output_dir)
# Good practice: save your training arguments together with the trained model
torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))
def _load_from_checkpoint(self, checkpoint_path: str) -> None:
from sentence_transformers import SentenceTransformer
loaded_model = SentenceTransformer(checkpoint_path, trust_remote_code=self.model.trust_remote_code)
self.model.load_state_dict(loaded_model.state_dict())
def create_model_card(
self,
language: str | None = None,
license: str | None = None,
tags: str | list[str] | None = None,
model_name: str | None = None,
finetuned_from: str | None = None,
tasks: str | list[str] | None = None,
dataset_tags: str | list[str] | None = None,
dataset: str | list[str] | None = None,
dataset_args: str | list[str] | None = None,
**kwargs,
) -> None:
if not self.is_world_process_zero():
return
if language:
self.model.model_card_data.set_language(language)
if license:
self.model.model_card_data.set_license(license)
if tags:
self.model.model_card_data.add_tags(tags)
self.model._create_model_card(self.args.output_dir, model_name=model_name)
def get_optimizer_cls_and_kwargs(
self, args: SentenceTransformerTrainingArguments, model: SentenceTransformer | None = None
) -> tuple[Any, Any]:
"""
We have to override the optimizer_grouped_parameters because the Trainer superclass bases it on the `model`
itself, but the SentenceTransformer losses can have weights that should be updated as well, e.g.
SoftmaxLoss (see #2872).
This method requires `transformers` >= 4.43.0.
"""
if isinstance(self.loss, dict):
loss_model = nn.Sequential(OrderedDict(self.loss))
else:
loss_model = self.loss
optimizer_cls, optimizer_kwargs = super().get_optimizer_cls_and_kwargs(args, loss_model)
# If the kwargs were not overridden by the super() call, then we should override them here so that the potential
# weights in the loss(es) can also be updated.
if not {"params", "model", "optimizer_dict"} & set(optimizer_kwargs.keys()):
decay_parameters = self.get_decay_parameter_names(loss_model)
optimizer_kwargs["optimizer_dict"] = [
{
"params": [
p for n, p in loss_model.named_parameters() if (n in decay_parameters and p.requires_grad)
],
"weight_decay": self.args.weight_decay,
},
{
"params": [
p for n, p in loss_model.named_parameters() if (n not in decay_parameters and p.requires_grad)
],
"weight_decay": 0.0,
},
]
return optimizer_cls, optimizer_kwargs