From ec6616c9760aa3d94d23ceec7dc7cdf01606537b Mon Sep 17 00:00:00 2001 From: wulixuan Date: Fri, 19 Jan 2024 14:25:05 +0800 Subject: [PATCH] feat: support Model Yuan2.0 --- docs/model_support.md | 4 ++++ fastchat/conversation.py | 29 ++++++++++++++++++++++++ fastchat/model/model_adapter.py | 39 ++++++++++++++++++++++++++++++-- fastchat/model/model_registry.py | 6 +++++ 4 files changed, 76 insertions(+), 2 deletions(-) diff --git a/docs/model_support.md b/docs/model_support.md index 472cc0a4a..2bc45f477 100644 --- a/docs/model_support.md +++ b/docs/model_support.md @@ -62,6 +62,10 @@ - [WizardLM/WizardCoder-15B-V1.0](https://huggingface.co/WizardLM/WizardCoder-15B-V1.0) - [Xwin-LM/Xwin-LM-7B-V0.1](https://huggingface.co/Xwin-LM/Xwin-LM-70B-V0.1) - [IEITYuan/Yuan2-2B/51B/102B-hf](https://huggingface.co/IEITYuan) +- [IEITYuan/Yuan2-2B-Janus-hf](https://huggingface.co/IEITYuan/Yuan2-2B-Janus-hf) +- [IEITYuan/Yuan2-2B-hf](https://huggingface.co/IEITYuan/Yuan2-2B-hf) +- [IEITYuan/Yuan2-51B-hf](https://huggingface.co/IEITYuan/Yuan2-51B-hf) +- [IEITYuan/Yuan2-102B-hf](https://huggingface.co/IEITYuan/Yuan2-102B-hf) - Any [EleutherAI](https://huggingface.co/EleutherAI) pythia model such as [pythia-6.9b](https://huggingface.co/EleutherAI/pythia-6.9b) - Any [Peft](https://github.com/huggingface/peft) adapter trained on top of a model above. To activate, must have `peft` in the model path. Note: If diff --git a/fastchat/conversation.py b/fastchat/conversation.py index 3d738def5..9955b3e00 100644 --- a/fastchat/conversation.py +++ b/fastchat/conversation.py @@ -31,6 +31,7 @@ class SeparatorStyle(IntEnum): CHATGLM3 = auto() DEEPSEEK_CHAT = auto() METAMATH = auto() + YUAN2 = auto() @dataclasses.dataclass @@ -245,6 +246,18 @@ def get_prompt(self) -> str: else: ret += role + ":" return ret + elif self.sep_style == SeparatorStyle.YUAN2: + seps = [self.sep, self.sep2] + ret = "" + if self.system_message: + ret += system_prompt + seps[1] + for _, message in self.messages: + if message: + ret += message + '' + else: + ret += "" + ret = ret.rstrip('') + seps[0] + return ret else: raise ValueError(f"Invalid style: {self.sep_style}") @@ -1420,6 +1433,22 @@ def get_conv_template(name: str) -> Conversation: ) ) +# Yuan2.0 chat template +# source: https://huggingface.co/IEITYuan/Yuan2-2B-Janus-hf/blob/main/tokenizer_config.json#L6 +register_conv_template( + Conversation( + name="yuan2", + roles=("user", "assistant"), + sep_style=SeparatorStyle.YUAN2, + sep="", + sep2="\n", + stop_token_ids=[ + 77185, + ], # "" + stop_str="", + ) +) + # Solar-10.7B Chat Template # Reference: https://huggingface.co/upstage/SOLAR-10.7B-Instruct-v1.0/blob/main/tokenizer_config.json register_conv_template( diff --git a/fastchat/model/model_adapter.py b/fastchat/model/model_adapter.py index 91881d214..445d480bc 100644 --- a/fastchat/model/model_adapter.py +++ b/fastchat/model/model_adapter.py @@ -2101,6 +2101,40 @@ def match(self, model_path: str): def get_default_conv_template(self, model_path: str) -> Conversation: return get_conv_template("deepseek-chat") +class Yuan2Adapter(BaseModelAdapter): + """The model adapter for Yuan2.0""" + + def match(self, model_path: str): + return "yuan2" in model_path.lower() + + def load_model(self, model_path: str, from_pretrained_kwargs: dict): + revision = from_pretrained_kwargs.get("revision", "main") + # from_pretrained_kwargs["torch_dtype"] = torch.bfloat16 + tokenizer = LlamaTokenizer.from_pretrained( + model_path, + add_eos_token=False, + add_bos_token=False, + eos_token='', + eod_token='', + sep_token='', + revision = revision, + ) + tokenizer.add_tokens( + ['', '', '', '', '', '', '', '', + '', '', '', '', '', + '', ''], special_tokens=True) + + model = AutoModelForCausalLM.from_pretrained( + model_path, + # device_map='auto', + trust_remote_code=True, + **from_pretrained_kwargs + ) + return model, tokenizer + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("yuan2") + class MetaMathAdapter(BaseModelAdapter): """The model adapter for MetaMath models""" @@ -2132,7 +2166,7 @@ def get_default_conv_template(self, model_path: str) -> Conversation: return get_conv_template("solar") -class Yuan2Adapter(BaseModelAdapter): +class YuanAdapter(BaseModelAdapter): """The model adapter for Yuan""" def match(self, model_path: str): @@ -2248,10 +2282,11 @@ def get_default_conv_template(self, model_path: str) -> Conversation: register_model_adapter(PplxAIAdapter) register_model_adapter(DeepseekCoderAdapter) register_model_adapter(DeepseekChatAdapter) +register_model_adapter(Yuan2Adapter) register_model_adapter(MetaMathAdapter) register_model_adapter(BagelAdapter) register_model_adapter(SolarAdapter) -register_model_adapter(Yuan2Adapter) +register_model_adapter(YuanAdapter) # After all adapters, try the default base adapter. register_model_adapter(BaseModelAdapter) diff --git a/fastchat/model/model_registry.py b/fastchat/model/model_registry.py index 0e8b650ff..361844fb0 100644 --- a/fastchat/model/model_registry.py +++ b/fastchat/model/model_registry.py @@ -27,6 +27,12 @@ def get_model_info(name: str) -> ModelInfo: name, "", "Register the description at fastchat/model/model_registry.py" ) +register_model_info( + ["IEITYuan/Yuan2-2B-Janus-hf", "IEITYuan/Yuan2-2B-hf", "IEITYuan/Yuan2-51B-hf", "IEITYuan/Yuan2-102B-hf"], + "IEIT-Yuan2", + "https://github.com/IEIT-Yuan/Yuan-2.0", + "Yuan2.0 is a new generation Fundamental Large Language Model developed by IEIT System.", +) register_model_info( ["mixtral-8x7b-instruct-v0.1", "mistral-7b-instruct"],