|
| 1 | +""" |
| 2 | +Eagle + Structured Output FSM Validation Fix |
| 3 | +
|
| 4 | +ISSUE: When using Eagle speculative decoding with structured output (tool calling), |
| 5 | + vLLM crashes with AssertionError when the FSM rejects a spec token that |
| 6 | + is present in the scheduled_spec_decode_tokens list. |
| 7 | +
|
| 8 | +ERROR: "Failed to advance FSM for request ... for tokens XXX. Please file an issue." |
| 9 | + Followed by: AssertionError at vllm/v1/structured_output/__init__.py line 263 |
| 10 | +
|
| 11 | +OBSERVED BEHAVIOR: |
| 12 | +- Eagle generates speculative tokens for next iterations |
| 13 | +- Scheduler validates these tokens via grammar.validate_tokens() and stores valid prefix |
| 14 | +- During model execution, grammar_bitmask() validates the same tokens again |
| 15 | +- Sometimes accept_tokens() returns False even for tokens in the scheduled list |
| 16 | +- The assertion crashes the entire engine instead of handling this gracefully |
| 17 | +
|
| 18 | +ROOT CAUSE UNKNOWN - Possible explanations: |
| 19 | +1. Race condition or state mismatch between validation and bitmask generation |
| 20 | +2. Bug in xgrammar rollback functionality |
| 21 | +3. Interaction between Eagle + structured output + penalties causing state corruption |
| 22 | +4. Concurrency issue with shared grammar state |
| 23 | +
|
| 24 | +SOLUTION: |
| 25 | +Replace the assertion with a defensive conditional check: |
| 26 | +- If token is valid according to FSM, accept it and advance state |
| 27 | +- If token is invalid, log debug message but continue loop |
| 28 | +- Still fill bitmasks for all tokens to maintain correct array size |
| 29 | +- Makes the code resilient to FSM state mismatches |
| 30 | +
|
| 31 | +This is a defensive fix that prevents crashes without fully understanding the root cause. |
| 32 | +
|
| 33 | +UPSTREAM STATUS: Not fixed in vLLM upstream (bug exists since PR #18879, May 2025) |
| 34 | +UPSTREAMABLE: Yes - this defensive approach should be contributed upstream |
| 35 | +
|
| 36 | +CLEANUP: When upstream fixes this bug: |
| 37 | +1. Delete this file |
| 38 | +2. Remove eagle_structured_output_fix from patch_config.json |
| 39 | +3. Remove registration from plugin.py |
| 40 | +4. Reinstall plugin |
| 41 | +""" |
| 42 | + |
| 43 | +from vllm.logger import init_logger |
| 44 | + |
| 45 | +logger = init_logger(__name__) |
| 46 | + |
| 47 | + |
| 48 | +def create_patched_grammar_bitmask(): |
| 49 | + """ |
| 50 | + Factory function that creates the patched grammar_bitmask method. |
| 51 | +
|
| 52 | + This replaces the assertion with a conditional check to handle |
| 53 | + the case where the scheduler drops invalid spec tokens. |
| 54 | + """ |
| 55 | + # Import here to avoid issues if vLLM isn't installed |
| 56 | + from typing import TYPE_CHECKING |
| 57 | + |
| 58 | + def grammar_bitmask( |
| 59 | + self, |
| 60 | + requests: dict, |
| 61 | + structured_output_request_ids: list[str], |
| 62 | + scheduled_spec_decode_tokens: dict[str, list[int]], |
| 63 | + ): |
| 64 | + """ |
| 65 | + Patched version that handles FSM rejection of scheduled spec tokens. |
| 66 | +
|
| 67 | + Changes from upstream: |
| 68 | + - Replace assertion with defensive conditional check |
| 69 | + - If spec token is rejected by FSM, log and continue (don't crash) |
| 70 | + - Continue filling bitmasks for all tokens to maintain array size |
| 71 | + - Makes code resilient to FSM state mismatches |
| 72 | + """ |
| 73 | + max_num_spec_tokens = 0 |
| 74 | + if self.vllm_config.speculative_config is not None: |
| 75 | + max_num_spec_tokens = ( |
| 76 | + self.vllm_config.speculative_config.num_speculative_tokens |
| 77 | + ) |
| 78 | + |
| 79 | + if self._grammar_bitmask is None: |
| 80 | + assert self.backend is not None |
| 81 | + max_batch_size = self.vllm_config.scheduler_config.max_num_seqs |
| 82 | + |
| 83 | + # Allocate a bitmask for each token needing to be checked: |
| 84 | + # one for each speculative position, and one more for the |
| 85 | + # bonus token / non-speculative token. |
| 86 | + self._grammar_bitmask = self.backend.allocate_token_bitmask( |
| 87 | + max_batch_size * (1 + max_num_spec_tokens) |
| 88 | + ) |
| 89 | + |
| 90 | + # Generate a batched bitmask for all structured output requests. |
| 91 | + # When speculative decoding is enabled, we need to include multiple |
| 92 | + # masks for each request, one for each possible bonus token position. |
| 93 | + # These are stored inline in the tensor and unpacked by the gpu runner. |
| 94 | + cumulative_index = 0 |
| 95 | + |
| 96 | + # Optimized parallel filling of bitmasks for |
| 97 | + # non-spec, large-batch-size cases |
| 98 | + if ( |
| 99 | + len(structured_output_request_ids) > self.fill_bitmask_parallel_threshold |
| 100 | + and max_num_spec_tokens == 0 |
| 101 | + ): |
| 102 | + promises = [] |
| 103 | + batch = [] |
| 104 | + for req_id in structured_output_request_ids: |
| 105 | + request = requests[req_id] |
| 106 | + structured_output_request = request.structured_output_request |
| 107 | + if TYPE_CHECKING: |
| 108 | + assert structured_output_request is not None |
| 109 | + assert structured_output_request.grammar is not None |
| 110 | + |
| 111 | + apply_bitmask = self.should_fill_bitmask(request) |
| 112 | + batch.append( |
| 113 | + (structured_output_request.grammar, cumulative_index, apply_bitmask) |
| 114 | + ) |
| 115 | + if len(batch) == self.fill_bitmask_parallel_batch_size: |
| 116 | + promises.append(self._async_submit_fill_bitmask(batch)) |
| 117 | + batch = [] |
| 118 | + |
| 119 | + cumulative_index += 1 |
| 120 | + if batch: |
| 121 | + promises.append(self._async_submit_fill_bitmask(batch)) |
| 122 | + |
| 123 | + # Wait for all bitmask filling tasks to complete. |
| 124 | + for promise in promises: |
| 125 | + promise.result() |
| 126 | + else: |
| 127 | + # Fallback to serial filling of bitmasks for small-batch-size cases |
| 128 | + for req_id in structured_output_request_ids: |
| 129 | + request = requests[req_id] |
| 130 | + structured_output_request = request.structured_output_request |
| 131 | + |
| 132 | + if TYPE_CHECKING: |
| 133 | + assert structured_output_request is not None |
| 134 | + assert structured_output_request.grammar is not None |
| 135 | + apply_bitmask = self.should_fill_bitmask(request) |
| 136 | + |
| 137 | + state_advancements = 0 |
| 138 | + req_tokens = scheduled_spec_decode_tokens.get(req_id, []) |
| 139 | + for i, token in enumerate(req_tokens + [None]): |
| 140 | + self._fill_bitmasks( |
| 141 | + [ |
| 142 | + ( |
| 143 | + structured_output_request.grammar, |
| 144 | + cumulative_index, |
| 145 | + apply_bitmask, |
| 146 | + ) |
| 147 | + ] |
| 148 | + ) |
| 149 | + |
| 150 | + # ============================================================ |
| 151 | + # MANTLE FIX: Replace assertion with conditional check |
| 152 | + # ============================================================ |
| 153 | + if ( |
| 154 | + apply_bitmask |
| 155 | + and token is not None |
| 156 | + and not structured_output_request.grammar.is_terminated() |
| 157 | + ): |
| 158 | + # ORIGINAL (causes crash): |
| 159 | + # assert structured_output_request.grammar.accept_tokens( |
| 160 | + # req_id, [token] |
| 161 | + # ) |
| 162 | + # state_advancements += 1 |
| 163 | + |
| 164 | + # FIXED (defensive approach - no crash): |
| 165 | + # Only advance state if token is accepted by grammar. |
| 166 | + # If rejected, continue loop to fill bitmasks for all tokens |
| 167 | + # (downstream code in apply_grammar_bitmask expects exact array size). |
| 168 | + if structured_output_request.grammar.accept_tokens(req_id, [token]): |
| 169 | + state_advancements += 1 |
| 170 | + else: |
| 171 | + # Token rejected by FSM even though it's in scheduled list. |
| 172 | + # Root cause unknown (FSM state mismatch, xgrammar bug, etc.) |
| 173 | + # but we handle it gracefully instead of crashing. |
| 174 | + logger.debug( |
| 175 | + f"Grammar rejected spec token {token} for request {req_id}. " |
| 176 | + "This indicates an FSM state mismatch in Eagle + structured output. " |
| 177 | + "Continuing without advancing grammar state." |
| 178 | + ) |
| 179 | + # Continue to next token (don't break) to ensure |
| 180 | + # bitmask array has correct size for apply_grammar_bitmask |
| 181 | + # ============================================================ |
| 182 | + # END MANTLE FIX |
| 183 | + # ============================================================ |
| 184 | + |
| 185 | + cumulative_index += 1 |
| 186 | + if state_advancements > 0: |
| 187 | + structured_output_request.grammar.rollback(state_advancements) |
| 188 | + |
| 189 | + bitmask_tensor = self._grammar_bitmask |
| 190 | + if cumulative_index < bitmask_tensor.shape[0]: |
| 191 | + bitmask_tensor = bitmask_tensor[:cumulative_index] |
| 192 | + |
| 193 | + # After finishing with the xgrammar operations, we convert to |
| 194 | + # np.ndarray, because that is much more efficient for serialization |
| 195 | + # and deserialization when sending this to the GPU workers. |
| 196 | + return bitmask_tensor.numpy() |
| 197 | + |
| 198 | + return grammar_bitmask |
0 commit comments