Skip to content

Commit c7f56f2

Browse files
vmpuriJack-Khuu
andauthoredAug 19, 2024··
Fix tokenization of chat interfaces (#1035)
Co-authored-by: Jack-Khuu <jack.khuu.7@gmail.com>
1 parent 1566512 commit c7f56f2

File tree

3 files changed

+88
-35
lines changed

3 files changed

+88
-35
lines changed
 

‎api/api.py

+27-21
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
from dataclasses import dataclass
1111
from typing import Any, Dict, List, Optional, Union
1212

13+
import torch
14+
1315
from build.utils import device_sync
1416

1517
from generate import Generator, GeneratorArgs
@@ -222,7 +224,6 @@ def __init__(self, *args, **kwargs):
222224
"""
223225

224226
super().__init__(*args, **kwargs)
225-
self.start_pos = 0
226227
self.max_seq_length = (
227228
self.model.config.max_seq_length
228229
+ self.speculative_builder_args.speculate_k
@@ -257,20 +258,25 @@ def chunked_completion(self, completion_request: CompletionRequest):
257258
CompletionResponseChunk objects in response to completion_request as tokens are generated.
258259
259260
"""
260-
device_sync(device=self.builder_args.device)
261261

262262
# Initialize counters for chunk responses and encode the prompt.
263263
id = str(uuid.uuid4())
264264

265265
idx = 0
266-
buffer = []
267-
encoded = self.encode_tokens(
268-
completion_request.messages[-1].get("content"),
269-
bos=True,
270-
device=self.builder_args.device,
266+
tokens = self.chat_formatter.encode_dialog_prompt(
267+
dialog=[
268+
{"role": message["role"], "content": message["content"]}
269+
for message in completion_request.messages
270+
]
271271
)
272+
273+
encoded = torch.tensor(tokens, dtype=torch.int, device=self.builder_args.device)
274+
print(self.tokenizer.decode(tokens))
275+
276+
start_pos = 0
277+
272278
generator_args = GeneratorArgs(
273-
completion_request.messages[-1].get("content"),
279+
None,
274280
max_new_tokens=(
275281
int(completion_request.max_tokens)
276282
if completion_request.max_tokens
@@ -279,33 +285,39 @@ def chunked_completion(self, completion_request: CompletionRequest):
279285
encoded_prompt=encoded,
280286
temperature=float(completion_request.temperature),
281287
chat_mode=False,
288+
sequential_prefill=True,
282289
)
283290

284291
def callback(x, *, done_generating=False):
285292
return self._callback(
286293
x,
287-
buffer=buffer,
294+
buffer=None,
288295
done_generating=done_generating,
289296
)
290297

298+
device_sync(device=self.builder_args.device)
299+
291300
# Process each token, metrics tuple yielded by Generator.generate.
292301
for y, _ in self.generate(
293-
self.model,
294-
encoded,
295-
generator_args.max_new_tokens,
302+
model=self.model,
303+
prompt=encoded,
304+
max_new_tokens=generator_args.max_new_tokens,
296305
draft_model=self.draft_model,
297306
speculate_k=generator_args.speculate_k,
298307
chat_mode=generator_args.chat_mode,
299308
callback=callback,
300309
temperature=generator_args.temperature,
301310
top_k=generator_args.top_k,
302311
sequential_prefill=generator_args.sequential_prefill,
303-
start_pos=self.start_pos,
312+
start_pos=start_pos,
304313
max_seq_length=self.max_seq_length,
305314
seed=int(completion_request.seed),
306315
):
307316
if y is None:
308317
continue
318+
elif y.item() == self.tokenizer.eos_id:
319+
# Stop generation if the EOS token is generated.
320+
break
309321

310322
# Decode the torch.Tensor token to a string and append to the buffer. Separate the sequences with a period token.
311323
content = "".join(
@@ -330,7 +342,7 @@ def callback(x, *, done_generating=False):
330342
system_fingerprint=self.system_fingerprint,
331343
)
332344
yield chunk_response
333-
self.start_pos += y.size(0)
345+
start_pos += y.size(0)
334346
idx += 1
335347

336348
# Yield an ending chunk indicating the generation has completed.
@@ -369,10 +381,4 @@ def sync_completion(self, request: CompletionRequest):
369381
)
370382

371383
def _callback(self, x, *, buffer, done_generating):
372-
period_id = self.tokenizer.encode(".")[0]
373-
buffer.append(self.tokenizer.decode([period_id] + x.tolist())[1:])
374-
if (
375-
self.is_llama3_model
376-
and x.item() == self.tokenizer.special_tokens["<|eot_id|>"]
377-
):
378-
buffer = buffer[:-1] # drop the eot_id from the output buffer
384+
pass

‎generate.py

+55-12
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
import os
1010
import textwrap
1111
import time
12+
13+
from abc import ABC, abstractmethod
1214
from dataclasses import dataclass
1315
from pathlib import Path
1416
from typing import List, Optional, Tuple
@@ -28,24 +30,33 @@
2830
from cli import add_arguments_for_verb, arg_init, check_args
2931
from utils.device_info import get_device_info
3032

31-
B_INST, E_INST = "[INST]", "[/INST]"
32-
B_SYS, E_SYS = "<<SYS>>", "<</SYS>>"
33-
3433

35-
class ChatFormat:
34+
class _ChatFormatter(ABC):
3635
def __init__(self, tokenizer):
3736
self.tokenizer = tokenizer
3837

39-
def encode_header(self, message) -> List[int]:
38+
@abstractmethod
39+
def encode_dialog_prompt(self, dialog) -> List[int]:
40+
raise NotImplementedError()
41+
42+
43+
class Llama3ChatFormatter(_ChatFormatter):
44+
"""Format a chat prompt using special tokens to demarcate roles and messages.
45+
46+
Refer to the LLaMA3 documentation for more details https://llama.meta.com/docs/model-cards-and-prompt-formats/meta-llama-3
47+
48+
"""
49+
50+
def encode_header(self, role) -> List[int]:
4051
tokens = []
4152
tokens.append(self.tokenizer.special_tokens["<|start_header_id|>"])
42-
tokens.extend(self.tokenizer.encode(message["role"], bos=False, eos=False))
53+
tokens.extend(self.tokenizer.encode(role, bos=False, eos=False))
4354
tokens.append(self.tokenizer.special_tokens["<|end_header_id|>"])
4455
tokens.extend(self.tokenizer.encode("\n\n", bos=False, eos=False))
4556
return tokens
4657

4758
def encode_message(self, message) -> List[int]:
48-
tokens = self.encode_header(message)
59+
tokens = self.encode_header(message.role)
4960
tokens.extend(
5061
self.tokenizer.encode(message["content"].strip(), bos=False, eos=False)
5162
)
@@ -62,9 +73,37 @@ def encode_dialog_prompt(self, dialog) -> List[int]:
6273
return tokens
6374

6475

76+
B_INST, E_INST = "[INST]", "[/INST]"
77+
B_SYS, E_SYS = "<<SYS>>", "<</SYS>>"
78+
79+
80+
class Llama2ChatFormatter(_ChatFormatter):
81+
def encode_dialog_prompt(self, dialog) -> List[int]:
82+
tokens = self.tokenizer.encode(f"{B_INST} ")
83+
first_message = True # Bool to handle placing the B_INST token. Behavior is weird - the system prompt should have the B_INST, but not the first user message. All following user messages *should* have it. Also, if there is no system prompt, then the user message should have it.
84+
for message in dialog:
85+
content = message["content"].strip()
86+
if message["role"] == "system":
87+
encoded = self.tokenizer.encode(f"{B_SYS}\n{content}\n{E_SYS}")
88+
first_message = False
89+
elif message["role"] == "user":
90+
encoded = [self.tokenizer.bos_id()] + self.tokenizer.encode(
91+
f"{B_INST if first_message else ''} {content} {E_INST} "
92+
)
93+
first_message = True
94+
elif message["role"] == "assistant":
95+
encoded = self.tokenizer.encode(f"{content}\n\n") + [
96+
self.tokenizer.eos_id()
97+
]
98+
tokens += encoded
99+
return tokens
100+
101+
65102
@dataclass
66103
class GeneratorArgs:
67-
prompt: str = "torchchat is pronounced torch-chat and is so cool because"
104+
prompt: Optional[str] = (
105+
None # When passed into the Generator, this will be used as the system prompt
106+
)
68107
encoded_prompt: Optional[torch.Tensor] = None
69108
chat_mode: bool = False
70109
gui_mode: bool = False
@@ -188,7 +227,7 @@ def __init__(
188227
))
189228
# fmt: on
190229
# raise RuntimeError("You need to use --is-chat-model to indicate model has chat support.")
191-
230+
self.system_prompt = generator_args.prompt
192231
self.tokenizer = _initialize_tokenizer(self.tokenizer_args)
193232

194233
# Right now the assumption is only llama3 uses tiktokenizer and it
@@ -200,6 +239,11 @@ def __init__(
200239
logging.debug(
201240
"Llama3 model detected in chat mode. Using updated sentence schemas"
202241
)
242+
self.chat_formatter = (
243+
Llama3ChatFormatter(self.tokenizer)
244+
if self.is_llama3_model
245+
else Llama2ChatFormatter(self.tokenizer)
246+
)
203247

204248
self.builder_args.setup_caches = False
205249
self.model = _initialize_model(self.builder_args, self.quantize, self.tokenizer)
@@ -641,8 +685,7 @@ def chat(
641685
)
642686
if get_system_prompt == "y" or get_system_prompt == "Y":
643687
self.system_prompt = input("What is your system prompt? \n")
644-
if self.is_llama3_model:
645-
self.chat_formatter = ChatFormat(self.tokenizer)
688+
646689
else:
647690
max_seq_length = min(
648691
encoded.size(0) + generator_args.max_new_tokens,
@@ -685,7 +728,7 @@ def chat(
685728
prompt, bos=True, device=self.builder_args.device
686729
)
687730
else:
688-
if self.system_prompt is not None:
731+
if self.system_prompt:
689732
encoded = self.chat_formatter.encode_dialog_prompt(
690733
[
691734
{"role": "system", "content": self.system_prompt},

‎server.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,10 @@
66

77
import json
88

9+
import logging
10+
11+
logger = logging.getLogger(__name__)
12+
913
from dataclasses import asdict
1014
from typing import Dict, List, Union
1115

@@ -21,7 +25,7 @@
2125
OPENAI_API_VERSION = "v1"
2226

2327

24-
def create_app(args):
28+
def create_app(args): # noqa: C901
2529
"""
2630
Creates a flask app that can be used to serve the model as a chat API.
2731
"""
@@ -69,7 +73,7 @@ def chunk_processor(chunked_completion_generator):
6973
for chunk in chunked_completion_generator:
7074
if (next_tok := chunk.choices[0].delta.content) is None:
7175
next_tok = ""
72-
print(next_tok, end="")
76+
print(next_tok, end="", flush=True)
7377
yield json.dumps(_del_none(asdict(chunk)))
7478

7579
return Response(

0 commit comments

Comments
 (0)
Please sign in to comment.