Skip to content

Commit e23aa25

Browse files
gshtraslulmer
authored andcommitted
[ROCm][FP8] Fix for adjustments needed only for fnuz (vllm-project#14689)
Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com> Signed-off-by: Louis Ulmer <ulmerlouis@gmail.com>
1 parent 8fcd9b8 commit e23aa25

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

vllm/model_executor/layers/quantization/kv_cache.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
5050
# We prefer to use separate k_scale and v_scale if present
5151
k_scale = layer.k_scale.to("cpu").tolist()
5252
v_scale = layer.v_scale.to("cpu").tolist()
53-
if current_platform.is_rocm():
53+
if current_platform.is_fp8_fnuz():
5454
k_scale *= 2
5555
v_scale *= 2
5656
elif layer.k_scale < 0.0 and layer.v_scale < 0.0:
@@ -66,7 +66,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
6666
scale_to_duplicate = max(layer.k_scale, layer.v_scale)
6767
k_scale = scale_to_duplicate.to("cpu").tolist()
6868
v_scale = scale_to_duplicate.to("cpu").tolist()
69-
if current_platform.is_rocm():
69+
if current_platform.is_fp8_fnuz():
7070
k_scale *= 2
7171
v_scale *= 2
7272

0 commit comments

Comments
 (0)