Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Bugfix] Change kv scaling factor by param json on nvidia gpu (vllm-p…
Browse files Browse the repository at this point in the history
…roject#11688)

Signed-off-by: bjmsong <bjmsong@126.com>
Co-authored-by: bjmsong <bjmsong@126.com>
2 people authored and Ubuntu committed Jan 19, 2025
1 parent bad4b52 commit 0e4640b
Showing 5 changed files with 14 additions and 9 deletions.
5 changes: 3 additions & 2 deletions vllm/model_executor/models/exaone.py
Original file line number Diff line number Diff line change
@@ -606,8 +606,9 @@ def load_kv_cache_scales(self, quantization_param_path: str) -> None:
# which is consistent with the practice of setting
# scaling_factor = tensor_amax / FPtype_max
scaling_factor *= 2
if hasattr(layer_self_attn, "kv_scale"):
layer_self_attn.attn._kv_scale = scaling_factor
if hasattr(layer_self_attn.attn, "_k_scale"):
layer_self_attn.attn._k_scale = scaling_factor
layer_self_attn.attn._v_scale = scaling_factor
else:
raise RuntimeError("Self attention has no KV cache scaling "
"factor attribute!")
5 changes: 3 additions & 2 deletions vllm/model_executor/models/granite.py
Original file line number Diff line number Diff line change
@@ -545,8 +545,9 @@ def load_kv_cache_scales(self, quantization_param_path: str) -> None:
# which is consistent with the practice of setting
# scaling_factor = tensor_amax / FPtype_max
scaling_factor *= 2
if hasattr(layer_self_attn, "kv_scale"):
layer_self_attn.attn._kv_scale = scaling_factor
if hasattr(layer_self_attn.attn, "_k_scale"):
layer_self_attn.attn._k_scale = scaling_factor
layer_self_attn.attn._v_scale = scaling_factor
else:
raise RuntimeError("Self attention has no KV cache scaling "
"factor attribute!")
5 changes: 3 additions & 2 deletions vllm/model_executor/models/llama.py
Original file line number Diff line number Diff line change
@@ -452,8 +452,9 @@ def load_kv_cache_scales(self, quantization_param_path: str) -> None:
# which is consistent with the practice of setting
# scaling_factor = tensor_amax / FPtype_max
scaling_factor *= 2
if hasattr(layer_self_attn, "kv_scale"):
layer_self_attn.attn._kv_scale = scaling_factor
if hasattr(layer_self_attn.attn, "_k_scale"):
layer_self_attn.attn._k_scale = scaling_factor
layer_self_attn.attn._v_scale = scaling_factor
else:
raise RuntimeError("Self attention has no KV cache scaling "
"factor attribute!")
5 changes: 3 additions & 2 deletions vllm/model_executor/models/solar.py
Original file line number Diff line number Diff line change
@@ -565,8 +565,9 @@ def load_kv_cache_scales(self, quantization_param_path: str) -> None:
# which is consistent with the practice of setting
# scaling_factor = tensor_amax / FPtype_max
scaling_factor *= 2
if hasattr(layer_self_attn, "kv_scale"):
layer_self_attn.attn._kv_scale = scaling_factor
if hasattr(layer_self_attn.attn, "_k_scale"):
layer_self_attn.attn._k_scale = scaling_factor
layer_self_attn.attn._v_scale = scaling_factor
else:
raise RuntimeError("Self attention has no KV cache scaling "
"factor attribute!")
3 changes: 2 additions & 1 deletion vllm/worker/model_runner.py
Original file line number Diff line number Diff line change
@@ -1136,7 +1136,8 @@ def load_model(self) -> None:
self.prompt_adapter_manager.create_prompt_adapter_manager(
self.model))

if self.kv_cache_dtype == "fp8" and current_platform.is_rocm():
if self.kv_cache_dtype == "fp8" and (current_platform.is_rocm()
or current_platform.is_cuda()):
# Currently only ROCm accepts kv-cache scaling factors
# via quantization_param_path and this will be deprecated
# in the future.

0 comments on commit 0e4640b

Please sign in to comment.