File tree Expand file tree Collapse file tree 2 files changed +4
-4
lines changed Expand file tree Collapse file tree 2 files changed +4
-4
lines changed Original file line number Diff line number Diff line change @@ -81,8 +81,8 @@ def forward_decode(
8181 blocksparse_head_sliding_step = blocksparse_head_sliding_step )
8282
8383 if "fp8" in kv_cache_dtype :
84- key_cache = key_cache .view (torch . float8_e4m3fnuz )
85- value_cache = value_cache .view (torch . float8_e4m3fnuz )
84+ key_cache = key_cache .view (current_platform . fp8_dtype () )
85+ value_cache = value_cache .view (current_platform . fp8_dtype () )
8686
8787 if blocksparse_vert_stride is not None and blocksparse_vert_stride > 1 :
8888 # use blocksparse paged attention
Original file line number Diff line number Diff line change @@ -479,8 +479,8 @@ def forward(
479479 )
480480
481481 if self .kv_cache_dtype .startswith ("fp8" ):
482- key_cache = key_cache .view (torch . float8_e4m3fnuz )
483- value_cache = value_cache .view (torch . float8_e4m3fnuz )
482+ key_cache = key_cache .view (current_platform . fp8_dtype () )
483+ value_cache = value_cache .view (current_platform . fp8_dtype () )
484484
485485 if not attn_metadata .use_cascade :
486486 cu_seqlens_q = attn_metadata .query_start_loc
You can’t perform that action at this time.
0 commit comments