diff --git a/parlai/agents/hugging_face/README.md b/parlai/agents/hugging_face/README.md index 2a670da9fce..f242c21b8e4 100644 --- a/parlai/agents/hugging_face/README.md +++ b/parlai/agents/hugging_face/README.md @@ -3,7 +3,10 @@ We offer wrappers for generative transformers from [Hugging Face's transformers repository](https://github.com/huggingface/transformers) for fine-tuning and evaluating in ParlAI. ## GPT2 -To use GPT2, run your command with the flag: `-m hugging_face/gpt2`. +To use GPT2, run your command with the flag: `-m hugging_face/gpt2`. And suppose you want to use another model other +than the default English GPT2 (small, medium, large and xl version), in that case, you can use `-m hugging_face/gpt2 --model_name `, +where `` can be any GPT2 model hosted in Huggingface such as **anonymous-german-nlp/german-gpt2** +or **indonesian-nlp/gpt2** ### Examples **Talk to GPT2 large in interactive mode, with beam size 10, 3-gram beam blocking, and minimum beam length 25:** diff --git a/parlai/agents/hugging_face/dict.py b/parlai/agents/hugging_face/dict.py index 2a22b8551d0..ed83c3bda7a 100644 --- a/parlai/agents/hugging_face/dict.py +++ b/parlai/agents/hugging_face/dict.py @@ -140,23 +140,26 @@ def get_tokenizer(self, opt): """ Instantiate tokenizer. """ - model_sz = opt["gpt2_size"] - if model_sz == "small": - model_key = "gpt2" - elif model_sz == "distilgpt2": - model_key = "distilgpt2" - else: - model_key = f"gpt2-{model_sz}" - # check if datapath has the files that hugging face code looks for - hf_dir = os.path.join(opt["datapath"], "hf", model_key) - if all( - PathManager.exists(os.path.join(hf_dir, file_name)) - for file_name in ["merges.txt", "vocab.json"] - ): - fle_key = PathManager.get_local_path(hf_dir, recursive=True) - + if opt["model_name"]: + fle_key = opt["model_name"] else: - fle_key = model_key + model_sz = opt["gpt2_size"] + if model_sz == "small": + model_key = "gpt2" + elif model_sz == "distilgpt2": + model_key = "distilgpt2" + else: + model_key = f"gpt2-{model_sz}" + # check if datapath has the files that hugging face code looks for + hf_dir = os.path.join(opt["datapath"], "hf", model_key) + if all( + PathManager.exists(os.path.join(hf_dir, file_name)) + for file_name in ["merges.txt", "vocab.json"] + ): + fle_key = PathManager.get_local_path(hf_dir, recursive=True) + + else: + fle_key = model_key return GPT2Tokenizer.from_pretrained(fle_key) def add_additional_special_tokens(self, additional_special_tokens: List[str]): diff --git a/parlai/agents/hugging_face/gpt2.py b/parlai/agents/hugging_face/gpt2.py index 6adcf466793..b4a7afcd30f 100644 --- a/parlai/agents/hugging_face/gpt2.py +++ b/parlai/agents/hugging_face/gpt2.py @@ -58,23 +58,26 @@ def __init__(self, opt, dict): def _init_from_pretrained(self, opt): # load model - model_sz = opt["gpt2_size"] - if model_sz == "small": - model_key = "gpt2" - elif model_sz == "distilgpt2": - model_key = "distilgpt2" + if opt["model_name"]: + fle_key = opt["model_name"] else: - model_key = f"gpt2-{model_sz}" - - # check if datapath has the files that hugging face code looks for - hf_dir = os.path.join(opt["datapath"], "hf", model_key) - if all( - PathManager.exists(os.path.join(hf_dir, file_name)) - for file_name in ["pytorch_model.bin", "config.json"] - ): - fle_key = PathManager.get_local_path(hf_dir, recursive=True) - else: - fle_key = model_key + model_sz = opt["gpt2_size"] + if model_sz == "small": + model_key = "gpt2" + elif model_sz == "distilgpt2": + model_key = "distilgpt2" + else: + model_key = f"gpt2-{model_sz}" + + # check if datapath has the files that hugging face code looks for + hf_dir = os.path.join(opt["datapath"], "hf", model_key) + if all( + PathManager.exists(os.path.join(hf_dir, file_name)) + for file_name in ["pytorch_model.bin", "config.json"] + ): + fle_key = PathManager.get_local_path(hf_dir, recursive=True) + else: + fle_key = model_key return GPT2Model.from_pretrained(fle_key) def forward(self, input, encoder_state, incr_state=None): @@ -237,6 +240,12 @@ def add_cmdline_args( cls, parser: ParlaiParser, partial_opt: Optional[Opt] = None ) -> ParlaiParser: agent = parser.add_argument_group("Gpt2 Args") + agent.add_argument( + "--model-name", + type=str, + default=None, + help="Any GPT-2 model names.", + ) agent.add_argument( "--gpt2-size", type=str,