@@ -279,7 +279,6 @@ def test_per_module_config_skip(self):
279279 ]
280280 self .assertTrue (tokenizer .decode (output [0 ], skip_special_tokens = True ) in EXPECTED_OUTPUT )
281281
282-
283282 @require_torchao_version_greater_or_equal ("0.13.0" )
284283 def test_module_fqn_to_config_regex_basic (self ):
285284 linear_config = Int8WeightOnlyConfig ()
@@ -311,7 +310,12 @@ def test_module_fqn_to_config_regex_fullmatch(self):
311310 linear1_config = Int8WeightOnlyConfig ()
312311 linear2_config = Float8WeightOnlyConfig ()
313312 # intentially removing `j` after `q_proj` so it's not a full match
314- config = ModuleFqnToConfig ({r"re:model\.layers\.+\.self_attn\.q_pro" : linear1_config , "model.layers.3.self_attn.q_proj" : linear2_config })
313+ config = ModuleFqnToConfig (
314+ {
315+ r"re:model\.layers\.+\.self_attn\.q_pro" : linear1_config ,
316+ "model.layers.3.self_attn.q_proj" : linear2_config ,
317+ }
318+ )
315319 quant_config = TorchAoConfig (quant_type = config )
316320 quantized_model = AutoModelForCausalLM .from_pretrained (
317321 self .model_name ,
@@ -338,7 +342,13 @@ def test_module_fqn_to_config_regex_fullmatch(self):
338342 def test_module_fqn_to_config_regex_precedence (self ):
339343 linear1_config = Int8WeightOnlyConfig ()
340344 linear2_config = Float8WeightOnlyConfig ()
341- config = ModuleFqnToConfig ({r"re:model\.layers\..+\.self_attn\.q_proj" : None , "model.layers.3.self_attn.q_proj" : linear2_config , "_default" : linear1_config })
345+ config = ModuleFqnToConfig (
346+ {
347+ r"re:model\.layers\..+\.self_attn\.q_proj" : None ,
348+ "model.layers.3.self_attn.q_proj" : linear2_config ,
349+ "_default" : linear1_config ,
350+ }
351+ )
342352 quant_config = TorchAoConfig (quant_type = config )
343353 quantized_model = AutoModelForCausalLM .from_pretrained (
344354 self .model_name ,
0 commit comments