Skip to content

[Bug]: Structured Output not working with MistralTokenizer (vLLM 0.8.2, V1) #15551

@hibukipanim

Description

@hibukipanim

Despite #14625 which aims to support structured-outputs with mistral tokenizr, it doesn't seem to work with vllm 0.8.2 which includes this PR.

to reproduce:

vllm serve mistralai/Ministral-8B-Instruct-2410 --tokenizer_mode mistral --config_format mistral --load_format mistral --max-model-len 4096 --gpu-memory-utilization 0.95

(the memory limits args are just to make it run in RTX 4090)

and a simplified version of the Structured Outputs example from vllm docs:

from pydantic import BaseModel
from enum import Enum

from openai import OpenAI

class CarDescription(BaseModel):
    brand: str
    model: str
    car_type: str

json_schema = CarDescription.model_json_schema()

client = OpenAI(
    base_url="http://localhost:8000/v1",
    api_key="-",
)
completion = client.chat.completions.create(
    model="mistralai/Ministral-8B-Instruct-2410",
    messages=[
        {
            "role": "user",
            "content": "Generate a JSON with the brand, model and car_type of the most iconic car from the 90's",
        }
    ],
    extra_body={"guided_json": json_schema},
    seed=42
)
print(completion.choices[0].message.content)

runs for very long and then outputs broken json e.g.:

{
{ 	"bݨ" 	: 	"Ford", 	" 		il		od		od		od		od		od		od		od		od ...

b.t.w maybe it's expected but first request is blocked few seconds when the server prints:

[2025-03-26 16:13:24] INFO tekken.py:114: Adding special tokens <SPECIAL_20>, ..., <SPECIAL_999>
[2025-03-26 16:13:24] INFO tekken.py:293: Vocab size: 150000
[2025-03-26 16:13:24] INFO tekken.py:297: Cutting vocab to first 130072 tokens.
INFO 03-26 16:13:24 [loggers.py:80] Avg prompt throughput: 0.0 tokens/s, Avg generation throughput: 0.0 tokens/s, Running: 0 reqs, Waiting: 0 reqs, GPU KV cache usage: 0.0%, Prefix cache hit rate: 0.0%
/home/user/code/play-vllm/.venv/lib/python3.10/site-packages/torch/utils/cpp_extension.py:2059: UserWarning: TORCH_CUDA_ARCH_LIST is not set, all archs for visible cards are included for compilation.
If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'].
  warnings.warn(

INFO 03-26 16:13:34 [loggers.py:80] Avg prompt throughput: 0.0 tokens/s, Avg generation throughput: 0.0 tokens/s, Running: 0 reqs, Waiting: 1 reqs, GPU KV cache usage: 0.0%, Prefix cache hit rate: 0.0%

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions