diff --git a/src/transformers/trainer_seq2seq.py b/src/transformers/trainer_seq2seq.py index 73d5181ec36401..5513b58bef94b9 100644 --- a/src/transformers/trainer_seq2seq.py +++ b/src/transformers/trainer_seq2seq.py @@ -163,9 +163,11 @@ def prediction_step( if "attention_mask" in inputs: gen_kwargs["attention_mask"] = inputs.get("attention_mask", None) + if "global_attention_mask" in inputs: + gen_kwargs["global_attention_mask"] = inputs.get("global_attention_mask", None) # prepare generation inputs - # some encoder-decoder models can have varying encder's and thus + # some encoder-decoder models can have varying encoder's and thus # varying model input names if hasattr(self.model, "encoder") and self.model.encoder.main_input_name != self.model.main_input_name: generation_inputs = inputs[self.model.encoder.main_input_name]