Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 0 additions & 11 deletions tests/v1/test_deferred_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,17 +196,6 @@ def test_nwor_immediate_mode_skips_window():
assert manager.get_mode() == "immediate"


def test_scv_vectorized_mask_matches_reference():
metadata = _make_metadata([1, 2, 3, 4], [4])
sampled = torch.tensor([[1, 2, 0, 4]], dtype=torch.int32)

runner = GPUModelRunner.__new__(GPUModelRunner)
runner._scv_mode = "adaptive"

mask = runner._build_nwor_acceptance_mask(metadata, sampled)
assert mask.tolist() == [True, True, False, False]


def test_commit_failure_triggers_fallback_metrics():
manager = DeferredWriteManager()
assert manager.begin_window([1])
Expand Down
3 changes: 0 additions & 3 deletions vllm/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,6 @@
VLLM_DISABLE_PAD_FOR_CUDAGRAPH: bool = False
VLLM_DISABLE_NWOR: bool = False
VLLM_NWOR_MODE: str = "stage"
VLLM_SCV_MODE: str = "off"
VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS: bool = False
VLLM_CUSTOM_SCOPES_FOR_PROFILING: bool = False
VLLM_NVTX_SCOPES_FOR_PROFILING: bool = False
Expand Down Expand Up @@ -1316,8 +1315,6 @@ def get_vllm_port() -> int | None:
"VLLM_DISABLE_NWOR": lambda: bool(int(os.getenv("VLLM_DISABLE_NWOR", "0"))),
# Select NWOR mode: "stage" (default) or "immediate" to bypass staging.
"VLLM_NWOR_MODE": lambda: os.getenv("VLLM_NWOR_MODE", "stage"),
# Speculative chunk verify mode: "off" (default), "graph", or "adaptive".
"VLLM_SCV_MODE": lambda: os.getenv("VLLM_SCV_MODE", "off"),
# Used to force set up loopback IP
"VLLM_LOOPBACK_IP": lambda: os.getenv("VLLM_LOOPBACK_IP", ""),
# Used to set the process name prefix for vLLM processes.
Expand Down
279 changes: 10 additions & 269 deletions vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import itertools
import time
from collections import defaultdict
from dataclasses import dataclass
from collections.abc import Iterator
from contextlib import contextmanager
from copy import deepcopy
Expand Down Expand Up @@ -510,8 +509,6 @@ def __init__(
# Cached outputs.
self._deferred_write_manager = DeferredWriteManager(mode=envs.VLLM_NWOR_MODE)
self._latest_nwor_window_metrics: dict[str, int | str] | None = None
self._scv_mode = envs.VLLM_SCV_MODE.lower()
self._scv_graph_executor: SCVGraphExecutor | None = None
self._draft_token_ids: list[list[int]] | torch.Tensor | None = None
self.transfer_event = torch.cuda.Event()
self.sampled_token_ids_pinned_cpu = torch.empty(
Expand All @@ -521,14 +518,6 @@ def __init__(
pin_memory=self.pin_memory,
)

def _scv_enabled(self) -> bool:
if not hasattr(self, "_scv_mode"):
self._scv_mode = envs.VLLM_SCV_MODE.lower()
if self._scv_mode not in ("off", "graph", "adaptive"):
logger.warning("SCV: unsupported mode '%s', disabling.", self._scv_mode)
self._scv_mode = "off"
return self._scv_mode != "off"

def reset_mm_cache(self) -> None:
if self.mm_budget:
self.mm_budget.reset_cache()
Expand Down Expand Up @@ -2327,15 +2316,6 @@ def _build_nwor_acceptance_mask(
target_device = spec_decode_metadata.draft_token_ids.device
work_device = sampled_token_ids.device

if self._scv_enabled():
mask = self._scv_vectorized_mask(
spec_decode_metadata, sampled_token_ids, total_tokens, work_device
)
if mask is not None:
if mask.device != target_device:
mask = mask.to(device=target_device)
return mask

draft_ids = spec_decode_metadata.draft_token_ids
if draft_ids.device != work_device:
draft_ids = draft_ids.to(device=work_device)
Expand All @@ -2356,9 +2336,16 @@ def _build_nwor_acceptance_mask(
row = row.to(dtype=draft_ids.dtype)

draft_slice = draft_ids[start:end]
comparison = (row == draft_slice)
prefix = torch.cumprod(comparison.to(torch.int32), dim=0)
mask_work[start:end] = prefix.to(torch.bool)
comparison = (row == draft_slice).flatten()

if bool(comparison.all().item()):
accepted = draft_count
else:
reject = torch.nonzero(~comparison, as_tuple=False)
accepted = int(reject[0, 0].item()) if reject.numel() > 0 else draft_count

if accepted > 0:
mask_work[start : start + accepted] = True
start = end

if start != total_tokens:
Expand All @@ -2368,130 +2355,6 @@ def _build_nwor_acceptance_mask(
return mask_work
return mask_work.to(device=target_device)

def _scv_vectorized_mask(
self,
spec_decode_metadata: SpecDecodeMetadata,
sampled_token_ids: torch.Tensor,
total_tokens: int,
device: torch.device,
) -> torch.Tensor | None:
draft_ids = spec_decode_metadata.draft_token_ids
max_spec_len = spec_decode_metadata.max_spec_len
num_draft_tensor = torch.tensor(
spec_decode_metadata.num_draft_tokens,
device=device,
dtype=torch.int32,
)
if draft_ids.device != device:
draft_ids = draft_ids.to(device=device)

cu = spec_decode_metadata.cu_num_draft_tokens.to(device=device)

if hasattr(self, "_scv_mode") and self._scv_mode == "graph":
executor = getattr(self, "_scv_graph_executor", None)
if executor is None:
executor = SCVGraphExecutor(device)
self._scv_graph_executor = executor
mask = executor.run(
spec_decode_metadata, sampled_token_ids, total_tokens
)
if mask is not None:
return mask

if hasattr(self, "_scv_mode") and self._scv_mode == "adaptive":
mask = self._scv_compute_mask(
draft_ids,
num_draft_tensor,
cu,
sampled_token_ids,
max_spec_len,
total_tokens,
)
self._scv_update_controller(spec_decode_metadata, mask)
return mask

mask = self._scv_compute_mask(
draft_ids,
num_draft_tensor,
cu,
sampled_token_ids,
max_spec_len,
total_tokens,
)
return mask

@staticmethod
def _scv_compute_mask(
draft_ids: torch.Tensor,
num_draft_tokens: torch.Tensor,
cu_num_draft_tokens: torch.Tensor,
sampled_token_ids: torch.Tensor,
max_spec_len: int,
total_tokens: int,
) -> torch.Tensor:
device = draft_ids.device
indices = torch.arange(total_tokens, device=device, dtype=torch.int32)
req_idx = torch.bucketize(indices, cu_num_draft_tokens)
prev_cu = torch.cat([cu_num_draft_tokens.new_zeros(1), cu_num_draft_tokens[:-1]])
pos_in_req = indices - prev_cu[req_idx]

gathered = sampled_token_ids[req_idx, pos_in_req]
comparison = gathered == draft_ids

max_val = max_spec_len + 1
values = torch.where(
~comparison,
(pos_in_req + 1).to(torch.int32),
torch.full_like(pos_in_req, max_val, dtype=torch.int32),
)

accepted = torch.full(
(num_draft_tokens.numel(),),
max_val,
device=device,
dtype=torch.int32,
)
accepted.scatter_reduce_(0, req_idx, values, reduce="amin")
accepted = torch.where(
accepted == max_val,
num_draft_tokens,
accepted - 1,
)
accepted_broadcast = accepted[req_idx]
mask_flat = pos_in_req < accepted_broadcast
return mask_flat

def _scv_update_controller(
self,
spec_decode_metadata: SpecDecodeMetadata,
mask: torch.Tensor,
) -> None:
target_ratio = 0.6
alpha = 0.2
accepted = int(mask.sum().item())
total = max(mask.numel(), 1)
ratio = accepted / total
prev = getattr(self, "_scv_accept_ratio", target_ratio)
new_ratio = (1 - alpha) * prev + alpha * ratio
self._scv_accept_ratio = new_ratio

speculative_config = getattr(self, "speculative_config", None)
if speculative_config is None or not hasattr(speculative_config, "num_speculative_tokens"):
return

base_k = speculative_config.num_speculative_tokens
k_min = max(1, base_k // 4)
k_max = max(1, base_k * 2)

if new_ratio < target_ratio * 0.8:
new_k = max(k_min, base_k - 1)
elif new_ratio > target_ratio * 1.2:
new_k = min(k_max, base_k + 1)
else:
new_k = base_k

speculative_config.num_speculative_tokens = new_k

def _bookkeeping_sync(
self,
scheduler_output: "SchedulerOutput",
Expand Down Expand Up @@ -4973,125 +4836,3 @@ def _to_list(self, sampled_token_ids: torch.Tensor) -> list[list[int]]:
self.transfer_event.record()
self.transfer_event.synchronize()
return pinned.tolist()
@dataclass
class _SCVGraphEntry:
num_reqs: int
max_spec_len: int
total_tokens: int
sampled_shape: tuple[int, int]
sampled_dtype: torch.dtype
draft_dtype: torch.dtype
device: torch.device

def __post_init__(self):
self.sampled_buffer = torch.empty(
self.sampled_shape, device=self.device, dtype=self.sampled_dtype
)
self.draft_buffer = torch.empty(
(self.total_tokens,), device=self.device, dtype=self.draft_dtype
)
self.num_tokens_buffer = torch.empty(
(self.num_reqs,), device=self.device, dtype=torch.int32
)
self.cu_buffer = torch.empty(
(self.num_reqs,), device=self.device, dtype=torch.int32
)
self.mask_buffer = torch.empty(
(self.total_tokens,), device=self.device, dtype=torch.bool
)
self.graph = torch.cuda.CUDAGraph()
self._captured = False

def capture(self):
if self._captured:
return
mask = GPUModelRunner._scv_compute_mask(
self.draft_buffer,
self.num_tokens_buffer,
self.cu_buffer,
self.sampled_buffer,
self.max_spec_len,
self.total_tokens,
)
self.mask_buffer.copy_(mask)
torch.cuda.synchronize()
with torch.cuda.graph(self.graph):
mask = GPUModelRunner._scv_compute_mask(
self.draft_buffer,
self.num_tokens_buffer,
self.cu_buffer,
self.sampled_buffer,
self.max_spec_len,
self.total_tokens,
)
self.mask_buffer.copy_(mask)
self._captured = True

def run(self):
if not self._captured:
self.capture()
self.graph.replay()
return self.mask_buffer


class SCVGraphExecutor:
def __init__(self, device: torch.device):
self.device = device
self.entries: dict[tuple[Any, ...], _SCVGraphEntry] = {}
self.enabled = torch.cuda.is_available()

def run(
self,
spec_decode_metadata: SpecDecodeMetadata,
sampled_token_ids: torch.Tensor,
total_tokens: int,
) -> torch.Tensor | None:
if not self.enabled:
return None
num_reqs = len(spec_decode_metadata.num_draft_tokens)
max_spec_len = spec_decode_metadata.max_spec_len
key = (
num_reqs,
max_spec_len,
sampled_token_ids.shape[1],
total_tokens,
sampled_token_ids.dtype,
)
entry = self.entries.get(key)
need_capture = False
if entry is None:
entry = _SCVGraphEntry(
num_reqs=num_reqs,
max_spec_len=max_spec_len,
total_tokens=total_tokens,
sampled_shape=sampled_token_ids[:, :max_spec_len].shape,
sampled_dtype=sampled_token_ids.dtype,
draft_dtype=spec_decode_metadata.draft_token_ids.dtype,
device=self.device,
)
self.entries[key] = entry
need_capture = True
try:
sampled_view = sampled_token_ids[:, :max_spec_len]
entry.sampled_buffer.copy_(sampled_view)
draft_ids = spec_decode_metadata.draft_token_ids.to(self.device)
entry.draft_buffer.zero_()
entry.draft_buffer[: draft_ids.numel()].copy_(draft_ids)
num_tokens_tensor = torch.tensor(
spec_decode_metadata.num_draft_tokens,
device=self.device,
dtype=torch.int32,
)
entry.num_tokens_buffer.copy_(num_tokens_tensor)
cu_tensor = spec_decode_metadata.cu_num_draft_tokens.to(
device=self.device, dtype=torch.int32
)
entry.cu_buffer.copy_(cu_tensor)
if need_capture:
entry.capture()
return entry.run()
except RuntimeError as exc:
logger.warning("SCV graph execution disabled: %s", exc)
self.enabled = False
self.entries.clear()
return None