diff --git a/vllm/model_executor/layers/utils.py b/vllm/model_executor/layers/utils.py index 925f9ac0a16e..3d90c9513683 100644 --- a/vllm/model_executor/layers/utils.py +++ b/vllm/model_executor/layers/utils.py @@ -103,12 +103,41 @@ def default_unquantized_gemm( return torch.nn.functional.linear(x, weight, bias) +def use_aiter_triton_gemm(n, m, k, dtype): + if ( + envs.VLLM_ROCM_USE_AITER == 0 + # MI300's - fp8nuz=True + or current_platform.is_fp8_fnuz() + or dtype not in [torch.float16, torch.bfloat16] + ): + return False + + # use hipblaslt for the larger GEMMs + if n > 2048 and m > 512: + return False + return ( + (m == 5120 and k == 2880) + or (m == 2880 and k == 4096) + or (m == 128 and k == 2880) + or (m == 640 and k == 2880) + or (m == 2880 and k == 512) + ) + + def rocm_unquantized_gemm_impl( x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor | None = None ) -> torch.Tensor: from vllm.platforms.rocm import on_gfx9 + n = x.numel() / x.size(-1) + m = weight.shape[0] k = weight.shape[1] + + if use_aiter_triton_gemm(n, m, k, x.dtype): + from aiter.ops.triton.gemm_a16w16 import gemm_a16w16 + + return gemm_a16w16(x, weight, bias) + use_skinny = ( envs.VLLM_ROCM_USE_SKINNY_GEMM and on_gfx9() @@ -120,11 +149,8 @@ def rocm_unquantized_gemm_impl( return torch.nn.functional.linear(x, weight, bias) x_view = x.reshape(-1, x.size(-1)) - n = x_view.shape[0] - m = weight.shape[0] - cu_count = current_platform.get_cu_count() - if m > 8 and 0 < n <= 4: + cu_count = current_platform.get_cu_count() out = ops.wvSplitK(weight, x_view, cu_count, bias) return out.reshape(*x.shape[:-1], weight.shape[0]) elif m % 4 == 0 and n == 1 and k <= 8192 and bias is None: @@ -133,7 +159,7 @@ def rocm_unquantized_gemm_impl( return torch.nn.functional.linear(x, weight, bias) -def rocm_unquantized_gemm_impl_fake( +def rocm_unquantized_gemm_fake( x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor | None = None ) -> torch.Tensor: return x.new_empty((*x.shape[:-1], weight.shape[0])) @@ -145,13 +171,13 @@ def rocm_unquantized_gemm( weight: torch.Tensor, bias: torch.Tensor | None = None, ) -> torch.Tensor: - return torch.ops.vllm.rocm_unquantized_gemm_impl(x, weight, bias) + return torch.ops.vllm.rocm_unquantized_gemm(x, weight, bias) direct_register_custom_op( - op_name="rocm_unquantized_gemm_impl", + op_name="rocm_unquantized_gemm", op_func=rocm_unquantized_gemm_impl, - fake_impl=rocm_unquantized_gemm_impl_fake, + fake_impl=rocm_unquantized_gemm_fake, ) diff --git a/vllm/model_executor/models/gpt_oss.py b/vllm/model_executor/models/gpt_oss.py index 44f6824b5212..863e5654094c 100644 --- a/vllm/model_executor/models/gpt_oss.py +++ b/vllm/model_executor/models/gpt_oss.py @@ -25,12 +25,14 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.utils import rocm_unquantized_gemm from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, ) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.utils import sequence_parallel_chunk +from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors from vllm.utils.math_utils import cdiv @@ -153,6 +155,7 @@ def __init__( self.layer_idx = layer_idx self.num_experts = config.num_local_experts + self.hidden_size = config.hidden_size self.experts_per_token = config.num_experts_per_tok self.world_size = dist.get_world_size() if dist.is_initialized() else 1 self.router = torch.nn.Linear(config.hidden_size, config.num_local_experts) @@ -177,7 +180,12 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: if self.is_sequence_parallel: x = sequence_parallel_chunk(x) - g = self.router(x) + if current_platform.is_rocm(): + g = rocm_unquantized_gemm( + self, x[:, : self.hidden_size], self.router.weight, self.router.bias + ) + else: + g = self.router(x) x = self.experts(hidden_states=x, router_logits=g) if self.is_sequence_parallel: