Skip to content

Commit d662144

Browse files
author
Doug Lehr
committed
Add check for float8 type in aiter mha
Change f8 kv-cache check in rocm_aiter_fa.py to account for both float8_e4m3fnuz and float8_e4m3fn datatypes. Signed-off-by: Doug Lehr <douglehr@amd.com>
1 parent f4d6eb9 commit d662144

File tree

2 files changed

+4
-4
lines changed

2 files changed

+4
-4
lines changed

vllm/attention/ops/rocm_aiter_paged_attn.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,8 +81,8 @@ def forward_decode(
8181
blocksparse_head_sliding_step=blocksparse_head_sliding_step)
8282

8383
if "fp8" in kv_cache_dtype:
84-
key_cache = key_cache.view(torch.float8_e4m3fnuz)
85-
value_cache = value_cache.view(torch.float8_e4m3fnuz)
84+
key_cache = key_cache.view(current_platform.fp8_dtype())
85+
value_cache = value_cache.view(current_platform.fp8_dtype())
8686

8787
if blocksparse_vert_stride is not None and blocksparse_vert_stride > 1:
8888
# use blocksparse paged attention

vllm/v1/attention/backends/rocm_aiter_fa.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -479,8 +479,8 @@ def forward(
479479
)
480480

481481
if self.kv_cache_dtype.startswith("fp8"):
482-
key_cache = key_cache.view(torch.float8_e4m3fnuz)
483-
value_cache = value_cache.view(torch.float8_e4m3fnuz)
482+
key_cache = key_cache.view(current_platform.fp8_dtype())
483+
value_cache = value_cache.view(current_platform.fp8_dtype())
484484

485485
if not attn_metadata.use_cascade:
486486
cu_seqlens_q = attn_metadata.query_start_loc

0 commit comments

Comments
 (0)