From b9f9f81d279385b5ae7962e6148eb73acd353aee Mon Sep 17 00:00:00 2001 From: Randall Smith Date: Mon, 28 Apr 2025 17:40:00 -0500 Subject: [PATCH 01/11] Add VLLM_ROCM_USE_FP8_SCALES flag Signed-off-by: Randall Smith --- vllm/attention/backends/rocm_flash_attn.py | 2 +- vllm/envs.py | 5 +++++ 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index 8076c4791d3c..21d7de82b590 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -768,7 +768,7 @@ def forward( make_attn_mask=causal_mask) # type: ignore use_fp8_scales = (layer._q_scale and layer._k_scale and layer._v_scale and layer._prob_scale - and self.kv_cache_dtype == "fp8") + and envs.VLLM_ROCM_USE_FP8_SCALES) full_scales = ( layer._q_scale, layer._k_scale, layer._v_scale, layer._prob_scale) if use_fp8_scales else None diff --git a/vllm/envs.py b/vllm/envs.py index ea40bfff11b5..5214f1e50996 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -110,6 +110,7 @@ VLLM_USE_DEEP_GEMM: bool = False VLLM_XGRAMMAR_CACHE_MB: int = 0 VLLM_MSGPACK_ZERO_COPY_THRESHOLD: int = 256 + VLLM_ROCM_USE_FP8_SCALES: int = True def get_default_cache_root(): @@ -727,6 +728,10 @@ def maybe_convert_int(value: Optional[str]) -> Optional[int]: # limit will actually be zero-copy decoded. "VLLM_MSGPACK_ZERO_COPY_THRESHOLD": lambda: int(os.getenv("VLLM_MSGPACK_ZERO_COPY_THRESHOLD", "256")), + + # Use fp8 scales for ROCm FA + "VLLM_ROCM_USE_FP8_SCALES": + lambda: bool(int(os.getenv("VLLM_ROCM_USE_FP8_SCALES", "1"))), } # end-env-vars-definition From 9048aa55511ae17eef50f575801ebc4e63a1cde5 Mon Sep 17 00:00:00 2001 From: Randall Smith Date: Tue, 29 Apr 2025 22:40:18 +0000 Subject: [PATCH 02/11] lint Signed-off-by: Randall Smith --- vllm/attention/backends/rocm_flash_attn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index 21d7de82b590..bd0a556af00f 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -768,7 +768,7 @@ def forward( make_attn_mask=causal_mask) # type: ignore use_fp8_scales = (layer._q_scale and layer._k_scale and layer._v_scale and layer._prob_scale - and envs.VLLM_ROCM_USE_FP8_SCALES) + and envs.VLLM_ROCM_USE_FP8_SCALES) full_scales = ( layer._q_scale, layer._k_scale, layer._v_scale, layer._prob_scale) if use_fp8_scales else None From 2f31d6b19d0aedde66544f510186a16e9d6f0853 Mon Sep 17 00:00:00 2001 From: Randall Smith Date: Thu, 8 May 2025 22:59:44 +0000 Subject: [PATCH 03/11] Use vllm config instead of env variable for fp8 scales option Signed-off-by: Randall Smith --- vllm/attention/backends/rocm_flash_attn.py | 9 ++++++++- vllm/config.py | 7 ++++++- vllm/engine/arg_utils.py | 4 ++++ vllm/envs.py | 5 ----- 4 files changed, 18 insertions(+), 7 deletions(-) diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index bd0a556af00f..fd01a79a2f67 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -16,6 +16,7 @@ CommonMetadataBuilder) from vllm.attention.ops.paged_attn import (PagedAttention, PagedAttentionMetadata) +from vllm.config import get_current_vllm_config from vllm.logger import init_logger from vllm.platforms import current_platform from vllm.platforms.rocm import use_rocm_custom_paged_attention @@ -766,9 +767,15 @@ def forward( query.dtype, seq_lens, make_attn_mask=causal_mask) # type: ignore + + vllm_config = get_current_vllm_config() + vllm_config_use_fp8_scales = ( + True if vllm_config is None or vllm_config.model_config + is None else vllm_config.model_config.use_fp8_scales) use_fp8_scales = (layer._q_scale and layer._k_scale and layer._v_scale and layer._prob_scale - and envs.VLLM_ROCM_USE_FP8_SCALES) + and vllm_config_use_fp8_scales) + full_scales = ( layer._q_scale, layer._k_scale, layer._v_scale, layer._prob_scale) if use_fp8_scales else None diff --git a/vllm/config.py b/vllm/config.py index 0bbf588fb3e8..c81c654ae26e 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -397,6 +397,8 @@ class ModelConfig: available.\n - "vllm" will use the vLLM model implementation.\n - "transformers" will use the Transformers model implementation.""" + use_fp8_scales: bool = True + """If true, pass the fp8 scales to the ROCm Triton attention backend""" def compute_hash(self) -> str: """ @@ -4339,10 +4341,12 @@ def set_current_vllm_config(vllm_config: VllmConfig, check_compile=False): old_vllm_config = _current_vllm_config from vllm.compilation.counter import compilation_counter num_models_seen = compilation_counter.num_models_seen + was_raised = False try: _current_vllm_config = vllm_config yield except Exception: + was_raised = True raise else: logger.debug("enabled custom ops: %s", @@ -4363,7 +4367,8 @@ def set_current_vllm_config(vllm_config: VllmConfig, check_compile=False): " if you want it to be supported.", vllm_config.model_config.model) finally: - _current_vllm_config = old_vllm_config + if was_raised: + _current_vllm_config = old_vllm_config def get_current_vllm_config() -> VllmConfig: diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 27af74e2e349..49800f271b8e 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -378,6 +378,7 @@ class EngineArgs: override_generation_config: dict[str, Any] = \ get_field(ModelConfig, "override_generation_config") model_impl: str = ModelConfig.model_impl + use_fp8_scales: bool = ModelConfig.use_fp8_scales calculate_kv_scales: bool = CacheConfig.calculate_kv_scales @@ -493,6 +494,8 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: model_group.add_argument("--model-impl", choices=[f.value for f in ModelImpl], **model_kwargs["model_impl"]) + model_group.add_argument("--use-fp8-scales", + **model_kwargs["use_fp8_scales"]) # Model loading arguments load_kwargs = get_kwargs(LoadConfig) @@ -904,6 +907,7 @@ def create_model_config(self) -> ModelConfig: override_generation_config=self.override_generation_config, enable_sleep_mode=self.enable_sleep_mode, model_impl=self.model_impl, + use_fp8_scales=self.use_fp8_scales, ) def create_load_config(self) -> LoadConfig: diff --git a/vllm/envs.py b/vllm/envs.py index a890278fce29..c8bb39ceb7b2 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -111,7 +111,6 @@ VLLM_USE_DEEP_GEMM: bool = False VLLM_XGRAMMAR_CACHE_MB: int = 0 VLLM_MSGPACK_ZERO_COPY_THRESHOLD: int = 256 - VLLM_ROCM_USE_FP8_SCALES: int = True def get_default_cache_root(): @@ -737,10 +736,6 @@ def maybe_convert_int(value: Optional[str]) -> Optional[int]: # limit will actually be zero-copy decoded. "VLLM_MSGPACK_ZERO_COPY_THRESHOLD": lambda: int(os.getenv("VLLM_MSGPACK_ZERO_COPY_THRESHOLD", "256")), - - # Use fp8 scales for ROCm FA - "VLLM_ROCM_USE_FP8_SCALES": - lambda: bool(int(os.getenv("VLLM_ROCM_USE_FP8_SCALES", "1"))), } # end-env-vars-definition From fdc428ba63a208c9e50bb1bdda95a108fa0159fc Mon Sep 17 00:00:00 2001 From: Randall Smith Date: Wed, 21 May 2025 21:58:57 +0000 Subject: [PATCH 04/11] use override instead Signed-off-by: Randall Smith --- vllm/attention/backends/rocm_flash_attn.py | 6 +++--- vllm/config.py | 4 ++-- vllm/engine/arg_utils.py | 8 ++++---- 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index fd01a79a2f67..2aac1c88f55a 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -581,6 +581,7 @@ def __init__( logger.debug("Using naive (SDPA) attention in ROCmBackend") self.aiter_kv_scales_initialized = False + self.vllm_config = get_current_vllm_config() def repeat_kv(self, x: torch.Tensor, n_rep: int) -> torch.Tensor: """torch.repeat_interleave(x, dim=1, repeats=n_rep)""" @@ -768,10 +769,9 @@ def forward( seq_lens, make_attn_mask=causal_mask) # type: ignore - vllm_config = get_current_vllm_config() vllm_config_use_fp8_scales = ( - True if vllm_config is None or vllm_config.model_config - is None else vllm_config.model_config.use_fp8_scales) + True if self.vllm_config is None or self.vllm_config.model_config + is None else self.vllm_config.model_config.override_attention_dtype == "fp8") use_fp8_scales = (layer._q_scale and layer._k_scale and layer._v_scale and layer._prob_scale and vllm_config_use_fp8_scales) diff --git a/vllm/config.py b/vllm/config.py index a8b6f4577003..f22127d6f999 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -407,8 +407,8 @@ class ModelConfig: available.\n - "vllm" will use the vLLM model implementation.\n - "transformers" will use the Transformers model implementation.""" - use_fp8_scales: bool = True - """If true, pass the fp8 scales to the ROCm Triton attention backend""" + override_attention_dtype: str = "fp8" + """Override dtype for attention""" def compute_hash(self) -> str: """ diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index a4877ef828c4..4d3e27b33161 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -410,7 +410,7 @@ class EngineArgs: override_generation_config: dict[str, Any] = \ get_field(ModelConfig, "override_generation_config") model_impl: str = ModelConfig.model_impl - use_fp8_scales: bool = ModelConfig.use_fp8_scales + override_attention_dtype: str = ModelConfig.override_attention_dtype calculate_kv_scales: bool = CacheConfig.calculate_kv_scales @@ -527,8 +527,8 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: model_group.add_argument("--model-impl", choices=[f.value for f in ModelImpl], **model_kwargs["model_impl"]) - model_group.add_argument("--use-fp8-scales", - **model_kwargs["use_fp8_scales"]) + model_group.add_argument("--override-attention-dtype", + **model_kwargs["override_attention_dtype"]) # Model loading arguments load_kwargs = get_kwargs(LoadConfig) @@ -912,7 +912,7 @@ def create_model_config(self) -> ModelConfig: override_generation_config=self.override_generation_config, enable_sleep_mode=self.enable_sleep_mode, model_impl=self.model_impl, - use_fp8_scales=self.use_fp8_scales, + override_attention_dtype=self.override_attention_dtype, ) def create_load_config(self) -> LoadConfig: From 44b18cec08c052453720bed85ea0aa28a1f1487f Mon Sep 17 00:00:00 2001 From: Randall Smith Date: Wed, 21 May 2025 22:18:59 +0000 Subject: [PATCH 05/11] format Signed-off-by: Randall Smith --- vllm/attention/backends/rocm_flash_attn.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index 2aac1c88f55a..430fe3eaa6ef 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -770,8 +770,10 @@ def forward( make_attn_mask=causal_mask) # type: ignore vllm_config_use_fp8_scales = ( - True if self.vllm_config is None or self.vllm_config.model_config - is None else self.vllm_config.model_config.override_attention_dtype == "fp8") + True if self.vllm_config is None + or self.vllm_config.model_config is None else + self.vllm_config.model_config.override_attention_dtype + == "fp8") use_fp8_scales = (layer._q_scale and layer._k_scale and layer._v_scale and layer._prob_scale and vllm_config_use_fp8_scales) From 1bc79b7efd97088430ed729646499e0ac30f058c Mon Sep 17 00:00:00 2001 From: Randall Smith Date: Thu, 22 May 2025 19:25:30 +0000 Subject: [PATCH 06/11] remove was_raised from set_current_vllm_config Signed-off-by: Randall Smith --- vllm/config.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index f22127d6f999..971e82030f22 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -4507,7 +4507,6 @@ def set_current_vllm_config(vllm_config: VllmConfig, check_compile=False): _current_vllm_config = vllm_config yield except Exception: - was_raised = True raise else: logger.debug("enabled custom ops: %s", @@ -4528,8 +4527,7 @@ def set_current_vllm_config(vllm_config: VllmConfig, check_compile=False): " if you want it to be supported.", vllm_config.model_config.model) finally: - if was_raised: - _current_vllm_config = old_vllm_config + _current_vllm_config = old_vllm_config def get_current_vllm_config() -> VllmConfig: From 5cec76f3cdee8cba50c12f935e38b7c188eae1ee Mon Sep 17 00:00:00 2001 From: Randall Smith Date: Thu, 22 May 2025 19:28:02 +0000 Subject: [PATCH 07/11] remove was_raised Signed-off-by: Randall Smith --- vllm/config.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/config.py b/vllm/config.py index 971e82030f22..985df2d3553b 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -4502,7 +4502,6 @@ def set_current_vllm_config(vllm_config: VllmConfig, check_compile=False): old_vllm_config = _current_vllm_config from vllm.compilation.counter import compilation_counter num_models_seen = compilation_counter.num_models_seen - was_raised = False try: _current_vllm_config = vllm_config yield From 2c5ffb08683daa7c24e118f6538570809de956b1 Mon Sep 17 00:00:00 2001 From: Randall Smith Date: Thu, 22 May 2025 21:52:04 +0000 Subject: [PATCH 08/11] simplify and add warning Signed-off-by: Randall Smith --- vllm/attention/backends/rocm_flash_attn.py | 12 +++++------- vllm/config.py | 7 ++++++- 2 files changed, 11 insertions(+), 8 deletions(-) diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index 430fe3eaa6ef..bdcdc2793a5a 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -581,7 +581,10 @@ def __init__( logger.debug("Using naive (SDPA) attention in ROCmBackend") self.aiter_kv_scales_initialized = False - self.vllm_config = get_current_vllm_config() + self.force_fp8_attention = ( + get_current_vllm_config is not None + and get_current_vllm_config().model_config.override_attention_dtype + == "fp8") def repeat_kv(self, x: torch.Tensor, n_rep: int) -> torch.Tensor: """torch.repeat_interleave(x, dim=1, repeats=n_rep)""" @@ -769,14 +772,9 @@ def forward( seq_lens, make_attn_mask=causal_mask) # type: ignore - vllm_config_use_fp8_scales = ( - True if self.vllm_config is None - or self.vllm_config.model_config is None else - self.vllm_config.model_config.override_attention_dtype - == "fp8") use_fp8_scales = (layer._q_scale and layer._k_scale and layer._v_scale and layer._prob_scale - and vllm_config_use_fp8_scales) + and self.force_fp8_attention) full_scales = ( layer._q_scale, layer._k_scale, layer._v_scale, diff --git a/vllm/config.py b/vllm/config.py index 985df2d3553b..8173834567a8 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -407,7 +407,7 @@ class ModelConfig: available.\n - "vllm" will use the vLLM model implementation.\n - "transformers" will use the Transformers model implementation.""" - override_attention_dtype: str = "fp8" + override_attention_dtype: Optional[str] = None """Override dtype for attention""" def compute_hash(self) -> str: @@ -509,6 +509,11 @@ def __post_init__(self) -> None: from vllm.platforms import current_platform + if (self.override_attention_dtype is not None + and not current_platform.is_rocm()): + warnings.warn( + "override-attention-dtype is set but not using ROCm platform") + if (self.enable_sleep_mode and not current_platform.is_sleep_mode_available()): raise ValueError( From e7400c160e8f915258355b9d370f3c358a9fd59d Mon Sep 17 00:00:00 2001 From: Randall Smith Date: Thu, 22 May 2025 22:04:08 +0000 Subject: [PATCH 09/11] set stacklevel for warning Signed-off-by: Randall Smith --- vllm/config.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/config.py b/vllm/config.py index 8173834567a8..823bc0f414f9 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -512,7 +512,8 @@ def __post_init__(self) -> None: if (self.override_attention_dtype is not None and not current_platform.is_rocm()): warnings.warn( - "override-attention-dtype is set but not using ROCm platform") + "override-attention-dtype is set but not using ROCm platform", + stacklevel=2) if (self.enable_sleep_mode and not current_platform.is_sleep_mode_available()): From e135f78cbbce3fce0e69dfe9e18399d68c13e970 Mon Sep 17 00:00:00 2001 From: Randall Smith Date: Thu, 22 May 2025 22:09:37 +0000 Subject: [PATCH 10/11] fix typo Signed-off-by: Randall Smith --- vllm/attention/backends/rocm_flash_attn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index bdcdc2793a5a..dcd66413b526 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -582,7 +582,7 @@ def __init__( self.aiter_kv_scales_initialized = False self.force_fp8_attention = ( - get_current_vllm_config is not None + get_current_vllm_config() is not None and get_current_vllm_config().model_config.override_attention_dtype == "fp8") From 7ad4a103b1fd7fe3d215d2c41595d10790ceac0c Mon Sep 17 00:00:00 2001 From: Randall Smith Date: Tue, 3 Jun 2025 16:28:38 +0000 Subject: [PATCH 11/11] check if kv cache is fp8 Signed-off-by: Randall Smith --- vllm/attention/backends/rocm_flash_attn.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index dcd66413b526..f6444ddbc7e3 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -774,7 +774,8 @@ def forward( use_fp8_scales = (layer._q_scale and layer._k_scale and layer._v_scale and layer._prob_scale - and self.force_fp8_attention) + and (self.kv_cache_dtype == "fp8" + or self.force_fp8_attention)) full_scales = ( layer._q_scale, layer._k_scale, layer._v_scale,