diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index c67759c1d09a..00328f56b713 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -35,7 +35,7 @@ from vllm.test_utils import MODEL_WEIGHTS_S3_BUCKET, MODELS_ON_S3 from vllm.transformers_utils.utils import check_gguf_file from vllm.usage.usage_lib import UsageContext -from vllm.utils import FlexibleArgumentParser, is_in_ray_actor +from vllm.utils import FlexibleArgumentParser, GiB_bytes, is_in_ray_actor # yapf: enable @@ -1625,13 +1625,13 @@ def _set_default_args_v1(self, usage_context: UsageContext) -> None: # values for non-H100/H200 GPUs. try: from vllm.platforms import current_platform - device_name = current_platform.get_device_name().lower() + device_memory = current_platform.get_device_total_memory() except Exception: # This is only used to set default_max_num_batched_tokens - device_name = "no-device" + device_memory = 0 - if "h100" in device_name or "h200" in device_name: - # For H100 and H200, we use larger default values. + if device_memory >= 70 * GiB_bytes: + # For GPUs like H100 and MI300x, use larger default values. default_max_num_batched_tokens = { UsageContext.LLM_CLASS: 16384, UsageContext.OPENAI_API_SERVER: 8192,