Skip to content

Commit

Permalink
Add global_attention_mask to gen_kwargs (#16485)
Browse files Browse the repository at this point in the history
If global_attention_mask is found in the models inputs (used by certain
models, like LED) in the prediction_step method of Seq2SeqTrainer,
it is added to the gen_kwargs, which are passed to model.decode().
This allows us to properly set the global attention when decoding.
  • Loading branch information
JohnGiorgi authored Apr 5, 2022
1 parent 9fd5e6b commit b33ab4e
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion src/transformers/trainer_seq2seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,9 +163,11 @@ def prediction_step(

if "attention_mask" in inputs:
gen_kwargs["attention_mask"] = inputs.get("attention_mask", None)
if "global_attention_mask" in inputs:
gen_kwargs["global_attention_mask"] = inputs.get("global_attention_mask", None)

# prepare generation inputs
# some encoder-decoder models can have varying encder's and thus
# some encoder-decoder models can have varying encoder's and thus
# varying model input names
if hasattr(self.model, "encoder") and self.model.encoder.main_input_name != self.model.main_input_name:
generation_inputs = inputs[self.model.encoder.main_input_name]
Expand Down

0 comments on commit b33ab4e

Please sign in to comment.