-
-
Notifications
You must be signed in to change notification settings - Fork 11.3k
Closed
Labels
bugSomething isn't workingSomething isn't working
Description
PS: i know this bug report was AI generated, but it is an issue i found, and it is described accurately
Your current environment
NA
🐛 Describe the bug
VLLM Bug Report: Incorrect CPU Penalty Fallback Condition
🐛 Bug Summary
VLLM's penalty fallback mechanism incorrectly checks current_platform.is_cuda() instead of logits.is_cuda(), causing CPU tensors to attempt CUDA operations on CUDA-capable systems.
📍 Location
File: vllm/_custom_ops.py
Function: apply_repetition_penalties()
Line: ~315
🔍 Problematic Code
def apply_repetition_penalties(logits: torch.Tensor, prompt_mask: torch.Tensor,
output_mask: torch.Tensor,
repetition_penalties: torch.Tensor) -> None:
"""Apply repetition penalties to logits in-place."""
if current_platform.is_cuda() and logits.is_contiguous(): # ❌ BUG HERE
apply_repetition_penalties_cuda(logits, prompt_mask, output_mask,
repetition_penalties)
else:
apply_repetition_penalties_torch(logits, prompt_mask, output_mask,
repetition_penalties)🚨 Issue Description
The condition current_platform.is_cuda() checks if the platform supports CUDA, not if the tensors are on CUDA. This causes:
- CPU tensors on CUDA-capable systems → Try to use CUDA operations → Crash
- Prevents CPU offloading → Can't use CPU for sampling while GPU handles model inference
💥 Error Reproduction
# On a CUDA-capable system:
import torch
from vllm._custom_ops import apply_repetition_penalties
# Create CPU tensors
logits = torch.randn(1, 1000, device='cpu')
prompt_mask = torch.zeros(1, 1000, dtype=torch.bool, device='cpu')
output_mask = torch.zeros(1, 1000, dtype=torch.bool, device='cpu')
penalties = torch.tensor([1.1], device='cpu')
# This crashes even though all tensors are on CPU!
apply_repetition_penalties(logits, prompt_mask, output_mask, penalties)Error:
NotImplementedError: Could not run '_C::apply_repetition_penalties_' with arguments from the 'CPU' backend.
'_C::apply_repetition_penalties_' is only available for these backends: [CUDA, Meta, ...]
✅ Correct Fix
def apply_repetition_penalties(logits: torch.Tensor, prompt_mask: torch.Tensor,
output_mask: torch.Tensor,
repetition_penalties: torch.Tensor) -> None:
"""Apply repetition penalties to logits in-place."""
if logits.is_cuda() and logits.is_contiguous(): # ✅ FIXED
apply_repetition_penalties_cuda(logits, prompt_mask, output_mask,
repetition_penalties)
else:
apply_repetition_penalties_torch(logits, prompt_mask, output_mask,
repetition_penalties)🎯 Impact
- Blocks CPU offloading: Can't use CPU for sampling operations
- Prevents hybrid architectures: GPU for inference + CPU for sampling
- Memory optimization: Can't reduce GPU memory pressure by offloading sampling
- Multi-model serving: Limits deployment flexibility
🔧 Workaround
Force tensors to be non-contiguous to trigger PyTorch fallback:
if logits.device.type == "cpu" and torch.cuda.is_available():
logits = logits.transpose(0, 1).transpose(0, 1) # Make non-contiguous📊 Test Case
def test_cpu_penalties_on_cuda_system():
"""Test that CPU penalties work on CUDA-capable systems"""
import torch
from vllm._custom_ops import apply_repetition_penalties
# Create CPU tensors
logits = torch.randn(1, 100, device='cpu')
prompt_mask = torch.zeros(1, 100, dtype=torch.bool, device='cpu')
output_mask = torch.zeros(1, 100, dtype=torch.bool, device='cpu')
penalties = torch.tensor([1.1], device='cpu')
# This should work but currently crashes
apply_repetition_penalties(logits, prompt_mask, output_mask, penalties)🔗 Related Files
vllm/_custom_ops.py(main bug)vllm/v1/sample/ops/penalties.py(calls the buggy function)vllm/model_executor/layers/utils.py(also calls the buggy function)
Before submitting a new issue...
- Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the documentation page, which can answer lots of frequently asked questions.
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working