diff --git a/vllm/v1/sample/ops/topk_topp_sampler.py b/vllm/v1/sample/ops/topk_topp_sampler.py index dbcdad07e4de..66f9ab22a789 100644 --- a/vllm/v1/sample/ops/topk_topp_sampler.py +++ b/vllm/v1/sample/ops/topk_topp_sampler.py @@ -75,6 +75,18 @@ def __init__(self, logprobs_mode: LogprobsMode = "raw_logprobs") -> None: self.forward = self.forward_native elif current_platform.is_cpu(): self.forward = self.forward_cpu + elif ( + logprobs_mode not in ("processed_logits", "processed_logprobs") + and current_platform.is_rocm() + and envs.VLLM_ROCM_USE_AITER + ): + import aiter.ops.sampling # noqa: F401 + + self.aiter_ops = torch.ops.aiter + logger.info_once( + "Using aiter sampler on ROCm (lazy import, sampling-only)." + ) + self.forward = self.forward_hip else: self.forward = self.forward_native @@ -120,9 +132,10 @@ def forward_cuda( "PyTorch-native implementation." ) return self.forward_native(logits, generators, k, p) - assert self.logprobs_mode not in ("processed_logits", "processed_logprobs"), ( - "FlashInfer does not support returning logits/logprobs" - ) + assert self.logprobs_mode not in ( + "processed_logits", + "processed_logprobs", + ), "FlashInfer does not support returning logits/logprobs" # flashinfer sampling functions expect contiguous logits. # In flex_attn/triton_attn fp32 inference, logits can be non-contiguous # because of slicing operation in logits_processor. @@ -167,6 +180,64 @@ def compiled_random_sample(logits: torch.Tensor) -> torch.Tensor: return probs.div_(q).argmax(dim=-1).view(-1), logits_to_return + def forward_hip( + self, + logits: torch.Tensor, + generators: dict[int, torch.Generator], + k: Optional[torch.Tensor], + p: Optional[torch.Tensor], + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + """Optimized ROCm/aiter path (same structure as forward_cuda).""" + if (k is None and p is None) or generators: + if generators: + logger.warning_once( + "aiter sampler does not support per-request generators; " + "falling back to PyTorch-native." + ) + return self.forward_native(logits, generators, k, p) + assert self.logprobs_mode not in ( + "processed_logits", + "processed_logprobs", + ), "aiter sampler does not support returning logits/logprobs." + return self.aiter_sample(logits, k, p, generators), None + + def aiter_sample( + self, + logits: torch.Tensor, + k: Optional[torch.Tensor], + p: Optional[torch.Tensor], + generators: dict[int, torch.Generator], + ) -> torch.Tensor: + """Sample from logits using aiter ops.""" + use_top_k = k is not None + use_top_p = p is not None + # Joint k+p path + if use_top_p and use_top_k: + probs = logits.softmax(dim=-1, dtype=torch.float32).contiguous() + next_token_ids = self.aiter_ops.top_k_top_p_sampling_from_probs( + probs, + None, + *_to_tensor_scalar_tuple(k), + *_to_tensor_scalar_tuple(p), + deterministic=True, + ) + return next_token_ids.view(-1) + # Top-p only path + elif use_top_p: + probs = logits.softmax(dim=-1, dtype=torch.float32).contiguous() + next_token_ids = self.aiter_ops.top_p_sampling_from_probs( + probs, None, *_to_tensor_scalar_tuple(p), deterministic=True + ) + return next_token_ids.view(-1) + # Top-k only path + elif use_top_k: + probs = logits.softmax(dim=-1, dtype=torch.float32).contiguous() + renorm_probs = self.aiter_ops.top_k_renorm_probs( + probs, *_to_tensor_scalar_tuple(k) + ) + return torch.multinomial(renorm_probs, num_samples=1).view(-1) + raise RuntimeError("aiter_sample was called with no active top-k or top-p.") + def apply_top_k_top_p( logits: torch.Tensor, @@ -300,3 +371,10 @@ def flashinfer_sample( ) return next_token_ids.view(-1) + + +def _to_tensor_scalar_tuple(x): + if isinstance(x, torch.Tensor): + return (x, 0) + else: + return (None, x)