Skip to content

Chat tokenization fixes in generate.py & API #1035

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Aug 19, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 27 additions & 21 deletions api/api.py
Original file line number Diff line number Diff line change
@@ -10,6 +10,8 @@
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Union

import torch

from build.utils import device_sync

from generate import Generator, GeneratorArgs
@@ -222,7 +224,6 @@ def __init__(self, *args, **kwargs):
"""

super().__init__(*args, **kwargs)
self.start_pos = 0
self.max_seq_length = (
self.model.config.max_seq_length
+ self.speculative_builder_args.speculate_k
@@ -257,20 +258,25 @@ def chunked_completion(self, completion_request: CompletionRequest):
CompletionResponseChunk objects in response to completion_request as tokens are generated.

"""
device_sync(device=self.builder_args.device)

# Initialize counters for chunk responses and encode the prompt.
id = str(uuid.uuid4())

idx = 0
buffer = []
encoded = self.encode_tokens(
completion_request.messages[-1].get("content"),
bos=True,
device=self.builder_args.device,
tokens = self.chat_formatter.encode_dialog_prompt(
dialog=[
{"role": message["role"], "content": message["content"]}
for message in completion_request.messages
]
)

encoded = torch.tensor(tokens, dtype=torch.int, device=self.builder_args.device)
print(self.tokenizer.decode(tokens))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just checking that this is an intentional print

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes - this prints out the prompt on the server side so that it's easy to track the full prompt solely from the server side.

However, this raises a larger issue in the generate/API stack - we need to replace print statements with a logger so that users can choose not to print these debug messages.


start_pos = 0

generator_args = GeneratorArgs(
completion_request.messages[-1].get("content"),
None,
max_new_tokens=(
int(completion_request.max_tokens)
if completion_request.max_tokens
@@ -279,33 +285,39 @@ def chunked_completion(self, completion_request: CompletionRequest):
encoded_prompt=encoded,
temperature=float(completion_request.temperature),
chat_mode=False,
sequential_prefill=True,
)

def callback(x, *, done_generating=False):
return self._callback(
x,
buffer=buffer,
buffer=None,
done_generating=done_generating,
)

device_sync(device=self.builder_args.device)

# Process each token, metrics tuple yielded by Generator.generate.
for y, _ in self.generate(
self.model,
encoded,
generator_args.max_new_tokens,
model=self.model,
prompt=encoded,
max_new_tokens=generator_args.max_new_tokens,
draft_model=self.draft_model,
speculate_k=generator_args.speculate_k,
chat_mode=generator_args.chat_mode,
callback=callback,
temperature=generator_args.temperature,
top_k=generator_args.top_k,
sequential_prefill=generator_args.sequential_prefill,
start_pos=self.start_pos,
start_pos=start_pos,
max_seq_length=self.max_seq_length,
seed=int(completion_request.seed),
):
if y is None:
continue
elif y.item() == self.tokenizer.eos_id:
# Stop generation if the EOS token is generated.
break

# Decode the torch.Tensor token to a string and append to the buffer. Separate the sequences with a period token.
content = "".join(
@@ -330,7 +342,7 @@ def callback(x, *, done_generating=False):
system_fingerprint=self.system_fingerprint,
)
yield chunk_response
self.start_pos += y.size(0)
start_pos += y.size(0)
idx += 1

# Yield an ending chunk indicating the generation has completed.
@@ -369,10 +381,4 @@ def sync_completion(self, request: CompletionRequest):
)

def _callback(self, x, *, buffer, done_generating):
period_id = self.tokenizer.encode(".")[0]
buffer.append(self.tokenizer.decode([period_id] + x.tolist())[1:])
if (
self.is_llama3_model
and x.item() == self.tokenizer.special_tokens["<|eot_id|>"]
):
buffer = buffer[:-1] # drop the eot_id from the output buffer
pass
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this is a pass again?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The callback function is only used in generate() for the CLI interactive chat to print results to stdout. I initially copied this code naively when refactoring the original generate.py and copied it over to openaiapi where it isn't used.

67 changes: 55 additions & 12 deletions generate.py
Original file line number Diff line number Diff line change
@@ -9,6 +9,8 @@
import os
import textwrap
import time

from abc import ABC, abstractmethod
from dataclasses import dataclass
from pathlib import Path
from typing import List, Optional, Tuple
@@ -28,24 +30,33 @@
from cli import add_arguments_for_verb, arg_init, check_args
from utils.device_info import get_device_info

B_INST, E_INST = "[INST]", "[/INST]"
B_SYS, E_SYS = "<<SYS>>", "<</SYS>>"


class ChatFormat:
class _ChatFormatter(ABC):
def __init__(self, tokenizer):
self.tokenizer = tokenizer

def encode_header(self, message) -> List[int]:
@abstractmethod
def encode_dialog_prompt(self, dialog) -> List[int]:
raise NotImplementedError()


class Llama3ChatFormatter(_ChatFormatter):
"""Format a chat prompt using special tokens to demarcate roles and messages.

Refer to the LLaMA3 documentation for more details https://llama.meta.com/docs/model-cards-and-prompt-formats/meta-llama-3

"""

def encode_header(self, role) -> List[int]:
tokens = []
tokens.append(self.tokenizer.special_tokens["<|start_header_id|>"])
tokens.extend(self.tokenizer.encode(message["role"], bos=False, eos=False))
tokens.extend(self.tokenizer.encode(role, bos=False, eos=False))
tokens.append(self.tokenizer.special_tokens["<|end_header_id|>"])
tokens.extend(self.tokenizer.encode("\n\n", bos=False, eos=False))
return tokens

def encode_message(self, message) -> List[int]:
tokens = self.encode_header(message)
tokens = self.encode_header(message.role)
tokens.extend(
self.tokenizer.encode(message["content"].strip(), bos=False, eos=False)
)
@@ -62,9 +73,37 @@ def encode_dialog_prompt(self, dialog) -> List[int]:
return tokens


B_INST, E_INST = "[INST]", "[/INST]"
B_SYS, E_SYS = "<<SYS>>", "<</SYS>>"


class Llama2ChatFormatter(_ChatFormatter):
def encode_dialog_prompt(self, dialog) -> List[int]:
tokens = self.tokenizer.encode(f"{B_INST} ")
first_message = True # Bool to handle placing the B_INST token. Behavior is weird - the system prompt should have the B_INST, but not the first user message. All following user messages *should* have it. Also, if there is no system prompt, then the user message should have it.
for message in dialog:
content = message["content"].strip()
if message["role"] == "system":
encoded = self.tokenizer.encode(f"{B_SYS}\n{content}\n{E_SYS}")
first_message = False
elif message["role"] == "user":
encoded = [self.tokenizer.bos_id()] + self.tokenizer.encode(
f"{B_INST if first_message else ''} {content} {E_INST} "
)
first_message = True
elif message["role"] == "assistant":
encoded = self.tokenizer.encode(f"{content}\n\n") + [
self.tokenizer.eos_id()
]
tokens += encoded
return tokens


@dataclass
class GeneratorArgs:
prompt: str = "torchchat is pronounced torch-chat and is so cool because"
prompt: Optional[str] = (
None # When passed into the Generator, this will be used as the system prompt
)
encoded_prompt: Optional[torch.Tensor] = None
chat_mode: bool = False
gui_mode: bool = False
@@ -188,7 +227,7 @@ def __init__(
))
# fmt: on
# raise RuntimeError("You need to use --is-chat-model to indicate model has chat support.")

self.system_prompt = generator_args.prompt
self.tokenizer = _initialize_tokenizer(self.tokenizer_args)

# Right now the assumption is only llama3 uses tiktokenizer and it
@@ -200,6 +239,11 @@ def __init__(
logging.debug(
"Llama3 model detected in chat mode. Using updated sentence schemas"
)
self.chat_formatter = (
Llama3ChatFormatter(self.tokenizer)
if self.is_llama3_model
else Llama2ChatFormatter(self.tokenizer)
)

self.builder_args.setup_caches = False
self.model = _initialize_model(self.builder_args, self.quantize, self.tokenizer)
@@ -641,8 +685,7 @@ def chat(
)
if get_system_prompt == "y" or get_system_prompt == "Y":
self.system_prompt = input("What is your system prompt? \n")
if self.is_llama3_model:
self.chat_formatter = ChatFormat(self.tokenizer)

else:
max_seq_length = min(
encoded.size(0) + generator_args.max_new_tokens,
@@ -685,7 +728,7 @@ def chat(
prompt, bos=True, device=self.builder_args.device
)
else:
if self.system_prompt is not None:
if self.system_prompt:
encoded = self.chat_formatter.encode_dialog_prompt(
[
{"role": "system", "content": self.system_prompt},
8 changes: 6 additions & 2 deletions server.py
Original file line number Diff line number Diff line change
@@ -6,6 +6,10 @@

import json

import logging

logger = logging.getLogger(__name__)

from dataclasses import asdict
from typing import Dict, List, Union

@@ -21,7 +25,7 @@
OPENAI_API_VERSION = "v1"


def create_app(args):
def create_app(args): # noqa: C901
"""
Creates a flask app that can be used to serve the model as a chat API.
"""
@@ -69,7 +73,7 @@ def chunk_processor(chunked_completion_generator):
for chunk in chunked_completion_generator:
if (next_tok := chunk.choices[0].delta.content) is None:
next_tok = ""
print(next_tok, end="")
print(next_tok, end="", flush=True)
yield json.dumps(_del_none(asdict(chunk)))

return Response(