From c4d7b08da8f21df4375f3dbcf62d1f300b8f1d6b Mon Sep 17 00:00:00 2001 From: Dexter Ju <5313281+dexterju27@users.noreply.github.com> Date: Fri, 11 Nov 2022 13:41:22 -0500 Subject: [PATCH] Allow flan-t5 models in ParlAI with fp16 improvment (#4875) * allow flan t5 models * enable fp16 --- parlai/agents/hugging_face/t5.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/parlai/agents/hugging_face/t5.py b/parlai/agents/hugging_face/t5.py index 893391dec71..721d6c8d125 100644 --- a/parlai/agents/hugging_face/t5.py +++ b/parlai/agents/hugging_face/t5.py @@ -41,8 +41,9 @@ def check_hf_version(v: Tuple[int, int]) -> bool: def build_t5(opt: Opt) -> T5ForConditionalGeneration: if not check_hf_version(HF_VERSION): raise RuntimeError('Must use transformers package >= 4.3 to use t5') + torch_dtype = torch.float16 if opt['fp16'] else torch.float32 return T5ForConditionalGeneration.from_pretrained( - opt['t5_model_arch'], dropout_rate=opt['t5_dropout'] + opt['t5_model_arch'], dropout_rate=opt['t5_dropout'], torch_dtype=torch_dtype ) @@ -86,7 +87,18 @@ def add_cmdline_args( '--t5-model-arch', type=str, default='t5-base', - choices=["t5-small", "t5-base", "t5-large", "t5-3b", "t5-11b"], + choices=[ + "t5-small", + "t5-base", + "t5-large", + "t5-3b", + "t5-11b", + "google/flan-t5-small", + "google/flan-t5-base", + "google/flan-t5-large", + "google/flan-t5-xl", + "google/flan-t5-xxl", + ], ) group.add_argument( '--t5-model-parallel',