Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 12 additions & 10 deletions vllm/model_executor/guided_decoding/xgrammar_decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,18 @@ def _ensure_ctx(self):
raise ValueError(
"Invalid configuration for xgrammar logits processor")

def accept(self, token_ids: int) -> bool:
if self.ctx is None:
self._ensure_ctx()
if len(self.matchers) == 0:
self.matchers = [
xgr.GrammarMatcher(self.ctx) for _ in range(self.batch_size)
]
self.token_bitmask = xgr.allocate_token_bitmask(
self.batch_size, self.tokenizer_info.vocab_size)
return self.matchers[0].accept_token(
token_ids) or self.matchers[0].is_terminated()

def __call__(self, input_ids: list[int],
scores: torch.Tensor) -> torch.Tensor:

Expand All @@ -345,15 +357,6 @@ def __call__(self, input_ids: list[int],
self.token_bitmask = xgr.allocate_token_bitmask(
self.batch_size, self.tokenizer_info.vocab_size)

if not self.prefilled:
# Have not sampled a token yet
self.prefilled = True
else:
for i, matcher in enumerate(self.matchers):
if not matcher.is_terminated():
sampled_token = input_ids[-1]
assert self.matchers[i].accept_token(sampled_token)

for i, matcher in enumerate(self.matchers):
if not matcher.is_terminated():
# @ubospica: ideally, fill_next_token_bitmask should be
Expand Down Expand Up @@ -402,5 +405,4 @@ def clone(self) -> XGrammarLogitsProcessor:
new_processor.batch_size = self.batch_size
# Reset prefilled state for new sequence
new_processor.prefilled = False

return new_processor
61 changes: 46 additions & 15 deletions vllm/model_executor/layers/logits_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,27 @@
envs.VLLM_LOGITS_PROCESSOR_THREADS)


def accept_grammar(token_ids: torch.Tensor,
sampling_metadata: SamplingMetadata,
mask: torch.Tensor = None) -> torch.Tensor:
if mask is None:
mask = torch.ones_like(token_ids, dtype=torch.bool)

accept = torch.ones_like(token_ids, dtype=torch.bool)
for seq_group in sampling_metadata.seq_groups:
logits_processors = seq_group.sampling_params.logits_processors
if logits_processors:
for row_idx, logits_processor in zip(seq_group.sample_indices,
logits_processors):
tkid = token_ids[row_idx].item()
tkmask = mask[row_idx].item()
# only when mask =1 , fsm accept the token
if tkmask:
accept[row_idx] = accept[row_idx].item(
) and logits_processor.accept(tkid)
return accept


class LogitsProcessor(nn.Module):
"""Process logits and apply logits processors from sampling metadata.

Expand Down Expand Up @@ -52,13 +73,12 @@ def __init__(self,
# Whether to use gather or all-gather to gather the logits.
self.use_all_gather = current_platform.use_all_gather()

def forward(
self,
lm_head: VocabParallelEmbedding,
hidden_states: torch.Tensor,
sampling_metadata: Optional[SamplingMetadata] = None,
embedding_bias: Optional[torch.Tensor] = None,
) -> Optional[torch.Tensor]:
def forward(self,
lm_head: VocabParallelEmbedding,
hidden_states: torch.Tensor,
sampling_metadata: Optional[SamplingMetadata] = None,
embedding_bias: Optional[torch.Tensor] = None,
skip_grammar: bool = False) -> Optional[torch.Tensor]:
if self.logits_as_input:
logits = hidden_states
else:
Expand All @@ -77,6 +97,9 @@ def forward(
if self.scale != 1.0:
logits *= self.scale

if skip_grammar:
return logits

# Apply logits processors (if any).
if sampling_metadata is not None and \
sampling_metadata.seq_groups is not None:
Expand Down Expand Up @@ -138,12 +161,15 @@ def _prune_hidden_states(
return hidden_states


def _apply_logits_processors(
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> torch.Tensor:
def _apply_logits_processors(logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
mask: torch.Tensor = None,
accept_last: bool = True) -> torch.Tensor:
found_logits_processors = False
logits_processed = 0

if mask is None:
mask = torch.ones(size=(logits.shape[0], ), dtype=torch.bool)
logits_row_ids_and_logits_row_futures = []
for seq_group in sampling_metadata.seq_groups:
seq_ids = seq_group.seq_ids
Expand All @@ -154,6 +180,9 @@ def _apply_logits_processors(

for seq_id, logits_row_idx in zip(seq_ids,
seq_group.sample_indices):
if not mask[logits_row_idx].item():
continue

logits_row = logits[logits_row_idx]
past_tokens_ids = seq_group.seq_data[seq_id].output_token_ids
prompt_tokens_ids = seq_group.seq_data[seq_id].prompt_token_ids
Expand All @@ -164,12 +193,12 @@ def _apply_logits_processors(
_logits_processor_threadpool.submit(
_apply_logits_processors_single_seq, logits_row,
logits_processors, past_tokens_ids,
prompt_tokens_ids)))
prompt_tokens_ids, accept_last)))
else:
logits[logits_row_idx] = \
_apply_logits_processors_single_seq(
logits_row, logits_processors, past_tokens_ids,
prompt_tokens_ids)
prompt_tokens_ids, accept_last)

logits_processed += len(seq_group.sample_indices) + len(
seq_group.prompt_logprob_indices)
Expand All @@ -184,13 +213,15 @@ def _apply_logits_processors(


def _apply_logits_processors_single_seq(logits_row, logits_processors,
past_tokens_ids,
prompt_tokens_ids) -> torch.Tensor:
past_tokens_ids, prompt_tokens_ids,
accept_last) -> torch.Tensor:
for logits_processor in logits_processors:
parameters = inspect.signature(logits_processor).parameters
if len(parameters) == 3:
logits_row = logits_processor(prompt_tokens_ids, past_tokens_ids,
logits_row)
else:
if accept_last and len(past_tokens_ids) > 0:
logits_processor.accept(past_tokens_ids[-1])
logits_row = logits_processor(past_tokens_ids, logits_row)
return logits_row
52 changes: 52 additions & 0 deletions vllm/model_executor/models/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@
QuantizationConfig)
from vllm.utils import supports_kw

from .. import SamplingMetadata
from ..layers.logits_processor import _apply_logits_processors, accept_grammar
from ..layers.sampler import SamplerOutput
from .interfaces_base import is_pooling_model

if TYPE_CHECKING:
Expand Down Expand Up @@ -221,6 +224,55 @@ def forward(
...


class SupportsSampleV2:

def compute_logits_v2(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head,
hidden_states,
sampling_metadata,
skip_grammar=True)
return logits

def samplev2(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
# compute logits
next_tokens: SamplerOutput = self.sampler(logits, sampling_metadata)

# check if the sampled tokens fit the grammars
tks = torch.tensor(
[o.samples[0].output_token for o in next_tokens.outputs])
accepted = accept_grammar(tks, sampling_metadata)
need_resample = torch.logical_not(accepted)
if accepted.all():
return next_tokens
# resample
# if the token is not valid, sample again.
# but first apply the grammar bitmask
# only apply logits processor when need_resample
logits = _apply_logits_processors(logits, sampling_metadata,
need_resample, False)
new_next_tokens: SamplerOutput = self.sampler(logits,
sampling_metadata)

for i, replace in enumerate(need_resample.tolist()):
if replace:
next_tokens.outputs[i] = new_next_tokens.outputs[i]

tks = torch.tensor(
[o.samples[0].output_token for o in next_tokens.outputs])
# matcher only accept next token when first round is not accepted.
accepted = accept_grammar(tks, sampling_metadata, need_resample)
assert accepted.all()
return next_tokens


# We can't use runtime_checkable with ClassVar for issubclass checks
# so we need to treat the class as an instance and use isinstance instead
@runtime_checkable
Expand Down
4 changes: 2 additions & 2 deletions vllm/model_executor/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors

from .interfaces import SupportsLoRA, SupportsPP
from .interfaces import SupportsLoRA, SupportsPP, SupportsSampleV2
from .utils import (AutoWeightsLoader, PPMissingLayer, extract_layer_index,
is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers,
Expand Down Expand Up @@ -433,7 +433,7 @@ def load_weights(self, weights: Iterable[Tuple[str,
return loaded_params


class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsSampleV2):
packed_modules_mapping = {
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
"gate_up_proj": ["gate_proj", "up_proj"]
Expand Down
5 changes: 3 additions & 2 deletions vllm/model_executor/models/qwen.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors

from .interfaces import SupportsLoRA, SupportsPP
from .interfaces import SupportsLoRA, SupportsPP, SupportsSampleV2
from .utils import (is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
Expand Down Expand Up @@ -335,7 +335,8 @@ def load_weights(self, weights: Iterable[Tuple[str,
return loaded_params


class QWenLMHeadModel(QWenBaseModel, SupportsPP, SupportsLoRA):
class QWenLMHeadModel(QWenBaseModel, SupportsPP, SupportsLoRA,
SupportsSampleV2):
packed_modules_mapping = {
"c_attn": ["c_attn"],
"gate_up_proj": [
Expand Down
4 changes: 2 additions & 2 deletions vllm/model_executor/models/qwen2.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors, PoolerOutput

from .interfaces import SupportsLoRA, SupportsPP
from .interfaces import SupportsLoRA, SupportsPP, SupportsSampleV2
from .utils import (AutoWeightsLoader, PPMissingLayer, WeightsMapper,
is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers,
Expand Down Expand Up @@ -405,7 +405,7 @@ def load_weights(self, weights: Iterable[Tuple[str,
return loaded_params


class Qwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
class Qwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsSampleV2):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
Expand Down
20 changes: 14 additions & 6 deletions vllm/worker/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from vllm.model_executor.model_loader import get_model
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
from vllm.model_executor.models import supports_lora, supports_multimodal
from vllm.model_executor.models.interfaces import SupportsSampleV2
from vllm.model_executor.models.utils import set_cpu_offload_max_bytes
from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs,
MultiModalKwargs, MultiModalPlaceholderMap,
Expand Down Expand Up @@ -1785,8 +1786,12 @@ def execute_model(
torch.tensor(model_forward_time + orig_model_forward_time))
return hidden_or_intermediate_states

logits = self.model.compute_logits(hidden_or_intermediate_states,
model_input.sampling_metadata)
if isinstance(self.model, SupportsSampleV2):
logits = self.model.compute_logits_v2(
hidden_or_intermediate_states, model_input.sampling_metadata)
else:
logits = self.model.compute_logits(hidden_or_intermediate_states,
model_input.sampling_metadata)

if not self.is_driver_worker:
return []
Expand All @@ -1795,10 +1800,13 @@ def execute_model(
model_input.async_callback()

# Sample the next token.
output: SamplerOutput = self.model.sample(
logits=logits,
sampling_metadata=model_input.sampling_metadata,
)
if isinstance(self.model, SupportsSampleV2):
output: SamplerOutput = self.model.samplev2(
logits=logits, sampling_metadata=model_input.sampling_metadata)
else:
output = self.model.sample(
logits=logits, sampling_metadata=model_input.sampling_metadata)

if (self.observability_config is not None
and self.observability_config.collect_model_forward_time
and output is not None):
Expand Down