Skip to content

Commit

Permalink
Separate kv_scale into k_scale and v_scale (#25)
Browse files Browse the repository at this point in the history
* Separate kv_scale into key_scale and value_scale

* New format and pass test
  • Loading branch information
mgoin authored Jul 23, 2024
1 parent 4b2092c commit 2cd265f
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 12 deletions.
9 changes: 4 additions & 5 deletions auto_fp8/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,9 +298,8 @@ def quantize_activations(
cleanup_memory()

# Post-process step for kv cache scales to take the k/v module
# `output_scale` parameters, take the max of them, and store them in
# the parent attention module as `kv_scale`
# NOTE: if we want to switch to the `output_scale` representation, we can simply remove this block
# `output_scale` parameters, and store them in the parent attention
# module as `k_scale` and `v_scale`
if hasattr(quantize_config, "kv_cache_quant_layers"):
# Assumes that list is ordered such that [layer0.k_proj, layer0.v_proj, layer1.k_proj, layer1.v_proj, ...]
# so we make a list of tuples [(layer0.k_proj, layer0.v_proj), (layer1.k_proj, layer1.v_proj), ...]
Expand All @@ -313,8 +312,8 @@ def quantize_activations(
k_proj = dict(model.named_modules())[k_proj_name]
v_proj = dict(model.named_modules())[v_proj_name]

kv_scale = max(k_proj.output_scale, v_proj.output_scale)
parent_module.kv_scale = torch.nn.Parameter(kv_scale, requires_grad=False)
parent_module.k_scale = torch.nn.Parameter(k_proj.output_scale, requires_grad=False)
parent_module.v_scale = torch.nn.Parameter(v_proj.output_scale, requires_grad=False)

# Remove output_scale from k_proj and v_proj
k_proj.output_scale = None
Expand Down
21 changes: 14 additions & 7 deletions tests/test_auto_fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,14 +80,21 @@ def test_kv_cache_static_quantization(model_id, target_size):
model.save_quantized(quantized_model_dir)

tensors = safetensors.torch.load_file(f"{quantized_model_dir}/model.safetensors")
proj_linear_count = 0
kv_scale_count = 0
k_proj_count = 0
v_proj_count = 0
k_scale_count = 0
v_scale_count = 0
for name, _ in tensors.items():
if name.endswith("k_proj.weight") or name.endswith("v_proj.weight"):
proj_linear_count += 1
if name.endswith("kv_scale"):
kv_scale_count += 1
assert proj_linear_count // 2 == kv_scale_count
if name.endswith(".k_proj.weight"):
k_proj_count += 1
if name.endswith(".v_proj.weight"):
v_proj_count += 1
if name.endswith(".k_scale"):
k_scale_count += 1
if name.endswith(".v_scale"):
v_scale_count += 1
assert k_proj_count == k_scale_count
assert v_proj_count == v_scale_count

# Measure checkpoint size and cleanup
model_size = os.path.getsize(f"{quantized_model_dir}/model.safetensors")
Expand Down

0 comments on commit 2cd265f

Please sign in to comment.