diff --git a/src/compressed_tensors/compressors/model_compressor.py b/src/compressed_tensors/compressors/model_compressor.py index b8bf0001..9807cbec 100644 --- a/src/compressed_tensors/compressors/model_compressor.py +++ b/src/compressed_tensors/compressors/model_compressor.py @@ -240,6 +240,56 @@ def compress( compressed_state_dict ) + # HACK (mgoin): Post-process step for kv cache scales to take the + # k/v_proj module `output_scale` parameters, and store them in the + # parent attention module as `k_scale` and `v_scale` + # + # Example: + # Replace `model.layers.0.self_attn.k_proj.output_scale` + # with `model.layers.0.self_attn.k_scale` + if self.quantization_config.kv_cache_scheme is not None: + # HACK (mgoin): We assume the quantized modules in question + # will be k_proj and v_proj since those are the default targets. + # We check that both of these modules have output activation + # quantization, and additionally check that q_proj doesn't. + q_proj_has_no_quant_output = 0 + k_proj_has_quant_output = 0 + v_proj_has_quant_output = 0 + for name, module in model.named_modules(): + if not hasattr(module, "quantization_scheme"): + continue + out_act = module.quantization_scheme.output_activations + if name.endswith(".q_proj") and out_act is None: + q_proj_has_no_quant_output += 1 + elif name.endswith(".k_proj") and out_act is not None: + k_proj_has_quant_output += 1 + elif name.endswith(".v_proj") and out_act is not None: + v_proj_has_quant_output += 1 + + assert ( + q_proj_has_no_quant_output > 0 + and k_proj_has_quant_output > 0 + and v_proj_has_quant_output > 0 + ) + assert ( + q_proj_has_no_quant_output + == k_proj_has_quant_output + == v_proj_has_quant_output + ) + + # Move all .k/v_proj.output_scale parameters to .k/v_scale + working_state_dict = {} + for key in compressed_state_dict.keys(): + if key.endswith(".k_proj.output_scale"): + new_key = key.replace(".k_proj.output_scale", ".k_scale") + working_state_dict[new_key] = compressed_state_dict[key] + elif key.endswith(".v_proj.output_scale"): + new_key = key.replace(".v_proj.output_scale", ".v_scale") + working_state_dict[new_key] = compressed_state_dict[key] + else: + working_state_dict[key] = compressed_state_dict[key] + compressed_state_dict = working_state_dict + # HACK: Override the dtype_byte_size function in transformers to # support float8 types. Fix is posted upstream # https://github.com/huggingface/transformers/pull/30488