Skip to content
18 changes: 15 additions & 3 deletions tests/entrypoints/llm/test_guided_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,23 @@
]


@pytest.fixture(scope="module")
def llm():
@pytest.fixture(scope="module", params=["autoregressive", "speculative"])
def llm(request):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I received a single test failure when running this test file on your branch. Below is the full log. It seems to indicate a guidance failure.

_________________________________________________________________________________________________________________________________________ test_json_with_any_whitespace_disabled[speculative] _________________________________________________________________________________________________________________________________________

llm = <weakproxy at 0x7f57833a1850 to LLM at 0x7f57919e88f0>

    @pytest.mark.skip_global_cleanup
    def test_json_with_any_whitespace_disabled(llm):
    
        class ResponseSchema(BaseModel):
            clarifying_question: str
            cost_per_serving: str
            calories: str
            type_dish_ids: str
            type_meal_ids: str
            product_ids: list[str]
            exclude_product_ids: list[str]
            allergen_ids: list[str]
            total_cooking_time: str
            kitchen_ids: str
            holiday_ids: str
    
        # Note: Without this setting, the response is sometimes full of `\n`
        # for some models. This option prevents that.
        guided_decoding_backend = 'xgrammar:disable-any-whitespace'
    
        schema = ResponseSchema.model_json_schema()
        guided_params = GuidedDecodingParams(json=schema,
                                             backend=\
                                               guided_decoding_backend)
        sampling_params = SamplingParams(max_tokens=2000,
                                         frequency_penalty=0,
                                         presence_penalty=-1.1,
                                         repetition_penalty=1.3,
                                         guided_decoding=guided_params)
    
        prompt = ("<|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You"
                  "are a helpful assistant.<|im_end|>\n<|im_start|>user\nI want a "
                  "quick launch fast with $10.<|im_end|>\n<|im_start|>assistant\n")
        outputs = llm.generate(prompts=prompt,
                               sampling_params=sampling_params,
                               use_tqdm=True)
    
        assert outputs is not None
    
        for output in outputs:
            assert output is not None
            assert isinstance(output, RequestOutput)
    
            generated_text = output.outputs[0].text
            assert generated_text is not None
            assert "\n" not in generated_text
    
            # Parse to verify it is valid JSON
>           parsed_json = json.loads(generated_text)

tests/entrypoints/llm/test_guided_generate.py:387: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
/usr/lib/python3.12/json/__init__.py:346: in loads
    return _default_decoder.decode(s)
/usr/lib/python3.12/json/decoder.py:337: in decode
    obj, end = self.raw_decode(s, idx=_w(s, 0).end())
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

self = <json.decoder.JSONDecoder object at 0x7f5923552570>, s = '{"clarifying_question": "", "cost_per_serving": ">$5", "calories": "}} Can you provide more context or clarify your q...onTemArray<MultipartUploadRequest<TemData,Versions.MLVersionInfo,urlMap>", "kitchen_ids": "$nil", "holiday_ids": "{}"}', idx = 0

    def raw_decode(self, s, idx=0):
        """Decode a JSON document from ``s`` (a ``str`` beginning with
        a JSON document) and return a 2-tuple of the Python
        representation and the index in ``s`` where the document ended.
    
        This can be used to decode a JSON document from a string that may
        have extraneous data at the end.
    
        """
        try:
>           obj, end = self.scan_once(s, idx)
E           json.decoder.JSONDecodeError: Invalid control character at: line 1 column 580 (char 579)

/usr/lib/python3.12/json/decoder.py:353: JSONDecodeError
-------------------------------------------------------------------------------------------------------------------------------------------------------- Captured stderr call ---------------------------------------------------------------------------------------------------------------------------------------------------------
Processed prompts: 100%|██████████| 1/1 [00:04<00:00,  4.55s/it, est. speed input: 8.79 toks/s, output: 136.43 toks/s]
========================================================================================================================================================== warnings summary ===========================================================================================================================================================
tests/entrypoints/llm/test_guided_generate.py::test_guided_grammar[autoregressive-outlines]
tests/entrypoints/llm/test_guided_generate.py::test_guided_grammar[speculative-outlines]
  /home/benchislett/Repos/vllm/.venv/lib/python3.12/site-packages/outlines/fsm/guide.py:110: UserWarning: Outlines' public *community-contributed* CFG structured generation is experimental. Please review https://dottxt-ai.github.io/outlines/latest/reference/generation/cfg#disclaimer
    warnings.warn(

tests/entrypoints/llm/test_guided_generate.py::test_validation_against_both_guided_decoding_options[autoregressive]
tests/entrypoints/llm/test_guided_generate.py::test_validation_against_both_guided_decoding_options[speculative]
  /home/benchislett/Repos/vllm/vllm/entrypoints/llm.py:462: DeprecationWarning: guided_options_request is deprecated, use SamplingParams.guided_decoding instead
    self._validate_and_add_requests(

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
======================================================================================================================================================= short test summary info =======================================================================================================================================================
FAILED tests/entrypoints/llm/test_guided_generate.py::test_json_with_any_whitespace_disabled[speculative] - json.decoder.JSONDecodeError: Invalid control character at: line 1 column 580 (char 579)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems stochastic, but it works for me with seed=0

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This tends to be the case with guidance failures: since informed prompting is ideal for guidance to be successful, the model often follows the schema even without guidance being enforced. This is why thorough testing is necessary to catch errors where guidance may not take place. Guidance should guarantee adherence, so even one failure is alarming to me.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree.
The problem is that it can fail to stop generating output after "}", regardless of whether speculative decoding is used or not. This issue is fixed in main by specifying the seed, as I understand.

My point is that it doesn’t seem to be a problem with speculative decoding itself. However, it is unexpected behaviour for me

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It does not seem clear to me that this is what is happening. To my understanding, when the matcher reaches a terminal state it will restrict the next token to be the end-of-sequence token, terminating the completion. Why would it fail to stop generating output?

The output from my above error report does not look like this is what happened:

{"clarifying_question": "", "cost_per_serving": ">$5", "calories": "}} Can you provide more context or clarify your q...onTemArray<MultipartUploadRequest<TemData,Versions.MLVersionInfo,urlMap>", "kitchen_ids": "$nil", "holiday_ids": "{}"}

This looks like a guidance failure, where the matcher state became out-of-sync with the completion resulting in incorrect outputs.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also checked the tests again, and it seems that guided output works correctly. The problem lies in the special tokens (e.g., \t in this case) that the model can place inside fields. It seems that xgrammar does not check them.

I also found that a model without speculative decoding can generate them as well, so I suppose this is not an issue with this PR.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand what you mean. Does this imply that there is a bug with xgrammar?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the generation of special symbols inside a JSON field (which can be loaded with strict=False) is a bug then yes


def get_llm_kwargs(mode: str):
if mode == "autoregressive":
return {}
return {
"speculative_config": {
"model": "Qwen/Qwen2.5-0.5B-Instruct",
"num_speculative_tokens": 3,
},
}

test_llm_kwargs = get_llm_kwargs(request.param)
# pytest caches the fixture so we use weakref.proxy to
# enable garbage collection
llm = LLM(model=MODEL_NAME, max_model_len=1024, seed=0)
llm = LLM(model=MODEL_NAME, max_model_len=1024, seed=0, **test_llm_kwargs)

with llm.deprecate_legacy_api():
yield weakref.proxy(llm)
Expand Down
19 changes: 13 additions & 6 deletions vllm/engine/async_llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@

import vllm.envs as envs
from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig,
ParallelConfig, SchedulerConfig, VllmConfig)
ParallelConfig, SchedulerConfig, SpeculativeConfig,
VllmConfig)
from vllm.core.scheduler import SchedulerOutputs
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_timeout import asyncio_timeout
Expand Down Expand Up @@ -511,7 +512,8 @@ async def add_request_async(
default_guided_backend=self.decoding_config.
guided_decoding_backend,
reasoning_backend=self.decoding_config.reasoning_backend,
model_config=self.model_config)
model_config=self.model_config,
speculative_config=self.speculative_config)

self._add_processed_request(
request_id=request_id,
Expand All @@ -536,9 +538,13 @@ async def collective_rpc_async(self,


async def build_guided_decoding_logits_processor_async(
sampling_params: SamplingParams, tokenizer: AnyTokenizer,
default_guided_backend: str, reasoning_backend: Optional[str],
model_config: ModelConfig) -> SamplingParams:
sampling_params: SamplingParams,
tokenizer: AnyTokenizer,
default_guided_backend: str,
reasoning_backend: Optional[str],
model_config: ModelConfig,
speculative_config: Optional[SpeculativeConfig] = None
) -> SamplingParams:
"""Constructs logits processors based on the guided_decoding,
logits_bias, and allowed_token_ids fields in sampling_params. Deletes
those fields and adds the constructed logits processors to the
Expand All @@ -564,7 +570,8 @@ async def build_guided_decoding_logits_processor_async(
guided_params=guided_decoding,
tokenizer=tokenizer,
reasoning_backend=reasoning_backend,
model_config=model_config)
model_config=model_config,
speculative_config=speculative_config)

if processor:
if sampling_params.logits_processors is None:
Expand Down
2 changes: 1 addition & 1 deletion vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -2102,7 +2102,7 @@ def _build_logits_processors(
tokenizer=tokenizer,
model_config=self.model_config,
reasoning_backend=self.decoding_config.reasoning_backend,
)
speculative_config=self.speculative_config)
if processor:
logits_processors.append(processor)

Expand Down
2 changes: 2 additions & 0 deletions vllm/engine/multiprocessing/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ def __init__(self, ipc_path: str, engine_config: VllmConfig,
self.vllm_config = engine_config
self.model_config = engine_config.model_config
self.decoding_config = engine_config.decoding_config
self.speculative_config = engine_config.speculative_config

# Create the tokenizer group.
self.tokenizer = init_tokenizer_from_configs(
Expand Down Expand Up @@ -620,6 +621,7 @@ async def _process_request(
else DecodingConfig.guided_decoding_backend),
model_config=self.model_config,
reasoning_backend=self.decoding_config.reasoning_backend,
speculative_config=self.speculative_config,
)

# 1) Create output queue for this requests.
Expand Down
22 changes: 14 additions & 8 deletions vllm/model_executor/guided_decoding/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from __future__ import annotations

from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Optional

from vllm.logger import init_logger
from vllm.model_executor.guided_decoding.utils import (
Expand All @@ -13,7 +13,7 @@
if TYPE_CHECKING:
from transformers import PreTrainedTokenizer

from vllm.config import ModelConfig
from vllm.config import ModelConfig, SpeculativeConfig
from vllm.logits_process import LogitsProcessor
from vllm.sampling_params import GuidedDecodingParams

Expand Down Expand Up @@ -100,6 +100,7 @@ async def get_guided_decoding_logits_processor(
guided_params: GuidedDecodingParams,
tokenizer: PreTrainedTokenizer,
model_config: ModelConfig,
speculative_config: Optional[SpeculativeConfig] = None,
reasoning_backend: str | None = None) -> LogitsProcessor | None:

reasoner = None
Expand All @@ -126,23 +127,27 @@ async def get_guided_decoding_logits_processor(
from vllm.model_executor.guided_decoding.xgrammar_decoding import ( # noqa
get_local_xgrammar_guided_decoding_logits_processor)
return get_local_xgrammar_guided_decoding_logits_processor(
guided_params, tokenizer, model_config, reasoner)
guided_params, tokenizer, model_config, reasoner,
speculative_config)
if guided_params.backend_name == 'guidance':
from vllm.model_executor.guided_decoding.guidance_decoding import (
get_local_guidance_guided_decoding_logits_processor)
return get_local_guidance_guided_decoding_logits_processor(
guided_params, tokenizer)

raise ValueError(
f"Unknown guided decoding backend '{guided_params.backend}'. "
"Must be one of 'outlines, 'lm-format-enforcer', 'xgrammar', 'guidance'"
)


def get_local_guided_decoding_logits_processor(
guided_params: GuidedDecodingParams,
tokenizer: PreTrainedTokenizer,
model_config: ModelConfig,
reasoning_backend: str | None = None) -> LogitsProcessor | None:
guided_params: GuidedDecodingParams,
tokenizer: PreTrainedTokenizer,
model_config: ModelConfig,
reasoning_backend: str | None = None,
speculative_config: Optional[SpeculativeConfig] = None
) -> LogitsProcessor | None:
guided_params = maybe_backend_fallback(guided_params)

reasoner = None
Expand All @@ -167,7 +172,8 @@ def get_local_guided_decoding_logits_processor(
from vllm.model_executor.guided_decoding.xgrammar_decoding import ( # noqa
get_local_xgrammar_guided_decoding_logits_processor)
return get_local_xgrammar_guided_decoding_logits_processor(
guided_params, tokenizer, model_config, reasoner)
guided_params, tokenizer, model_config, reasoner,
speculative_config)
if guided_params.backend_name == 'guidance':
from vllm.model_executor.guided_decoding.guidance_decoding import (
get_local_guidance_guided_decoding_logits_processor)
Expand Down
11 changes: 11 additions & 0 deletions vllm/model_executor/guided_decoding/guidance_logits_processors.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def __init__(
self.tokenizer_name = tokenizer.name_or_path
self.new_sampling = False
self.initialized = False
self.num_processed_tokens = 0

def _initialize(self):
if self.initialized:
Expand Down Expand Up @@ -69,7 +70,17 @@ def __call__(
# to avoid pickling ll_tokenizer and ll_interpreter
self._initialize()

if self.num_processed_tokens > 0 and self.num_processed_tokens >= len(
input_ids):
diff = self.num_processed_tokens - len(input_ids) + 1
self.ll_matcher.rollback(diff)
self.num_processed_tokens -= diff

if self.new_sampling and len(input_ids) > 0:
# The tokens are not truly consumed when the matcher is stopped,
# despite consume_token returning True. This is a workaround.
self.num_processed_tokens += 1 if not self.ll_matcher.is_stopped(
) else 0
self.ll_matcher.consume_token(input_ids[-1])
err = self.ll_matcher.get_error()
if err:
Expand Down
Loading