diff --git a/parlai/agents/hugging_face/hugging_face.py b/parlai/agents/hugging_face/hugging_face.py index f2cb5bca283..1e2c57a3e9c 100644 --- a/parlai/agents/hugging_face/hugging_face.py +++ b/parlai/agents/hugging_face/hugging_face.py @@ -17,7 +17,10 @@ raise ImportError('Please run `pip install transformers`.') -HF_VERSION = float('.'.join(transformers.__version__.split('.')[:2])) +HF_VERSION = ( + int(transformers.__version__.split('.')[0]), + int(transformers.__version__.split('.')[1]), +) class HuggingFaceAgent: diff --git a/parlai/agents/hugging_face/t5.py b/parlai/agents/hugging_face/t5.py index f4edfc3a8e8..b95257c298e 100644 --- a/parlai/agents/hugging_face/t5.py +++ b/parlai/agents/hugging_face/t5.py @@ -29,8 +29,16 @@ from parlai.core.torch_generator_agent import TorchGeneratorAgent, TorchGeneratorModel +def check_hf_version(v: Tuple[int, int]) -> bool: + """ + Check that HF version is greater than 4.3 + """ + main, sub = v + return main > 4 or (main == 4 and sub >= 3) + + def build_t5(opt: Opt) -> T5ForConditionalGeneration: - if not HF_VERSION >= 4.3: + if not check_hf_version(HF_VERSION): raise RuntimeError('Must use transformers package >= 4.3 to use t5') return T5ForConditionalGeneration.from_pretrained( opt['t5_model_arch'], dropout_rate=opt['t5_dropout'] diff --git a/tests/nightly/gpu/test_t5.py b/tests/nightly/gpu/test_t5.py index d9c76e74c11..4f01443bed4 100644 --- a/tests/nightly/gpu/test_t5.py +++ b/tests/nightly/gpu/test_t5.py @@ -18,10 +18,9 @@ try: import transformers # noqa from parlai.agents.hugging_face.hugging_face import HF_VERSION - from parlai.agents.hugging_face.t5 import TASK_CONFIGS - from parlai.agents.hugging_face.t5 import set_device + from parlai.agents.hugging_face.t5 import TASK_CONFIGS, check_hf_version, set_device - HF_AVAILABLE = HF_VERSION >= 4.3 + HF_AVAILABLE = check_hf_version(HF_VERSION) except ImportError: TASK_CONFIGS = None set_device = unittest.skip