|
1 | 1 | # SPDX-License-Identifier: Apache-2.0 |
2 | 2 |
|
3 | 3 | import fnmatch |
4 | | -import re |
5 | 4 | from typing import Any, Dict, List, Optional, cast |
6 | 5 |
|
7 | 6 | import torch |
@@ -125,6 +124,13 @@ def from_config(cls, config: Dict[str, Any]) -> "QuarkConfig": |
125 | 124 | for q_config in q_configs: |
126 | 125 | q_config["output_tensors"] = None |
127 | 126 |
|
| 127 | + # In case q_proj output is also quantized, remove the configuration |
| 128 | + # to keep qkv consistency. |
| 129 | + q_proj_q_config = cast(Dict[str, Any], |
| 130 | + layer_quant_config.get("*q_proj")) |
| 131 | + if q_proj_q_config is not None: |
| 132 | + q_proj_q_config["output_tensors"] = None |
| 133 | + |
128 | 134 | return cls(quant_config=config, |
129 | 135 | kv_cache_group=kv_cache_group, |
130 | 136 | kv_cache_config=kv_cache_config, |
@@ -289,29 +295,30 @@ def get_cache_scale(self, name: str) -> Optional[str]: |
289 | 295 | :param name: param name |
290 | 296 | :return: matching param name for KV cache scale in vLLM |
291 | 297 | """ |
292 | | - if self.kv_cache_group is None or len(self.kv_cache_group) == 0: |
293 | | - return None |
294 | | - |
295 | | - kv_proj_names = [ |
296 | | - re.split(r"[*.]", kv_cache)[-1] for kv_cache in self.kv_cache_group |
297 | | - ] |
298 | | - if name.endswith(".output_scale"): |
299 | | - if len(kv_proj_names) == 1 and kv_proj_names[0] in name: |
300 | | - kv_output_scale_name = "." + kv_proj_names[0] + ".output_scale" |
301 | | - return name.replace(kv_output_scale_name, ".attn.k_scale") |
302 | | - |
303 | | - elif len(kv_proj_names) == 2: |
304 | | - for kv_proj_name in kv_proj_names: |
305 | | - if kv_proj_name in name and kv_proj_name == "k_proj": |
306 | | - return name.replace(".k_proj.output_scale", |
307 | | - ".attn.k_scale") |
308 | | - elif kv_proj_name in name and kv_proj_name == "v_proj": |
309 | | - return name.replace(".v_proj.output_scale", |
310 | | - ".attn.v_scale") |
| 298 | + if name.endswith(".output_scale") and ".k_proj" in name: |
| 299 | + return name.replace(".k_proj.output_scale", ".attn.k_scale") |
| 300 | + if name.endswith(".output_scale") and ".v_proj" in name: |
| 301 | + return name.replace(".v_proj.output_scale", ".attn.v_scale") |
| 302 | + if name.endswith(".output_scale") and ".q_proj" in name: |
| 303 | + return name.replace(".q_proj.output_scale", ".attn.q_scale") |
| 304 | + if name.endswith("self_attn.prob_output_scale"): |
| 305 | + return name.replace(".prob_output_scale", ".attn.prob_scale") |
311 | 306 |
|
312 | 307 | # If no matches, return None |
313 | 308 | return None |
314 | 309 |
|
| 310 | + def has_fp8_layer_weights(self): |
| 311 | + layer_quant_config = self.quant_config.get("layer_quant_config") |
| 312 | + to_dict = lambda obj: cast(Dict[str, Any], obj) or {} |
| 313 | + return any([ |
| 314 | + 'fp8' in cast( |
| 315 | + str, |
| 316 | + to_dict( |
| 317 | + to_dict(to_dict(layer_quant_config).get(layer_name)).get( |
| 318 | + "weight")).get("dtype")) |
| 319 | + for layer_name in ["*v_proj", "*k_proj", "*q_proj"] |
| 320 | + ]) |
| 321 | + |
315 | 322 |
|
316 | 323 | class QuarkLinearMethod(LinearMethodBase): |
317 | 324 |
|
|
0 commit comments