Skip to content
11 changes: 11 additions & 0 deletions tests/v1/test_deferred_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,17 @@ def test_nwor_immediate_mode_skips_window():
assert manager.get_mode() == "immediate"


def test_scv_vectorized_mask_matches_reference():
metadata = _make_metadata([1, 2, 3, 4], [4])
sampled = torch.tensor([[1, 2, 0, 4]], dtype=torch.int32)

runner = GPUModelRunner.__new__(GPUModelRunner)
runner._scv_mode = "adaptive"

mask = runner._build_nwor_acceptance_mask(metadata, sampled)
assert mask.tolist() == [True, True, False, False]


def test_commit_failure_triggers_fallback_metrics():
manager = DeferredWriteManager()
assert manager.begin_window([1])
Expand Down
3 changes: 3 additions & 0 deletions vllm/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,7 @@
VLLM_DISABLE_PAD_FOR_CUDAGRAPH: bool = False
VLLM_DISABLE_NWOR: bool = False
VLLM_NWOR_MODE: str = "stage"
VLLM_SCV_MODE: str = "off"
VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS: bool = False
VLLM_CUSTOM_SCOPES_FOR_PROFILING: bool = False
VLLM_NVTX_SCOPES_FOR_PROFILING: bool = False
Expand Down Expand Up @@ -1315,6 +1316,8 @@ def get_vllm_port() -> int | None:
"VLLM_DISABLE_NWOR": lambda: bool(int(os.getenv("VLLM_DISABLE_NWOR", "0"))),
# Select NWOR mode: "stage" (default) or "immediate" to bypass staging.
"VLLM_NWOR_MODE": lambda: os.getenv("VLLM_NWOR_MODE", "stage"),
# Speculative chunk verify mode: "off" (default), "graph", or "adaptive".
"VLLM_SCV_MODE": lambda: os.getenv("VLLM_SCV_MODE", "off"),
# Used to force set up loopback IP
"VLLM_LOOPBACK_IP": lambda: os.getenv("VLLM_LOOPBACK_IP", ""),
# Used to set the process name prefix for vLLM processes.
Expand Down
279 changes: 269 additions & 10 deletions vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import itertools
import time
from collections import defaultdict
from dataclasses import dataclass
from collections.abc import Iterator
from contextlib import contextmanager
from copy import deepcopy
Expand Down Expand Up @@ -509,6 +510,8 @@ def __init__(
# Cached outputs.
self._deferred_write_manager = DeferredWriteManager(mode=envs.VLLM_NWOR_MODE)
self._latest_nwor_window_metrics: dict[str, int | str] | None = None
self._scv_mode = envs.VLLM_SCV_MODE.lower()
self._scv_graph_executor: SCVGraphExecutor | None = None
self._draft_token_ids: list[list[int]] | torch.Tensor | None = None
self.transfer_event = torch.cuda.Event()
self.sampled_token_ids_pinned_cpu = torch.empty(
Expand All @@ -518,6 +521,14 @@ def __init__(
pin_memory=self.pin_memory,
)

def _scv_enabled(self) -> bool:
if not hasattr(self, "_scv_mode"):
self._scv_mode = envs.VLLM_SCV_MODE.lower()
if self._scv_mode not in ("off", "graph", "adaptive"):
logger.warning("SCV: unsupported mode '%s', disabling.", self._scv_mode)
self._scv_mode = "off"
return self._scv_mode != "off"

def reset_mm_cache(self) -> None:
if self.mm_budget:
self.mm_budget.reset_cache()
Expand Down Expand Up @@ -2316,6 +2327,15 @@ def _build_nwor_acceptance_mask(
target_device = spec_decode_metadata.draft_token_ids.device
work_device = sampled_token_ids.device

if self._scv_enabled():
mask = self._scv_vectorized_mask(
spec_decode_metadata, sampled_token_ids, total_tokens, work_device
)
if mask is not None:
if mask.device != target_device:
mask = mask.to(device=target_device)
return mask

draft_ids = spec_decode_metadata.draft_token_ids
if draft_ids.device != work_device:
draft_ids = draft_ids.to(device=work_device)
Expand All @@ -2336,16 +2356,9 @@ def _build_nwor_acceptance_mask(
row = row.to(dtype=draft_ids.dtype)

draft_slice = draft_ids[start:end]
comparison = (row == draft_slice).flatten()

if bool(comparison.all().item()):
accepted = draft_count
else:
reject = torch.nonzero(~comparison, as_tuple=False)
accepted = int(reject[0, 0].item()) if reject.numel() > 0 else draft_count

if accepted > 0:
mask_work[start : start + accepted] = True
comparison = (row == draft_slice)
prefix = torch.cumprod(comparison.to(torch.int32), dim=0)
mask_work[start:end] = prefix.to(torch.bool)
start = end

if start != total_tokens:
Expand All @@ -2355,6 +2368,130 @@ def _build_nwor_acceptance_mask(
return mask_work
return mask_work.to(device=target_device)

def _scv_vectorized_mask(
self,
spec_decode_metadata: SpecDecodeMetadata,
sampled_token_ids: torch.Tensor,
total_tokens: int,
device: torch.device,
) -> torch.Tensor | None:
draft_ids = spec_decode_metadata.draft_token_ids
max_spec_len = spec_decode_metadata.max_spec_len
num_draft_tensor = torch.tensor(
spec_decode_metadata.num_draft_tokens,
device=device,
dtype=torch.int32,
)
if draft_ids.device != device:
draft_ids = draft_ids.to(device=device)

cu = spec_decode_metadata.cu_num_draft_tokens.to(device=device)

if hasattr(self, "_scv_mode") and self._scv_mode == "graph":
executor = getattr(self, "_scv_graph_executor", None)
if executor is None:
executor = SCVGraphExecutor(device)
self._scv_graph_executor = executor
mask = executor.run(
spec_decode_metadata, sampled_token_ids, total_tokens
)
if mask is not None:
return mask

if hasattr(self, "_scv_mode") and self._scv_mode == "adaptive":
mask = self._scv_compute_mask(
draft_ids,
num_draft_tensor,
cu,
sampled_token_ids,
max_spec_len,
total_tokens,
)
self._scv_update_controller(spec_decode_metadata, mask)
return mask

mask = self._scv_compute_mask(
draft_ids,
num_draft_tensor,
cu,
sampled_token_ids,
max_spec_len,
total_tokens,
)
return mask

@staticmethod
def _scv_compute_mask(
draft_ids: torch.Tensor,
num_draft_tokens: torch.Tensor,
cu_num_draft_tokens: torch.Tensor,
sampled_token_ids: torch.Tensor,
max_spec_len: int,
total_tokens: int,
) -> torch.Tensor:
device = draft_ids.device
indices = torch.arange(total_tokens, device=device, dtype=torch.int32)
req_idx = torch.bucketize(indices, cu_num_draft_tokens)
prev_cu = torch.cat([cu_num_draft_tokens.new_zeros(1), cu_num_draft_tokens[:-1]])
pos_in_req = indices - prev_cu[req_idx]

gathered = sampled_token_ids[req_idx, pos_in_req]
comparison = gathered == draft_ids

max_val = max_spec_len + 1
values = torch.where(
~comparison,
(pos_in_req + 1).to(torch.int32),
torch.full_like(pos_in_req, max_val, dtype=torch.int32),
)

accepted = torch.full(
(num_draft_tokens.numel(),),
max_val,
device=device,
dtype=torch.int32,
)
accepted.scatter_reduce_(0, req_idx, values, reduce="amin")
accepted = torch.where(
accepted == max_val,
num_draft_tokens,
accepted - 1,
)
accepted_broadcast = accepted[req_idx]
mask_flat = pos_in_req < accepted_broadcast
return mask_flat

def _scv_update_controller(
self,
spec_decode_metadata: SpecDecodeMetadata,
mask: torch.Tensor,
) -> None:
target_ratio = 0.6
alpha = 0.2
accepted = int(mask.sum().item())
total = max(mask.numel(), 1)
ratio = accepted / total
prev = getattr(self, "_scv_accept_ratio", target_ratio)
new_ratio = (1 - alpha) * prev + alpha * ratio
self._scv_accept_ratio = new_ratio

speculative_config = getattr(self, "speculative_config", None)
if speculative_config is None or not hasattr(speculative_config, "num_speculative_tokens"):
return

base_k = speculative_config.num_speculative_tokens
k_min = max(1, base_k // 4)
k_max = max(1, base_k * 2)

if new_ratio < target_ratio * 0.8:
new_k = max(k_min, base_k - 1)
elif new_ratio > target_ratio * 1.2:
new_k = min(k_max, base_k + 1)
else:
new_k = base_k

speculative_config.num_speculative_tokens = new_k

def _bookkeeping_sync(
self,
scheduler_output: "SchedulerOutput",
Expand Down Expand Up @@ -4836,3 +4973,125 @@ def _to_list(self, sampled_token_ids: torch.Tensor) -> list[list[int]]:
self.transfer_event.record()
self.transfer_event.synchronize()
return pinned.tolist()
@dataclass
class _SCVGraphEntry:
num_reqs: int
max_spec_len: int
total_tokens: int
sampled_shape: tuple[int, int]
sampled_dtype: torch.dtype
draft_dtype: torch.dtype
device: torch.device

def __post_init__(self):
self.sampled_buffer = torch.empty(
self.sampled_shape, device=self.device, dtype=self.sampled_dtype
)
self.draft_buffer = torch.empty(
(self.total_tokens,), device=self.device, dtype=self.draft_dtype
)
self.num_tokens_buffer = torch.empty(
(self.num_reqs,), device=self.device, dtype=torch.int32
)
self.cu_buffer = torch.empty(
(self.num_reqs,), device=self.device, dtype=torch.int32
)
self.mask_buffer = torch.empty(
(self.total_tokens,), device=self.device, dtype=torch.bool
)
self.graph = torch.cuda.CUDAGraph()
self._captured = False

def capture(self):
if self._captured:
return
mask = GPUModelRunner._scv_compute_mask(
self.draft_buffer,
self.num_tokens_buffer,
self.cu_buffer,
self.sampled_buffer,
self.max_spec_len,
self.total_tokens,
)
self.mask_buffer.copy_(mask)
torch.cuda.synchronize()
with torch.cuda.graph(self.graph):
mask = GPUModelRunner._scv_compute_mask(
self.draft_buffer,
self.num_tokens_buffer,
self.cu_buffer,
self.sampled_buffer,
self.max_spec_len,
self.total_tokens,
)
self.mask_buffer.copy_(mask)
self._captured = True

def run(self):
if not self._captured:
self.capture()
self.graph.replay()
return self.mask_buffer


class SCVGraphExecutor:
def __init__(self, device: torch.device):
self.device = device
self.entries: dict[tuple[Any, ...], _SCVGraphEntry] = {}
self.enabled = torch.cuda.is_available()

def run(
self,
spec_decode_metadata: SpecDecodeMetadata,
sampled_token_ids: torch.Tensor,
total_tokens: int,
) -> torch.Tensor | None:
if not self.enabled:
return None
num_reqs = len(spec_decode_metadata.num_draft_tokens)
max_spec_len = spec_decode_metadata.max_spec_len
key = (
num_reqs,
max_spec_len,
sampled_token_ids.shape[1],
total_tokens,
sampled_token_ids.dtype,
)
entry = self.entries.get(key)
need_capture = False
if entry is None:
entry = _SCVGraphEntry(
num_reqs=num_reqs,
max_spec_len=max_spec_len,
total_tokens=total_tokens,
sampled_shape=sampled_token_ids[:, :max_spec_len].shape,
sampled_dtype=sampled_token_ids.dtype,
draft_dtype=spec_decode_metadata.draft_token_ids.dtype,
device=self.device,
)
self.entries[key] = entry
need_capture = True
try:
sampled_view = sampled_token_ids[:, :max_spec_len]
entry.sampled_buffer.copy_(sampled_view)
draft_ids = spec_decode_metadata.draft_token_ids.to(self.device)
entry.draft_buffer.zero_()
entry.draft_buffer[: draft_ids.numel()].copy_(draft_ids)
num_tokens_tensor = torch.tensor(
spec_decode_metadata.num_draft_tokens,
device=self.device,
dtype=torch.int32,
)
entry.num_tokens_buffer.copy_(num_tokens_tensor)
cu_tensor = spec_decode_metadata.cu_num_draft_tokens.to(
device=self.device, dtype=torch.int32
)
entry.cu_buffer.copy_(cu_tensor)
if need_capture:
entry.capture()
return entry.run()
except RuntimeError as exc:
logger.warning("SCV graph execution disabled: %s", exc)
self.enabled = False
self.entries.clear()
return None