diff --git a/modelopt/torch/quantization/plugins/attention.py b/modelopt/torch/quantization/plugins/attention.py index 4761e9c35..643774da7 100644 --- a/modelopt/torch/quantization/plugins/attention.py +++ b/modelopt/torch/quantization/plugins/attention.py @@ -271,25 +271,44 @@ def _create_quantized_class_from_ast( # The exec path can handle decorated methods, but the safety compliance disallows exec closure = original_method.__closure__ globals = original_method.__globals__ + if method_code.co_freevars != original_method.__code__.co_freevars: warn(f"{new_class_name}.{method_name} is a decorated method. Ignoring the decorator!") + # Search for the actual undecorated method in the closure + actual_method = None + if original_method.__closure__: + for closure_item in original_method.__closure__: + if ( + not hasattr(closure_item, "cell_contents") + or closure_item.cell_contents is None + ): + continue + item = closure_item.cell_contents + if isinstance(item, types.FunctionType) and item.__name__ == method_name: + # Check if this is the method with the right freevars + if all(var in item.__code__.co_freevars for var in method_code.co_freevars): + actual_method = item + globals = item.__globals__ + break + + if actual_method is None: + raise ValueError( + f"Cannot find undecorated method {method_name} with required free variables " + f"{method_code.co_freevars} in closure" + ) + + # Build closure from actual method + assert actual_method.__closure__ is not None, ( + "Actual method must have closure for freevars" + ) new_closure = () for freevar in method_code.co_freevars: - assert freevar in original_method.__closure__ + assert freevar in actual_method.__code__.co_freevars new_closure += ( - original_method.__closure__[ # type: ignore[index] - original_method.__code__.co_freevars.index(freevar) - ], + actual_method.__closure__[actual_method.__code__.co_freevars.index(freevar)], ) closure = new_closure - for closure_item in original_method.__closure__: # type: ignore[union-attr] - item = closure_item.cell_contents - if isinstance(item, types.FunctionType) and item.__name__ == method_name: - globals = item.__globals__ - break - else: - raise ValueError(f"Cannot find the original method {method_name} in the closure") # Create a new class method from bytecode new_method = types.FunctionType(method_code, globals=globals, closure=closure) diff --git a/tests/_test_utils/torch/transformers_models.py b/tests/_test_utils/torch/transformers_models.py index 5c145c1e9..927ed0ea4 100644 --- a/tests/_test_utils/torch/transformers_models.py +++ b/tests/_test_utils/torch/transformers_models.py @@ -85,6 +85,24 @@ def get_tiny_qwen3_moe(**config_kwargs) -> "Qwen3MoeForCausalLM": return tiny_qwen3_moe +def get_tiny_bert(**config_kwargs) -> "BertForQuestionAnswering": + set_seed(SEED) + + kwargs = { + "hidden_size": 32, + "intermediate_size": 32, + "num_hidden_layers": 2, + "num_attention_heads": 16, + "num_key_value_heads": 2, + "max_position_embeddings": 32, + "vocab_size": 32, + } + kwargs.update(**config_kwargs) + tiny_bert = BertForQuestionAnswering(BertConfig(**kwargs)) + + return tiny_bert + + def get_tiny_llama(**config_kwargs) -> LlamaForCausalLM: set_seed(SEED) kwargs = { diff --git a/tests/unit/torch/quantization/plugins/test_attention_quant.py b/tests/unit/torch/quantization/plugins/test_attention_quant.py index 098d8a62e..832bc8a2e 100644 --- a/tests/unit/torch/quantization/plugins/test_attention_quant.py +++ b/tests/unit/torch/quantization/plugins/test_attention_quant.py @@ -17,7 +17,7 @@ import torch import torch.nn as nn import torch.nn.functional as F -from _test_utils.torch.transformers_models import get_tiny_llama, get_tiny_t5 +from _test_utils.torch.transformers_models import get_tiny_bert, get_tiny_llama, get_tiny_t5 import modelopt.torch.quantization as mtq from modelopt.torch.quantization.plugins.huggingface import _QuantAttention @@ -72,6 +72,7 @@ def forward(self, hidden_states, **kwargs): ) def test_kv_quant_hf(model_getter, attn_cls): model_test = model_getter() + print(model_test) input_ids = torch.randint(0, model_test.config.vocab_size, (1, 4)) if getattr(model_test.config, "is_encoder_decoder", False): kwargs = {"decoder_input_ids": input_ids} @@ -110,3 +111,39 @@ def test_kv_quant_hf(model_getter, attn_cls): if attn_cls is not None: _QuantAttention.is_compatible_attention = original_is_compatible_attention mtq.unregister(attn_cls) + + +def test_kv_quant_bert(): + """Test KV cache quantization on BERT model with decorated attention.""" + model_test = get_tiny_bert() + input_ids = torch.randint(0, model_test.config.vocab_size, (1, 8)) + attention_mask = torch.ones_like(input_ids) + + # Run forward pass before quantization + model_test(input_ids, attention_mask=attention_mask) + + # Quantize with KV cache quantization + mtq.quantize( + model_test, + kv_cache_config, + lambda model: model(input_ids, attention_mask=attention_mask), + ) + + # BERT attention modules are at encoder.layer.X.attention.self + found_quantized_attention = False + for name, module in model_test.named_modules(): + if "attention.self" in name or name.endswith(".self"): + if hasattr(module, "k_bmm_quantizer") and hasattr(module, "v_bmm_quantizer"): + found_quantized_attention = True + # Verify quantizers were calibrated + assert module.k_bmm_quantizer.amax is not None, f"k_bmm not calibrated in {name}" + assert module.v_bmm_quantizer.amax is not None, f"v_bmm not calibrated in {name}" + assert module.q_bmm_quantizer.amax is not None, f"q_bmm not calibrated in {name}" + + assert found_quantized_attention, "No quantized attention modules found in BERT model" + + # Run forward pass after quantization to ensure it works + output = model_test(input_ids, attention_mask=attention_mask) + assert output is not None + assert output.start_logits is not None + assert output.end_logits is not None