diff --git a/vllm/config.py b/vllm/config.py index 92e887e08639..1aadf2c25b43 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1619,13 +1619,12 @@ def _verify_args(self) -> None: if self.use_ray: from vllm.executor import ray_utils ray_utils.assert_ray_available() - device_capability = current_platform.get_device_capability() - if (current_platform.is_rocm() and device_capability is not None - and device_capability < (9, 4)): + + if not current_platform.use_custom_allreduce(): self.disable_custom_all_reduce = True logger.info( "Disabled the custom all-reduce kernel because it is not " - "supported on AMD GPUs older than MI300X.") + "supported on current platform.") if self.ray_workers_use_nsight and not self.use_ray: raise ValueError("Unable to use nsight profiling unless workers " "run with Ray.") diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index 28505fca10df..0576022be448 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -308,6 +308,10 @@ def supports_fp8(cls) -> bool: def supports_v1(cls, model_config: ModelConfig) -> bool: return True + @classmethod + def use_custom_allreduce(cls) -> bool: + return True + # NVML utils # Note that NVML is not affected by `CUDA_VISIBLE_DEVICES`, diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index 36db70681a19..b6f6029de9c8 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -379,6 +379,13 @@ def supports_v1(cls, model_config: ModelConfig) -> bool: """ return False + @classmethod + def use_custom_allreduce(cls) -> bool: + """ + Returns if custom allreduce is supported on the current platform + """ + return False + class UnspecifiedPlatform(Platform): _enum = PlatformEnum.UNSPECIFIED diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index 0bedd80e5ecf..d18b7c26f7ec 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -302,3 +302,10 @@ def fp8_dtype(cls) -> torch.dtype: def supports_v1(cls, model_config: ModelConfig) -> bool: # V1 support on AMD gpus is experimental return True + + @classmethod + def use_custom_allreduce(cls) -> bool: + # We only enable custom allreduce for MI300 series + gcn_arch = torch.cuda.get_device_properties(0).gcnArchName + supported_archs = ['gfx94'] + return any(gfx in gcn_arch for gfx in supported_archs)