From 521adb2121e33142d4e28df8148ed93c03b7f32e Mon Sep 17 00:00:00 2001 From: sgugger Date: Tue, 17 Nov 2020 07:59:30 -0500 Subject: [PATCH] Fix init for MT5 --- src/transformers/__init__.py | 3 ++- src/transformers/utils/dummy_pt_objects.py | 5 ----- 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 65ad1bbfcd8667..f41a8d8f763472 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -145,6 +145,7 @@ from .models.mbart import MBartConfig from .models.mmbt import MMBTConfig from .models.mobilebert import MOBILEBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, MobileBertConfig, MobileBertTokenizer +from .models.mt5 import MT5Config from .models.openai import OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP, OpenAIGPTConfig, OpenAIGPTTokenizer from .models.pegasus import PegasusConfig from .models.phobert import PhobertTokenizer @@ -498,7 +499,7 @@ MobileBertPreTrainedModel, load_tf_weights_in_mobilebert, ) - from .models.mt5 import MT5Config, MT5ForConditionalGeneration, MT5Model + from .models.mt5 import MT5ForConditionalGeneration, MT5Model from .models.openai import ( OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_LIST, OpenAIGPTDoubleHeadsModel, diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index 596992e54a2de4..b0e81bd8cbc105 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -1361,11 +1361,6 @@ def load_tf_weights_in_mobilebert(*args, **kwargs): requires_pytorch(load_tf_weights_in_mobilebert) -class MT5Config: - def __init__(self, *args, **kwargs): - requires_pytorch(self) - - class MT5ForConditionalGeneration: def __init__(self, *args, **kwargs): requires_pytorch(self)