diff --git a/tests/v1/test_deferred_writer.py b/tests/v1/test_deferred_writer.py index 16b65a08b7bb..91496757fe69 100644 --- a/tests/v1/test_deferred_writer.py +++ b/tests/v1/test_deferred_writer.py @@ -196,6 +196,17 @@ def test_nwor_immediate_mode_skips_window(): assert manager.get_mode() == "immediate" +def test_scv_vectorized_mask_matches_reference(): + metadata = _make_metadata([1, 2, 3, 4], [4]) + sampled = torch.tensor([[1, 2, 0, 4]], dtype=torch.int32) + + runner = GPUModelRunner.__new__(GPUModelRunner) + runner._scv_mode = "adaptive" + + mask = runner._build_nwor_acceptance_mask(metadata, sampled) + assert mask.tolist() == [True, True, False, False] + + def test_commit_failure_triggers_fallback_metrics(): manager = DeferredWriteManager() assert manager.begin_window([1]) diff --git a/vllm/envs.py b/vllm/envs.py index f876a0765496..5336660dd1be 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -200,6 +200,7 @@ VLLM_DISABLE_PAD_FOR_CUDAGRAPH: bool = False VLLM_DISABLE_NWOR: bool = False VLLM_NWOR_MODE: str = "stage" + VLLM_SCV_MODE: str = "off" VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS: bool = False VLLM_CUSTOM_SCOPES_FOR_PROFILING: bool = False VLLM_NVTX_SCOPES_FOR_PROFILING: bool = False @@ -1315,6 +1316,8 @@ def get_vllm_port() -> int | None: "VLLM_DISABLE_NWOR": lambda: bool(int(os.getenv("VLLM_DISABLE_NWOR", "0"))), # Select NWOR mode: "stage" (default) or "immediate" to bypass staging. "VLLM_NWOR_MODE": lambda: os.getenv("VLLM_NWOR_MODE", "stage"), + # Speculative chunk verify mode: "off" (default), "graph", or "adaptive". + "VLLM_SCV_MODE": lambda: os.getenv("VLLM_SCV_MODE", "off"), # Used to force set up loopback IP "VLLM_LOOPBACK_IP": lambda: os.getenv("VLLM_LOOPBACK_IP", ""), # Used to set the process name prefix for vLLM processes. diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 0909d0f8dd0a..b84256dec815 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -5,6 +5,7 @@ import itertools import time from collections import defaultdict +from dataclasses import dataclass from collections.abc import Iterator from contextlib import contextmanager from copy import deepcopy @@ -509,6 +510,8 @@ def __init__( # Cached outputs. self._deferred_write_manager = DeferredWriteManager(mode=envs.VLLM_NWOR_MODE) self._latest_nwor_window_metrics: dict[str, int | str] | None = None + self._scv_mode = envs.VLLM_SCV_MODE.lower() + self._scv_graph_executor: SCVGraphExecutor | None = None self._draft_token_ids: list[list[int]] | torch.Tensor | None = None self.transfer_event = torch.cuda.Event() self.sampled_token_ids_pinned_cpu = torch.empty( @@ -518,6 +521,14 @@ def __init__( pin_memory=self.pin_memory, ) + def _scv_enabled(self) -> bool: + if not hasattr(self, "_scv_mode"): + self._scv_mode = envs.VLLM_SCV_MODE.lower() + if self._scv_mode not in ("off", "graph", "adaptive"): + logger.warning("SCV: unsupported mode '%s', disabling.", self._scv_mode) + self._scv_mode = "off" + return self._scv_mode != "off" + def reset_mm_cache(self) -> None: if self.mm_budget: self.mm_budget.reset_cache() @@ -2316,6 +2327,15 @@ def _build_nwor_acceptance_mask( target_device = spec_decode_metadata.draft_token_ids.device work_device = sampled_token_ids.device + if self._scv_enabled(): + mask = self._scv_vectorized_mask( + spec_decode_metadata, sampled_token_ids, total_tokens, work_device + ) + if mask is not None: + if mask.device != target_device: + mask = mask.to(device=target_device) + return mask + draft_ids = spec_decode_metadata.draft_token_ids if draft_ids.device != work_device: draft_ids = draft_ids.to(device=work_device) @@ -2336,16 +2356,9 @@ def _build_nwor_acceptance_mask( row = row.to(dtype=draft_ids.dtype) draft_slice = draft_ids[start:end] - comparison = (row == draft_slice).flatten() - - if bool(comparison.all().item()): - accepted = draft_count - else: - reject = torch.nonzero(~comparison, as_tuple=False) - accepted = int(reject[0, 0].item()) if reject.numel() > 0 else draft_count - - if accepted > 0: - mask_work[start : start + accepted] = True + comparison = (row == draft_slice) + prefix = torch.cumprod(comparison.to(torch.int32), dim=0) + mask_work[start:end] = prefix.to(torch.bool) start = end if start != total_tokens: @@ -2355,6 +2368,130 @@ def _build_nwor_acceptance_mask( return mask_work return mask_work.to(device=target_device) + def _scv_vectorized_mask( + self, + spec_decode_metadata: SpecDecodeMetadata, + sampled_token_ids: torch.Tensor, + total_tokens: int, + device: torch.device, + ) -> torch.Tensor | None: + draft_ids = spec_decode_metadata.draft_token_ids + max_spec_len = spec_decode_metadata.max_spec_len + num_draft_tensor = torch.tensor( + spec_decode_metadata.num_draft_tokens, + device=device, + dtype=torch.int32, + ) + if draft_ids.device != device: + draft_ids = draft_ids.to(device=device) + + cu = spec_decode_metadata.cu_num_draft_tokens.to(device=device) + + if hasattr(self, "_scv_mode") and self._scv_mode == "graph": + executor = getattr(self, "_scv_graph_executor", None) + if executor is None: + executor = SCVGraphExecutor(device) + self._scv_graph_executor = executor + mask = executor.run( + spec_decode_metadata, sampled_token_ids, total_tokens + ) + if mask is not None: + return mask + + if hasattr(self, "_scv_mode") and self._scv_mode == "adaptive": + mask = self._scv_compute_mask( + draft_ids, + num_draft_tensor, + cu, + sampled_token_ids, + max_spec_len, + total_tokens, + ) + self._scv_update_controller(spec_decode_metadata, mask) + return mask + + mask = self._scv_compute_mask( + draft_ids, + num_draft_tensor, + cu, + sampled_token_ids, + max_spec_len, + total_tokens, + ) + return mask + + @staticmethod + def _scv_compute_mask( + draft_ids: torch.Tensor, + num_draft_tokens: torch.Tensor, + cu_num_draft_tokens: torch.Tensor, + sampled_token_ids: torch.Tensor, + max_spec_len: int, + total_tokens: int, + ) -> torch.Tensor: + device = draft_ids.device + indices = torch.arange(total_tokens, device=device, dtype=torch.int32) + req_idx = torch.bucketize(indices, cu_num_draft_tokens) + prev_cu = torch.cat([cu_num_draft_tokens.new_zeros(1), cu_num_draft_tokens[:-1]]) + pos_in_req = indices - prev_cu[req_idx] + + gathered = sampled_token_ids[req_idx, pos_in_req] + comparison = gathered == draft_ids + + max_val = max_spec_len + 1 + values = torch.where( + ~comparison, + (pos_in_req + 1).to(torch.int32), + torch.full_like(pos_in_req, max_val, dtype=torch.int32), + ) + + accepted = torch.full( + (num_draft_tokens.numel(),), + max_val, + device=device, + dtype=torch.int32, + ) + accepted.scatter_reduce_(0, req_idx, values, reduce="amin") + accepted = torch.where( + accepted == max_val, + num_draft_tokens, + accepted - 1, + ) + accepted_broadcast = accepted[req_idx] + mask_flat = pos_in_req < accepted_broadcast + return mask_flat + + def _scv_update_controller( + self, + spec_decode_metadata: SpecDecodeMetadata, + mask: torch.Tensor, + ) -> None: + target_ratio = 0.6 + alpha = 0.2 + accepted = int(mask.sum().item()) + total = max(mask.numel(), 1) + ratio = accepted / total + prev = getattr(self, "_scv_accept_ratio", target_ratio) + new_ratio = (1 - alpha) * prev + alpha * ratio + self._scv_accept_ratio = new_ratio + + speculative_config = getattr(self, "speculative_config", None) + if speculative_config is None or not hasattr(speculative_config, "num_speculative_tokens"): + return + + base_k = speculative_config.num_speculative_tokens + k_min = max(1, base_k // 4) + k_max = max(1, base_k * 2) + + if new_ratio < target_ratio * 0.8: + new_k = max(k_min, base_k - 1) + elif new_ratio > target_ratio * 1.2: + new_k = min(k_max, base_k + 1) + else: + new_k = base_k + + speculative_config.num_speculative_tokens = new_k + def _bookkeeping_sync( self, scheduler_output: "SchedulerOutput", @@ -4836,3 +4973,125 @@ def _to_list(self, sampled_token_ids: torch.Tensor) -> list[list[int]]: self.transfer_event.record() self.transfer_event.synchronize() return pinned.tolist() +@dataclass +class _SCVGraphEntry: + num_reqs: int + max_spec_len: int + total_tokens: int + sampled_shape: tuple[int, int] + sampled_dtype: torch.dtype + draft_dtype: torch.dtype + device: torch.device + + def __post_init__(self): + self.sampled_buffer = torch.empty( + self.sampled_shape, device=self.device, dtype=self.sampled_dtype + ) + self.draft_buffer = torch.empty( + (self.total_tokens,), device=self.device, dtype=self.draft_dtype + ) + self.num_tokens_buffer = torch.empty( + (self.num_reqs,), device=self.device, dtype=torch.int32 + ) + self.cu_buffer = torch.empty( + (self.num_reqs,), device=self.device, dtype=torch.int32 + ) + self.mask_buffer = torch.empty( + (self.total_tokens,), device=self.device, dtype=torch.bool + ) + self.graph = torch.cuda.CUDAGraph() + self._captured = False + + def capture(self): + if self._captured: + return + mask = GPUModelRunner._scv_compute_mask( + self.draft_buffer, + self.num_tokens_buffer, + self.cu_buffer, + self.sampled_buffer, + self.max_spec_len, + self.total_tokens, + ) + self.mask_buffer.copy_(mask) + torch.cuda.synchronize() + with torch.cuda.graph(self.graph): + mask = GPUModelRunner._scv_compute_mask( + self.draft_buffer, + self.num_tokens_buffer, + self.cu_buffer, + self.sampled_buffer, + self.max_spec_len, + self.total_tokens, + ) + self.mask_buffer.copy_(mask) + self._captured = True + + def run(self): + if not self._captured: + self.capture() + self.graph.replay() + return self.mask_buffer + + +class SCVGraphExecutor: + def __init__(self, device: torch.device): + self.device = device + self.entries: dict[tuple[Any, ...], _SCVGraphEntry] = {} + self.enabled = torch.cuda.is_available() + + def run( + self, + spec_decode_metadata: SpecDecodeMetadata, + sampled_token_ids: torch.Tensor, + total_tokens: int, + ) -> torch.Tensor | None: + if not self.enabled: + return None + num_reqs = len(spec_decode_metadata.num_draft_tokens) + max_spec_len = spec_decode_metadata.max_spec_len + key = ( + num_reqs, + max_spec_len, + sampled_token_ids.shape[1], + total_tokens, + sampled_token_ids.dtype, + ) + entry = self.entries.get(key) + need_capture = False + if entry is None: + entry = _SCVGraphEntry( + num_reqs=num_reqs, + max_spec_len=max_spec_len, + total_tokens=total_tokens, + sampled_shape=sampled_token_ids[:, :max_spec_len].shape, + sampled_dtype=sampled_token_ids.dtype, + draft_dtype=spec_decode_metadata.draft_token_ids.dtype, + device=self.device, + ) + self.entries[key] = entry + need_capture = True + try: + sampled_view = sampled_token_ids[:, :max_spec_len] + entry.sampled_buffer.copy_(sampled_view) + draft_ids = spec_decode_metadata.draft_token_ids.to(self.device) + entry.draft_buffer.zero_() + entry.draft_buffer[: draft_ids.numel()].copy_(draft_ids) + num_tokens_tensor = torch.tensor( + spec_decode_metadata.num_draft_tokens, + device=self.device, + dtype=torch.int32, + ) + entry.num_tokens_buffer.copy_(num_tokens_tensor) + cu_tensor = spec_decode_metadata.cu_num_draft_tokens.to( + device=self.device, dtype=torch.int32 + ) + entry.cu_buffer.copy_(cu_tensor) + if need_capture: + entry.capture() + return entry.run() + except RuntimeError as exc: + logger.warning("SCV graph execution disabled: %s", exc) + self.enabled = False + self.entries.clear() + return None