@@ -80,24 +80,18 @@ def test_module_path(self, dtype):
8080 def test_activation_prescaling (self ):
8181 dtype = torch .bfloat16
8282 input = torch .randn (1 , 128 , dtype = dtype )
83- linear1 = torch .nn .Linear (128 , 256 , bias = False , dtype = dtype )
84- linear2 = torch .nn .Linear (128 , 256 , bias = False , dtype = dtype )
85- with torch .no_grad ():
86- linear2 .weight .copy_ (linear1 .weight )
87- original_output = linear2 (input )
88- quantize_ (linear1 , get_config (group_size = 128 ))
89- quantize_ (linear2 , get_config (group_size = 128 ))
90- qw1 = linear1 .weight
91- assert isinstance (qw1 , SupportsActivationPreScaling ), (
83+ linear = torch .nn .Linear (128 , 256 , bias = False , dtype = dtype )
84+ original_output = linear (input )
85+ quantize_ (linear , get_config (group_size = 128 ))
86+ qw = linear .weight
87+ assert isinstance (qw , SupportsActivationPreScaling ), (
9288 "Expected int4 tensor supports activation prescaling"
9389 )
94- assert qw1 .act_pre_scale is None , "Default `act_pre_scale` is None"
95-
90+ assert qw .act_pre_scale is None , "Default `act_pre_scale` is None"
9691 _ACT_PRE_SCALE = 2
97- manual_scaled_quantized = linear1 (input * _ACT_PRE_SCALE )
98- qw2 = linear2 .weight
99- qw2 .act_pre_scale = _ACT_PRE_SCALE
100- auto_scaled_quantized = linear2 (input )
92+ manual_scaled_quantized = linear (input * _ACT_PRE_SCALE )
93+ qw .act_pre_scale = _ACT_PRE_SCALE
94+ auto_scaled_quantized = linear (input )
10195
10296 # Making sure activation pre scaling is successfully applied to the activation.
10397 self .assertEqual (manual_scaled_quantized , auto_scaled_quantized )
0 commit comments