diff --git a/examples/seq2seq/utils.py b/examples/seq2seq/utils.py index 8b24bfdadcf6f4..303b89f78192dc 100644 --- a/examples/seq2seq/utils.py +++ b/examples/seq2seq/utils.py @@ -563,7 +563,7 @@ def freeze_embeds(model): """Freeze token embeddings and positional embeddings for bart, just token embeddings for t5.""" model_type = model.config.model_type - if model_type == "t5": + if model_type in ["t5", "mt5"]: freeze_params(model.shared) for d in [model.encoder, model.decoder]: freeze_params(d.embed_tokens)