diff --git a/tests/v1/test_deferred_writer.py b/tests/v1/test_deferred_writer.py new file mode 100644 index 000000000000..16b65a08b7bb --- /dev/null +++ b/tests/v1/test_deferred_writer.py @@ -0,0 +1,230 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import pytest +import torch + +from vllm.v1.kv_cache.deferred import DeferredWriteManager, ShouldFallback +from vllm.v1.spec_decode.metadata import SpecDecodeMetadata +from vllm.v1.worker.gpu_model_runner import GPUModelRunner + + +def _make_metadata(draft_token_ids: list[int], per_request: list[int]) -> SpecDecodeMetadata: + total = len(draft_token_ids) + cu = torch.tensor(per_request, dtype=torch.int32) + cu = torch.cumsum(cu, dim=0) + return SpecDecodeMetadata( + draft_token_ids=torch.tensor(draft_token_ids, dtype=torch.int32), + num_draft_tokens=list(per_request), + cu_num_draft_tokens=cu, + target_logits_indices=torch.zeros(total, dtype=torch.int32), + bonus_logits_indices=torch.zeros(len(per_request), dtype=torch.int32), + logits_indices=torch.zeros(total + len(per_request), dtype=torch.int32), + ) + + +def test_deferred_manager_commit_partial_acceptance(): + manager = DeferredWriteManager() + assert manager.begin_window([2]) + + writes: list[tuple[torch.Tensor, torch.Tensor]] = [] + + def writer(key, value, key_cache, value_cache, slot_mapping, *_): + writes.append((key.clone(), slot_mapping.clone())) + + key = torch.arange(4, dtype=torch.float32).view(2, 1, 2) + value = torch.arange(4, dtype=torch.float32).view(2, 1, 2) + slot_mapping = torch.tensor([3, 7], dtype=torch.int32) + key_cache = torch.empty_like(key) + value_cache = torch.empty_like(value) + + manager.stage_layer( + layer_id="layer0", + key=key, + value=value, + key_cache=key_cache, + value_cache=value_cache, + slot_mapping=slot_mapping, + kv_cache_dtype="fp16", + k_scale=None, + v_scale=None, + writer=writer, + ) + + mask = torch.tensor([True, False]) + manager.commit(mask) + + assert len(writes) == 1 + committed_key, committed_slots = writes[0] + assert committed_key.shape[0] == 1 + assert committed_slots.tolist() == [3] + window_metrics = manager.pop_last_window_metrics() + assert window_metrics == { + "mode": "stage", + "committed": 1, + "rejected": 1, + "fallback": 0, + } + + +def test_deferred_manager_cancel_flush_writes_all(): + manager = DeferredWriteManager() + assert manager.begin_window([1, 1]) + + writes: list[tuple[str, torch.Tensor]] = [] + + def writer(key, value, *_args): # pragma: no cover - signature compatibility + writes.append(("commit", key.clone())) + + key = torch.randn(1, 1, 2) + value = torch.randn(1, 1, 2) + slot_mapping = torch.tensor([5], dtype=torch.int32) + key_cache = torch.empty_like(key) + value_cache = torch.empty_like(value) + + manager.stage_layer( + layer_id="layer0", + key=key, + value=value, + key_cache=key_cache, + value_cache=value_cache, + slot_mapping=slot_mapping, + kv_cache_dtype="fp16", + k_scale=None, + v_scale=None, + writer=writer, + ) + manager.stage_layer( + layer_id="layer1", + key=key, + value=value, + key_cache=key_cache, + value_cache=value_cache, + slot_mapping=slot_mapping, + kv_cache_dtype="fp16", + k_scale=None, + v_scale=None, + writer=writer, + ) + + manager.cancel_and_flush("test_cancel") + assert len(writes) == 2 + assert all(tensor.shape[0] == 1 for _tag, tensor in writes) + window_metrics = manager.pop_last_window_metrics() + assert window_metrics is not None + assert window_metrics.get("fallback") == 1 + + +def test_build_acceptance_mask_matches_expected(): + metadata = _make_metadata([10, 11, 20], [2, 1]) + sampled = torch.tensor( + [ + [10, 99, 0], # second token rejected + [20, 0, 0], + ], + dtype=torch.int32, + ) + + runner = GPUModelRunner.__new__(GPUModelRunner) + mask = runner._build_nwor_acceptance_mask(metadata, sampled) + expected = torch.tensor([True, False, True], dtype=torch.bool) + assert torch.equal(mask.cpu(), expected) + + +def test_nwor_disabled_env(monkeypatch): + monkeypatch.setenv("VLLM_DISABLE_NWOR", "1") + + runner = GPUModelRunner.__new__(GPUModelRunner) + runner.speculative_config = object() + runner._deferred_write_manager = DeferredWriteManager() + + metadata = _make_metadata([1, 2], [2]) + runner._maybe_begin_nwor_window(metadata) + + assert not runner._deferred_write_manager.window_active + + +def test_fp8_staging_slices_quant_scales(): + manager = DeferredWriteManager() + assert manager.begin_window([2]) + + recorded: list[tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]] = [] + + def writer(key, value, key_cache, value_cache, slot_mapping, kv_cache_dtype, k_scale, v_scale): + recorded.append((key.clone(), value.clone(), slot_mapping.clone(), k_scale.clone() if k_scale is not None else None)) + + key = torch.arange(4, dtype=torch.float32).view(2, 1, 2) + value = torch.arange(4, dtype=torch.float32).view(2, 1, 2) + slot_mapping = torch.tensor([3, 7], dtype=torch.int32) + key_cache = torch.empty_like(key, dtype=torch.uint8) + value_cache = torch.empty_like(value, dtype=torch.uint8) + k_scale = torch.tensor([0.5, 0.7], dtype=torch.float32) + v_scale = torch.tensor([0.6, 0.9], dtype=torch.float32) + + manager.stage_layer( + layer_id="layer0", + key=key, + value=value, + key_cache=key_cache, + value_cache=value_cache, + slot_mapping=slot_mapping, + kv_cache_dtype="fp8", + k_scale=k_scale, + v_scale=v_scale, + writer=writer, + ) + + manager.commit(torch.tensor([True, False])) + + assert len(recorded) == 1 + committed_key, committed_value, slots, committed_k_scale = recorded[0] + assert committed_key.shape[0] == 1 + assert torch.equal(slots, torch.tensor([3], dtype=torch.int32)) + assert committed_k_scale is None or committed_k_scale.shape[0] == 1 + window_metrics = manager.pop_last_window_metrics() + assert window_metrics == { + "mode": "stage", + "committed": 1, + "rejected": 1, + "fallback": 0, + } + + +def test_nwor_immediate_mode_skips_window(): + manager = DeferredWriteManager(mode="immediate") + assert not manager.begin_window([2]) + assert manager.get_mode() == "immediate" + + +def test_commit_failure_triggers_fallback_metrics(): + manager = DeferredWriteManager() + assert manager.begin_window([1]) + + key = torch.randn(1, 1, 2) + value = torch.randn(1, 1, 2) + slot_mapping = torch.tensor([0], dtype=torch.int32) + key_cache = torch.empty_like(key) + value_cache = torch.empty_like(value) + + def writer(*_args, **_kwargs): + raise RuntimeError("forced failure") + + manager.stage_layer( + layer_id="layer0", + key=key, + value=value, + key_cache=key_cache, + value_cache=value_cache, + slot_mapping=slot_mapping, + kv_cache_dtype="fp16", + k_scale=None, + v_scale=None, + writer=writer, + ) + + with pytest.raises(ShouldFallback): + manager.commit(torch.tensor([True])) + + window_metrics = manager.pop_last_window_metrics() + assert window_metrics is not None + assert window_metrics.get("fallback") == 1 diff --git a/vllm/envs.py b/vllm/envs.py index c3686477d88d..f876a0765496 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -198,6 +198,8 @@ VLLM_ALLREDUCE_USE_SYMM_MEM: bool = True VLLM_TUNED_CONFIG_FOLDER: str | None = None VLLM_DISABLE_PAD_FOR_CUDAGRAPH: bool = False + VLLM_DISABLE_NWOR: bool = False + VLLM_NWOR_MODE: str = "stage" VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS: bool = False VLLM_CUSTOM_SCOPES_FOR_PROFILING: bool = False VLLM_NVTX_SCOPES_FOR_PROFILING: bool = False @@ -1309,6 +1311,10 @@ def get_vllm_port() -> int | None: "VLLM_DISABLE_PAD_FOR_CUDAGRAPH": lambda: bool( int(os.getenv("VLLM_DISABLE_PAD_FOR_CUDAGRAPH", "0")) ), + # Disable No-Write-On-Reject staging for speculative decoding if set to 1. + "VLLM_DISABLE_NWOR": lambda: bool(int(os.getenv("VLLM_DISABLE_NWOR", "0"))), + # Select NWOR mode: "stage" (default) or "immediate" to bypass staging. + "VLLM_NWOR_MODE": lambda: os.getenv("VLLM_NWOR_MODE", "stage"), # Used to force set up loopback IP "VLLM_LOOPBACK_IP": lambda: os.getenv("VLLM_LOOPBACK_IP", ""), # Used to set the process name prefix for vLLM processes. diff --git a/vllm/model_executor/models/llama_eagle.py b/vllm/model_executor/models/llama_eagle.py index 3617294bd621..e71390a669d9 100644 --- a/vllm/model_executor/models/llama_eagle.py +++ b/vllm/model_executor/models/llama_eagle.py @@ -121,6 +121,9 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: if weight_name not in name: continue name = name.replace(weight_name, param_name) + if name not in params_dict: + logger.debug("Skipping unmatched weight %s", name) + break param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) @@ -130,6 +133,9 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: if get_pp_group().world_size == 1 and "embed_tokens." in name: continue + if name not in params_dict: + logger.debug("Skipping unmatched weight %s", name) + continue param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index fb5ff499de2c..03c7c46b80f2 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -34,6 +34,7 @@ from vllm.config import VllmConfig, get_layers_from_vllm_config from vllm.logger import init_logger from vllm.utils import cdiv +from vllm.v1.kv_cache import record_or_write_kv_cache from vllm.v1.attention.backends.utils import ( AttentionCGSupport, AttentionMetadataBuilder, @@ -533,15 +534,22 @@ def forward( # and value[:num_actual_tokens] because the reshape_and_cache_flash # op uses the slot_mapping's shape to determine the number of # actual tokens. - reshape_and_cache_flash( - key, - value, - key_cache, - value_cache, - attn_metadata.slot_mapping, - self.kv_cache_dtype, - layer._k_scale, - layer._v_scale, + layer_id = getattr( + layer, + "layer_name", + getattr(layer, "layer_id", layer.__class__.__name__), + ) + record_or_write_kv_cache( + writer=reshape_and_cache_flash, + layer_id=layer_id, + key=key, + value=value, + key_cache=key_cache, + value_cache=value_cache, + slot_mapping=attn_metadata.slot_mapping, + kv_cache_dtype=self.kv_cache_dtype, + k_scale=layer._k_scale, + v_scale=layer._v_scale, ) if self.kv_cache_dtype.startswith("fp8"): diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index 0fa71afa62ee..848615862238 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -50,6 +50,7 @@ infer_global_hyperparameters, split_decodes_and_prefills, ) +from vllm.v1.kv_cache import record_or_write_kv_cache from vllm.v1.kv_cache_interface import AttentionSpec FLASHINFER_WORKSPACE_BUFFER_SIZE = 256 * 1024 * 1024 @@ -922,20 +923,25 @@ def forward( if self.kv_sharing_target_layer_name is None: # Reshape the input keys and values and store them in the cache. # Skip this if sharing KV cache with an earlier attention layer. - # NOTE(woosuk): Here, key and value are padded while slot_mapping is - # not padded. However, we don't need to do key[:num_actual_tokens] - # and value[:num_actual_tokens] because the reshape_and_cache_flash - # op uses the slot_mapping's shape to determine the number of - # actual tokens. - torch.ops._C_cache_ops.reshape_and_cache_flash( - key, - value, - kv_cache[:, 0], - kv_cache[:, 1], - attn_metadata.slot_mapping, - self.kv_cache_dtype, - layer._k_scale, - layer._v_scale, + # NOTE(woosuk): key/value are padded while slot_mapping is not. + key_cache = kv_cache[:, 0] + value_cache = kv_cache[:, 1] + layer_id = getattr( + layer, + "layer_name", + getattr(layer, "layer_id", layer.__class__.__name__), + ) + record_or_write_kv_cache( + writer=torch.ops._C_cache_ops.reshape_and_cache_flash, + layer_id=layer_id, + key=key, + value=value, + key_cache=key_cache, + value_cache=value_cache, + slot_mapping=attn_metadata.slot_mapping, + kv_cache_dtype=self.kv_cache_dtype, + k_scale=layer._k_scale, + v_scale=layer._v_scale, ) # The FlashInfer api requires data to be in fp8_e4m3 or fp8_e5m2 diff --git a/vllm/v1/attention/backends/flex_attention.py b/vllm/v1/attention/backends/flex_attention.py index 2595851e5042..2b53a70411aa 100644 --- a/vllm/v1/attention/backends/flex_attention.py +++ b/vllm/v1/attention/backends/flex_attention.py @@ -33,6 +33,7 @@ AttentionMetadataBuilder, CommonAttentionMetadata, ) +from vllm.v1.kv_cache import record_or_write_kv_cache from vllm.v1.kv_cache_interface import AttentionSpec logger = init_logger(__name__) @@ -810,15 +811,22 @@ def forward( assert self.attn_type == AttentionType.DECODER key_cache, value_cache = kv_cache.unbind(0) - torch.ops._C_cache_ops.reshape_and_cache_flash( - key, - value, - key_cache, - value_cache, - attn_metadata.slot_mapping, - self.kv_cache_dtype, - layer._k_scale, - layer._v_scale, + layer_id = getattr( + layer, + "layer_name", + getattr(layer, "layer_id", layer.__class__.__name__), + ) + record_or_write_kv_cache( + writer=torch.ops._C_cache_ops.reshape_and_cache_flash, + layer_id=layer_id, + key=key, + value=value, + key_cache=key_cache, + value_cache=value_cache, + slot_mapping=attn_metadata.slot_mapping, + kv_cache_dtype=self.kv_cache_dtype, + k_scale=layer._k_scale, + v_scale=layer._v_scale, ) # View out the block_size dim diff --git a/vllm/v1/attention/backends/rocm_aiter_fa.py b/vllm/v1/attention/backends/rocm_aiter_fa.py index cce43b220da7..bde183c2e269 100644 --- a/vllm/v1/attention/backends/rocm_aiter_fa.py +++ b/vllm/v1/attention/backends/rocm_aiter_fa.py @@ -21,6 +21,7 @@ AttentionMetadataBuilder, CommonAttentionMetadata, ) +from vllm.v1.kv_cache import record_or_write_kv_cache from vllm.v1.kv_cache_interface import AttentionSpec _PARTITION_SIZE_ROCM = 256 @@ -506,15 +507,22 @@ def forward( # and value[:num_actual_tokens] because the reshape_and_cache_flash # op uses the slot_mapping's shape to determine the number of # actual tokens. - torch.ops._C_cache_ops.reshape_and_cache_flash( - key, - value, - key_cache, - value_cache, - attn_metadata.slot_mapping, - self.kv_cache_dtype, - layer._k_scale, - layer._v_scale, + layer_id = getattr( + layer, + "layer_name", + getattr(layer, "layer_id", layer.__class__.__name__), + ) + record_or_write_kv_cache( + writer=torch.ops._C_cache_ops.reshape_and_cache_flash, + layer_id=layer_id, + key=key, + value=value, + key_cache=key_cache, + value_cache=value_cache, + slot_mapping=attn_metadata.slot_mapping, + kv_cache_dtype=self.kv_cache_dtype, + k_scale=layer._k_scale, + v_scale=layer._v_scale, ) if self.kv_cache_dtype.startswith("fp8"): diff --git a/vllm/v1/attention/backends/rocm_aiter_unified_attn.py b/vllm/v1/attention/backends/rocm_aiter_unified_attn.py index 14184944934f..5871e932e9ac 100644 --- a/vllm/v1/attention/backends/rocm_aiter_unified_attn.py +++ b/vllm/v1/attention/backends/rocm_aiter_unified_attn.py @@ -18,6 +18,7 @@ RocmAttentionMetadata, RocmAttentionMetadataBuilder, ) +from vllm.v1.kv_cache import record_or_write_kv_cache logger = init_logger(__name__) @@ -150,15 +151,22 @@ def forward( if self.kv_sharing_target_layer_name is None: # Reshape the input keys and values and store them in the cache. # Skip this if sharing KV cache with an earlier attention layer. - ops.reshape_and_cache_flash( - key, - value, - key_cache, - value_cache, - attn_metadata.slot_mapping, - self.kv_cache_dtype, - layer._k_scale, - layer._v_scale, + layer_id = getattr( + layer, + "layer_name", + getattr(layer, "layer_id", layer.__class__.__name__), + ) + record_or_write_kv_cache( + writer=ops.reshape_and_cache_flash, + layer_id=layer_id, + key=key, + value=value, + key_cache=key_cache, + value_cache=value_cache, + slot_mapping=attn_metadata.slot_mapping, + kv_cache_dtype=self.kv_cache_dtype, + k_scale=layer._k_scale, + v_scale=layer._v_scale, ) if self.kv_cache_dtype.startswith("fp8"): diff --git a/vllm/v1/attention/backends/tree_attn.py b/vllm/v1/attention/backends/tree_attn.py index aab90cfd1fe0..19f1ff31e1c4 100644 --- a/vllm/v1/attention/backends/tree_attn.py +++ b/vllm/v1/attention/backends/tree_attn.py @@ -24,6 +24,7 @@ CommonAttentionMetadata, split_decodes_and_prefills, ) +from vllm.v1.kv_cache import record_or_write_kv_cache from vllm.v1.kv_cache_interface import AttentionSpec logger = init_logger(__name__) @@ -391,15 +392,22 @@ def forward( # and value[:num_actual_tokens] because the reshape_and_cache_flash # op uses the slot_mapping's shape to determine the number of # actual tokens. - ops.reshape_and_cache_flash( - key, - value, - key_cache, - value_cache, - attn_metadata.slot_mapping, - self.kv_cache_dtype, - layer._k_scale, - layer._v_scale, + layer_id = getattr( + layer, + "layer_name", + getattr(layer, "layer_id", layer.__class__.__name__), + ) + record_or_write_kv_cache( + writer=ops.reshape_and_cache_flash, + layer_id=layer_id, + key=key, + value=value, + key_cache=key_cache, + value_cache=value_cache, + slot_mapping=attn_metadata.slot_mapping, + kv_cache_dtype=self.kv_cache_dtype, + k_scale=layer._k_scale, + v_scale=layer._v_scale, ) num_actual_tokens = attn_metadata.num_actual_tokens diff --git a/vllm/v1/attention/backends/triton_attn.py b/vllm/v1/attention/backends/triton_attn.py index 9d1d007a08e4..66e8cea95fa6 100644 --- a/vllm/v1/attention/backends/triton_attn.py +++ b/vllm/v1/attention/backends/triton_attn.py @@ -30,6 +30,7 @@ AttentionMetadataBuilder, CommonAttentionMetadata, ) +from vllm.v1.kv_cache import record_or_write_kv_cache from vllm.v1.kv_cache_interface import AttentionSpec if current_platform.is_cuda_alike(): @@ -323,15 +324,22 @@ def forward( # triton kernel does not support uint8 kv_cache # (because some explicit casts (e.g. float8_e4m3fnuz) # are not supported) - triton_reshape_and_cache_flash( - key, - value, - key_cache, - value_cache, - attn_metadata.slot_mapping, - self.kv_cache_dtype, - layer._k_scale, - layer._v_scale, + layer_id = getattr( + layer, + "layer_name", + getattr(layer, "layer_id", layer.__class__.__name__), + ) + record_or_write_kv_cache( + writer=triton_reshape_and_cache_flash, + layer_id=layer_id, + key=key, + value=value, + key_cache=key_cache, + value_cache=value_cache, + slot_mapping=attn_metadata.slot_mapping, + kv_cache_dtype=self.kv_cache_dtype, + k_scale=layer._k_scale, + v_scale=layer._v_scale, ) if self.kv_cache_dtype.startswith("fp8"): diff --git a/vllm/v1/attention/backends/xformers.py b/vllm/v1/attention/backends/xformers.py index 41c543c18adc..deb975a35729 100644 --- a/vllm/v1/attention/backends/xformers.py +++ b/vllm/v1/attention/backends/xformers.py @@ -23,6 +23,7 @@ split_decodes_and_prefills, ) from vllm.v1.kv_cache_interface import AttentionSpec +from vllm.v1.kv_cache import record_or_write_kv_cache try: from xformers import ops as xops @@ -366,15 +367,22 @@ def forward( # and value[:num_actual_tokens] because the reshape_and_cache_flash # op uses the slot_mapping's shape to determine the number of # actual tokens. - ops.reshape_and_cache_flash( - key, - value, - key_cache, - value_cache, - attn_metadata.slot_mapping, - self.kv_cache_dtype, - layer._k_scale, - layer._v_scale, + layer_id = getattr( + layer, + "layer_name", + getattr(layer, "layer_id", layer.__class__.__name__), + ) + record_or_write_kv_cache( + writer=ops.reshape_and_cache_flash, + layer_id=layer_id, + key=key, + value=value, + key_cache=key_cache, + value_cache=value_cache, + slot_mapping=attn_metadata.slot_mapping, + kv_cache_dtype=self.kv_cache_dtype, + k_scale=layer._k_scale, + v_scale=layer._v_scale, ) num_actual_tokens = attn_metadata.num_actual_tokens diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index cbbdf48c6e0c..20fe6875e8a9 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -918,6 +918,7 @@ def update_from_output( pooler_outputs = model_runner_output.pooler_output num_nans_in_logits = model_runner_output.num_nans_in_logits kv_connector_output = model_runner_output.kv_connector_output + nwor_stats = model_runner_output.nwor_metrics outputs: dict[int, list[EngineCoreOutput]] = defaultdict(list) spec_decoding_stats: SpecDecodingStats | None = None @@ -1077,7 +1078,7 @@ def update_from_output( finished_req_ids.clear() if ( - stats := self.make_stats(spec_decoding_stats, kv_connector_stats) + stats := self.make_stats(spec_decoding_stats, kv_connector_stats, nwor_stats) ) is not None: # Return stats to only one of the front-ends. if (eco := next(iter(engine_core_outputs.values()), None)) is None: @@ -1244,6 +1245,7 @@ def make_stats( self, spec_decoding_stats: SpecDecodingStats | None = None, kv_connector_stats: KVConnectorStats | None = None, + nwor_stats: dict[str, Any] | None = None, ) -> SchedulerStats | None: if not self.log_stats: return None @@ -1257,6 +1259,7 @@ def make_stats( spec_decoding_stats=spec_decoding_stats, num_corrupted_reqs=sum(req.is_output_corrupted for req in self.running), kv_connector_stats=kv_connector_stats.data if kv_connector_stats else None, + nwor_stats=nwor_stats, ) def make_spec_decoding_stats( diff --git a/vllm/v1/kv_cache/__init__.py b/vllm/v1/kv_cache/__init__.py new file mode 100644 index 000000000000..5fda6cb3af13 --- /dev/null +++ b/vllm/v1/kv_cache/__init__.py @@ -0,0 +1,20 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +"""Deferred KV cache staging utilities for NWOR (No-Write-On-Reject).""" + +from .deferred import ( # noqa: F401 + DeferredWriteManager, + get_global_deferred_manager, + record_or_write_kv_cache, + ShouldFallback, + set_global_deferred_manager, +) + +__all__ = [ + "DeferredWriteManager", + "get_global_deferred_manager", + "record_or_write_kv_cache", + "ShouldFallback", + "set_global_deferred_manager", +] diff --git a/vllm/v1/kv_cache/deferred.py b/vllm/v1/kv_cache/deferred.py new file mode 100644 index 000000000000..8d91a9e4fed1 --- /dev/null +++ b/vllm/v1/kv_cache/deferred.py @@ -0,0 +1,460 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +"""Deferred KV cache staging for No-Write-On-Reject (NWOR).""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Callable, Optional, Sequence + +import torch +from torch import Tensor + +from vllm.logger import init_logger + +logger = init_logger(__name__) + +try: # pragma: no cover - optional import + from torch._subclasses.fake_tensor import FakeTensor # type: ignore +except Exception: # pragma: no cover - fallback for older PyTorch + FakeTensor = () + +try: # pragma: no cover - optional import + from torch._C import _is_fake_tensor as torch_is_fake_tensor # type: ignore +except Exception: # pragma: no cover - fallback when helper missing + torch_is_fake_tensor = None + + +class ShouldFallback(RuntimeError): + """Raised when the deferred writer must abandon staging.""" + + +@dataclass +class _LayerEntry: + layer_id: str + start: int + length: int + key_source: Tensor + value_source: Tensor + slot_mapping: Tensor + key_cache: Tensor + value_cache: Tensor + kv_cache_dtype: str + k_scale: Optional[Tensor] + v_scale: Optional[Tensor] + writer: Callable[[Tensor, Tensor, Tensor, Tensor, Tensor, str, Optional[Tensor], Optional[Tensor]], None] + + +_global_manager: Optional["DeferredWriteManager"] = None + + +def get_global_deferred_manager() -> Optional["DeferredWriteManager"]: + return _global_manager + + +def set_global_deferred_manager(manager: Optional["DeferredWriteManager"]) -> None: + global _global_manager + _global_manager = manager + + +def _is_fake_tensor(tensor: Tensor) -> bool: + if isinstance(tensor, FakeTensor) or tensor.__class__.__name__ == "FakeTensor": + return True + if torch_is_fake_tensor is not None: + try: + return bool(torch_is_fake_tensor(tensor)) + except TypeError: # pragma: no cover - defensive + return False + return False + + +def _tensor_has_storage(tensor: Tensor) -> bool: + if not isinstance(tensor, Tensor): + return False + if tensor.is_meta: + return False + if _is_fake_tensor(tensor): + return False + try: + tensor.data_ptr() + except (RuntimeError, ValueError): + return False + return True + + +def _in_restricted_context() -> bool: + try: # pragma: no cover - torch.compile path + import torch._dynamo as dynamo # type: ignore + + if dynamo.is_compiling(): + return True + except (ImportError, AttributeError): # pragma: no cover - optional + pass + + if not torch.cuda.is_available(): + return False + try: + return torch.cuda.is_current_stream_capturing() + except (RuntimeError, AttributeError): # pragma: no cover - defensive + return False + + +def _ensure_int32_slots(slot_mapping: Tensor, device: torch.device) -> Tensor: + if slot_mapping.dtype != torch.int32 or slot_mapping.device != device: + slot_mapping = slot_mapping.to(device=device, dtype=torch.int32, copy=False) + if not slot_mapping.is_contiguous(): + slot_mapping = slot_mapping.contiguous() + return slot_mapping + + +def _slice_scale(scale: Optional[Tensor], indices: Tensor) -> Optional[Tensor]: + if scale is None: + return None + if scale.ndim == 0: + return scale + if scale.shape[0] == 0: + return scale + first_dim = scale.shape[0] + target = int(indices.numel()) + if first_dim == target: + return torch.index_select(scale, 0, indices) + # Some implementations append an extra sentinel slot; ignore it. + if first_dim == target + 1: + return torch.index_select(scale[:-1], 0, indices) + # Default: return the original scale (per-layer scale etc.). + return scale + + +class DeferredWriteManager: + """Stages KV writes until acceptance is known.""" + + SUPPORTED_MODES = {"stage", "immediate"} + + def __init__(self, *, mode: str = "stage") -> None: + self._window_active = False + self._num_draft_tokens: list[int] = [] + self._expected_tokens = 0 + self._staged_tokens = 0 + self._entries: list[_LayerEntry] = [] + self._fallback_reason: Optional[str] = None + self._metrics = { + "windows": 0, + "tokens_staged": 0, + "tokens_committed": 0, + "tokens_rejected": 0, + "tokens_fallback": 0, + "fallbacks": 0, + } + self._mode = self._validate_mode(mode) + self._last_window_metrics: dict[str, int | str] | None = None + + # ---------------------------------------------------------------------- + # Lifecycle + # ---------------------------------------------------------------------- + @property + def window_active(self) -> bool: + return self._window_active + + @property + def fallback_reason(self) -> Optional[str]: + return self._fallback_reason + + def begin_window(self, num_draft_tokens: Sequence[int]) -> bool: + """Arm the manager for a new speculative decode window.""" + + if self._mode != "stage": + return False + + self._clear_window() + + total_tokens = sum(int(n) for n in num_draft_tokens) + if total_tokens <= 0: + return False + + if _in_restricted_context(): + self._record_fallback("cuda_graph_capture") + return False + + self._window_active = True + self._num_draft_tokens = [int(n) for n in num_draft_tokens] + self._expected_tokens = total_tokens + self._staged_tokens = 0 + self._entries.clear() + self._fallback_reason = None + self._last_window_metrics = None + self._metrics["windows"] += 1 + self._metrics["tokens_staged"] += total_tokens + return True + + def set_mode(self, mode: str) -> None: + self._mode = self._validate_mode(mode) + + def get_mode(self) -> str: + return self._mode + + def finish_step(self) -> None: + """Flush any pending data if the window did not complete.""" + + if self._window_active: + self.cancel_and_flush("incomplete_window") + + def get_metrics(self) -> dict[str, int | str]: + metrics = dict(self._metrics) + metrics["mode"] = self._mode + return metrics + + # ------------------------------------------------------------------ + # Staging + # ------------------------------------------------------------------ + def stage_layer( + self, + *, + layer_id: str, + key: Tensor, + value: Tensor, + key_cache: Tensor, + value_cache: Tensor, + slot_mapping: Tensor, + kv_cache_dtype: str, + k_scale: Optional[Tensor], + v_scale: Optional[Tensor], + writer: Callable[[Tensor, Tensor, Tensor, Tensor, Tensor, str, Optional[Tensor], Optional[Tensor]], None], + ) -> bool: + if not self._window_active: + return False + + if not (_tensor_has_storage(key) and _tensor_has_storage(value)): + raise ShouldFallback("kv_slice_without_storage") + + if not (_tensor_has_storage(key_cache) and _tensor_has_storage(value_cache)): + raise ShouldFallback("kv_cache_not_materialized") + + slot_mapping = _ensure_int32_slots(slot_mapping, key.device) + + length = int(slot_mapping.shape[0]) + if length == 0: + return True + + if self._staged_tokens + length > self._expected_tokens: + raise ShouldFallback("staged_tokens_exceed_expected") + + entry = _LayerEntry( + layer_id=layer_id, + start=self._staged_tokens, + length=length, + key_source=key, + value_source=value, + slot_mapping=slot_mapping, + key_cache=key_cache, + value_cache=value_cache, + kv_cache_dtype=kv_cache_dtype, + k_scale=k_scale, + v_scale=v_scale, + writer=writer, + ) + self._entries.append(entry) + self._staged_tokens += length + return True + + # ------------------------------------------------------------------ + # Commit / Fallback + # ------------------------------------------------------------------ + def commit(self, accepted_mask: Tensor) -> None: + if not self._window_active: + return + + if accepted_mask.numel() != self._expected_tokens: + raise ShouldFallback("accepted_mask_mismatch") + + if accepted_mask.dtype != torch.bool: + accepted_mask = accepted_mask.to(dtype=torch.bool) + + committed_total = 0 + start = 0 + for entry in self._entries: + end = start + entry.length + layer_mask = accepted_mask[start:end] + if layer_mask.device != entry.key_source.device: + layer_mask = layer_mask.to(device=entry.key_source.device) + start = end + + if layer_mask.numel() != entry.length: + raise ShouldFallback("layer_mask_length_mismatch") + + if not layer_mask.any(): + continue + + indices = torch.nonzero(layer_mask, as_tuple=False).squeeze(1) + committed_total += int(indices.numel()) + + key_slice = torch.index_select(entry.key_source, 0, indices).contiguous() + value_slice = torch.index_select(entry.value_source, 0, indices).contiguous() + slot_slice = torch.index_select(entry.slot_mapping, 0, indices) + slot_slice = _ensure_int32_slots(slot_slice, entry.slot_mapping.device) + + k_scale_slice = _slice_scale(entry.k_scale, indices) + v_scale_slice = _slice_scale(entry.v_scale, indices) + + try: + entry.writer( + key_slice, + value_slice, + entry.key_cache, + entry.value_cache, + slot_slice, + entry.kv_cache_dtype, + k_scale_slice, + v_scale_slice, + ) + except Exception as exc: # pragma: no cover - propagate for upstream handling + reason = f"commit_failed:{entry.layer_id}" + self._record_fallback(reason) + self._flush_entries() + self._last_window_metrics = { + "mode": self._mode, + "committed": 0, + "rejected": self._expected_tokens, + "fallback": 1, + "reason": reason, + } + self._clear_window() + raise ShouldFallback(reason) from exc + + rejected = max(self._expected_tokens - committed_total, 0) + self._metrics["tokens_committed"] += committed_total + self._metrics["tokens_rejected"] += rejected + self._last_window_metrics = { + "mode": self._mode, + "committed": committed_total, + "rejected": rejected, + "fallback": 0, + } + self._clear_window() + + def cancel_and_flush(self, reason: str) -> None: + if not self._window_active: + return + self._record_fallback(reason) + self._flush_entries() + self._last_window_metrics = { + "mode": self._mode, + "committed": 0, + "rejected": 0, + "fallback": 1, + "reason": reason, + } + self._clear_window() + + # ------------------------------------------------------------------ + # Internal helpers + # ------------------------------------------------------------------ + def _flush_entries(self) -> None: + for entry in self._entries: + try: + entry.writer( + entry.key_source, + entry.value_source, + entry.key_cache, + entry.value_cache, + entry.slot_mapping, + entry.kv_cache_dtype, + entry.k_scale, + entry.v_scale, + ) + except Exception: # pragma: no cover - log and continue + logger.exception("NWOR fallback failed for layer %s", entry.layer_id) + if self._entries: + flushed_tokens = sum(e.length for e in self._entries) + self._metrics["tokens_fallback"] += flushed_tokens + + def _record_fallback(self, reason: str) -> None: + self._fallback_reason = reason + self._metrics["fallbacks"] += 1 + + def _clear_window(self) -> None: + self._window_active = False + self._num_draft_tokens.clear() + self._expected_tokens = 0 + self._staged_tokens = 0 + self._entries.clear() + + def _validate_mode(self, mode: str) -> str: + normalized = mode.lower() + if normalized not in self.SUPPORTED_MODES: + logger.warning("NWOR: unsupported mode '%s', defaulting to 'stage'", mode) + return "stage" + return normalized + + def pop_last_window_metrics(self) -> dict[str, int | str] | None: + metrics = self._last_window_metrics + self._last_window_metrics = None + return metrics + + +def record_or_write_kv_cache( + *, + writer: Callable[[Tensor, Tensor, Tensor, Tensor, Tensor, str, Optional[Tensor], Optional[Tensor]], None], + layer_id: str, + key: Tensor, + value: Tensor, + key_cache: Tensor, + value_cache: Tensor, + slot_mapping: Tensor, + kv_cache_dtype: str, + k_scale: Optional[Tensor], + v_scale: Optional[Tensor], +) -> None: + manager = get_global_deferred_manager() + if manager is None: + writer( + key, + value, + key_cache, + value_cache, + slot_mapping, + kv_cache_dtype, + k_scale, + v_scale, + ) + return + + try: + staged = manager.stage_layer( + layer_id=layer_id, + key=key, + value=value, + key_cache=key_cache, + value_cache=value_cache, + slot_mapping=slot_mapping, + kv_cache_dtype=kv_cache_dtype, + k_scale=k_scale, + v_scale=v_scale, + writer=writer, + ) + except ShouldFallback as exc: + manager.cancel_and_flush(str(exc)) + set_global_deferred_manager(None) + writer( + key, + value, + key_cache, + value_cache, + slot_mapping, + kv_cache_dtype, + k_scale, + v_scale, + ) + return + + if not staged: + writer( + key, + value, + key_cache, + value_cache, + slot_mapping, + kv_cache_dtype, + k_scale, + v_scale, + ) diff --git a/vllm/v1/metrics/loggers.py b/vllm/v1/metrics/loggers.py index 8c5abae2ae65..796216717b9e 100644 --- a/vllm/v1/metrics/loggers.py +++ b/vllm/v1/metrics/loggers.py @@ -165,6 +165,24 @@ def log(self): *log_args, ) + if scheduler_stats.nwor_stats: + nwor = scheduler_stats.nwor_stats + mode = nwor.get("mode", "stage") + committed = nwor.get("committed", 0) + rejected = nwor.get("rejected", 0) + fallback = nwor.get("fallback", 0) + reason = nwor.get("reason") + extra = f", reason={reason}" if reason else "" + log_fn( + "Engine %03d: NWOR mode=%s committed=%s rejected=%s fallback=%s%s", + self.engine_index, + mode, + committed, + rejected, + fallback, + extra, + ) + self.spec_decoding_logging.log(log_fn=log_fn) self.kv_connector_logging.log(log_fn=log_fn) @@ -339,6 +357,44 @@ def __init__( counter_mm_cache_hits, engine_indexes, model_name ) + self.counter_nwor_committed = make_per_engine( + self._counter_cls( + name="vllm:nwor_committed_tokens", + documentation="Number of tokens committed via NWOR in this engine.", + labelnames=labelnames, + ), + engine_indexes, + model_name, + ) + self.counter_nwor_rejected = make_per_engine( + self._counter_cls( + name="vllm:nwor_rejected_tokens", + documentation="Number of draft tokens rejected by NWOR.", + labelnames=labelnames, + ), + engine_indexes, + model_name, + ) + self.counter_nwor_fallbacks = make_per_engine( + self._counter_cls( + name="vllm:nwor_fallbacks", + documentation="Number of NWOR fallbacks triggered.", + labelnames=labelnames, + ), + engine_indexes, + model_name, + ) + self.gauge_nwor_enabled = make_per_engine( + self._gauge_cls( + name="vllm:nwor_enabled", + documentation="Whether NWOR is active for this engine (1=yes, 0=no).", + multiprocess_mode="mostrecent", + labelnames=labelnames, + ), + engine_indexes, + model_name, + ) + # # Counters # @@ -744,6 +800,17 @@ def record( scheduler_stats.spec_decoding_stats, engine_idx ) + if scheduler_stats.nwor_stats is not None: + nwor = scheduler_stats.nwor_stats + committed = int(nwor.get("committed", 0)) + rejected = int(nwor.get("rejected", 0)) + fallback = int(nwor.get("fallback", 0)) + mode = nwor.get("mode", "stage") + self.counter_nwor_committed[engine_idx].inc(committed) + self.counter_nwor_rejected[engine_idx].inc(rejected) + self.counter_nwor_fallbacks[engine_idx].inc(fallback) + self.gauge_nwor_enabled[engine_idx].set(1 if mode == "stage" else 0) + if mm_cache_stats is not None: self.counter_mm_cache_queries[engine_idx].inc(mm_cache_stats.queries) self.counter_mm_cache_hits[engine_idx].inc(mm_cache_stats.hits) diff --git a/vllm/v1/metrics/stats.py b/vllm/v1/metrics/stats.py index a4a8ab32ad72..ab8cd06a00ec 100644 --- a/vllm/v1/metrics/stats.py +++ b/vllm/v1/metrics/stats.py @@ -154,6 +154,7 @@ class SchedulerStats: spec_decoding_stats: SpecDecodingStats | None = None kv_connector_stats: dict[str, Any] | None = None + nwor_stats: dict[str, Any] | None = None num_corrupted_reqs: int = 0 diff --git a/vllm/v1/outputs.py b/vllm/v1/outputs.py index c224555da6ca..176b742df607 100644 --- a/vllm/v1/outputs.py +++ b/vllm/v1/outputs.py @@ -132,6 +132,8 @@ class ModelRunnerOutput: # req_id -> num_nans_in_logits num_nans_in_logits: dict[str, int] | None = None + nwor_metrics: dict[str, int | str] | None = None + # ModelRunnerOutput wrapper for async scheduling. class AsyncModelRunnerOutput(ABC): @@ -162,4 +164,5 @@ class DraftTokenIds: prompt_logprobs_dict={}, pooler_output=[], num_nans_in_logits=None, + nwor_metrics=None, ) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 5c2893bd0926..0909d0f8dd0a 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -96,6 +96,11 @@ split_attn_metadata, ) from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher +from vllm.v1.kv_cache import ( + DeferredWriteManager, + ShouldFallback, + set_global_deferred_manager, +) from vllm.v1.kv_cache_interface import ( AttentionSpec, ChunkedLocalAttentionSpec, @@ -502,6 +507,8 @@ def __init__( self.runner_only_attn_layers: set[str] = set() # Cached outputs. + self._deferred_write_manager = DeferredWriteManager(mode=envs.VLLM_NWOR_MODE) + self._latest_nwor_window_metrics: dict[str, int | str] | None = None self._draft_token_ids: list[list[int]] | torch.Tensor | None = None self.transfer_event = torch.cuda.Event() self.sampled_token_ids_pinned_cpu = torch.empty( @@ -2238,6 +2245,116 @@ def _sample( self._update_states_after_model_execute(output_token_ids) return sampler_output + def _maybe_begin_nwor_window( + self, spec_decode_metadata: SpecDecodeMetadata | None + ) -> None: + set_global_deferred_manager(None) + + if envs.VLLM_DISABLE_NWOR: + self._latest_nwor_window_metrics = None + return + + self._deferred_write_manager.set_mode(envs.VLLM_NWOR_MODE) + self._latest_nwor_window_metrics = None + + if self._deferred_write_manager.get_mode() != "stage": + return + + if self.speculative_config is None or spec_decode_metadata is None: + return + + num_draft_tokens = spec_decode_metadata.num_draft_tokens + if not num_draft_tokens or sum(int(n) for n in num_draft_tokens) <= 0: + return + + if self._deferred_write_manager.begin_window(num_draft_tokens): + set_global_deferred_manager(self._deferred_write_manager) + + def _finalize_nwor_window( + self, + spec_decode_metadata: SpecDecodeMetadata | None, + sampled_token_ids: torch.Tensor | None, + ) -> None: + manager = self._deferred_write_manager + if not manager.window_active: + return + + try: + if spec_decode_metadata is None or sampled_token_ids is None: + manager.cancel_and_flush("missing_spec_metadata") + else: + mask = self._build_nwor_acceptance_mask( + spec_decode_metadata, sampled_token_ids + ) + if mask is None: + manager.cancel_and_flush("accept_mask_construction_failed") + else: + manager.commit(mask) + except ShouldFallback: + pass + finally: + self._latest_nwor_window_metrics = manager.pop_last_window_metrics() + set_global_deferred_manager(None) + + def _cleanup_nwor(self) -> None: + set_global_deferred_manager(None) + self._deferred_write_manager.finish_step() + pending = self._deferred_write_manager.pop_last_window_metrics() + if pending is not None and self._latest_nwor_window_metrics is None: + self._latest_nwor_window_metrics = pending + + def _build_nwor_acceptance_mask( + self, + spec_decode_metadata: SpecDecodeMetadata, + sampled_token_ids: torch.Tensor, + ) -> torch.Tensor | None: + num_draft_tokens = spec_decode_metadata.num_draft_tokens + total_tokens = sum(int(n) for n in num_draft_tokens) + if total_tokens <= 0: + return None + + target_device = spec_decode_metadata.draft_token_ids.device + work_device = sampled_token_ids.device + + draft_ids = spec_decode_metadata.draft_token_ids + if draft_ids.device != work_device: + draft_ids = draft_ids.to(device=work_device) + draft_ids = draft_ids.to(dtype=sampled_token_ids.dtype, copy=False) + + mask_work = torch.zeros(total_tokens, dtype=torch.bool, device=work_device) + + start = 0 + for req_idx, draft_count in enumerate(num_draft_tokens): + draft_count = int(draft_count) + if draft_count == 0: + continue + end = start + draft_count + row = sampled_token_ids[req_idx, :draft_count] + if row.device != work_device: + row = row.to(device=work_device) + if row.dtype != draft_ids.dtype: + row = row.to(dtype=draft_ids.dtype) + + draft_slice = draft_ids[start:end] + comparison = (row == draft_slice).flatten() + + if bool(comparison.all().item()): + accepted = draft_count + else: + reject = torch.nonzero(~comparison, as_tuple=False) + accepted = int(reject[0, 0].item()) if reject.numel() > 0 else draft_count + + if accepted > 0: + mask_work[start : start + accepted] = True + start = end + + if start != total_tokens: + return None + + if mask_work.device == target_device: + return mask_work + return mask_work.to(device=target_device) + def _bookkeeping_sync( self, scheduler_output: "SchedulerOutput", @@ -2467,215 +2584,232 @@ def execute_model( scheduler_output, num_input_tokens, intermediate_tensors ) - uniform_decode = (max_query_len == self.uniform_decode_query_len) and ( - num_scheduled_tokens == self.input_batch.num_reqs * max_query_len - ) - batch_descriptor = BatchDescriptor( - num_tokens=num_input_tokens, uniform_decode=uniform_decode - ) - cudagraph_runtime_mode, batch_descriptor = ( - self.cudagraph_dispatcher.dispatch(batch_descriptor, use_cascade_attn) - ) - - # Set cudagraph mode to none if calc_kv_scales is true. - if attn_metadata is not None: - metadata_list = ( - attn_metadata.values() - if isinstance(attn_metadata, dict) - else [attn_metadata] - ) - if any( - getattr(m, "enable_kv_scales_calculation", False) for m in metadata_list - ): - cudagraph_runtime_mode = CUDAGraphMode.NONE + self._maybe_begin_nwor_window(spec_decode_metadata) - # Run the model. - # Use persistent buffers for CUDA graphs. - with ( - set_forward_context( - attn_metadata, - self.vllm_config, - num_tokens=num_input_tokens, - num_tokens_across_dp=num_tokens_across_dp, - cudagraph_runtime_mode=cudagraph_runtime_mode, - batch_descriptor=batch_descriptor, - ubatch_slices=ubatch_slices, - ), - record_function_or_nullcontext("Forward"), - self.maybe_get_kv_connector_output(scheduler_output) as kv_connector_output, - ): - model_output = self._model_forward( - input_ids=input_ids, - positions=positions, - intermediate_tensors=intermediate_tensors, - inputs_embeds=inputs_embeds, - **model_kwargs, - ) + try: + uniform_decode = (max_query_len == self.uniform_decode_query_len) and ( + num_scheduled_tokens == self.input_batch.num_reqs * max_query_len + ) + batch_descriptor = BatchDescriptor( + num_tokens=num_input_tokens, uniform_decode=uniform_decode + ) + cudagraph_runtime_mode, batch_descriptor = ( + self.cudagraph_dispatcher.dispatch(batch_descriptor, use_cascade_attn) + ) - with record_function_or_nullcontext("Postprocess"): - if self.use_aux_hidden_state_outputs: - # True when EAGLE 3 is used. - hidden_states, aux_hidden_states = model_output - else: - # Common case. - hidden_states = model_output - aux_hidden_states = None - - if not self.broadcast_pp_output: - # Common case. - if not get_pp_group().is_last_rank: - # Return the intermediate tensors. - assert isinstance(hidden_states, IntermediateTensors) - hidden_states.kv_connector_output = kv_connector_output - return hidden_states - - if self.is_pooling_model: - # Return the pooling output. - output = self._pool( - hidden_states, num_scheduled_tokens, num_scheduled_tokens_np + # Set cudagraph mode to none if calc_kv_scales is true. + if attn_metadata is not None: + metadata_list = ( + attn_metadata.values() + if isinstance(attn_metadata, dict) + else [attn_metadata] + ) + if any( + getattr(m, "enable_kv_scales_calculation", False) + for m in metadata_list + ): + cudagraph_runtime_mode = CUDAGraphMode.NONE + + # Run the model. + # Use persistent buffers for CUDA graphs. + with ( + set_forward_context( + attn_metadata, + self.vllm_config, + num_tokens=num_input_tokens, + num_tokens_across_dp=num_tokens_across_dp, + cudagraph_runtime_mode=cudagraph_runtime_mode, + batch_descriptor=batch_descriptor, + ubatch_slices=ubatch_slices, + ), + record_function_or_nullcontext("Forward"), + self.maybe_get_kv_connector_output(scheduler_output) + as kv_connector_output, + ): + model_output = self._model_forward( + input_ids=input_ids, + positions=positions, + intermediate_tensors=intermediate_tensors, + inputs_embeds=inputs_embeds, + **model_kwargs, ) - output.kv_connector_output = kv_connector_output - return output - sample_hidden_states = hidden_states[logits_indices] - logits = self.model.compute_logits(sample_hidden_states) - else: - # Rare case. - assert not self.is_pooling_model + with record_function_or_nullcontext("Postprocess"): + if self.use_aux_hidden_state_outputs: + # True when EAGLE 3 is used. + hidden_states, aux_hidden_states = model_output + else: + # Common case. + hidden_states = model_output + aux_hidden_states = None + + if not self.broadcast_pp_output: + # Common case. + if not get_pp_group().is_last_rank: + # Return the intermediate tensors. + assert isinstance(hidden_states, IntermediateTensors) + hidden_states.kv_connector_output = kv_connector_output + return hidden_states + + if self.is_pooling_model: + # Return the pooling output. + output = self._pool( + hidden_states, + num_scheduled_tokens, + num_scheduled_tokens_np, + ) + output.kv_connector_output = kv_connector_output + return output + + sample_hidden_states = hidden_states[logits_indices] + logits = self.model.compute_logits(sample_hidden_states) + else: + # Rare case. + assert not self.is_pooling_model + + if not get_pp_group().is_last_rank: + all_gather_tensors = { + "residual": not is_residual_scattered_for_sp( + self.vllm_config, num_input_tokens + ) + } + get_pp_group().send_tensor_dict( + hidden_states.tensors, + all_gather_group=get_tp_group(), + all_gather_tensors=all_gather_tensors, + ) + logits = None + else: + sample_hidden_states = hidden_states[logits_indices] + logits = self.model.compute_logits(sample_hidden_states) + + model_output_broadcast_data = {} + if logits is not None: + model_output_broadcast_data["logits"] = logits.contiguous() - if not get_pp_group().is_last_rank: - all_gather_tensors = { - "residual": not is_residual_scattered_for_sp( - self.vllm_config, num_input_tokens + model_output_broadcast_data = get_pp_group().broadcast_tensor_dict( + model_output_broadcast_data, src=len(get_pp_group().ranks) - 1 ) - } - get_pp_group().send_tensor_dict( - hidden_states.tensors, - all_gather_group=get_tp_group(), - all_gather_tensors=all_gather_tensors, - ) - logits = None - else: - sample_hidden_states = hidden_states[logits_indices] - logits = self.model.compute_logits(sample_hidden_states) + assert model_output_broadcast_data is not None + logits = model_output_broadcast_data["logits"] - model_output_broadcast_data = {} - if logits is not None: - model_output_broadcast_data["logits"] = logits.contiguous() + # Apply structured output bitmasks if present + if scheduler_output.grammar_bitmask is not None: + apply_grammar_bitmask( + scheduler_output, self.input_batch, logits, self.device + ) - model_output_broadcast_data = get_pp_group().broadcast_tensor_dict( - model_output_broadcast_data, src=len(get_pp_group().ranks) - 1 - ) - assert model_output_broadcast_data is not None - logits = model_output_broadcast_data["logits"] + with record_function_or_nullcontext("Sample"): + sampler_output = self._sample(logits, spec_decode_metadata) - # Apply structured output bitmasks if present - if scheduler_output.grammar_bitmask is not None: - apply_grammar_bitmask( - scheduler_output, self.input_batch, logits, self.device + self._finalize_nwor_window( + spec_decode_metadata, sampler_output.sampled_token_ids ) - with record_function_or_nullcontext("Sample"): - sampler_output = self._sample(logits, spec_decode_metadata) + def propose_draft_token_ids(sampled_token_ids): + assert spec_decode_common_attn_metadata is not None + with record_function_or_nullcontext("Draft"): + self._draft_token_ids = self.propose_draft_token_ids( + scheduler_output, + sampled_token_ids, + self.input_batch.sampling_metadata, + hidden_states, + sample_hidden_states, + aux_hidden_states, + spec_decode_metadata, + spec_decode_common_attn_metadata, + ) - def propose_draft_token_ids(sampled_token_ids): - assert spec_decode_common_attn_metadata is not None - with record_function_or_nullcontext("Draft"): - self._draft_token_ids = self.propose_draft_token_ids( - scheduler_output, - sampled_token_ids, - self.input_batch.sampling_metadata, - hidden_states, - sample_hidden_states, - aux_hidden_states, - spec_decode_metadata, - spec_decode_common_attn_metadata, + use_padded_batch_for_eagle = ( + self.speculative_config + and self.speculative_config.use_eagle() + and not self.speculative_config.disable_padded_drafter_batch ) + effective_drafter_max_model_len = self.max_model_len + if effective_drafter_max_model_len is None: + effective_drafter_max_model_len = self.model_config.max_model_len + if ( + self.speculative_config + and self.speculative_config.draft_model_config is not None + and self.speculative_config.draft_model_config.max_model_len + is not None + ): + effective_drafter_max_model_len = ( + self.speculative_config.draft_model_config.max_model_len + ) + input_fits_in_drafter = spec_decode_common_attn_metadata and ( + spec_decode_common_attn_metadata.max_seq_len + + self.speculative_config.num_speculative_tokens + <= effective_drafter_max_model_len + ) + if use_padded_batch_for_eagle and input_fits_in_drafter: + # EAGLE speculative decoding can use the GPU sampled tokens + # as inputs, and does not need to wait for bookkeeping to finish. + propose_draft_token_ids(sampler_output.sampled_token_ids) + + with record_function_or_nullcontext("Bookkeep"): + ( + num_nans_in_logits, + logprobs_lists, + valid_sampled_token_ids, + prompt_logprobs_dict, + req_ids_output_copy, + req_id_to_index_output_copy, + invalid_req_indices, + ) = self._bookkeeping_sync( + scheduler_output, + sampler_output, + logits, + hidden_states, + num_scheduled_tokens, + ) - use_padded_batch_for_eagle = ( - self.speculative_config - and self.speculative_config.use_eagle() - and not self.speculative_config.disable_padded_drafter_batch - ) - effective_drafter_max_model_len = self.max_model_len - if effective_drafter_max_model_len is None: - effective_drafter_max_model_len = self.model_config.max_model_len - if ( - self.speculative_config - and self.speculative_config.draft_model_config is not None - and self.speculative_config.draft_model_config.max_model_len is not None - ): - effective_drafter_max_model_len = ( - self.speculative_config.draft_model_config.max_model_len - ) - input_fits_in_drafter = spec_decode_common_attn_metadata and ( - spec_decode_common_attn_metadata.max_seq_len - + self.speculative_config.num_speculative_tokens - <= effective_drafter_max_model_len - ) - if use_padded_batch_for_eagle and input_fits_in_drafter: - # EAGLE speculative decoding can use the GPU sampled tokens - # as inputs, and does not need to wait for bookkeeping to finish. - propose_draft_token_ids(sampler_output.sampled_token_ids) - - with record_function_or_nullcontext("Bookkeep"): - ( - num_nans_in_logits, - logprobs_lists, - valid_sampled_token_ids, - prompt_logprobs_dict, - req_ids_output_copy, - req_id_to_index_output_copy, - invalid_req_indices, - ) = self._bookkeeping_sync( - scheduler_output, - sampler_output, - logits, - hidden_states, - num_scheduled_tokens, - ) + if ( + self.speculative_config + and not use_padded_batch_for_eagle + and input_fits_in_drafter + ): + # ngram and other speculative decoding methods use the sampled + # tokens on the CPU, so they are run after bookkeeping. + propose_draft_token_ids(valid_sampled_token_ids) + + with record_function_or_nullcontext("EPLB"): + self.eplb_step() + + output = ModelRunnerOutput( + req_ids=req_ids_output_copy, + req_id_to_index=req_id_to_index_output_copy, + sampled_token_ids=valid_sampled_token_ids, + logprobs=logprobs_lists, + prompt_logprobs_dict=prompt_logprobs_dict, + pooler_output=[], + kv_connector_output=kv_connector_output, + num_nans_in_logits=num_nans_in_logits, + ) + output.nwor_metrics = self._latest_nwor_window_metrics + self._latest_nwor_window_metrics = None - if ( - self.speculative_config - and not use_padded_batch_for_eagle - and input_fits_in_drafter - ): - # ngram and other speculative decoding methods use the sampled - # tokens on the CPU, so they are run after bookkeeping. - propose_draft_token_ids(valid_sampled_token_ids) - - with record_function_or_nullcontext("EPLB"): - self.eplb_step() - - output = ModelRunnerOutput( - req_ids=req_ids_output_copy, - req_id_to_index=req_id_to_index_output_copy, - sampled_token_ids=valid_sampled_token_ids, - logprobs=logprobs_lists, - prompt_logprobs_dict=prompt_logprobs_dict, - pooler_output=[], - kv_connector_output=kv_connector_output, - num_nans_in_logits=num_nans_in_logits, - ) + if not self.use_async_scheduling: + return output - if not self.use_async_scheduling: - return output + async_output = AsyncGPUModelRunnerOutput( + model_runner_output=output, + sampled_token_ids=sampler_output.sampled_token_ids, + invalid_req_indices=invalid_req_indices, + async_output_copy_stream=self.async_output_copy_stream, + ) - async_output = AsyncGPUModelRunnerOutput( - model_runner_output=output, - sampled_token_ids=sampler_output.sampled_token_ids, - invalid_req_indices=invalid_req_indices, - async_output_copy_stream=self.async_output_copy_stream, - ) + # Save ref of sampled_token_ids CPU tensor if the batch contains + # any requests with sampling params that that require output ids. + self.input_batch.set_async_sampled_token_ids( + async_output.sampled_token_ids_cpu, + async_output.async_copy_ready_event, + ) - # Save ref of sampled_token_ids CPU tensor if the batch contains - # any requests with sampling params that that require output ids. - self.input_batch.set_async_sampled_token_ids( - async_output.sampled_token_ids_cpu, - async_output.async_copy_ready_event, - ) + return async_output - return async_output + finally: + self._cleanup_nwor() def take_draft_token_ids(self) -> DraftTokenIds | None: if self._draft_token_ids is None: