diff --git a/src/transformers/models/blip_2/modeling_blip_2.py b/src/transformers/models/blip_2/modeling_blip_2.py index 0d09ad7a5312..5006a7819c58 100644 --- a/src/transformers/models/blip_2/modeling_blip_2.py +++ b/src/transformers/models/blip_2/modeling_blip_2.py @@ -285,6 +285,7 @@ class Blip2PreTrainedModel(PreTrainedModel): r"language_model.decoder.embed_tokens.weight", ] _no_split_modules = ["Blip2Attention", "T5Block", "OPTDecoderLayer"] + _keep_in_fp32_modules = ["wo"] def _init_weights(self, module): """Initialize the weights"""