Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Frontend] Represent tokens with identifiable strings #6626

Merged
merged 43 commits into from
Jul 25, 2024
Merged
Show file tree
Hide file tree
Changes from 37 commits
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
29ffa8a
Escape any non-ASCII encodable characters for top log probs
ezliu Jul 16, 2024
1b757de
format
ezliu Jul 17, 2024
2d8665f
Force workflow run
ezliu Jul 17, 2024
4289604
Return token ID
ezliu Jul 18, 2024
c706ed4
trigger workflows
ezliu Jul 21, 2024
2ceeb8e
format
ezliu Jul 21, 2024
3a2bcfa
Add arg to argparser
ezliu Jul 21, 2024
b3eb5d2
format
ezliu Jul 21, 2024
2d7d253
Merge remote-tracking branch 'upstream/main'
ezliu Jul 22, 2024
d55bfb6
format
ezliu Jul 22, 2024
81add92
Be consistent with initial token, too
ezliu Jul 22, 2024
a2f150d
Merge remote-tracking branch 'upstream/main'
ezliu Jul 23, 2024
9521884
fix merge
ezliu Jul 23, 2024
0446488
format
ezliu Jul 23, 2024
2f6ebe9
Add test for token ids
ezliu Jul 24, 2024
4eea860
refactor fixture
ezliu Jul 24, 2024
f47ff66
format
ezliu Jul 24, 2024
6e63012
fix type
ezliu Jul 24, 2024
9a8d850
Test against the keys
ezliu Jul 24, 2024
3e83182
Discard first token
ezliu Jul 24, 2024
5120421
Compare text without special tokens
ezliu Jul 24, 2024
b8d353b
format
ezliu Jul 24, 2024
a6f9027
Avoid mutating list
ezliu Jul 24, 2024
cfb20dd
Add chat test
ezliu Jul 24, 2024
4d5b968
Missing comma
ezliu Jul 24, 2024
1df84b1
remove unsupported arg
ezliu Jul 24, 2024
ed5d694
grab correct attr
ezliu Jul 24, 2024
489d31b
Support token_ids for chat
ezliu Jul 24, 2024
5c6f463
debug
ezliu Jul 24, 2024
b538953
Add tests
ezliu Jul 24, 2024
d9fe2ba
pass workflows
ezliu Jul 24, 2024
b180231
formatting suggestion
ezliu Jul 24, 2024
7a81668
format
ezliu Jul 24, 2024
4c5a59a
format
ezliu Jul 24, 2024
75d9880
ruff
ezliu Jul 24, 2024
fb43a24
yapf
ezliu Jul 24, 2024
fef0abc
Update prompt to solicit emoji responses
ezliu Jul 24, 2024
58ee7e9
Choose multiple of 4 bytes
ezliu Jul 24, 2024
c0c52dd
Move tests to diff file to prevent OOM
ezliu Jul 24, 2024
543378e
Add fixture dependencies
ezliu Jul 24, 2024
3f293b3
Merge branch 'evan/dev'
ezliu Jul 24, 2024
ac09ede
workflows
ezliu Jul 24, 2024
df6fed5
fix import order
ezliu Jul 24, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 45 additions & 3 deletions tests/entrypoints/openai/test_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
import torch
from openai import BadRequestError

from vllm.transformers_utils.tokenizer import get_tokenizer

from ...utils import RemoteOpenAIServer
from .test_completion import zephyr_lora_added_tokens_files # noqa: F401
from .test_completion import zephyr_lora_files # noqa: F401
Expand All @@ -21,8 +23,10 @@


@pytest.fixture(scope="module")
def server(zephyr_lora_files, zephyr_lora_added_tokens_files): # noqa: F811
args = [
def default_server_args(
zephyr_lora_files, # noqa: F811
zephyr_lora_added_tokens_files): # noqa: F811
return [
# use half precision for speed and memory savings in CI environment
"--dtype",
"bfloat16",
Expand All @@ -42,7 +46,17 @@ def server(zephyr_lora_files, zephyr_lora_added_tokens_files): # noqa: F811
"128",
]

with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:

@pytest.fixture(scope="module")
def server(default_server_args):
with RemoteOpenAIServer(MODEL_NAME, default_server_args) as remote_server:
yield remote_server


@pytest.fixture(scope="module")
def server_with_return_tokens_as_token_ids_flag(default_server_args):
args_with_flag = default_server_args + ["--return-tokens-as-token-ids"]
with RemoteOpenAIServer(MODEL_NAME, args_with_flag) as remote_server:
yield remote_server


Expand Down Expand Up @@ -840,3 +854,31 @@ async def test_long_seed(client: openai.AsyncOpenAI):

assert ("greater_than_equal" in exc_info.value.message
or "less_than_equal" in exc_info.value.message)


@pytest.mark.asyncio
async def test_return_tokens_as_token_ids_completion(
server_with_return_tokens_as_token_ids_flag):
client = server_with_return_tokens_as_token_ids_flag.get_async_client()
response = await client.chat.completions.create(
model=MODEL_NAME,
# Include Unicode characters to test for dividing a single
# character across multiple tokens: πŸŽ‰ is [28705, 31862] for the
# Zephyr tokenizer
messages=[{
"role": "system",
"content": "You like to respond in only emojis, like πŸŽ‰"
}, {
"role": "user",
"content": "Please write some emojis: πŸ±πŸΆπŸŽ‰"
}],
temperature=0,
max_tokens=10,
logprobs=True)

text = response.choices[0].message.content
tokenizer = get_tokenizer(tokenizer_name=MODEL_NAME)
token_ids = []
for logprob_content in response.choices[0].logprobs.content:
token_ids.append(int(logprob_content.token.removeprefix("token_id:")))
assert tokenizer.decode(token_ids, skip_special_tokens=True) == text
50 changes: 47 additions & 3 deletions tests/entrypoints/openai/test_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,9 @@ def zephyr_pa_files():


@pytest.fixture(scope="module")
def server(zephyr_lora_files, zephyr_lora_added_tokens_files, zephyr_pa_files):
args = [
def default_server_args(zephyr_lora_files, zephyr_lora_added_tokens_files,
zephyr_pa_files):
return [
# use half precision for speed and memory savings in CI environment
"--dtype",
"bfloat16",
Expand Down Expand Up @@ -85,7 +86,17 @@ def server(zephyr_lora_files, zephyr_lora_added_tokens_files, zephyr_pa_files):
"128",
]

with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:

@pytest.fixture(scope="module")
def server(default_server_args):
with RemoteOpenAIServer(MODEL_NAME, default_server_args) as remote_server:
yield remote_server


@pytest.fixture(scope="module")
def server_with_return_tokens_as_token_ids_flag(default_server_args):
args_with_flag = default_server_args + ["--return-tokens-as-token-ids"]
with RemoteOpenAIServer(MODEL_NAME, args_with_flag) as remote_server:
yield remote_server


Expand Down Expand Up @@ -682,3 +693,36 @@ async def test_guided_decoding_type_error(client: openai.AsyncOpenAI,
prompt="Give an example string that fits this regex",
extra_body=dict(guided_regex=sample_regex,
guided_json=sample_json_schema))


@pytest.mark.asyncio
async def test_return_tokens_as_token_ids_completion(
server_with_return_tokens_as_token_ids_flag):
client = server_with_return_tokens_as_token_ids_flag.get_async_client()

completion = await client.completions.create(
model=MODEL_NAME,
# Include Unicode characters to test for dividing a single
# character across multiple tokens: πŸŽ‰ is [28705, 31862] for the
# Zephyr tokenizer
prompt="Say 'Hello, world! πŸŽ‰'",
echo=True,
temperature=0,
max_tokens=10,
logprobs=1)

text = completion.choices[0].text
token_strs = completion.choices[0].logprobs.tokens
tokenizer = get_tokenizer(tokenizer_name=MODEL_NAME)
# Check that the token representations are consistent between raw tokens
# and top_logprobs
# Slice off the first one, because there's no scoring associated with BOS
top_logprobs = completion.choices[0].logprobs.top_logprobs[1:]
top_logprob_keys = [
next(iter(logprob_by_tokens)) for logprob_by_tokens in top_logprobs
]
assert token_strs[1:] == top_logprob_keys

# Check that decoding the tokens gives the expected text
tokens = [int(token.removeprefix("token_id:")) for token in token_strs]
assert text == tokenizer.decode(tokens, skip_special_tokens=True)
2 changes: 2 additions & 0 deletions vllm/entrypoints/openai/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,7 @@ def run_server(args, llm_engine=None):
prompt_adapters=args.prompt_adapters,
request_logger=request_logger,
chat_template=args.chat_template,
return_tokens_as_token_ids=args.return_tokens_as_token_ids,
)
openai_serving_completion = OpenAIServingCompletion(
engine,
Expand All @@ -272,6 +273,7 @@ def run_server(args, llm_engine=None):
lora_modules=args.lora_modules,
prompt_adapters=args.prompt_adapters,
request_logger=request_logger,
return_tokens_as_token_ids=args.return_tokens_as_token_ids,
)
openai_serving_embedding = OpenAIServingEmbedding(
engine,
Expand Down
6 changes: 6 additions & 0 deletions vllm/entrypoints/openai/cli_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,12 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
"using @app.middleware('http'). "
"If a class is provided, vLLM will add it to the server "
"using app.add_middleware(). ")
parser.add_argument(
"--return-tokens-as-token-ids",
action="store_true",
help="When --max-logprobs is specified, represents single tokens as"
"strings of the form 'token_id:{token_id}' so that tokens that"
"are not JSON-encodable can be identified.")

parser = AsyncEngineArgs.add_cli_args(parser)

Expand Down
23 changes: 16 additions & 7 deletions vllm/entrypoints/openai/serving_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,13 +50,15 @@ def __init__(
prompt_adapters: Optional[List[PromptAdapterPath]],
request_logger: Optional[RequestLogger],
chat_template: Optional[str],
return_tokens_as_token_ids: bool = False,
):
super().__init__(engine=engine,
model_config=model_config,
served_model_names=served_model_names,
lora_modules=lora_modules,
prompt_adapters=prompt_adapters,
request_logger=request_logger)
request_logger=request_logger,
return_tokens_as_token_ids=return_tokens_as_token_ids)

self.response_role = response_role

Expand Down Expand Up @@ -522,11 +524,14 @@ def _get_top_logprobs(
self, logprobs: Dict[int, Logprob], top_logprobs: Optional[int],
tokenizer: PreTrainedTokenizer) -> List[ChatCompletionLogProb]:
return [
ChatCompletionLogProb(
token=(token := self._get_decoded_token(p[1], p[0],
tokenizer)),
logprob=max(p[1].logprob, -9999.0),
bytes=list(token.encode("utf-8", errors="replace")))
ChatCompletionLogProb(token=(token := self._get_decoded_token(
p[1],
p[0],
tokenizer,
return_as_token_id=self.return_tokens_as_token_ids)),
logprob=max(p[1].logprob, -9999.0),
bytes=list(
token.encode("utf-8", errors="replace")))
DarkLight1337 marked this conversation as resolved.
Show resolved Hide resolved
for i, p in enumerate(logprobs.items())
if top_logprobs and i < top_logprobs
]
Expand All @@ -546,14 +551,18 @@ def _create_chat_logprobs(
step_top_logprobs = top_logprobs[i]
if step_top_logprobs is None:
token = tokenizer.decode(token_id)
if self.return_tokens_as_token_ids:
token = f"token_id:{token_id}"
logprobs_content.append(
ChatCompletionLogProbsContent(
token=token,
bytes=list(token.encode("utf-8", errors="replace"))))
else:
logprobs_content.append(
ChatCompletionLogProbsContent(
token=step_top_logprobs[token_id].decoded_token,
token=self._get_decoded_token(
step_top_logprobs[token_id], token_id, tokenizer,
self.return_tokens_as_token_ids),
logprob=max(step_top_logprobs[token_id].logprob,
-9999.0),
bytes=list(
Expand Down
19 changes: 15 additions & 4 deletions vllm/entrypoints/openai/serving_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,13 +51,15 @@ def __init__(
lora_modules: Optional[List[LoRAModulePath]],
prompt_adapters: Optional[List[PromptAdapterPath]],
request_logger: Optional[RequestLogger],
return_tokens_as_token_ids: bool = False,
):
super().__init__(engine=engine,
model_config=model_config,
served_model_names=served_model_names,
lora_modules=lora_modules,
prompt_adapters=prompt_adapters,
request_logger=request_logger)
request_logger=request_logger,
return_tokens_as_token_ids=return_tokens_as_token_ids)

async def create_completion(self, request: CompletionRequest,
raw_request: Request):
Expand Down Expand Up @@ -430,12 +432,17 @@ def _create_completion_logprobs(
step_top_logprobs = top_logprobs[i]
if step_top_logprobs is None:
token = tokenizer.decode(token_id)
if self.return_tokens_as_token_ids:
token = f"token_id:{token_id}"
out_tokens.append(token)
out_token_logprobs.append(None)
out_top_logprobs.append(None)
else:
token = self._get_decoded_token(step_top_logprobs[token_id],
token_id, tokenizer)
token = self._get_decoded_token(
step_top_logprobs[token_id],
token_id,
tokenizer,
return_as_token_id=self.return_tokens_as_token_ids)
token_logprob = max(step_top_logprobs[token_id].logprob,
-9999.0)
out_tokens.append(token)
Expand All @@ -448,7 +455,11 @@ def _create_completion_logprobs(
out_top_logprobs.append({
# Convert float("-inf") to the
# JSON-serializable float that OpenAI uses
self._get_decoded_token(top_lp[1], top_lp[0], tokenizer):
self._get_decoded_token(
top_lp[1],
top_lp[0],
tokenizer,
return_as_token_id=self.return_tokens_as_token_ids):
max(top_lp[1].logprob, -9999.0)
for i, top_lp in enumerate(step_top_logprobs.items())
if num_output_top_logprobs >= i
Expand Down
14 changes: 9 additions & 5 deletions vllm/entrypoints/openai/serving_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ def __init__(
lora_modules: Optional[List[LoRAModulePath]],
prompt_adapters: Optional[List[PromptAdapterPath]],
request_logger: Optional[RequestLogger],
return_tokens_as_token_ids: bool = False,
):
super().__init__()

Expand Down Expand Up @@ -102,6 +103,7 @@ def __init__(
prompt_adapter_num_virtual_tokens=num_virtual_tokens))

self.request_logger = request_logger
self.return_tokens_as_token_ids = return_tokens_as_token_ids

async def show_available_models(self) -> ModelList:
"""Show available models. Right now we only have one model."""
Expand Down Expand Up @@ -384,11 +386,13 @@ def _log_inputs(
)

@staticmethod
def _get_decoded_token(
logprob: Logprob,
token_id: int,
tokenizer: AnyTokenizer,
) -> str:
def _get_decoded_token(logprob: Logprob,
token_id: int,
tokenizer: AnyTokenizer,
return_as_token_id: bool = False) -> str:
if return_as_token_id:
return f"token_id:{token_id}"

if logprob.decoded_token is not None:
return logprob.decoded_token
return tokenizer.decode(token_id)
Loading