Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

Commit

Permalink
versioning issue (#4164)
Browse files Browse the repository at this point in the history
  • Loading branch information
klshuster authored Nov 12, 2021
1 parent 1874977 commit c7cce82
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 5 deletions.
5 changes: 4 additions & 1 deletion parlai/agents/hugging_face/hugging_face.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,10 @@
raise ImportError('Please run `pip install transformers`.')


HF_VERSION = float('.'.join(transformers.__version__.split('.')[:2]))
HF_VERSION = (
int(transformers.__version__.split('.')[0]),
int(transformers.__version__.split('.')[1]),
)


class HuggingFaceAgent:
Expand Down
10 changes: 9 additions & 1 deletion parlai/agents/hugging_face/t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,16 @@
from parlai.core.torch_generator_agent import TorchGeneratorAgent, TorchGeneratorModel


def check_hf_version(v: Tuple[int, int]) -> bool:
"""
Check that HF version is greater than 4.3
"""
main, sub = v
return main > 4 or (main == 4 and sub >= 3)


def build_t5(opt: Opt) -> T5ForConditionalGeneration:
if not HF_VERSION >= 4.3:
if not check_hf_version(HF_VERSION):
raise RuntimeError('Must use transformers package >= 4.3 to use t5')
return T5ForConditionalGeneration.from_pretrained(
opt['t5_model_arch'], dropout_rate=opt['t5_dropout']
Expand Down
5 changes: 2 additions & 3 deletions tests/nightly/gpu/test_t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,9 @@
try:
import transformers # noqa
from parlai.agents.hugging_face.hugging_face import HF_VERSION
from parlai.agents.hugging_face.t5 import TASK_CONFIGS
from parlai.agents.hugging_face.t5 import set_device
from parlai.agents.hugging_face.t5 import TASK_CONFIGS, check_hf_version, set_device

HF_AVAILABLE = HF_VERSION >= 4.3
HF_AVAILABLE = check_hf_version(HF_VERSION)
except ImportError:
TASK_CONFIGS = None
set_device = unittest.skip
Expand Down

0 comments on commit c7cce82

Please sign in to comment.