Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 18 additions & 12 deletions vllm/v1/structured_output/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,31 +149,37 @@ def grammar_bitmask(
# NOTE: This outer loop can likely be parallelized to improve
# performance of bitmask generation for large batches.
for req_id, _ in ordered_seq:
request = requests[req_id].structured_output_request
if TYPE_CHECKING:
assert request is not None
assert request.grammar is not None
request = requests[req_id]
structured_output_request = request.structured_output_request

apply_bitmask = (
request.reasoning_ended if self.reasoner is not None else True
) # noqa: E501
if TYPE_CHECKING:
assert structured_output_request is not None
assert structured_output_request.grammar is not None
apply_bitmask: bool = True
if self.reasoner is not None:
if structured_output_request.reasoning_ended is None:
structured_output_request.reasoning_ended = \
self.reasoner.is_reasoning_end(request.prompt_token_ids)
apply_bitmask = structured_output_request.reasoning_ended

state_advancements = 0
req_tokens = scheduled_spec_decode_tokens.get(req_id, []) + [None]
for i, token in enumerate(req_tokens):
if apply_bitmask and not request.grammar.is_terminated():
request.grammar.fill_bitmask(bitmask_tensor,
cumulative_index)
if apply_bitmask and not \
structured_output_request.grammar.is_terminated():
structured_output_request.grammar.fill_bitmask(
bitmask_tensor, cumulative_index)
if token is not None:
# In order to generate the correct bitmask for each
# position in the speculative sequence, we advance
# the FSM state for each speculative token and rollback
# to restore the previous state when we are finished.
assert request.grammar.accept_tokens(req_id, [token])
assert structured_output_request.grammar.accept_tokens(
req_id, [token])
state_advancements += 1
cumulative_index += 1
if state_advancements > 0:
request.grammar.rollback(state_advancements)
structured_output_request.grammar.rollback(state_advancements)

if cumulative_index < bitmask_tensor.shape[0]:
bitmask_tensor = bitmask_tensor[:cumulative_index]
Expand Down
2 changes: 1 addition & 1 deletion vllm/v1/structured_output/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ class StructuredOutputRequest:
sampling_params: SamplingParams
_grammar: Optional[Union[Future[StructuredOutputGrammar],
StructuredOutputGrammar]] = None
reasoning_ended: bool = False
reasoning_ended: Optional[bool] = None

def _check_grammar_completion(self) -> bool:
# NOTE: We have to lazy import to gate circular imports
Expand Down