Skip to content

Commit 23afac1

Browse files
dllehr-amdDoug Lehr
authored andcommitted
Aiter mha fp8 fix (vllm-project#24991)
Signed-off-by: Doug Lehr <douglehr@amd.com> Co-authored-by: Doug Lehr <douglehr@amd.com>
1 parent 0c141a6 commit 23afac1

File tree

2 files changed

+4
-4
lines changed

2 files changed

+4
-4
lines changed

vllm/attention/ops/rocm_aiter_paged_attn.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,8 +81,8 @@ def forward_decode(
8181
blocksparse_head_sliding_step=blocksparse_head_sliding_step)
8282

8383
if "fp8" in kv_cache_dtype:
84-
key_cache = key_cache.view(torch.float8_e4m3fnuz)
85-
value_cache = value_cache.view(torch.float8_e4m3fnuz)
84+
key_cache = key_cache.view(current_platform.fp8_dtype())
85+
value_cache = value_cache.view(current_platform.fp8_dtype())
8686

8787
if blocksparse_vert_stride is not None and blocksparse_vert_stride > 1:
8888
# use blocksparse paged attention

vllm/v1/attention/backends/rocm_aiter_fa.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -479,8 +479,8 @@ def forward(
479479
)
480480

481481
if self.kv_cache_dtype.startswith("fp8"):
482-
key_cache = key_cache.view(torch.float8_e4m3fnuz)
483-
value_cache = value_cache.view(torch.float8_e4m3fnuz)
482+
key_cache = key_cache.view(current_platform.fp8_dtype())
483+
value_cache = value_cache.view(current_platform.fp8_dtype())
484484

485485
if not attn_metadata.use_cascade:
486486
cu_seqlens_q = attn_metadata.query_start_loc

0 commit comments

Comments
 (0)