Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add global_attention_mask to gen_kwargs (#16485)
If global_attention_mask is found in the models inputs (used by certain models, like LED) in the prediction_step method of Seq2SeqTrainer, it is added to the gen_kwargs, which are passed to model.decode(). This allows us to properly set the global attention when decoding.
- Loading branch information