Skip to content

Commit

Permalink
[WIP] Trainer supporting evaluation on multiple datasets (#19158)
Browse files Browse the repository at this point in the history
* support for multiple eval datasets

* support multiple datasets in seq2seq trainer

* add documentation

* update documentation

* make fixup

* revert option for multiple compute_metrics

* revert option for multiple compute_metrics

* revert added empty line
  • Loading branch information
timbmg authored Sep 23, 2022
1 parent 49629e7 commit 905635f
Showing 1 changed file with 12 additions and 3 deletions.
15 changes: 12 additions & 3 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,9 +238,10 @@ class Trainer:
`torch.Generator` for the randomization that must be identical on all processes (and the Trainer will
manually set the seed of this `generator` at each epoch) or have a `set_epoch()` method that internally
sets the seed of the RNGs used.
eval_dataset (`torch.utils.data.Dataset`, *optional*):
eval_dataset (Union[`torch.utils.data.Dataset`, Dict[str, `torch.utils.data.Dataset`]), *optional*):
The dataset to use for evaluation. If it is a [`~datasets.Dataset`], columns not accepted by the
`model.forward()` method are automatically removed.
`model.forward()` method are automatically removed. If it is a dictionary, it will evaluate on each
dataset prepending the dictionary key to the metric name.
tokenizer ([`PreTrainedTokenizerBase`], *optional*):
The tokenizer used to preprocess the data. If provided, will be used to automatically pad the inputs the
maximum length when batching inputs, and it will be saved along the model to make it easier to rerun an
Expand Down Expand Up @@ -2040,7 +2041,15 @@ def _maybe_log_save_evaluate(self, tr_loss, model, trial, epoch, ignore_keys_for

metrics = None
if self.control.should_evaluate:
metrics = self.evaluate(ignore_keys=ignore_keys_for_eval)
if isinstance(self.eval_dataset, dict):
for eval_dataset_name, eval_dataset in self.eval_dataset.items():
metrics = self.evaluate(
eval_dataset=eval_dataset,
ignore_keys=ignore_keys_for_eval,
metric_key_prefix=f"eval_{eval_dataset_name}",
)
else:
metrics = self.evaluate(ignore_keys=ignore_keys_for_eval)
self._report_to_hp_search(trial, self.state.global_step, metrics)

if self.control.should_save:
Expand Down

0 comments on commit 905635f

Please sign in to comment.