Skip to content

Commit b729072

Browse files
committed
Fix Bug #3 and nucleus collapse: Increase tau_q soft floor to 2.0
ROOT CAUSE: draft_q_soft_temp=0.50 was SHARPENING the distribution instead of softening it (dividing by tau<1.0 doubles logit magnitudes). This caused nucleus to collapse to 1-2 survivors → q≈1.0 → acceptance stuck at ~0.7038 (average p_target). FIXES: 1. Config defaults (config.py, arg_utils.py): - draft_q_temp_offset: 0.15 → 0.25 (better dynamic range) - draft_q_soft_temp: 0.50 → 2.0 (SOFTENS instead of sharpens) At draft_temp=0.05: - Before: tau_q = max(0.05+0.15, 0.50) = 0.50 (2x sharper!) - After: tau_q = max(0.05+0.25, 2.0) = 2.0 (2x softer) 2. Force min_keep=2 in nucleus (eagle.py line 271): - Added keep_sorted[..., :2] = True - Prevents survivors=1 by construction (defensive programming) 3. Fix smoothing to uniform over kept set (eagle.py lines 275-287): - Before: Mixed with untempered baseline (wrong approach) - After: Uniform distribution over survivors only (correct) - Prevents q from reaching exactly 1.0 in corner cases 4. Remove dead code (eagle.py line 322): - Deleted unused self._current_sampling_metadata assignment - No longer needed with draft-anchored approach (bug #2 fix) Expected results: - tau_q ≥ 2.0 at ultracold temps → softer distribution - NUC_DEBUG: survivors = hundreds/thousands (not 1-2) - Q_DEBUG: q ∈ [0.5, 0.8] (not 0.98-1.0) - Accept rate: dynamic range restored across temp sweep
1 parent 65f57a3 commit b729072

File tree

3 files changed

+18
-16
lines changed

3 files changed

+18
-16
lines changed

vllm/engine/arg_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -376,8 +376,8 @@ class EngineArgs:
376376
draft_top_p: float = 0.95
377377
draft_top_k: int = 0
378378
# Draft-anchored adaptive temperature settings
379-
draft_q_temp_offset: float = 0.15
380-
draft_q_soft_temp: float = 0.50
379+
draft_q_temp_offset: float = 0.25
380+
draft_q_soft_temp: float = 2.0
381381
draft_mix_lambda_max: float = 0.05
382382
revision: Optional[str] = ModelConfig.revision
383383
code_revision: Optional[str] = ModelConfig.code_revision

vllm/v1/spec_decode/config.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,8 @@ class SpecDecodeOptConfig:
2828
draft_top_k: int = 0 # 0 = disabled
2929

3030
# Draft-anchored adaptive temperature settings
31-
draft_q_temp_offset: float = 0.15 # Offset added to draft_temp
32-
draft_q_soft_temp: float = 0.50 # Soft floor to prevent ultra-cold collapse
31+
draft_q_temp_offset: float = 0.25 # Offset added to draft_temp
32+
draft_q_soft_temp: float = 2.0 # Soft floor to prevent ultra-cold collapse
3333
draft_mix_lambda_max: float = 0.05 # Tiny smoothing over baseline
3434

3535
# Debug and profiling settings
@@ -95,12 +95,12 @@ def from_cli_args(cls, vllm_config) -> "SpecDecodeOptConfig":
9595
if hasattr(vllm_config, 'draft_q_temp_offset'):
9696
config.draft_q_temp_offset = vllm_config.draft_q_temp_offset
9797
else:
98-
config.draft_q_temp_offset = float(os.environ.get('VLLM_DRAFT_Q_TEMP_OFFSET', '0.15'))
98+
config.draft_q_temp_offset = float(os.environ.get('VLLM_DRAFT_Q_TEMP_OFFSET', '0.25'))
9999

100100
if hasattr(vllm_config, 'draft_q_soft_temp'):
101101
config.draft_q_soft_temp = vllm_config.draft_q_soft_temp
102102
else:
103-
config.draft_q_soft_temp = float(os.environ.get('VLLM_DRAFT_Q_SOFT_TEMP', '0.50'))
103+
config.draft_q_soft_temp = float(os.environ.get('VLLM_DRAFT_Q_SOFT_TEMP', '2.0'))
104104

105105
if hasattr(vllm_config, 'draft_mix_lambda_max'):
106106
config.draft_mix_lambda_max = vllm_config.draft_mix_lambda_max

vllm/v1/spec_decode/eagle.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -268,22 +268,27 @@ def _sample_draft_tokens(
268268
keep_sorted = torch.zeros_like(sp, dtype=torch.bool)
269269
keep_sorted[..., 0] = True
270270
keep_sorted[..., 1:] = csum[..., :-1] < top_p # STRICTLY below (correct rule)
271+
keep_sorted[..., :2] = True # Force min_keep=2 (prevents survivors=1)
271272
keep = torch.zeros_like(p, dtype=torch.bool).scatter(-1, si, keep_sorted)
272273
x = torch.where(keep, x, torch.full_like(x, float("-inf")))
273274

274-
# Optional smoothing with untempered baseline
275-
probs_full = torch.softmax(x, dim=-1)
275+
# Optional smoothing over kept set (uniform mix)
276276
lam = float(getattr(self.opt_config, "draft_mix_lambda_max", 0.0) or 0.0)
277277
print(f"[SMOOTH_DEBUG] lambda_max from config: {lam}, will run smoothing: {lam > 0.0}",
278278
file=sys.stderr, flush=True)
279+
logp_full = torch.log_softmax(x, dim=-1)
279280
if lam > 0.0:
280-
base = torch.softmax(logits_f32, dim=-1) # untempered baseline
281-
probs_full = (1.0 - lam) * probs_full + lam * base
282-
probs_full = probs_full / probs_full.sum(dim=-1, keepdim=True)
283-
logp_full = torch.log(probs_full.clamp_min(1e-20))
281+
kept = torch.isfinite(logp_full)
282+
p = torch.exp(logp_full)
283+
# Uniform over survivors only
284+
u = kept.float() / kept.float().sum(dim=-1, keepdim=True).clamp_min(1.0)
285+
p = (1.0 - lam) * p + lam * u
286+
p = p * kept # Ensure dropped stay at 0
287+
logp_full = torch.log(p.clamp_min(1e-45))
284288

285289
# Sample token and gather its logp
286-
tok = torch.distributions.Categorical(probs=probs_full).sample()
290+
cat = torch.distributions.Categorical(logits=logp_full)
291+
tok = cat.sample()
287292
tok_logp = logp_full.gather(-1, tok.unsqueeze(-1)).squeeze(-1)
288293

289294
# Debug logging
@@ -318,9 +323,6 @@ def propose(
318323
sampling_metadata: SamplingMetadata,
319324
mm_embeds: Optional[list[torch.Tensor]] = None,
320325
) -> torch.Tensor:
321-
# Store sampling_metadata so _sample_draft_tokens() can access target temperature
322-
self._current_sampling_metadata = sampling_metadata
323-
324326
num_tokens = target_token_ids.shape[0]
325327
batch_size = next_token_ids.shape[0]
326328

0 commit comments

Comments
 (0)