Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Trainer supporting evaluation on multiple datasets #19158

Merged
merged 9 commits into from
Sep 23, 2022
Merged
40 changes: 31 additions & 9 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,9 +238,10 @@ class Trainer:
`torch.Generator` for the randomization that must be identical on all processes (and the Trainer will
manually set the seed of this `generator` at each epoch) or have a `set_epoch()` method that internally
sets the seed of the RNGs used.
eval_dataset (`torch.utils.data.Dataset`, *optional*):
eval_dataset (Union[`torch.utils.data.Dataset`, Dict[str, `torch.utils.data.Dataset`]), *optional*):
The dataset to use for evaluation. If it is a [`~datasets.Dataset`], columns not accepted by the
`model.forward()` method are automatically removed.
`model.forward()` method are automatically removed. If it is a dictionary, it will evaluate on each
dataset prepending the dictionary key to the metric name.
tokenizer ([`PreTrainedTokenizerBase`], *optional*):
The tokenizer used to preprocess the data. If provided, will be used to automatically pad the inputs the
maximum length when batching inputs, and it will be saved along the model to make it easier to rerun an
Expand All @@ -252,9 +253,10 @@ class Trainer:
The function may have zero argument, or a single one containing the optuna/Ray Tune/SigOpt trial object, to
be able to choose different architectures according to hyper parameters (such as layer count, sizes of
inner layers, dropout probabilities etc).
compute_metrics (`Callable[[EvalPrediction], Dict]`, *optional*):
compute_metrics (`Union[Callable[[EvalPrediction], Dict], Dict[str, Callable[[EvalPrediction], Dict]]]`, *optional*):
The function that will be used to compute metrics at evaluation. Must take a [`EvalPrediction`] and return
a dictionary string to metric values.
a dictionary string to metric values. If `eval_dataset` is a dict, each dataset can be evaluated with a
separate compute function. The keys in `eval_dataset` and `compute_metrics` must match.
callbacks (List of [`TrainerCallback`], *optional*):
A list of callbacks to customize the training loop. Will add those to the list of default callbacks
detailed in [here](callback).
Expand Down Expand Up @@ -2040,7 +2042,19 @@ def _maybe_log_save_evaluate(self, tr_loss, model, trial, epoch, ignore_keys_for

metrics = None
if self.control.should_evaluate:
metrics = self.evaluate(ignore_keys=ignore_keys_for_eval)
if isinstance(self.eval_dataset, dict):
for eval_dataset_name, eval_dataset in self.eval_dataset.items():
evaluate_kwargs = {}
if isinstance(self.compute_metrics, dict):
evaluate_kwargs["compute_metrics"] = self.compute_metrics[eval_dataset_name]
metrics = self.evaluate(
eval_dataset=eval_dataset,
ignore_keys=ignore_keys_for_eval,
metric_key_prefix=f"eval_{eval_dataset_name}",
**evaluate_kwargs,
)
else:
metrics = self.evaluate(ignore_keys=ignore_keys_for_eval)
self._report_to_hp_search(trial, self.state.global_step, metrics)

if self.control.should_save:
Expand Down Expand Up @@ -2728,6 +2742,7 @@ def _rotate_checkpoints(self, use_mtime=False, output_dir=None) -> None:
def evaluate(
self,
eval_dataset: Optional[Dataset] = None,
compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,
ignore_keys: Optional[List[str]] = None,
metric_key_prefix: str = "eval",
) -> Dict[str, float]:
Expand All @@ -2744,6 +2759,9 @@ def evaluate(
Pass a dataset if you wish to override `self.eval_dataset`. If it is a [`~datasets.Dataset`], columns
not accepted by the `model.forward()` method are automatically removed. It must implement the `__len__`
method.
compute_metrics (`Callable[[EvalPrediction], Dict]`, *optional*, defaults to `None`):
Pass a compute_metric function if you wish to override `self.compute_metrics`. Used when
`self.eval_dataset` holds multiple datasets.
ignore_keys (`Lst[str]`, *optional*):
A list of keys in the output of your model (if it is a dictionary) that should be ignored when
gathering predictions.
Expand All @@ -2759,6 +2777,8 @@ def evaluate(
self._memory_tracker.start()

eval_dataloader = self.get_eval_dataloader(eval_dataset)
compute_metrics = compute_metrics if compute_metrics is not None else self.compute_metrics

start_time = time.time()

eval_loop = self.prediction_loop if self.args.use_legacy_prediction_loop else self.evaluation_loop
Expand All @@ -2767,9 +2787,10 @@ def evaluate(
description="Evaluation",
# No point gathering the predictions if there are no metrics, otherwise we defer to
# self.args.prediction_loss_only
prediction_loss_only=True if self.compute_metrics is None else None,
prediction_loss_only=True if compute_metrics is None else None,
ignore_keys=ignore_keys,
metric_key_prefix=metric_key_prefix,
compute_metrics=compute_metrics,
)

total_batch_size = self.args.eval_batch_size * self.args.world_size
Expand Down Expand Up @@ -2861,6 +2882,7 @@ def evaluation_loop(
prediction_loss_only: Optional[bool] = None,
ignore_keys: Optional[List[str]] = None,
metric_key_prefix: str = "eval",
compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,
) -> EvalLoopOutput:
"""
Prediction/evaluation loop, shared by `Trainer.evaluate()` and `Trainer.predict()`.
Expand Down Expand Up @@ -3041,13 +3063,13 @@ def evaluation_loop(
all_inputs = nested_truncate(all_inputs, num_samples)

# Metrics!
if self.compute_metrics is not None and all_preds is not None and all_labels is not None:
if compute_metrics is not None and all_preds is not None and all_labels is not None:
if args.include_inputs_for_metrics:
metrics = self.compute_metrics(
metrics = compute_metrics(
EvalPrediction(predictions=all_preds, label_ids=all_labels, inputs=all_inputs)
)
else:
metrics = self.compute_metrics(EvalPrediction(predictions=all_preds, label_ids=all_labels))
metrics = compute_metrics(EvalPrediction(predictions=all_preds, label_ids=all_labels))
else:
metrics = {}

Expand Down
12 changes: 9 additions & 3 deletions src/transformers/trainer_seq2seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, Dict, List, Optional, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import torch
from torch import nn
from torch.utils.data import Dataset

from .deepspeed import is_deepspeed_zero3_enabled
from .trainer import Trainer
from .trainer_utils import PredictionOutput
from .trainer_utils import EvalPrediction, PredictionOutput
from .utils import logging


Expand All @@ -31,6 +31,7 @@ class Seq2SeqTrainer(Trainer):
def evaluate(
self,
eval_dataset: Optional[Dataset] = None,
compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,
ignore_keys: Optional[List[str]] = None,
metric_key_prefix: str = "eval",
**gen_kwargs
Expand All @@ -48,6 +49,9 @@ def evaluate(
Pass a dataset if you wish to override `self.eval_dataset`. If it is an [`~datasets.Dataset`], columns
not accepted by the `model.forward()` method are automatically removed. It must implement the `__len__`
method.
compute_metrics (`Callable[[EvalPrediction], Dict]`, *optional*, defaults to `None`):
Pass a compute_metric function if you wish to override `self.compute_metrics`. Used when
`self.eval_dataset` holds multiple datasets.
ignore_keys (`List[str]`, *optional*):
A list of keys in the output of your model (if it is a dictionary) that should be ignored when
gathering predictions.
Expand Down Expand Up @@ -75,7 +79,9 @@ def evaluate(
)
self._gen_kwargs = gen_kwargs

return super().evaluate(eval_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix)
return super().evaluate(
eval_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix, compute_metrics=compute_metrics
)

def predict(
self,
Expand Down