From 905635f5d36d3b2c8407005fa36601f5d99b03bb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tim=20Baumg=C3=A4rtner?= Date: Fri, 23 Sep 2022 15:14:53 +0200 Subject: [PATCH] [WIP] Trainer supporting evaluation on multiple datasets (#19158) * support for multiple eval datasets * support multiple datasets in seq2seq trainer * add documentation * update documentation * make fixup * revert option for multiple compute_metrics * revert option for multiple compute_metrics * revert added empty line --- src/transformers/trainer.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index c1869ef76f0055..214e7a9789d2c7 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -238,9 +238,10 @@ class Trainer: `torch.Generator` for the randomization that must be identical on all processes (and the Trainer will manually set the seed of this `generator` at each epoch) or have a `set_epoch()` method that internally sets the seed of the RNGs used. - eval_dataset (`torch.utils.data.Dataset`, *optional*): + eval_dataset (Union[`torch.utils.data.Dataset`, Dict[str, `torch.utils.data.Dataset`]), *optional*): The dataset to use for evaluation. If it is a [`~datasets.Dataset`], columns not accepted by the - `model.forward()` method are automatically removed. + `model.forward()` method are automatically removed. If it is a dictionary, it will evaluate on each + dataset prepending the dictionary key to the metric name. tokenizer ([`PreTrainedTokenizerBase`], *optional*): The tokenizer used to preprocess the data. If provided, will be used to automatically pad the inputs the maximum length when batching inputs, and it will be saved along the model to make it easier to rerun an @@ -2040,7 +2041,15 @@ def _maybe_log_save_evaluate(self, tr_loss, model, trial, epoch, ignore_keys_for metrics = None if self.control.should_evaluate: - metrics = self.evaluate(ignore_keys=ignore_keys_for_eval) + if isinstance(self.eval_dataset, dict): + for eval_dataset_name, eval_dataset in self.eval_dataset.items(): + metrics = self.evaluate( + eval_dataset=eval_dataset, + ignore_keys=ignore_keys_for_eval, + metric_key_prefix=f"eval_{eval_dataset_name}", + ) + else: + metrics = self.evaluate(ignore_keys=ignore_keys_for_eval) self._report_to_hp_search(trial, self.state.global_step, metrics) if self.control.should_save: