Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move kv cache scales from k/v_proj.output_scale to self_attn.k/v_scale #133

Merged
merged 2 commits into from
Aug 15, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 50 additions & 0 deletions src/compressed_tensors/compressors/model_compressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,56 @@ def compress(
compressed_state_dict
)

# HACK (mgoin): Post-process step for kv cache scales to take the
# k/v_proj module `output_scale` parameters, and store them in the
# parent attention module as `k_scale` and `v_scale`
#
# Example:
# Replace `model.layers.0.self_attn.k_proj.output_scale`
# with `model.layers.0.self_attn.k_scale`
if self.quantization_config.kv_cache_scheme is not None:
# HACK (mgoin): We assume the quantized modules in question
# will be k_proj and v_proj since those are the default targets.
# We check that both of these modules have output activation
# quantization, and additionally check that q_proj doesn't.
q_proj_has_no_quant_output = 0
k_proj_has_quant_output = 0
v_proj_has_quant_output = 0
for name, module in model.named_modules():
if not hasattr(module, "quantization_scheme"):
continue
out_act = module.quantization_scheme.output_activations
if name.endswith(".q_proj") and out_act is None:
q_proj_has_no_quant_output += 1
elif name.endswith(".k_proj") and out_act is not None:
k_proj_has_quant_output += 1
elif name.endswith(".v_proj") and out_act is not None:
v_proj_has_quant_output += 1

assert (
q_proj_has_no_quant_output > 0
and k_proj_has_quant_output > 0
and v_proj_has_quant_output > 0
)
assert (
q_proj_has_no_quant_output
== k_proj_has_quant_output
== v_proj_has_quant_output
)

# Move all .k/v_proj.output_scale parameters to .k/v_scale
working_state_dict = {}
for key in compressed_state_dict.keys():
Satrat marked this conversation as resolved.
Show resolved Hide resolved
if key.endswith(".k_proj.output_scale"):
new_key = key.replace(".k_proj.output_scale", ".k_scale")
working_state_dict[new_key] = compressed_state_dict[key]
elif key.endswith(".v_proj.output_scale"):
new_key = key.replace(".v_proj.output_scale", ".v_scale")
working_state_dict[new_key] = compressed_state_dict[key]
else:
working_state_dict[key] = compressed_state_dict[key]
compressed_state_dict = working_state_dict

# HACK: Override the dtype_byte_size function in transformers to
# support float8 types. Fix is posted upstream
# https://github.com/huggingface/transformers/pull/30488
Expand Down
Loading