Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 30 additions & 11 deletions modelopt/torch/quantization/plugins/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,25 +271,44 @@ def _create_quantized_class_from_ast(
# The exec path can handle decorated methods, but the safety compliance disallows exec
closure = original_method.__closure__
globals = original_method.__globals__

if method_code.co_freevars != original_method.__code__.co_freevars:
warn(f"{new_class_name}.{method_name} is a decorated method. Ignoring the decorator!")

# Search for the actual undecorated method in the closure
actual_method = None
if original_method.__closure__:
for closure_item in original_method.__closure__:
if (
not hasattr(closure_item, "cell_contents")
or closure_item.cell_contents is None
):
continue
item = closure_item.cell_contents
if isinstance(item, types.FunctionType) and item.__name__ == method_name:
# Check if this is the method with the right freevars
if all(var in item.__code__.co_freevars for var in method_code.co_freevars):
actual_method = item
globals = item.__globals__
break

if actual_method is None:
raise ValueError(
f"Cannot find undecorated method {method_name} with required free variables "
f"{method_code.co_freevars} in closure"
)

# Build closure from actual method
assert actual_method.__closure__ is not None, (
"Actual method must have closure for freevars"
)
new_closure = ()
for freevar in method_code.co_freevars:
assert freevar in original_method.__closure__
assert freevar in actual_method.__code__.co_freevars
new_closure += (
original_method.__closure__[ # type: ignore[index]
original_method.__code__.co_freevars.index(freevar)
],
actual_method.__closure__[actual_method.__code__.co_freevars.index(freevar)],
)
closure = new_closure
for closure_item in original_method.__closure__: # type: ignore[union-attr]
item = closure_item.cell_contents
if isinstance(item, types.FunctionType) and item.__name__ == method_name:
globals = item.__globals__
break
else:
raise ValueError(f"Cannot find the original method {method_name} in the closure")

# Create a new class method from bytecode
new_method = types.FunctionType(method_code, globals=globals, closure=closure)
Expand Down
18 changes: 18 additions & 0 deletions tests/_test_utils/torch/transformers_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,24 @@ def get_tiny_qwen3_moe(**config_kwargs) -> "Qwen3MoeForCausalLM":
return tiny_qwen3_moe


def get_tiny_bert(**config_kwargs) -> "BertForQuestionAnswering":
set_seed(SEED)

kwargs = {
"hidden_size": 32,
"intermediate_size": 32,
"num_hidden_layers": 2,
"num_attention_heads": 16,
"num_key_value_heads": 2,
"max_position_embeddings": 32,
"vocab_size": 32,
}
kwargs.update(**config_kwargs)
tiny_bert = BertForQuestionAnswering(BertConfig(**kwargs))

return tiny_bert


def get_tiny_llama(**config_kwargs) -> LlamaForCausalLM:
set_seed(SEED)
kwargs = {
Expand Down
39 changes: 38 additions & 1 deletion tests/unit/torch/quantization/plugins/test_attention_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from _test_utils.torch.transformers_models import get_tiny_llama, get_tiny_t5
from _test_utils.torch.transformers_models import get_tiny_bert, get_tiny_llama, get_tiny_t5

import modelopt.torch.quantization as mtq
from modelopt.torch.quantization.plugins.huggingface import _QuantAttention
Expand Down Expand Up @@ -72,6 +72,7 @@ def forward(self, hidden_states, **kwargs):
)
def test_kv_quant_hf(model_getter, attn_cls):
model_test = model_getter()
print(model_test)
input_ids = torch.randint(0, model_test.config.vocab_size, (1, 4))
if getattr(model_test.config, "is_encoder_decoder", False):
kwargs = {"decoder_input_ids": input_ids}
Expand Down Expand Up @@ -110,3 +111,39 @@ def test_kv_quant_hf(model_getter, attn_cls):
if attn_cls is not None:
_QuantAttention.is_compatible_attention = original_is_compatible_attention
mtq.unregister(attn_cls)


def test_kv_quant_bert():
"""Test KV cache quantization on BERT model with decorated attention."""
model_test = get_tiny_bert()
input_ids = torch.randint(0, model_test.config.vocab_size, (1, 8))
attention_mask = torch.ones_like(input_ids)

# Run forward pass before quantization
model_test(input_ids, attention_mask=attention_mask)

# Quantize with KV cache quantization
mtq.quantize(
model_test,
kv_cache_config,
lambda model: model(input_ids, attention_mask=attention_mask),
)

# BERT attention modules are at encoder.layer.X.attention.self
found_quantized_attention = False
for name, module in model_test.named_modules():
if "attention.self" in name or name.endswith(".self"):
if hasattr(module, "k_bmm_quantizer") and hasattr(module, "v_bmm_quantizer"):
found_quantized_attention = True
# Verify quantizers were calibrated
assert module.k_bmm_quantizer.amax is not None, f"k_bmm not calibrated in {name}"
assert module.v_bmm_quantizer.amax is not None, f"v_bmm not calibrated in {name}"
assert module.q_bmm_quantizer.amax is not None, f"q_bmm not calibrated in {name}"

assert found_quantized_attention, "No quantized attention modules found in BERT model"

# Run forward pass after quantization to ensure it works
output = model_test(input_ids, attention_mask=attention_mask)
assert output is not None
assert output.start_logits is not None
assert output.end_logits is not None