Skip to content

Commit

Permalink
Revert "[Seq2SeqTrainer] Move import to init to make file self-contai…
Browse files Browse the repository at this point in the history
…ned (huggingface#8194)"

This reverts commit 4aec1b6.
  • Loading branch information
fabiocapsouza authored Nov 15, 2020
1 parent 5923d99 commit 9d46aac
Showing 1 changed file with 12 additions and 13 deletions.
25 changes: 12 additions & 13 deletions examples/seq2seq/seq2seq_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,12 @@
from transformers.trainer_pt_utils import get_tpu_sampler


try:
from .utils import label_smoothed_nll_loss
except ImportError:
from utils import label_smoothed_nll_loss


logger = logging.get_logger(__name__)

arg_to_scheduler = {
Expand Down Expand Up @@ -58,17 +64,6 @@ def __init__(self, config=None, data_args=None, *args, **kwargs):
f"The `config.pad_token_id` is `None`. Using `config.eos_token_id` = {self.config.eos_token_id} for padding.."
)

if self.args.label_smoothing == 0:
self.loss_fn = torch.nn.CrossEntropyLoss(ignore_index=self.config.pad_token_id)
else:
# dynamically import label_smoothed_nll_loss
try:
from .utils import label_smoothed_nll_loss
except ImportError:
from utils import label_smoothed_nll_loss

self.loss_fn = label_smoothed_nll_loss

def create_optimizer_and_scheduler(self, num_training_steps: int):
"""
Setup the optimizer and the learning rate scheduler.
Expand Down Expand Up @@ -140,15 +135,19 @@ def _compute_loss(self, model, inputs, labels):
if self.data_args is not None and self.data_args.ignore_pad_token_for_loss:
# force training to ignore pad token
logits = model(**inputs, use_cache=False)[0]
loss = self.loss_fn(logits.view(-1, logits.shape[-1]), labels.view(-1))

loss_fct = torch.nn.CrossEntropyLoss(ignore_index=self.config.pad_token_id)
loss = loss_fct(logits.view(-1, logits.shape[-1]), labels.view(-1))
else:
# compute usual loss via models
loss, logits = model(**inputs, labels=labels, use_cache=False)[:2]
else:
# compute label smoothed loss
logits = model(**inputs, use_cache=False)[0]
lprobs = torch.nn.functional.log_softmax(logits, dim=-1)
loss, _ = self.loss_fn(lprobs, labels, self.args.label_smoothing, ignore_index=self.config.pad_token_id)
loss, _ = label_smoothed_nll_loss(
lprobs, labels, self.args.label_smoothing, ignore_index=self.config.pad_token_id
)
return loss, logits

def compute_loss(self, model, inputs):
Expand Down

0 comments on commit 9d46aac

Please sign in to comment.