diff --git a/src/transformers/models/bart/modeling_bart.py b/src/transformers/models/bart/modeling_bart.py index 1fa6ed1892df..8844f191e1c4 100644 --- a/src/transformers/models/bart/modeling_bart.py +++ b/src/transformers/models/bart/modeling_bart.py @@ -70,6 +70,11 @@ def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int): index_of_eos = (input_ids.ne(pad_token_id).sum(dim=1) - 1).unsqueeze(-1) prev_output_tokens[:, 0] = input_ids.gather(1, index_of_eos).squeeze() prev_output_tokens[:, 1:] = input_ids[:, :-1] + + assert pad_token_id is not None, "self.model.config.pad_token_id has to be defined." + # replace possible -100 values in labels by `pad_token_id` + prev_output_tokens.masked_fill_(prev_output_tokens == -100, pad_token_id) + return prev_output_tokens