Skip to content

Commit

Permalink
add GemmaConversationFormatter
Browse files Browse the repository at this point in the history
  • Loading branch information
runninglsy committed Jul 29, 2024
1 parent d95b8b4 commit 5dc43d6
Showing 1 changed file with 67 additions and 0 deletions.
67 changes: 67 additions & 0 deletions ovis/model/conversation_formatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,3 +172,70 @@ def format_query(self, query, generation_preface=""):
}], generation_preface=generation_preface)

return prompt, input_ids


class GemmaConversationFormatter(ConversationFormatter):
support_tokenizer_types = ['GemmaTokenizer', 'GemmaTokenizerFast']

def __init__(self, tokenizer):
super().__init__(tokenizer)
# Gemma does not support system prompt
self.from2role = {
"human": "<start_of_turn>user\n",
"gpt": "<start_of_turn>model\n",
}
self.gpt_token_num = None
self.im_end = "<end_of_turn>\n"
self.bos_token = "<bos>"
self.bos_token_ids = None

def format(self, conversations: List[Dict], generation_preface=None):
if self.gpt_token_num is None:
self.gpt_token_num = len(self.tokenizer(self.from2role["gpt"], add_special_tokens=False).input_ids)

if self.bos_token_ids is None:
self.bos_token_ids = self.tokenizer(self.bos_token, add_special_tokens=False).input_ids

if conversations[0]["from"] == "system":
raise ValueError("Gemma does not support system prompt")

if generation_preface is not None:
conversations.append({
"from": "gpt",
"value": generation_preface
})

prompt = "" + self.bos_token
input_ids = [] + self.bos_token_ids
labels = [] + [IGNORE_INDEX] * len(input_ids)
num_conversation = len(conversations)
for i, conversation in enumerate(conversations):
frm = conversation["from"]
role = self.from2role[frm]
message = conversation["value"].strip()
text = role + message
if i < num_conversation - 1 or generation_preface is None:
text += self.im_end
prompt += text
token_ids = self._tokenize_with_image_symbol(text)
input_ids.extend(token_ids)
label_ids = [self.ignore_index] * len(token_ids)
if frm == "gpt":
# learning `\n` following `im_end` is meaningless, so the last `\n` token is ignored in label
label_ids[self.gpt_token_num:-1] = token_ids[self.gpt_token_num:-1]
labels.extend(label_ids)

assert self._tokenize_with_image_symbol(prompt) == input_ids
assert len(input_ids) == len(labels)
input_ids = torch.tensor(input_ids, dtype=torch.long)
labels = torch.tensor(labels, dtype=torch.long)

return prompt, input_ids, labels

def format_query(self, query, generation_preface=""):
prompt, input_ids, _ = self.format([{
"from": "human",
"value": query
}], generation_preface=generation_preface)

return prompt, input_ids

0 comments on commit 5dc43d6

Please sign in to comment.