diff --git a/applications/DeepSpeed-Chat/inference/chatbot.py b/applications/DeepSpeed-Chat/inference/chatbot.py index 38b900d7d..5a4e36895 100644 --- a/applications/DeepSpeed-Chat/inference/chatbot.py +++ b/applications/DeepSpeed-Chat/inference/chatbot.py @@ -10,7 +10,7 @@ import os import json from transformers import pipeline, set_seed -from transformers import AutoConfig, OPTForCausalLM, AutoTokenizer +from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM def parse_args(): @@ -43,9 +43,10 @@ def get_generator(path): tokenizer.pad_token = tokenizer.eos_token model_config = AutoConfig.from_pretrained(path) - model = OPTForCausalLM.from_pretrained(path, - from_tf=bool(".ckpt" in path), - config=model_config).half() + model_class = AutoModelForCausalLM.from_config(model_config) + model = model_class.from_pretrained(path, + from_tf=bool(".ckpt" in path), + config=model_config).half() model.config.end_token_id = tokenizer.eos_token_id model.config.pad_token_id = model.config.eos_token_id