Skip to content

Commit 3ef9e39

Browse files
Fix BertSdpaSelfAttention quantization (#648)
## What does this PR do? **Type of change:** Bug fix **Overview:** `BertSDPASelfAttention` quantization was failing Issue: Decorators (e.g., @deprecate_kwarg in BERT) wrap methods and change their free variables. The modified AST needs __class__ for super(), but the decorated wrapper has different freevars (e.g., 'additional_message', 'func'). Fix: Search the decorator's closure to find the actual undecorated method with matching free variables, then use its globals and closure for bytecode reconstruction. ## Usage ``` cd examples/chained_optimizations bash scripts/1_prune.sh bash scripts/2_int8_quantize.sh ``` ## Testing Added unit test: test_kv_quant_bert ## Before your PR is "*Ready for review*" <!-- If you haven't finished some of the above items you can still open `Draft` PR. --> - **Make sure you read and follow [Contributor guidelines](https://github.com/NVIDIA/TensorRT-Model-Optimizer/blob/main/CONTRIBUTING.md)** and your commits are signed. - **Is this change backward compatible?**: Yes - **Did you write any new necessary tests?**: Yes - **Did you add or update any necessary documentation?**: Yes - **Did you update [Changelog](https://github.com/NVIDIA/TensorRT-Model-Optimizer/blob/main/CHANGELOG.rst)?**: NA ## Additional Information <!-- E.g. related issue. --> Signed-off-by: Kinjal Patel <kinjalpravin@nvidia.com>
1 parent c6c9905 commit 3ef9e39

File tree

3 files changed

+86
-12
lines changed

3 files changed

+86
-12
lines changed

modelopt/torch/quantization/plugins/attention.py

Lines changed: 30 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -271,25 +271,44 @@ def _create_quantized_class_from_ast(
271271
# The exec path can handle decorated methods, but the safety compliance disallows exec
272272
closure = original_method.__closure__
273273
globals = original_method.__globals__
274+
274275
if method_code.co_freevars != original_method.__code__.co_freevars:
275276
warn(f"{new_class_name}.{method_name} is a decorated method. Ignoring the decorator!")
276277

278+
# Search for the actual undecorated method in the closure
279+
actual_method = None
280+
if original_method.__closure__:
281+
for closure_item in original_method.__closure__:
282+
if (
283+
not hasattr(closure_item, "cell_contents")
284+
or closure_item.cell_contents is None
285+
):
286+
continue
287+
item = closure_item.cell_contents
288+
if isinstance(item, types.FunctionType) and item.__name__ == method_name:
289+
# Check if this is the method with the right freevars
290+
if all(var in item.__code__.co_freevars for var in method_code.co_freevars):
291+
actual_method = item
292+
globals = item.__globals__
293+
break
294+
295+
if actual_method is None:
296+
raise ValueError(
297+
f"Cannot find undecorated method {method_name} with required free variables "
298+
f"{method_code.co_freevars} in closure"
299+
)
300+
301+
# Build closure from actual method
302+
assert actual_method.__closure__ is not None, (
303+
"Actual method must have closure for freevars"
304+
)
277305
new_closure = ()
278306
for freevar in method_code.co_freevars:
279-
assert freevar in original_method.__closure__
307+
assert freevar in actual_method.__code__.co_freevars
280308
new_closure += (
281-
original_method.__closure__[ # type: ignore[index]
282-
original_method.__code__.co_freevars.index(freevar)
283-
],
309+
actual_method.__closure__[actual_method.__code__.co_freevars.index(freevar)],
284310
)
285311
closure = new_closure
286-
for closure_item in original_method.__closure__: # type: ignore[union-attr]
287-
item = closure_item.cell_contents
288-
if isinstance(item, types.FunctionType) and item.__name__ == method_name:
289-
globals = item.__globals__
290-
break
291-
else:
292-
raise ValueError(f"Cannot find the original method {method_name} in the closure")
293312

294313
# Create a new class method from bytecode
295314
new_method = types.FunctionType(method_code, globals=globals, closure=closure)

tests/_test_utils/torch/transformers_models.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,24 @@ def get_tiny_qwen3_moe(**config_kwargs) -> "Qwen3MoeForCausalLM":
8585
return tiny_qwen3_moe
8686

8787

88+
def get_tiny_bert(**config_kwargs) -> "BertForQuestionAnswering":
89+
set_seed(SEED)
90+
91+
kwargs = {
92+
"hidden_size": 32,
93+
"intermediate_size": 32,
94+
"num_hidden_layers": 2,
95+
"num_attention_heads": 16,
96+
"num_key_value_heads": 2,
97+
"max_position_embeddings": 32,
98+
"vocab_size": 32,
99+
}
100+
kwargs.update(**config_kwargs)
101+
tiny_bert = BertForQuestionAnswering(BertConfig(**kwargs))
102+
103+
return tiny_bert
104+
105+
88106
def get_tiny_llama(**config_kwargs) -> LlamaForCausalLM:
89107
set_seed(SEED)
90108
kwargs = {

tests/unit/torch/quantization/plugins/test_attention_quant.py

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import torch
1818
import torch.nn as nn
1919
import torch.nn.functional as F
20-
from _test_utils.torch.transformers_models import get_tiny_llama, get_tiny_t5
20+
from _test_utils.torch.transformers_models import get_tiny_bert, get_tiny_llama, get_tiny_t5
2121

2222
import modelopt.torch.quantization as mtq
2323
from modelopt.torch.quantization.plugins.huggingface import _QuantAttention
@@ -72,6 +72,7 @@ def forward(self, hidden_states, **kwargs):
7272
)
7373
def test_kv_quant_hf(model_getter, attn_cls):
7474
model_test = model_getter()
75+
print(model_test)
7576
input_ids = torch.randint(0, model_test.config.vocab_size, (1, 4))
7677
if getattr(model_test.config, "is_encoder_decoder", False):
7778
kwargs = {"decoder_input_ids": input_ids}
@@ -110,3 +111,39 @@ def test_kv_quant_hf(model_getter, attn_cls):
110111
if attn_cls is not None:
111112
_QuantAttention.is_compatible_attention = original_is_compatible_attention
112113
mtq.unregister(attn_cls)
114+
115+
116+
def test_kv_quant_bert():
117+
"""Test KV cache quantization on BERT model with decorated attention."""
118+
model_test = get_tiny_bert()
119+
input_ids = torch.randint(0, model_test.config.vocab_size, (1, 8))
120+
attention_mask = torch.ones_like(input_ids)
121+
122+
# Run forward pass before quantization
123+
model_test(input_ids, attention_mask=attention_mask)
124+
125+
# Quantize with KV cache quantization
126+
mtq.quantize(
127+
model_test,
128+
kv_cache_config,
129+
lambda model: model(input_ids, attention_mask=attention_mask),
130+
)
131+
132+
# BERT attention modules are at encoder.layer.X.attention.self
133+
found_quantized_attention = False
134+
for name, module in model_test.named_modules():
135+
if "attention.self" in name or name.endswith(".self"):
136+
if hasattr(module, "k_bmm_quantizer") and hasattr(module, "v_bmm_quantizer"):
137+
found_quantized_attention = True
138+
# Verify quantizers were calibrated
139+
assert module.k_bmm_quantizer.amax is not None, f"k_bmm not calibrated in {name}"
140+
assert module.v_bmm_quantizer.amax is not None, f"v_bmm not calibrated in {name}"
141+
assert module.q_bmm_quantizer.amax is not None, f"q_bmm not calibrated in {name}"
142+
143+
assert found_quantized_attention, "No quantized attention modules found in BERT model"
144+
145+
# Run forward pass after quantization to ensure it works
146+
output = model_test(input_ids, attention_mask=attention_mask)
147+
assert output is not None
148+
assert output.start_logits is not None
149+
assert output.end_logits is not None

0 commit comments

Comments
 (0)