Skip to content

[TRTLLM-6406, TRTLLM-5172] feat: Enable guided decoding with overlap scheduler #6000

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Jul 17, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions cpp/tensorrt_llm/thop/logitsBitmaskOp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,8 @@ void logitsBitmask(std::vector<torch::Tensor> const& logits, std::vector<torch::
bitmaskPtrsHost[i] = reinterpret_cast<uint64_t>(bitmask[i].data_ptr());
}

auto logitsPtrs = logitsPtrsHost.to(torch::kCUDA);
auto bitmaskPtrs = bitmaskPtrsHost.to(torch::kCUDA);
auto logitsPtrs = logitsPtrsHost.to(torch::kCUDA, /*non_blocking=*/true);
auto bitmaskPtrs = bitmaskPtrsHost.to(torch::kCUDA, /*non_blocking=*/true);

auto stream = at::cuda::getCurrentCUDAStream(logits[0].get_device()).stream();

Expand Down
2 changes: 1 addition & 1 deletion docs/source/torch/features/feature_combination_matrix.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,4 @@
| KV Cache Reuse | Yes | Yes | Yes | Untested | Untested | Untested | Yes | No | Yes | Yes | --- | | | |
| Slide Window Attention | Yes | Yes | Yes | Untested | Untested | Untested | Untested | Untested | Yes | Yes | WIP | --- | | |
| Logits Post Processor | No | Yes | Yes | No | Untested | No | No | No | Yes | Yes | Yes | Yes | --- | |
| Guided Decoding | No | Yes | Yes | Untested | Yes | No | No | No | Yes | Yes | Yes | Yes | Yes | --- |
| Guided Decoding | Yes | Yes | Yes | No | Yes | No | No | No | Yes | Yes | Yes | Yes | Yes | --- |
9 changes: 3 additions & 6 deletions examples/llm-api/llm_guided_decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,9 @@

def main():

# Specify the guided decoding backend; xgrammar is supported currently.
llm = LLM(
model="TinyLlama/TinyLlama-1.1B-Chat-v1.0",
guided_decoding_backend='xgrammar',
disable_overlap_scheduler=True # Not supported by xgrammar mode
)
# Specify the guided decoding backend; xgrammar and llguidance are supported currently.
llm = LLM(model="TinyLlama/TinyLlama-1.1B-Chat-v1.0",
guided_decoding_backend='xgrammar')

# An example from json-mode-eval
schema = '{"title": "WirelessAccessPoint", "type": "object", "properties": {"ssid": {"title": "SSID", "type": "string"}, "securityProtocol": {"title": "SecurityProtocol", "type": "string"}, "bandwidth": {"title": "Bandwidth", "type": "string"}}, "required": ["ssid", "securityProtocol", "bandwidth"]}'
Expand Down
11 changes: 3 additions & 8 deletions tensorrt_llm/_torch/pyexecutor/_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from ..speculative import get_spec_decoder
from .config import PyTorchConfig
from .config_utils import is_mla, is_nemotron_hybrid
from .guided_decoder import GuidedDecoder
from .kv_cache_transceiver import AttentionTypeCpp, create_kv_cache_transceiver
from .llm_request import ExecutorResponse
from .model_engine import PyTorchModelEngine
Expand Down Expand Up @@ -414,19 +415,12 @@ def create_py_executor_instance(
start_worker,
sampler,
drafter,
guided_decoder: Optional[GuidedDecoder] = None,
lora_config: Optional[LoraConfig] = None,
garbage_collection_gen0_threshold: Optional[int] = None) -> PyExecutor:
kv_cache_manager = resources.get(ResourceManagerType.KV_CACHE_MANAGER, None)

spec_config = model_engine.spec_config
if mapping.is_last_pp_rank(
) and executor_config.guided_decoding_config is not None:
if spec_config is not None:
raise ValueError(
"Guided decoding is not supported with speculative decoding.")
if not pytorch_backend_config.disable_overlap_scheduler:
raise ValueError(
"Guided decoding is not supported with overlap scheduler.")

logger.info(
f"max_seq_len={executor_config.max_seq_len}, max_num_requests={executor_config.max_batch_size}, max_num_tokens={executor_config.max_num_tokens}, max_batch_size={executor_config.max_batch_size}"
Expand Down Expand Up @@ -544,6 +538,7 @@ def create_py_executor_instance(
if spec_config is not None else 0,
kv_cache_transceiver=kv_cache_transceiver,
draft_model_engine=draft_model_engine,
guided_decoder=guided_decoder,
start_worker=start_worker,
garbage_collection_gen0_threshold=garbage_collection_gen0_threshold)

Expand Down
14 changes: 7 additions & 7 deletions tensorrt_llm/_torch/pyexecutor/guided_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@

import torch

from ..._utils import nvtx_range
from ...bindings.executor import GuidedDecodingConfig
from .grammar_matcher import (GrammarMatcher, GrammarMatcherFactory,
LLGuidanceMatcherFactory, XGrammarMatcherFactory)
from .scheduler import ScheduledRequests
from .seq_slot_manager import SeqSlotManager


class GuidedDecoder:
Expand Down Expand Up @@ -49,12 +49,12 @@ def __init__(self, guided_decoding_config: GuidedDecodingConfig,
def bitmask_size(self) -> int:
return math.ceil(self.vocab_size_padded / 32)

def build(self, scheduled_requests: ScheduledRequests,
resource_manager: SeqSlotManager) -> None:
@nvtx_range("GuidedDecoder.build")
def build(self, scheduled_requests: ScheduledRequests) -> None:
for llm_req in scheduled_requests.all_requests():
if llm_req.guided_decoding_params is None:
continue
slot = resource_manager.slot_manager.get_slot(llm_req.request_id)
slot = llm_req.py_seq_slot
if llm_req.is_context_init_state and llm_req.context_current_position == llm_req.prepopulated_prompt_len:
self.grammar_matchers[
slot] = self.grammar_matcher_factory.create(
Expand All @@ -75,8 +75,9 @@ def build(self, scheduled_requests: ScheduledRequests,
self.bitmask[slot].copy_(self.bitmask_host[slot],
non_blocking=True)

@nvtx_range("GuidedDecoder.execute")
def execute(self, scheduled_requests: ScheduledRequests,
logits: torch.Tensor, resource_manager: SeqSlotManager) -> None:
logits: torch.Tensor) -> None:
assert logits.size(0) == len(scheduled_requests.context_requests) + len(
scheduled_requests.generation_requests)
torch.cuda.current_stream().wait_stream(self._stream)
Expand All @@ -88,8 +89,7 @@ def execute(self, scheduled_requests: ScheduledRequests,
if llm_req.is_context_init_state and not llm_req.is_last_context_chunk:
continue
batched_logits.append(logits[i])
slot = resource_manager.slot_manager.get_slot(llm_req.request_id)
batched_bitmask.append(self.bitmask[slot])
batched_bitmask.append(self.bitmask[llm_req.py_seq_slot])

if len(batched_logits) > 0:
torch.ops.trtllm.logits_bitmask(batched_logits, batched_bitmask)
22 changes: 0 additions & 22 deletions tensorrt_llm/_torch/pyexecutor/model_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
from tensorrt_llm._utils import (is_trace_enabled, local_mpi_rank,
local_mpi_size, nvtx_range, release_gc,
torch_dtype_to_str, trace_func)
from tensorrt_llm.bindings.executor import GuidedDecodingConfig
from tensorrt_llm.inputs.multimodal import MultimodalParams
from tensorrt_llm.logger import logger
from tensorrt_llm.lora_manager import LoraConfig, LoraModelConfig
Expand Down Expand Up @@ -57,7 +56,6 @@
from .config import LoadFormat, PyTorchConfig
from .config_utils import is_mla
from .cuda_graph_runner import DecodingCUDAGraphRunner
from .guided_decoder import GuidedDecoder
from .layerwise_nvtx_marker import LayerwiseNvtxMarker
from .resource_manager import (BaseResourceManager, KVCacheManager,
ResourceManager, ResourceManagerType)
Expand Down Expand Up @@ -354,7 +352,6 @@ def __init__(
attn_runtime_features: Optional[AttentionRuntimeFeatures] = None,
dist: Optional[MPIDist] = None,
spec_config: Optional["DecodingBaseConfig"] = None,
guided_decoding_config: Optional[GuidedDecodingConfig] = None,
lora_config: Optional[LoraConfig] = None,
is_draft_model: bool = False,
):
Expand Down Expand Up @@ -408,13 +405,6 @@ def __init__(
self.dtype = self.model.config.torch_dtype
self._init_model_capacity()

self.guided_decoder: Optional[GuidedDecoder] = None
if self.mapping.is_last_pp_rank(
) and guided_decoding_config is not None:
self.guided_decoder = GuidedDecoder(guided_decoding_config,
self.batch_size,
self.model.vocab_size_padded)

self._torch_compile_backend = None

try:
Expand Down Expand Up @@ -2170,18 +2160,6 @@ def capture_forward_fn(inputs: Dict[str, Any]):
with MoeLoadBalancerIterContext(moe_load_balancer):
outputs = maybe_graph.run(inputs)

# Note: To overlap the CPU and GPU computation as much as possible,
# guided_decoder.build should be called immediately after the launch of the single step;
# while guided_decoder.execute should be called right before the samplings.
# We can insert other CPU computation between them in the future.
if self.mapping.is_last_pp_rank(
) and self.guided_decoder is not None:
seq_slot_manager = resource_manager.get_resource_manager(
ResourceManagerType.SEQ_SLOT_MANAGER)
self.guided_decoder.build(scheduled_requests, seq_slot_manager)
self.guided_decoder.execute(scheduled_requests,
outputs['logits'], seq_slot_manager)

self._execute_logit_post_processors(scheduled_requests, outputs)

return outputs
Expand Down
24 changes: 22 additions & 2 deletions tensorrt_llm/_torch/pyexecutor/py_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@

from ..distributed import Distributed
from ..speculative.drafter import Drafter
from .guided_decoder import GuidedDecoder
from .kv_cache_transceiver import KvCacheTransceiver
from .llm_request import (ExecutorRequest, LlmRequest, LlmRequestState,
LlmResponse, executor_request_to_llm_request)
Expand Down Expand Up @@ -204,6 +205,7 @@ def __init__(self,
max_draft_len: int = 0,
kv_cache_transceiver: Optional[KvCacheTransceiver] = None,
draft_model_engine: Optional[ModelEngine] = None,
guided_decoder: Optional[GuidedDecoder] = None,
garbage_collection_gen0_threshold: Optional[int] = None,
start_worker: bool = True):
super(PyExecutor, self).__init__()
Expand All @@ -225,6 +227,7 @@ def __init__(self,
self.enable_attention_dp = model_engine.enable_attention_dp
self.sampler = sampler
self.drafter = drafter
self.guided_decoder = guided_decoder
self.dist = dist
self.disable_overlap_scheduler = disable_overlap_scheduler

Expand Down Expand Up @@ -801,6 +804,12 @@ def _executor_loop_pp(self):
if self._need_return_logits(scheduled_batch):
logits_host = batch_outputs["logits"].to(
"cpu", non_blocking=True)

if self.guided_decoder is not None:
self.guided_decoder.build(scheduled_batch)
self.guided_decoder.execute(
scheduled_batch, batch_outputs['logits'])

sample_state = self._sample_async(
scheduled_batch, batch_outputs)
sample_state.host.logits = logits_host
Expand Down Expand Up @@ -975,6 +984,11 @@ def _executor_loop(self):

batch_outputs = self._forward_step(scheduled_batch)

if self.guided_decoder is not None:
self.guided_decoder.build(scheduled_batch)
self.guided_decoder.execute(scheduled_batch,
batch_outputs['logits'])

sample_state = self._sample_async(scheduled_batch,
batch_outputs)

Expand Down Expand Up @@ -1123,6 +1137,14 @@ def _executor_loop_overlap(self):
batch_outputs = self._forward_step(scheduled_batch,
previous_tensors_device)

if self.previous_batch is not None:
self._update_requests(self.previous_batch.sample_state)

if self.guided_decoder is not None:
self.guided_decoder.build(scheduled_batch)
self.guided_decoder.execute(scheduled_batch,
batch_outputs['logits'])

sample_state = self._sample_async(scheduled_batch,
batch_outputs)
assert sample_state is not None, "Sampling failed"
Expand Down Expand Up @@ -1156,8 +1178,6 @@ def _executor_loop_overlap(self):
self._terminate_ctx_finished_requests()

def _process_previous_batch(self):
self._update_requests(self.previous_batch.sample_state)

if self.kv_cache_transceiver and self.previous_batch.ctx_transmission_reqs:
for req in self.previous_batch.ctx_transmission_reqs:
req.state = LlmRequestState.DISAGG_CONTEXT_TRANS_IN_PROGRESS
Expand Down
15 changes: 14 additions & 1 deletion tensorrt_llm/_torch/pyexecutor/py_executor_creator.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
create_py_executor_instance, instantiate_sampler, is_mla)
from .config import PyTorchConfig
from .config_utils import is_mla
from .guided_decoder import GuidedDecoder
from .model_engine import PyTorchModelEngine
from .py_executor import PyExecutor

Expand Down Expand Up @@ -237,7 +238,6 @@ def create_py_executor(
attn_runtime_features=attn_runtime_features,
dist=dist,
spec_config=spec_config,
guided_decoding_config=executor_config.guided_decoding_config,
lora_config=lora_config,
)

Expand Down Expand Up @@ -342,6 +342,17 @@ def create_py_executor(
sampler = instantiate_sampler(model_engine, executor_config,
pytorch_backend_config, mapping)

guided_decoder: Optional[GuidedDecoder] = None
if executor_config.guided_decoding_config is not None:
if spec_config is not None:
raise ValueError(
"Guided decoding is not supported with speculative decoding.")
if mapping.is_last_pp_rank():
guided_decoder = GuidedDecoder(
executor_config.guided_decoding_config,
executor_config.max_batch_size,
model_engine.model.vocab_size_padded)

resources = {}
estimating_kv_cache = False
kv_cache_creator = None
Expand Down Expand Up @@ -385,6 +396,7 @@ def create_py_executor(
start_worker=False,
sampler=sampler,
drafter=drafter,
guided_decoder=guided_decoder,
lora_config=lora_config,
garbage_collection_gen0_threshold=garbage_collection_gen0_threshold,
)
Expand Down Expand Up @@ -427,6 +439,7 @@ def create_py_executor(
start_worker=False,
sampler=sampler,
drafter=drafter,
guided_decoder=guided_decoder,
lora_config=lora_config,
garbage_collection_gen0_threshold=
garbage_collection_gen0_threshold,
Expand Down
2 changes: 0 additions & 2 deletions tests/integration/defs/accuracy/test_llm_api_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,6 @@ def test_guided_decoding(self, backend: str, mocker):
mocker.patch.dict(os.environ, {"TRTLLM_XGUIDANCE_LENIENT": "1"})
llm = LLM(self.MODEL_PATH,
guided_decoding_backend=backend,
disable_overlap_scheduler=True,
cuda_graph_config=CudaGraphConfig())
with llm:
task = JsonModeEval(self.MODEL_NAME)
Expand All @@ -301,7 +300,6 @@ def test_guided_decoding_4gpus(self, backend: str, mocker):
mocker.patch.dict(os.environ, {"TRTLLM_XGUIDANCE_LENIENT": "1"})
with LLM(self.MODEL_PATH,
guided_decoding_backend=backend,
disable_overlap_scheduler=True,
cuda_graph_config=CudaGraphConfig(),
tensor_parallel_size=2,
pipeline_parallel_size=2) as llm:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,7 @@ def temp_extra_llm_api_options_file(request):
temp_dir = tempfile.gettempdir()
temp_file_path = os.path.join(temp_dir, "extra_llm_api_options.yaml")
try:
extra_llm_api_options_dict = {
"guided_decoding_backend": "xgrammar",
"disable_overlap_scheduler": True,
}
extra_llm_api_options_dict = {"guided_decoding_backend": "xgrammar"}

with open(temp_file_path, 'w') as f:
yaml.dump(extra_llm_api_options_dict, f)
Expand Down