From daddaf7af0ef03b74351cf32817d51abdba9138b Mon Sep 17 00:00:00 2001 From: Stas Bekman Date: Fri, 29 Jan 2021 08:11:22 -0800 Subject: [PATCH] correctly handle mt5 (#9879) --- examples/seq2seq/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/seq2seq/utils.py b/examples/seq2seq/utils.py index 8b24bfdadcf6f4..303b89f78192dc 100644 --- a/examples/seq2seq/utils.py +++ b/examples/seq2seq/utils.py @@ -563,7 +563,7 @@ def freeze_embeds(model): """Freeze token embeddings and positional embeddings for bart, just token embeddings for t5.""" model_type = model.config.model_type - if model_type == "t5": + if model_type in ["t5", "mt5"]: freeze_params(model.shared) for d in [model.encoder, model.decoder]: freeze_params(d.embed_tokens)