diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index c1869ef76f0055..214e7a9789d2c7 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -238,9 +238,10 @@ class Trainer: `torch.Generator` for the randomization that must be identical on all processes (and the Trainer will manually set the seed of this `generator` at each epoch) or have a `set_epoch()` method that internally sets the seed of the RNGs used. - eval_dataset (`torch.utils.data.Dataset`, *optional*): + eval_dataset (Union[`torch.utils.data.Dataset`, Dict[str, `torch.utils.data.Dataset`]), *optional*): The dataset to use for evaluation. If it is a [`~datasets.Dataset`], columns not accepted by the - `model.forward()` method are automatically removed. + `model.forward()` method are automatically removed. If it is a dictionary, it will evaluate on each + dataset prepending the dictionary key to the metric name. tokenizer ([`PreTrainedTokenizerBase`], *optional*): The tokenizer used to preprocess the data. If provided, will be used to automatically pad the inputs the maximum length when batching inputs, and it will be saved along the model to make it easier to rerun an @@ -2040,7 +2041,15 @@ def _maybe_log_save_evaluate(self, tr_loss, model, trial, epoch, ignore_keys_for metrics = None if self.control.should_evaluate: - metrics = self.evaluate(ignore_keys=ignore_keys_for_eval) + if isinstance(self.eval_dataset, dict): + for eval_dataset_name, eval_dataset in self.eval_dataset.items(): + metrics = self.evaluate( + eval_dataset=eval_dataset, + ignore_keys=ignore_keys_for_eval, + metric_key_prefix=f"eval_{eval_dataset_name}", + ) + else: + metrics = self.evaluate(ignore_keys=ignore_keys_for_eval) self._report_to_hp_search(trial, self.state.global_step, metrics) if self.control.should_save: