Skip to content

Commit

Permalink
Put Seq2SeqTrainer in the main lib
Browse files Browse the repository at this point in the history
  • Loading branch information
sgugger committed Dec 21, 2020
1 parent 194ca71 commit fe7960b
Show file tree
Hide file tree
Showing 7 changed files with 329 additions and 18 deletions.
27 changes: 20 additions & 7 deletions examples/seq2seq/finetune_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,16 @@
from typing import Optional

import transformers
from seq2seq_trainer import Seq2SeqTrainer
from seq2seq_training_args import Seq2SeqTrainingArguments
from transformers import AutoConfig, AutoModelForSeq2SeqLM, AutoTokenizer, HfArgumentParser, MBartTokenizer, set_seed
from transformers import (
AutoConfig,
AutoModelForSeq2SeqLM,
AutoTokenizer,
HfArgumentParser,
MBartTokenizer,
Seq2SeqTrainer,
Seq2SeqTrainingArguments,
set_seed,
)
from transformers.trainer_utils import EvaluationStrategy, is_main_process
from transformers.training_args import ParallelMode
from utils import (
Expand Down Expand Up @@ -273,13 +280,12 @@ def main():
)
trainer = Seq2SeqTrainer(
model=model,
config=config,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
data_collator=Seq2SeqDataCollator(tokenizer, data_args, training_args.tpu_num_cores),
compute_metrics=compute_metrics_fn,
data_args=data_args,
tokenizer=tokenizer,
)

all_metrics = {}
Expand Down Expand Up @@ -310,7 +316,9 @@ def main():
if training_args.do_eval:
logger.info("*** Evaluate ***")

metrics = trainer.evaluate(metric_key_prefix="val")
metrics = trainer.evaluate(
metric_key_prefix="val", max_target_length=data_args.val_max_target_length, num_beams=data_args.eval_beams
)
metrics["val_n_objs"] = data_args.n_val
metrics["val_loss"] = round(metrics["val_loss"], 4)

Expand All @@ -322,7 +330,12 @@ def main():
if training_args.do_predict:
logger.info("*** Predict ***")

test_output = trainer.predict(test_dataset=test_dataset, metric_key_prefix="test")
test_output = trainer.predict(
test_dataset=test_dataset,
metric_key_prefix="test",
max_target_length=data_args.test_max_target_length,
num_beams=data_args.eval_beams,
)
metrics = test_output.metrics
metrics["test_n_objs"] = data_args.n_test

Expand Down
6 changes: 3 additions & 3 deletions examples/seq2seq/test_finetune_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import unittest
from unittest.mock import patch

from transformers import BertTokenizer, EncoderDecoderModel
from transformers import BertTokenizer, EncoderDecoderModel, Seq2SeqTrainer, Seq2SeqTrainingArguments
from transformers.file_utils import is_apex_available, is_datasets_available
from transformers.integrations import is_fairscale_available
from transformers.testing_utils import (
Expand All @@ -31,8 +31,7 @@
from transformers.trainer_callback import TrainerState
from transformers.trainer_utils import set_seed

from .finetune_trainer import Seq2SeqTrainingArguments, main
from .seq2seq_trainer import Seq2SeqTrainer
from .finetune_trainer import main


set_seed(42)
Expand Down Expand Up @@ -228,6 +227,7 @@ def _compute_metrics(pred):
compute_metrics=_compute_metrics,
train_dataset=train_dataset,
eval_dataset=val_dataset,
tokenizer=tokenizer,
)

# start training
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,7 @@
)
from .trainer_utils import EvalPrediction, EvaluationStrategy, SchedulerType, set_seed
from .training_args import TrainingArguments
from .training_args_seq2seq import Seq2SeqTrainingArguments
from .training_args_tf import TFTrainingArguments
from .utils import logging

Expand Down Expand Up @@ -688,6 +689,7 @@
# Trainer
from .trainer import Trainer
from .trainer_pt_utils import torch_distributed_zero_first
from .trainer_seq2seq import Seq2SeqTrainer
else:
from .utils.dummy_pt_objects import *

Expand Down
18 changes: 10 additions & 8 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,12 @@ def __init__(
)
self.use_apex = True

# Label smoothing
if self.args.label_smoothing_factor != 0:
self.label_smoother = LabelSmoother(epsilon=self.args.label_smoothing_factor)
else:
self.label_smoother = None

self.state = TrainerState()
self.control = TrainerControl()
# Internal variable for total_flos used to count as tensors (for distributed + TPU), will be sent in the
Expand Down Expand Up @@ -693,12 +699,6 @@ def train(self, model_path: Optional[str] = None, trial: Union["optuna.Trial", D
# find_unused_parameters breaks checkpointing as per
# https://github.com/huggingface/transformers/pull/4659#issuecomment-643356021

# Label smoothing
if self.args.label_smoothing_factor != 0:
self.label_smoother = LabelSmoother(epsilon=self.args.label_smoothing_factor)
else:
self.label_smoother = None

# Train!
if is_torch_tpu_available():
total_train_batch_size = self.args.train_batch_size * xm.xrt_world_size()
Expand Down Expand Up @@ -1575,11 +1575,13 @@ def prediction_step(
else:
outputs = model(**inputs)
if has_labels:
if self.label_smoother is not None and "labels" in inputs:
loss = self.label_smoother(outputs, inputs["labels"]).mean().detach()
else:
loss = (outputs["loss"] if isinstance(outputs, dict) else outputs[0]).mean().detach()
if isinstance(outputs, dict):
loss = outputs["loss"].mean().detach()
logits = tuple(v for k, v in outputs.items() if k not in ignore_keys + ["loss"])
else:
loss = outputs[0].mean().detach()
logits = outputs[1:]
else:
loss = None
Expand Down
247 changes: 247 additions & 0 deletions src/transformers/trainer_seq2seq.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,247 @@
# Copyright 2020 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings
from typing import Any, Dict, List, Optional, Tuple, Union

import torch
from packaging import version
from torch import nn
from torch.utils.data import DistributedSampler, RandomSampler
from torch.utils.data.dataset import Dataset

from .file_utils import is_torch_tpu_available
from .trainer import Trainer
from .trainer_pt_utils import LabelSmoother, get_tpu_sampler
from .trainer_utils import PredictionOutput
from .training_args import ParallelMode
from .utils import logging


if version.parse(torch.__version__) >= version.parse("1.6"):
from torch.cuda.amp import autocast


logger = logging.get_logger(__name__)


class Seq2SeqTrainer(Trainer):
def __init__(self, *args, ignore_pad_token_for_loss=None, **kwargs):
super().__init__(*args, **kwargs)
# TODO: When BART uses -100 everywhere, this can be removed entirely.
if ignore_pad_token_for_loss is not None:
warnings.warn(
"Passing `ignore_pad_token_for_loss` is deprecated. Your model should always use -100 to mark tokens "
"to ignore in the loss."
)
if self.label_smoother is not None:
self.label_smoother = LabelSmoother(
epsilon=self.args.label_smoothing_factor,
ignore_index=ignore_pad_token_for_loss,
)

def _get_train_sampler(self) -> Optional[torch.utils.data.sampler.Sampler]:
if isinstance(self.train_dataset, torch.utils.data.IterableDataset):
return None
elif is_torch_tpu_available():
return get_tpu_sampler(self.train_dataset)
else:
if self.args.sortish_sampler:
self.train_dataset.make_sortish_sampler(
self.args.per_device_train_batch_size,
distributed=(self.args.parallel_mode == ParallelMode.DISTRIBUTED),
)

return (
RandomSampler(self.train_dataset)
if self.args.local_rank == -1
else DistributedSampler(self.train_dataset)
)

def evaluate(
self,
eval_dataset: Optional[Dataset] = None,
ignore_keys: Optional[List[str]] = None,
metric_key_prefix: str = "eval",
max_target_length: Optional[int] = None,
num_beams: Optional[int] = None,
) -> Dict[str, float]:
"""
Run evaluation and returns metrics.
The calling script will be responsible for providing a method to compute metrics, as they are task-dependent
(pass it to the init :obj:`compute_metrics` argument).
You can also subclass and override this method to inject custom behavior.
Args:
eval_dataset (:obj:`Dataset`, `optional`):
Pass a dataset if you wish to override :obj:`self.eval_dataset`. If it is an :obj:`datasets.Dataset`,
columns not accepted by the ``model.forward()`` method are automatically removed. It must implement the
:obj:`__len__` method.
ignore_keys (:obj:`Lst[str]`, `optional`):
A list of keys in the output of your model (if it is a dictionary) that should be ignored when
gathering predictions.
metric_key_prefix (:obj:`str`, `optional`, defaults to :obj:`"eval"`):
An optional prefix to be used as the metrics key prefix. For example the metrics "bleu" will be named
"eval_bleu" if the prefix is "eval" (default)
max_target_length (:obj:`int`, `optional`):
The maximum target length to use when predicting with the generate method.
num_beams (:obj:`int`, `optional`):
Number of beams for beam search that will be used when predicting with the generate method. 1 means no
beam search.
Returns:
A dictionary containing the evaluation loss and the potential metrics computed from the predictions. The
dictionary also contains the epoch number which comes from the training state.
"""
self._max_target_length = max_target_length
self._num_beams = num_beams
return super().evaluate(eval_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix)

def predict(
self,
test_dataset: Dataset,
ignore_keys: Optional[List[str]] = None,
metric_key_prefix: str = "eval",
max_target_length: Optional[int] = None,
num_beams: Optional[int] = None,
) -> PredictionOutput:
"""
Run prediction and returns predictions and potential metrics.
Depending on the dataset and your use case, your test dataset may contain labels. In that case, this method
will also return metrics, like in :obj:`evaluate()`.
Args:
test_dataset (:obj:`Dataset`):
Dataset to run the predictions on. If it is an :obj:`datasets.Dataset`, columns not accepted by the
``model.forward()`` method are automatically removed. Has to implement the method :obj:`__len__`
ignore_keys (:obj:`Lst[str]`, `optional`):
A list of keys in the output of your model (if it is a dictionary) that should be ignored when
gathering predictions.
metric_key_prefix (:obj:`str`, `optional`, defaults to :obj:`"eval"`):
An optional prefix to be used as the metrics key prefix. For example the metrics "bleu" will be named
"eval_bleu" if the prefix is "eval" (default)
max_target_length (:obj:`int`, `optional`):
The maximum target length to use when predicting with the generate method.
num_beams (:obj:`int`, `optional`):
Number of beams for beam search that will be used when predicting with the generate method. 1 means no
beam search.
.. note::
If your predictions or labels have different sequence length (for instance because you're doing dynamic
padding in a token classification task) the predictions will be padded (on the right) to allow for
concatenation into one array. The padding index is -100.
Returns: `NamedTuple` A namedtuple with the following keys:
- predictions (:obj:`np.ndarray`): The predictions on :obj:`test_dataset`.
- label_ids (:obj:`np.ndarray`, `optional`): The labels (if the dataset contained some).
- metrics (:obj:`Dict[str, float]`, `optional`): The potential dictionary of metrics (if the dataset
contained labels).
"""
self._max_target_length = max_target_length
self._num_beams = num_beams
return super().predict(test_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix)

def prediction_step(
self,
model: nn.Module,
inputs: Dict[str, Union[torch.Tensor, Any]],
prediction_loss_only: bool,
ignore_keys: Optional[List[str]] = None,
) -> Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]:
"""
Perform an evaluation step on :obj:`model` using obj:`inputs`.
Subclass and override to inject custom behavior.
Args:
model (:obj:`nn.Module`):
The model to evaluate.
inputs (:obj:`Dict[str, Union[torch.Tensor, Any]]`):
The inputs and targets of the model.
The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
argument :obj:`labels`. Check your model's documentation for all accepted arguments.
prediction_loss_only (:obj:`bool`):
Whether or not to return the loss only.
Return:
Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]: A tuple with the loss, logits and
labels (each being optional).
"""

if not self.args.predict_with_generate or prediction_loss_only:
return super()(self, model, inputs, prediction_loss_only=prediction_loss_only, ignore_keys=ignore_keys)

has_labels = "labels" in inputs
inputs = self._prepare_inputs(inputs)

gen_kwargs = {
"max_length": self._max_target_length
if self._max_target_length is not None
else self.model.config.max_length,
"num_beams": self._num_beams if self._num_beams is not None else self.model.config.num_beams,
}

generated_tokens = self.model.generate(
inputs["input_ids"],
attention_mask=inputs["attention_mask"],
**gen_kwargs,
)
# in case the batch is shorter than max length, the output should be padded
if generated_tokens.shape[-1] < gen_kwargs["max_length"]:
generated_tokens = self._pad_tensors_to_max_len(generated_tokens, gen_kwargs["max_length"])

with torch.no_grad():
if self.use_amp:
with autocast():
outputs = model(**inputs)
else:
outputs = model(**inputs)
if has_labels:
if self.label_smoother is not None:
loss = self.label_smoother(outputs, inputs["labels"]).mean().detach()
else:
loss = (outputs["loss"] if isinstance(outputs, dict) else outputs[0]).mean().detach()
else:
loss = None

if self.args.prediction_loss_only:
return (loss, None, None)

labels = inputs["labels"]
if labels.shape[-1] < gen_kwargs["max_length"]:
labels = self._pad_tensors_to_max_len(labels, gen_kwargs["max_length"])

return (loss, generated_tokens, labels)

def _pad_tensors_to_max_len(self, tensor, max_length):
if self.tokenizer is None:
raise ValueError(
f"Tensor need to be padded to `max_length={max_length}` but no tokenzier was passed when creating "
"this `Trainer`. Make sure to create your `Trainer` with the appropriate tokenizer."
)
# If PAD token is not defined at least EOS token has to be defined
pad_token_id = (
self.tokenizer.pad_token_id if self.tokenizer.pad_token_id is not None else self.tokenizer.eos_token_id
)

padded_tensor = pad_token_id * torch.ones(
(tensor.shape[0], max_length), dtype=tensor.dtype, device=tensor.device
)
padded_tensor[:, : tensor.shape[-1]] = tensor
return padded_tensor
Loading

0 comments on commit fe7960b

Please sign in to comment.