From 75581c1384cf880d4cd68a8519aafd604d835e66 Mon Sep 17 00:00:00 2001 From: algoroxyolo Date: Tue, 6 Feb 2024 19:03:05 +0000 Subject: [PATCH 1/3] yaya --- code/ChatHaruhi/Mixtral.py | 53 ++++++++++++++++++++++++++ code/ChatHaruhi/Qwen118k2GPT.py | 50 +++++++------------------ code/ChatHaruhi/phi.py | 66 +++++++++++++++++++++++++++++++++ code/ChatHaruhi/qwen.py | 60 ++++++++++++++++++++++++++++++ 4 files changed, 193 insertions(+), 36 deletions(-) create mode 100644 code/ChatHaruhi/Mixtral.py create mode 100644 code/ChatHaruhi/phi.py create mode 100644 code/ChatHaruhi/qwen.py diff --git a/code/ChatHaruhi/Mixtral.py b/code/ChatHaruhi/Mixtral.py new file mode 100644 index 0000000..1543d3c --- /dev/null +++ b/code/ChatHaruhi/Mixtral.py @@ -0,0 +1,53 @@ +from .BaseLLM import BaseLLM +import torch +from transformers import AutoTokenizer, AutoModelForCausalLM +from transformers import LlamaTokenizer, MixtralForCausalLM +import bitsandbytes, flash_attn +tokenizer_LLaMA = None +model_LLaMA = None + +def initialize_Mixtral(): + global model_LLaMA, tokenizer_LLaMA + + if model_LLaMA is None: + model_LLaMA = MixtralForCausalLM.from_pretrained( + "NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO", + torch_dtype=torch.float16, + device_map="auto" + ) + + if tokenizer_LLaMA is None: + tokenizer_LLaMA = LlamaTokenizer.from_pretrained('NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO', trust_remote_code=True) + + return model_LLaMA, tokenizer_LLaMA + +def LLaMA_tokenizer(text): + return len(tokenizer_LLaMA.encode(text)) + +class ChatMixtral(BaseLLM): + def __init__(self, model="Mixtral"): + super(ChatMixtral, self).__init__() + self.model, self.tokenizer = initialize_Mixtral() + self.messages = "" + + def initialize_message(self): + self.messages = "" + + def ai_message(self, payload): + self.messages = self.messages + "\n " + payload + + def system_message(self, payload): + self.messages = self.messages + "\n " + payload + + def user_message(self, payload): + self.messages = self.messages + "\n " + payload + + def get_response(self): + with torch.no_grad(): + input_ids = self.tokenizer(self.messages, return_tensors="pt").input_ids.to("cuda") + generated_ids = self.model.generate(input_ids, max_new_tokens=750, temperature=0.8, repetition_penalty=1.1, do_sample=True, eos_token_id=self.tokenizer.eos_token_id) + response = self.tokenizer.decode(generated_ids[0][input_ids.shape[-1]:], skip_special_tokens=True, clean_up_tokenization_space=True) + return response + + def print_prompt(self): + print(self.messages) diff --git a/code/ChatHaruhi/Qwen118k2GPT.py b/code/ChatHaruhi/Qwen118k2GPT.py index 6fc7af3..f194bed 100644 --- a/code/ChatHaruhi/Qwen118k2GPT.py +++ b/code/ChatHaruhi/Qwen118k2GPT.py @@ -1,21 +1,16 @@ -import torch +import torch from .BaseLLM import BaseLLM -from transformers import AutoTokenizer, AutoModel -from peft import PeftModel -from transformers import AutoModelForCausalLM, AutoTokenizer -from transformers.generation import GenerationConfig - +from transformers import AutoTokenizer, AutoModelForCausalLM +import pdb tokenizer_qwen = None model_qwen = None - - - -def initialize_Qwen2LORA(model): +# Load model directly +def initialize_Qwen2LORA(): global model_qwen, tokenizer_qwen if model_qwen is None: model_qwen = AutoModelForCausalLM.from_pretrained( - model, + "silk-road/ChatHaruhi_RolePlaying_qwen_7b", device_map="auto", trust_remote_code=True ) @@ -27,49 +22,34 @@ def initialize_Qwen2LORA(model): if tokenizer_qwen is None: tokenizer_qwen = AutoTokenizer.from_pretrained( - model, + "silk-road/ChatHaruhi_RolePlaying_qwen_7b", # use_fast=True, trust_remote_code=True ) return model_qwen, tokenizer_qwen -def Qwen_tokenizer(text): + +def LLaMA_tokenizer(text): return len(tokenizer_qwen.encode(text)) class Qwen118k2GPT(BaseLLM): - def __init__(self, model): + def __init__(self, model="qwen-118k"): super(Qwen118k2GPT, self).__init__() - global model_qwen, tokenizer_qwen - if model == "Qwen/Qwen-1_8B-Chat": - tokenizer_qwen = AutoTokenizer.from_pretrained( - "Qwen/Qwen-1_8B-Chat", - trust_remote_code=True - ) - model_qwen = AutoModelForCausalLM.from_pretrained( - "Qwen/Qwen-1_8B-Chat", - device_map="auto", - trust_remote_code=True - ).eval() - self.model = model_qwen - self.tokenizer = tokenizer_qwen - elif "silk-road/" in model : - self.model, self.tokenizer = initialize_Qwen2LORA(model) - else: - raise Exception("Unknown Qwen model") + self.model, self.tokenizer = initialize_Qwen2LORA() self.messages = "" def initialize_message(self): self.messages = "" def ai_message(self, payload): - self.messages = "AI: " + self.messages + "\n " + payload + self.messages = self.messages + "\n " + payload def system_message(self, payload): - self.messages = "SYSTEM PROMPT: " + self.messages + "\n " + payload + self.messages = self.messages + "\n " + payload def user_message(self, payload): - self.messages = "User: " + self.messages + "\n " + payload + self.messages = self.messages + "\n " + payload def get_response(self): with torch.no_grad(): @@ -80,5 +60,3 @@ def get_response(self): def print_prompt(self): print(type(self.messages)) print(self.messages) - - diff --git a/code/ChatHaruhi/phi.py b/code/ChatHaruhi/phi.py new file mode 100644 index 0000000..f5ce20e --- /dev/null +++ b/code/ChatHaruhi/phi.py @@ -0,0 +1,66 @@ +import torch +from .BaseLLM import BaseLLM +from transformers import AutoTokenizer, PhiForCausalLM +tokenizer_phi = None +model_phi = None +# Load model directly +def initialize_phi(): + global model_phi, tokenizer_phi + + if model_phi is None: + model_phi = PhiForCausalLM.from_pretrained( + "cognitivecomputations/dolphin-2_6-phi-2", + local_files_only=True, + torch_dtype=torch.float16, + device_map="auto", + ) + + if tokenizer_phi is None: + tokenizer_phi = AutoTokenizer.from_pretrained( + "cognitivecomputations/dolphin-2_6-phi-2", + local_files_only=True, + use_fast=True, + ) + + + + + return model_phi, tokenizer_phi + +def LLaMA_tokenizer(text): + return len(tokenizer_phi.encode(text)) + +class Chatphi(BaseLLM): + def __init__(self, model="phi"): + super(Chatphi, self).__init__() + self.model, self.tokenizer = initialize_phi() + self.messages = "" + + def initialize_message(self): + self.messages = "" + + def ai_message(self, payload): + self.messages = self.messages + "\n " + payload + + def system_message(self, payload): + self.messages = self.messages + "\n " + payload + + def user_message(self, payload): + self.messages = self.messages + "\n " + payload + + def get_response(self): + with torch.no_grad(): + # Prepare the model input with attention mask + inputs = self.tokenizer(self.messages, return_tensors="pt", padding=True, truncation=True) + attention_mask = inputs['attention_mask'] + + # Generate the model output using the prepared input and attention mask + outputs = self.model.generate(input_ids=inputs['input_ids'], attention_mask=attention_mask, max_length=114514) + response = self.tokenizer.batch_decode(outputs, skip_special_tokens=True)[0] + + return response + + + def print_prompt(self): + print(type(self.messages)) + print(self.messages) diff --git a/code/ChatHaruhi/qwen.py b/code/ChatHaruhi/qwen.py new file mode 100644 index 0000000..30d804f --- /dev/null +++ b/code/ChatHaruhi/qwen.py @@ -0,0 +1,60 @@ +import torch +from .BaseLLM import BaseLLM +from transformers import AutoTokenizer, AutoModelForCausalLM +import pdb +tokenizer_qwen = None +model_qwen = None +# Load model directly +def initialize_qwen(): + global model_qwen, tokenizer_qwen + + if model_qwen is None: + model_qwen = AutoModelForCausalLM.from_pretrained( + "Qwen/Qwen-7B-Chat", + torch_dtype=torch.float16, + device_map="auto", + trust_remote_code=True + ) + + if tokenizer_qwen is None: + tokenizer_qwen = AutoTokenizer.from_pretrained( + "Qwen/Qwen-7B-Chat", + use_fast=True, + trust_remote_code=True + ) + + + + + return model_qwen, tokenizer_qwen + +def LLaMA_tokenizer(text): + return len(tokenizer_qwen.encode(text)) + +class ChatQwen(BaseLLM): + def __init__(self, model="qwen7b"): + super(ChatQwen, self).__init__() + self.model, self.tokenizer = initialize_qwen() + self.messages = "" + + def initialize_message(self): + self.messages = "" + + def ai_message(self, payload): + self.messages = self.messages + "\n " + payload + + def system_message(self, payload): + self.messages = self.messages + "\n " + payload + + def user_message(self, payload): + self.messages = self.messages + "\n " + payload + + def get_response(self): + with torch.no_grad(): + response, history = self.model.chat(self.tokenizer, self.messages, history=[]) + # print(response) + return response + + def print_prompt(self): + print(type(self.messages)) + print(self.messages) From 2e5d10e5d64dc5ba4d4d113e1f0bdb7472428d69 Mon Sep 17 00:00:00 2001 From: algoroxyolo Date: Wed, 7 Feb 2024 05:34:51 +0000 Subject: [PATCH 2/3] @ --- code/ChatHaruhi/ChatHaruhi.py | 67 +++++++++--------- code/personality_tests.py | 125 +++++++++++++--------------------- 2 files changed, 82 insertions(+), 110 deletions(-) diff --git a/code/ChatHaruhi/ChatHaruhi.py b/code/ChatHaruhi/ChatHaruhi.py index 85f7bca..382974a 100644 --- a/code/ChatHaruhi/ChatHaruhi.py +++ b/code/ChatHaruhi/ChatHaruhi.py @@ -90,6 +90,16 @@ def __init__(self, system_prompt = None, \ self.llm, self.tokenizer = self.get_models(llm) elif "llama" in llm: self.llm, self.tokenizer = self.get_models(llm) + elif "phi" in llm: + self.llm, self.tokenizer = self.get_models(llm) + elif "Mixtral" in llm: + self.llm, self.tokenizer = self.get_models(llm) + elif "Qwen-118k" in llm: + self.llm, self.tokenizer = self.get_models(llm) + elif "mistral" in llm: + self.llm, self.tokenizer = self.get_models(llm) + elif "openChat" in llm: + self.llm, self.tokenizer = self.get_models(llm) else: print(f'warning! undefined llm {llm}, use openai instead.') self.llm, self.tokenizer = self.get_models('openai') @@ -306,30 +316,26 @@ def get_models(self, model_name): from .llama2 import ChatLLaMA return (ChatLLaMA(), tiktokenizer) elif "qwen" in model_name: - if model_name == "qwen118k_raw": - from .Qwen118k2GPT import Qwen118k2GPT, Qwen_tokenizer - return (Qwen118k2GPT(model = "Qwen/Qwen-1_8B-Chat"), Qwen_tokenizer) - from huggingface_hub import HfApi - from huggingface_hub.hf_api import ModelFilter - qwen_api = HfApi() - qwen_models = qwen_api.list_models( - filter = ModelFilter(model_name=model_name), - author = "silk-road" - ) - qwen_models_id = [] - for qwen_model in qwen_models: - qwen_models_id.append(qwen_model.id) - # print(model.id) - if "silk-road/" + model_name in qwen_models_id: - from .Qwen118k2GPT import Qwen118k2GPT, Qwen_tokenizer - return (Qwen118k2GPT(model = "silk-road/" + model_name), Qwen_tokenizer) - else: - print(f'warning! undefined model {model_name}, use openai instead.') - from .LangChainGPT import LangChainGPT - return (LangChainGPT(), tiktokenizer) + from.qwen import ChatQwen + return (ChatQwen(), tiktokenizer) # print(models_id) + elif model_name == "phi": + from.phi import Chatphi + return (Chatphi(), tiktokenizer) + elif "Mixtral" in model_name: + from.Mixtral import ChatMixtral + return (ChatMixtral(), tiktokenizer) + elif model_name == "Qwen-118k": + from .Qwen118k2GPT import Qwen118k2GPT + return (Qwen118k2GPT(), tiktokenizer) + elif "mistral" in model_name: + from.mistral import ChatMistral + return (ChatMistral(), tiktokenizer) + elif "openChat" in model_name: + from.openChat import ChatOpenChat + return (ChatOpenChat(), tiktokenizer) else: - print(f'warning! undefined model {model_name}, use openai instead.') + print(f'warning! undefinecd model {model_name}, use openai instead.') from .LangChainGPT import LangChainGPT return (LangChainGPT(), tiktokenizer) @@ -448,14 +454,11 @@ def chat(self, text, role, nth_test): # add system prompt self.llm.initialize_message() - - if not 'no_description' in self.llm_type.split('='): - self.llm.system_message(self.system_prompt) - + self.llm.system_message(self.system_prompt) + + # add story query = self.get_query_string(text, role) - if not 'no_retrieve' in self.llm_type.split('='): - # add story - self.add_story( query ) + self.add_story( query ) # add history self.add_history() @@ -467,10 +470,12 @@ def chat(self, text, role, nth_test): response_raw = self.llm.get_response() response = response_postprocess(response_raw, self.dialogue_bra_token, self.dialogue_ket_token) - + # record dialogue history self.dialogue_history.append((query, response)) + + return response def get_query_string(self, text, role): @@ -485,7 +490,7 @@ def add_story(self, query): return query_vec = self.embedding(query) - + stories = self.db.search(query_vec, self.k_search) story_string = self.story_prefix_prompt diff --git a/code/personality_tests.py b/code/personality_tests.py index fbeed36..ed013b4 100644 --- a/code/personality_tests.py +++ b/code/personality_tests.py @@ -33,7 +33,7 @@ # Added choices for the agent_llm argument parser.add_argument('--agent_llm', type=str, default='gpt-3.5-turbo', - choices=['gpt-3.5-turbo', 'openai', 'GLMPro', 'ChatGLM2GPT',"qwen118k_raw","llama2"], + choices=['gpt-3.5-turbo', 'openChat', 'mistral', 'ChatGLM2GPT',"qwen-118k","llama2","Mixtral"], help='agent LLM (gpt-3.5-turbo)') # Added choices for the evaluator_llm argument @@ -71,10 +71,6 @@ def load_questionnaire(questionnaire_name): # read this jsonl file with open(q_path, 'r', encoding='utf-8') as f: questionnaire = json.load(f) - - if questionnaire_name not in dims_dict: - dims_dict[questionnaire_name] = [ _['cat_name'] for _ in questionnaire['categories'] ] - return questionnaire def subsample_questionnaire(questionnaire, n=20): @@ -149,8 +145,6 @@ def split_list(input_list, n=4): def build_character_agent(character_code, agent_type, agent_llm): from ChatHaruhi import ChatHaruhi - agent_type_args = agent_type.split('=', 1) - if agent_llm.startswith('gpt-'): if agent_llm.startswith('gpt-3.5'): agent_llm = 'gpt-3.5-turbo-1106' @@ -158,25 +152,20 @@ def build_character_agent(character_code, agent_type, agent_llm): agent_llm = 'gpt-4-1106-preview' os.environ["OPENAI_API_KEY"] = config['openai_apikey'] - - if agent_type_args[0] == 'ChatHaruhi': + + if agent_type == 'ChatHaruhi': character_agent = ChatHaruhi(role_name = character_info[character_code]["agent"]["ChatHaruhi"], llm = 'openai') - elif agent_type_args[0] == 'RoleLLM': + elif agent_type == 'RoleLLM': character_agent = ChatHaruhi( role_from_hf = f'silk-road/ChatHaruhi-from-RoleLLM/{character_info[character_code]["agent"]["RoleLLM"]}', llm = 'openai', embedding = 'bge_en') character_agent.role_name = 'RoleLLM/' + character_info[character_code]["agent"]["RoleLLM"] character_agent.llm.model = agent_llm - character_agent.llm_type = agent_llm # just to set different keys for cache - else: - if agent_type_args[0] == 'ChatHaruhi': + if agent_type == 'ChatHaruhi': os.environ["OPENAI_API_KEY"] = config['openai_apikey'] character_agent = ChatHaruhi(role_name = character_info[character_code]["agent"]["ChatHaruhi"], llm = agent_llm) #character_agent.llm.chat.temperature = 0 - - if len(agent_type_args) > 1: - character_agent.llm_type = character_agent.llm_type + '=' + agent_type_args[1] return character_agent @@ -198,9 +187,9 @@ def interview(character_agent, questionnaire, experimenter, questionnaire_prompt elif query_style.startswith('choose'): q = questionnaire_prompts["rpa_choose_prefix"][language].replace('', question[f'origin_{language}']) + ' ' + questionnaire_prompts["rpa_choice_instruction"][language] - if query_style == 'choosecot2': + if query_style == 'choosecot': if language == 'en': - q = q.replace('Please answer with the number only, without anything else.', 'Please think step by step. Start by sharing your thoughts, then proceed to present the number.') + q = q.replace('Please answer with the number only, without anything else.', 'Please give your reasons.') else: q = q.replace('请你只回答这一个数字,不要说其他内容。', '请给出你的理由。') @@ -256,10 +245,8 @@ def assess(character_aliases, experimenter, questionnaire_results, questionnaire global previous_file_path - previous_file_path_cp = previous_file_path.replace('../results/assessment/', '../results/assessment_cp/') - if os.path.exists(previous_file_path_cp): - previous_file_path = previous_file_path_cp - + previous_file_path = previous_file_path.replace('../results/assessment/', '../results/assessment_cp/') + if True and os.path.exists(previous_file_path): with open(previous_file_path, 'r') as f: @@ -302,9 +289,6 @@ def assess(character_aliases, experimenter, questionnaire_results, questionnaire else: need_convert = questionnaire_results - - - if 'adjoption' in eval_args: # split need_convert based on dimension @@ -345,70 +329,55 @@ def assess(character_aliases, experimenter, questionnaire_results, questionnaire else: sys_prompt = (questionnaire_metadata["prompts"]["convert_to_choice"]['en'] + '\n' + questionnaire_metadata["prompts"]["llm_choice_instruction"]['en']).replace('', character_name) - # control batch size - from utils import count_tokens - - if evaluator_llm == 'gpt-3.5' and count_tokens(json.dumps(need_convert, indent=4, ensure_ascii=False), evaluator_llm) > 15500: - - need_convert_list = [ {str(j+1): need_convert[str(j+1)] for j in range(i, i+30)} for i in range(0, len(need_convert), 30)] - else: - need_convert_list = [ need_convert ] + user_input = json.dumps(need_convert, indent=4, ensure_ascii=False) - for need_convert in need_convert_list: - user_input = json.dumps(need_convert, indent=4, ensure_ascii=False) + if 'anonymous' in eval_args: + for a in character_aliases: + sys_prompt = sys_prompt.replace(a, '') + user_input = user_input.replace(a, '') + sys_prompt = sys_prompt.replace(experimenter, '') + user_input = user_input.replace(experimenter, '') - if 'anonymous' in eval_args: - for a in character_aliases: - sys_prompt = sys_prompt.replace(a, '') - user_input = user_input.replace(a, '') - sys_prompt = sys_prompt.replace(experimenter, '') - user_input = user_input.replace(experimenter, '') + from utils import string2json_ensure_keys + - from utils import string2json_ensure_keys + if evaluator_llm.startswith('gpt'): + # call llm to convert to choices + converted_choices = get_response_json([string2json_ensure_keys], sys_prompt = sys_prompt, inputs = user_input, model=evaluator_llm) + else: + from utils import string2json_ensure_choice_format + sys_prompt = sys_prompt + '\n===OUTPUT EXAMPLE===\n{\n \"1\": 1,\n ...\n \"9\": 0\n}===My Input Is===' + + converted_choices = get_response_json([string2json_ensure_choice_format, string2json_ensure_keys], sys_prompt = sys_prompt, inputs = user_input, model=evaluator_llm) + - if evaluator_llm.startswith('gpt'): - # call llm to convert to choices - try: - converted_choices = get_response_json([string2json_ensure_keys], sys_prompt = sys_prompt, inputs = user_input, model=evaluator_llm) - except: - import pdb; pdb.set_trace() - - else: - from utils import string2json_ensure_choice_format - sys_prompt = sys_prompt + '\n===OUTPUT EXAMPLE===\n{\n \"1\": 1,\n ...\n \"9\": 0\n}===My Input Is===' + if 'adjoption' in eval_args: + # convert 'negative' question choices. I.e. strongly extraverted (5) -> strongly disagree (1). + for idx, choice in converted_choices.items(): + dim = idx2dimension[idx] + category = idx2category[idx] - converted_choices = get_response_json([string2json_ensure_choice_format, string2json_ensure_keys], sys_prompt = sys_prompt, inputs = user_input, model=evaluator_llm) + if questionnaire_name == '16Personalities': + category = 'positive' if category == dim[0] else 'negative' + + if category == 'negative' and choice != 'x': + converted_choices[idx] = questionnaire_metadata['range'][0] + questionnaire_metadata['range'][1] - float(choice) + + assert( len(need_convert.keys() - converted_choices.keys()) == 0 ) + - if 'adjoption' in eval_args: - # convert 'negative' question choices. I.e. strongly extraverted (5) -> strongly disagree (1). - for idx, choice in converted_choices.items(): - dim = idx2dimension[idx] - category = idx2category[idx] - - if questionnaire_name == '16Personalities': - category = 'positive' if category == dim[0] else 'negative' - - if category == 'negative' and choice != 'x': - converted_choices[idx] = questionnaire_metadata['range'][0] + questionnaire_metadata['range'][1] - float(choice) - - - - assert( len(need_convert.keys() - converted_choices.keys()) == 0 ) - - choices.update(converted_choices) + choices.update(converted_choices) for idx, choice in choices.items(): - if choice == 'x' or choice is None: + if choice == 'x': choice = (questionnaire_metadata['range'][0] + questionnaire_metadata['range'][1] ) / 2 - + choice = float(choice) - - dim = idx2dimension[idx] category = idx2category[idx] @@ -431,7 +400,6 @@ def assess(character_aliases, experimenter, questionnaire_results, questionnaire dim_responses = [r for i, r in enumerate(questionnaire_results) if questionnaire[i]['dimension'] == dim] if nth_test > 0: - random.seed(nth_test) random.shuffle(dim_responses) eval_setting = eval_args[2] @@ -556,9 +524,8 @@ def assess(character_aliases, experimenter, questionnaire_results, questionnaire from api_16personality import submit_16personality_api - pred = submit_16personality_api(answers) - + pred = submit_16personality_api(answers) #assessment_results = { dim: {'score': pred[dim]['score'][dim[0]]} for dim in dims } for dim in dims: @@ -880,7 +847,8 @@ def calculate_measured_alignment(preds, labels, questionnaire_name, labels_pdb): agent_types = list(set([ rpa[1] for rpa in preds.keys()])) - dims = dims_dict[questionnaire_name] + dims = dims_dict[questionnaire_name] + questionnaire_metadata = load_questionnaire(questionnaire_name) @@ -932,7 +900,6 @@ def calculate_measured_alignment(preds, labels, questionnaire_name, labels_pdb): full_correct = False sum_mse_each_dim[a][dim] += ((pred_score - label_score) / range_span) ** 2 - sum_mae_each_dim[a][dim] += abs((pred_score - label_score) / range_span) if full_correct: From 9c9e91e90b9b749c27e54c6371d1f9f74c46aa4a Mon Sep 17 00:00:00 2001 From: algoroxyolo Date: Wed, 7 Feb 2024 06:07:09 +0000 Subject: [PATCH 3/3] yay --- code/ChatHaruhi/mistral.py | 53 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 53 insertions(+) create mode 100644 code/ChatHaruhi/mistral.py diff --git a/code/ChatHaruhi/mistral.py b/code/ChatHaruhi/mistral.py new file mode 100644 index 0000000..27e8a51 --- /dev/null +++ b/code/ChatHaruhi/mistral.py @@ -0,0 +1,53 @@ +from .BaseLLM import BaseLLM +import torch +from transformers import AutoTokenizer, AutoModelForCausalLM +import bitsandbytes, flash_attn +tokenizer_LLaMA = None +model_LLaMA = None + +def initialize_Mistral(): + global model_LLaMA, tokenizer_LLaMA + + if model_LLaMA is None: + model_LLaMA = AutoModelForCausalLM.from_pretrained( + "mistralai/Mistral-7B-Instruct-v0.2", + torch_dtype=torch.float16, + device_map="auto" + ) + + if tokenizer_LLaMA is None: + tokenizer_LLaMA = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2", trust_remote_code=True) + + return model_LLaMA, tokenizer_LLaMA + +def LLaMA_tokenizer(text): + return len(tokenizer_LLaMA.encode(text)) + +class ChatMistral(BaseLLM): + def __init__(self, model="Mistral"): + super(ChatMistral, self).__init__() + self.model, self.tokenizer = initialize_Mistral() + self.messages = "" + + def initialize_message(self): + self.messages = "[INST]" + + def ai_message(self, payload): + self.messages = self.messages + "\n " + payload + + def system_message(self, payload): + self.messages = self.messages + "\n " + payload + + def user_message(self, payload): + self.messages = self.messages + "\n " + payload + + def get_response(self): + with torch.no_grad(): + encodeds = self.tokenizer.encode(self.messages+"[/INST]", return_tensors="pt") + generated_ids = self.model.generate(encodeds, max_new_tokens=2000, do_sample=True) + decoded = self.tokenizer.batch_decode(generated_ids) + + return decoded[0].split("[/INST]")[1] + + def print_prompt(self): + print(self.messages)