9
9
import os
10
10
import textwrap
11
11
import time
12
+
13
+ from abc import ABC , abstractmethod
12
14
from dataclasses import dataclass
13
15
from pathlib import Path
14
16
from typing import List , Optional , Tuple
28
30
from cli import add_arguments_for_verb , arg_init , check_args
29
31
from utils .device_info import get_device_info
30
32
31
- B_INST , E_INST = "[INST]" , "[/INST]"
32
- B_SYS , E_SYS = "<<SYS>>" , "<</SYS>>"
33
-
34
33
35
- class ChatFormat :
34
+ class _ChatFormatter ( ABC ) :
36
35
def __init__ (self , tokenizer ):
37
36
self .tokenizer = tokenizer
38
37
39
- def encode_header (self , message ) -> List [int ]:
38
+ @abstractmethod
39
+ def encode_dialog_prompt (self , dialog ) -> List [int ]:
40
+ raise NotImplementedError ()
41
+
42
+
43
+ class Llama3ChatFormatter (_ChatFormatter ):
44
+ """Format a chat prompt using special tokens to demarcate roles and messages.
45
+
46
+ Refer to the LLaMA3 documentation for more details https://llama.meta.com/docs/model-cards-and-prompt-formats/meta-llama-3
47
+
48
+ """
49
+
50
+ def encode_header (self , role ) -> List [int ]:
40
51
tokens = []
41
52
tokens .append (self .tokenizer .special_tokens ["<|start_header_id|>" ])
42
- tokens .extend (self .tokenizer .encode (message [ " role" ] , bos = False , eos = False ))
53
+ tokens .extend (self .tokenizer .encode (role , bos = False , eos = False ))
43
54
tokens .append (self .tokenizer .special_tokens ["<|end_header_id|>" ])
44
55
tokens .extend (self .tokenizer .encode ("\n \n " , bos = False , eos = False ))
45
56
return tokens
46
57
47
58
def encode_message (self , message ) -> List [int ]:
48
- tokens = self .encode_header (message )
59
+ tokens = self .encode_header (message . role )
49
60
tokens .extend (
50
61
self .tokenizer .encode (message ["content" ].strip (), bos = False , eos = False )
51
62
)
@@ -62,9 +73,37 @@ def encode_dialog_prompt(self, dialog) -> List[int]:
62
73
return tokens
63
74
64
75
76
+ B_INST , E_INST = "[INST]" , "[/INST]"
77
+ B_SYS , E_SYS = "<<SYS>>" , "<</SYS>>"
78
+
79
+
80
+ class Llama2ChatFormatter (_ChatFormatter ):
81
+ def encode_dialog_prompt (self , dialog ) -> List [int ]:
82
+ tokens = self .tokenizer .encode (f"{ B_INST } " )
83
+ first_message = True # Bool to handle placing the B_INST token. Behavior is weird - the system prompt should have the B_INST, but not the first user message. All following user messages *should* have it. Also, if there is no system prompt, then the user message should have it.
84
+ for message in dialog :
85
+ content = message ["content" ].strip ()
86
+ if message ["role" ] == "system" :
87
+ encoded = self .tokenizer .encode (f"{ B_SYS } \n { content } \n { E_SYS } " )
88
+ first_message = False
89
+ elif message ["role" ] == "user" :
90
+ encoded = [self .tokenizer .bos_id ()] + self .tokenizer .encode (
91
+ f"{ B_INST if first_message else '' } { content } { E_INST } "
92
+ )
93
+ first_message = True
94
+ elif message ["role" ] == "assistant" :
95
+ encoded = self .tokenizer .encode (f"{ content } \n \n " ) + [
96
+ self .tokenizer .eos_id ()
97
+ ]
98
+ tokens += encoded
99
+ return tokens
100
+
101
+
65
102
@dataclass
66
103
class GeneratorArgs :
67
- prompt : str = "torchchat is pronounced torch-chat and is so cool because"
104
+ prompt : Optional [str ] = (
105
+ None # When passed into the Generator, this will be used as the system prompt
106
+ )
68
107
encoded_prompt : Optional [torch .Tensor ] = None
69
108
chat_mode : bool = False
70
109
gui_mode : bool = False
@@ -188,7 +227,7 @@ def __init__(
188
227
))
189
228
# fmt: on
190
229
# raise RuntimeError("You need to use --is-chat-model to indicate model has chat support.")
191
-
230
+ self . system_prompt = generator_args . prompt
192
231
self .tokenizer = _initialize_tokenizer (self .tokenizer_args )
193
232
194
233
# Right now the assumption is only llama3 uses tiktokenizer and it
@@ -200,6 +239,11 @@ def __init__(
200
239
logging .debug (
201
240
"Llama3 model detected in chat mode. Using updated sentence schemas"
202
241
)
242
+ self .chat_formatter = (
243
+ Llama3ChatFormatter (self .tokenizer )
244
+ if self .is_llama3_model
245
+ else Llama2ChatFormatter (self .tokenizer )
246
+ )
203
247
204
248
self .builder_args .setup_caches = False
205
249
self .model = _initialize_model (self .builder_args , self .quantize , self .tokenizer )
@@ -641,8 +685,7 @@ def chat(
641
685
)
642
686
if get_system_prompt == "y" or get_system_prompt == "Y" :
643
687
self .system_prompt = input ("What is your system prompt? \n " )
644
- if self .is_llama3_model :
645
- self .chat_formatter = ChatFormat (self .tokenizer )
688
+
646
689
else :
647
690
max_seq_length = min (
648
691
encoded .size (0 ) + generator_args .max_new_tokens ,
@@ -685,7 +728,7 @@ def chat(
685
728
prompt , bos = True , device = self .builder_args .device
686
729
)
687
730
else :
688
- if self .system_prompt is not None :
731
+ if self .system_prompt :
689
732
encoded = self .chat_formatter .encode_dialog_prompt (
690
733
[
691
734
{"role" : "system" , "content" : self .system_prompt },
0 commit comments