File tree Expand file tree Collapse file tree 1 file changed +8
-1
lines changed Expand file tree Collapse file tree 1 file changed +8
-1
lines changed Original file line number Diff line number Diff line change 6666
6767if TYPE_CHECKING :
6868 import xgrammar as xgr
69+ import xgrammar .kernels .apply_token_bitmask_inplace_torch_compile as xgr_torch_compile # noqa: E501
6970
7071 from vllm .model_executor .model_loader .tensorizer import TensorizerConfig
7172 from vllm .v1 .core .sched .output import SchedulerOutput
7273else :
7374 xgr = LazyLoader ("xgr" , globals (), "xgrammar" )
75+ xgr_torch_compile = LazyLoader (
76+ "xgr_torch_compile" , globals (),
77+ "xgrammar.kernels.apply_token_bitmask_inplace_torch_compile" )
7478
7579logger = init_logger (__name__ )
7680
@@ -1103,7 +1107,10 @@ def apply_grammar_bitmask(
11031107 # so we receive it in that format.
11041108 grammar_bitmask = torch .from_numpy (grammar_bitmask )
11051109
1106- xgr .apply_token_bitmask_inplace (
1110+ # Force use of the torch.compile implementation from xgrammar to work
1111+ # around issues with the Triton kernel in concurrent structured output
1112+ # scenarios. See PR #19565 and issues #19493, #18376 for details.
1113+ xgr_torch_compile .apply_token_bitmask_inplace_torch_compile (
11071114 logits ,
11081115 grammar_bitmask .to (self .device , non_blocking = True ),
11091116 indices = out_indices ,
You can’t perform that action at this time.
0 commit comments