diff --git a/vllm/model_executor/guided_decoding/xgrammar_decoding.py b/vllm/model_executor/guided_decoding/xgrammar_decoding.py index bc156223953e..8a912a4d6482 100644 --- a/vllm/model_executor/guided_decoding/xgrammar_decoding.py +++ b/vllm/model_executor/guided_decoding/xgrammar_decoding.py @@ -325,6 +325,18 @@ def _ensure_ctx(self): raise ValueError( "Invalid configuration for xgrammar logits processor") + def accept(self, token_ids: int) -> bool: + if self.ctx is None: + self._ensure_ctx() + if len(self.matchers) == 0: + self.matchers = [ + xgr.GrammarMatcher(self.ctx) for _ in range(self.batch_size) + ] + self.token_bitmask = xgr.allocate_token_bitmask( + self.batch_size, self.tokenizer_info.vocab_size) + return self.matchers[0].accept_token( + token_ids) or self.matchers[0].is_terminated() + def __call__(self, input_ids: list[int], scores: torch.Tensor) -> torch.Tensor: @@ -345,15 +357,6 @@ def __call__(self, input_ids: list[int], self.token_bitmask = xgr.allocate_token_bitmask( self.batch_size, self.tokenizer_info.vocab_size) - if not self.prefilled: - # Have not sampled a token yet - self.prefilled = True - else: - for i, matcher in enumerate(self.matchers): - if not matcher.is_terminated(): - sampled_token = input_ids[-1] - assert self.matchers[i].accept_token(sampled_token) - for i, matcher in enumerate(self.matchers): if not matcher.is_terminated(): # @ubospica: ideally, fill_next_token_bitmask should be @@ -402,5 +405,4 @@ def clone(self) -> XGrammarLogitsProcessor: new_processor.batch_size = self.batch_size # Reset prefilled state for new sequence new_processor.prefilled = False - return new_processor diff --git a/vllm/model_executor/layers/logits_processor.py b/vllm/model_executor/layers/logits_processor.py index 4a359725bad0..690f5fe05b2a 100644 --- a/vllm/model_executor/layers/logits_processor.py +++ b/vllm/model_executor/layers/logits_processor.py @@ -21,6 +21,27 @@ envs.VLLM_LOGITS_PROCESSOR_THREADS) +def accept_grammar(token_ids: torch.Tensor, + sampling_metadata: SamplingMetadata, + mask: torch.Tensor = None) -> torch.Tensor: + if mask is None: + mask = torch.ones_like(token_ids, dtype=torch.bool) + + accept = torch.ones_like(token_ids, dtype=torch.bool) + for seq_group in sampling_metadata.seq_groups: + logits_processors = seq_group.sampling_params.logits_processors + if logits_processors: + for row_idx, logits_processor in zip(seq_group.sample_indices, + logits_processors): + tkid = token_ids[row_idx].item() + tkmask = mask[row_idx].item() + # only when mask =1 , fsm accept the token + if tkmask: + accept[row_idx] = accept[row_idx].item( + ) and logits_processor.accept(tkid) + return accept + + class LogitsProcessor(nn.Module): """Process logits and apply logits processors from sampling metadata. @@ -52,13 +73,12 @@ def __init__(self, # Whether to use gather or all-gather to gather the logits. self.use_all_gather = current_platform.use_all_gather() - def forward( - self, - lm_head: VocabParallelEmbedding, - hidden_states: torch.Tensor, - sampling_metadata: Optional[SamplingMetadata] = None, - embedding_bias: Optional[torch.Tensor] = None, - ) -> Optional[torch.Tensor]: + def forward(self, + lm_head: VocabParallelEmbedding, + hidden_states: torch.Tensor, + sampling_metadata: Optional[SamplingMetadata] = None, + embedding_bias: Optional[torch.Tensor] = None, + skip_grammar: bool = False) -> Optional[torch.Tensor]: if self.logits_as_input: logits = hidden_states else: @@ -77,6 +97,9 @@ def forward( if self.scale != 1.0: logits *= self.scale + if skip_grammar: + return logits + # Apply logits processors (if any). if sampling_metadata is not None and \ sampling_metadata.seq_groups is not None: @@ -138,12 +161,15 @@ def _prune_hidden_states( return hidden_states -def _apply_logits_processors( - logits: torch.Tensor, - sampling_metadata: SamplingMetadata, -) -> torch.Tensor: +def _apply_logits_processors(logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + mask: torch.Tensor = None, + accept_last: bool = True) -> torch.Tensor: found_logits_processors = False logits_processed = 0 + + if mask is None: + mask = torch.ones(size=(logits.shape[0], ), dtype=torch.bool) logits_row_ids_and_logits_row_futures = [] for seq_group in sampling_metadata.seq_groups: seq_ids = seq_group.seq_ids @@ -154,6 +180,9 @@ def _apply_logits_processors( for seq_id, logits_row_idx in zip(seq_ids, seq_group.sample_indices): + if not mask[logits_row_idx].item(): + continue + logits_row = logits[logits_row_idx] past_tokens_ids = seq_group.seq_data[seq_id].output_token_ids prompt_tokens_ids = seq_group.seq_data[seq_id].prompt_token_ids @@ -164,12 +193,12 @@ def _apply_logits_processors( _logits_processor_threadpool.submit( _apply_logits_processors_single_seq, logits_row, logits_processors, past_tokens_ids, - prompt_tokens_ids))) + prompt_tokens_ids, accept_last))) else: logits[logits_row_idx] = \ _apply_logits_processors_single_seq( logits_row, logits_processors, past_tokens_ids, - prompt_tokens_ids) + prompt_tokens_ids, accept_last) logits_processed += len(seq_group.sample_indices) + len( seq_group.prompt_logprob_indices) @@ -184,13 +213,15 @@ def _apply_logits_processors( def _apply_logits_processors_single_seq(logits_row, logits_processors, - past_tokens_ids, - prompt_tokens_ids) -> torch.Tensor: + past_tokens_ids, prompt_tokens_ids, + accept_last) -> torch.Tensor: for logits_processor in logits_processors: parameters = inspect.signature(logits_processor).parameters if len(parameters) == 3: logits_row = logits_processor(prompt_tokens_ids, past_tokens_ids, logits_row) else: + if accept_last and len(past_tokens_ids) > 0: + logits_processor.accept(past_tokens_ids[-1]) logits_row = logits_processor(past_tokens_ids, logits_row) return logits_row diff --git a/vllm/model_executor/models/interfaces.py b/vllm/model_executor/models/interfaces.py index c77324bab59c..acaa0326e96c 100644 --- a/vllm/model_executor/models/interfaces.py +++ b/vllm/model_executor/models/interfaces.py @@ -12,6 +12,9 @@ QuantizationConfig) from vllm.utils import supports_kw +from .. import SamplingMetadata +from ..layers.logits_processor import _apply_logits_processors, accept_grammar +from ..layers.sampler import SamplerOutput from .interfaces_base import is_pooling_model if TYPE_CHECKING: @@ -221,6 +224,55 @@ def forward( ... +class SupportsSampleV2: + + def compute_logits_v2( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, + hidden_states, + sampling_metadata, + skip_grammar=True) + return logits + + def samplev2( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + # compute logits + next_tokens: SamplerOutput = self.sampler(logits, sampling_metadata) + + # check if the sampled tokens fit the grammars + tks = torch.tensor( + [o.samples[0].output_token for o in next_tokens.outputs]) + accepted = accept_grammar(tks, sampling_metadata) + need_resample = torch.logical_not(accepted) + if accepted.all(): + return next_tokens + # resample + # if the token is not valid, sample again. + # but first apply the grammar bitmask + # only apply logits processor when need_resample + logits = _apply_logits_processors(logits, sampling_metadata, + need_resample, False) + new_next_tokens: SamplerOutput = self.sampler(logits, + sampling_metadata) + + for i, replace in enumerate(need_resample.tolist()): + if replace: + next_tokens.outputs[i] = new_next_tokens.outputs[i] + + tks = torch.tensor( + [o.samples[0].output_token for o in next_tokens.outputs]) + # matcher only accept next token when first round is not accepted. + accepted = accept_grammar(tks, sampling_metadata, need_resample) + assert accepted.all() + return next_tokens + + # We can't use runtime_checkable with ClassVar for issubclass checks # so we need to treat the class as an instance and use isinstance instead @runtime_checkable diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 81b5d9bda9ac..4939f5e99f98 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -48,7 +48,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors -from .interfaces import SupportsLoRA, SupportsPP +from .interfaces import SupportsLoRA, SupportsPP, SupportsSampleV2 from .utils import (AutoWeightsLoader, PPMissingLayer, extract_layer_index, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, @@ -433,7 +433,7 @@ def load_weights(self, weights: Iterable[Tuple[str, return loaded_params -class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP): +class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsSampleV2): packed_modules_mapping = { "qkv_proj": ["q_proj", "k_proj", "v_proj"], "gate_up_proj": ["gate_proj", "up_proj"] diff --git a/vllm/model_executor/models/qwen.py b/vllm/model_executor/models/qwen.py index a33739a8eef9..ed9abd4f8dc6 100644 --- a/vllm/model_executor/models/qwen.py +++ b/vllm/model_executor/models/qwen.py @@ -31,7 +31,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors -from .interfaces import SupportsLoRA, SupportsPP +from .interfaces import SupportsLoRA, SupportsPP, SupportsSampleV2 from .utils import (is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) @@ -335,7 +335,8 @@ def load_weights(self, weights: Iterable[Tuple[str, return loaded_params -class QWenLMHeadModel(QWenBaseModel, SupportsPP, SupportsLoRA): +class QWenLMHeadModel(QWenBaseModel, SupportsPP, SupportsLoRA, + SupportsSampleV2): packed_modules_mapping = { "c_attn": ["c_attn"], "gate_up_proj": [ diff --git a/vllm/model_executor/models/qwen2.py b/vllm/model_executor/models/qwen2.py index c4d02e5ddeb1..88038761cb65 100644 --- a/vllm/model_executor/models/qwen2.py +++ b/vllm/model_executor/models/qwen2.py @@ -52,7 +52,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors, PoolerOutput -from .interfaces import SupportsLoRA, SupportsPP +from .interfaces import SupportsLoRA, SupportsPP, SupportsSampleV2 from .utils import (AutoWeightsLoader, PPMissingLayer, WeightsMapper, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, @@ -405,7 +405,7 @@ def load_weights(self, weights: Iterable[Tuple[str, return loaded_params -class Qwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): +class Qwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsSampleV2): packed_modules_mapping = { "qkv_proj": [ "q_proj", diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 473bd901b5b2..1c98b2d38fd4 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -38,6 +38,7 @@ from vllm.model_executor.model_loader import get_model from vllm.model_executor.model_loader.tensorizer import TensorizerConfig from vllm.model_executor.models import supports_lora, supports_multimodal +from vllm.model_executor.models.interfaces import SupportsSampleV2 from vllm.model_executor.models.utils import set_cpu_offload_max_bytes from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs, MultiModalKwargs, MultiModalPlaceholderMap, @@ -1785,8 +1786,12 @@ def execute_model( torch.tensor(model_forward_time + orig_model_forward_time)) return hidden_or_intermediate_states - logits = self.model.compute_logits(hidden_or_intermediate_states, - model_input.sampling_metadata) + if isinstance(self.model, SupportsSampleV2): + logits = self.model.compute_logits_v2( + hidden_or_intermediate_states, model_input.sampling_metadata) + else: + logits = self.model.compute_logits(hidden_or_intermediate_states, + model_input.sampling_metadata) if not self.is_driver_worker: return [] @@ -1795,10 +1800,13 @@ def execute_model( model_input.async_callback() # Sample the next token. - output: SamplerOutput = self.model.sample( - logits=logits, - sampling_metadata=model_input.sampling_metadata, - ) + if isinstance(self.model, SupportsSampleV2): + output: SamplerOutput = self.model.samplev2( + logits=logits, sampling_metadata=model_input.sampling_metadata) + else: + output = self.model.sample( + logits=logits, sampling_metadata=model_input.sampling_metadata) + if (self.observability_config is not None and self.observability_config.collect_model_forward_time and output is not None):