2424
2525# please check model card for how to generate these models
2626
27- _DEPRECATED_SINGLE_LINEAR_MODEL_NAMES = [
27+ # high precision model, used for testing config deprecation warning
28+ _HIGH_PRECISION_MODEL = "facebook/opt-125m"
29+
30+ _DEPRECATED_SINGLE_LINEAR_MODEL_INFO = [
2831 # model card: https://huggingface.co/torchao-testing/single-linear-Float8DynamicActivationFloat8WeightConfig-v1-0.13.dev
29- "torchao-testing/single-linear-Float8DynamicActivationFloat8WeightConfig-v1-0.13.dev"
32+ (
33+ "torchao-testing/single-linear-Float8DynamicActivationFloat8WeightConfig-v1-0.13.dev" ,
34+ 1 ,
35+ "Float8DynamicActivationFloat8WeightConfig" ,
36+ ),
37+ # model card: https://huggingface.co/torchao-testing/single-linear-Int4WeightOnlyConfig-v1-0.14.dev
38+ (
39+ "torchao-testing/single-linear-Int4WeightOnlyConfig-v1-0.14.dev" ,
40+ 1 ,
41+ "Int4WeightOnlyConfig" ,
42+ ),
3043]
3144
3245_DEPRECATED_MODEL_INFO = [
3649 1 ,
3750 "Float8DynamicActivationFloat8WeightConfig" ,
3851 ),
52+ # model card: https://huggingface.co/torchao-testing/opt-125m-Int4WeightOnlyConfig-v1-0.14.dev
53+ (
54+ "torchao-testing/opt-125m-Int4WeightOnlyConfig-v1-0.14.dev" ,
55+ 1 ,
56+ "Int4WeightOnlyConfig" ,
57+ ),
3958]
4059
41- _SINGLE_LINEAR_MODEL_NAMES = [
60+ _SINGLE_LINEAR_MODEL_INFO = [
4261 # model card: https://huggingface.co/torchao-testing/single-linear-Float8DynamicActivationFloat8WeightConfig-v2-0.13.dev
43- "torchao-testing/single-linear-Float8DynamicActivationFloat8WeightConfig-v2-0.13.dev" ,
62+ (
63+ "torchao-testing/single-linear-Float8DynamicActivationFloat8WeightConfig-v2-0.13.dev" ,
64+ 2 ,
65+ "Float8DynamicActivationFloat8WeightConfig" ,
66+ ),
4467 # model card: https://huggingface.co/torchao-testing/single-linear-Int4WeightOnlyConfig-v2-0.13.dev
45- "torchao-testing/single-linear-Int4WeightOnlyConfig-v2-0.13.dev" ,
68+ (
69+ "torchao-testing/single-linear-Int4WeightOnlyConfig-v2-0.13.dev" ,
70+ 2 ,
71+ "Int4WeightOnlyConfig" ,
72+ ),
4673 # model card: https://huggingface.co/torchao-testing/single-linear-Int4WeightOnlyConfig-preshuffled-v2-0.13.dev
47- "torchao-testing/single-linear-Int4WeightOnlyConfig-preshuffled-v2-0.13.dev" ,
74+ (
75+ "torchao-testing/single-linear-Int4WeightOnlyConfig-preshuffled-v2-0.13.dev" ,
76+ 2 ,
77+ "Int4WeightOnlyConfig" ,
78+ ),
4879]
4980
5081
5586 "Skipping the test in fbcode for now, not sure how to download from transformers" ,
5687)
5788class TestLoadAndRunCheckpoint (TestCase ):
58- def _test_single_linear_helper (self , model_name ):
89+ def _test_single_linear_helper (
90+ self , model_name , version , config_name , is_deprecated
91+ ):
5992 from huggingface_hub import hf_hub_download
6093
6194 downloaded_model = hf_hub_download (model_name , filename = "model.pt" )
@@ -69,8 +102,20 @@ def _test_single_linear_helper(self, model_name):
69102 model = torch .nn .Sequential (
70103 torch .nn .Linear (32 , 256 , dtype = torch .bfloat16 , device = "cuda" )
71104 )
72- with open (downloaded_model , "rb" ) as f :
105+
106+ with (
107+ open (downloaded_model , "rb" ) as f ,
108+ warnings .catch_warnings (record = True ) as caught_warnings ,
109+ ):
73110 model .load_state_dict (torch .load (f ), assign = True )
111+ if is_deprecated :
112+ assert any (
113+ f"Models quantized with version { version } of { config_name } is deprecated"
114+ in str (w .message )
115+ for w in caught_warnings
116+ ), (
117+ f"Didn't get expected warning message for deprecation for model: { model_name } "
118+ )
74119
75120 downloaded_example_inputs = hf_hub_download (
76121 model_name , filename = "model_inputs.pt"
@@ -84,17 +129,23 @@ def _test_single_linear_helper(self, model_name):
84129 output = model (* example_inputs )
85130 self .assertTrue (torch .equal (output , ref_output ))
86131
87- @common_utils .parametrize ("model_name" , _DEPRECATED_SINGLE_LINEAR_MODEL_NAMES )
88- def test_deprecated_single_linear (self , model_name ):
89- self ._test_single_linear_helper (model_name )
132+ @common_utils .parametrize ("model_info" , _DEPRECATED_SINGLE_LINEAR_MODEL_INFO )
133+ def test_deprecated_single_linear (self , model_info ):
134+ model_name , version , config_name = model_info
135+ self ._test_single_linear_helper (
136+ model_name , version , config_name , is_deprecated = True
137+ )
90138
91- @common_utils .parametrize ("model_name " , _SINGLE_LINEAR_MODEL_NAMES )
92- def test_single_linear (self , model_name ):
139+ @common_utils .parametrize ("model_info " , _SINGLE_LINEAR_MODEL_INFO )
140+ def test_single_linear (self , model_info ):
93141 """Test that we can load and run the quantized linear checkpoint with saved sample input
94142 and match the saved output, to make sure there is no BC breaking changes
95143 when we make changes to tensor subclass implementations
96144 """
97- self ._test_single_linear_helper (model_name )
145+ model_name , version , config_name = model_info
146+ self ._test_single_linear_helper (
147+ model_name , version , config_name , is_deprecated = False
148+ )
98149
99150 @common_utils .parametrize ("model_info" , _DEPRECATED_MODEL_INFO )
100151 def test_deprecated_hf_models (self , model_info ):
@@ -109,17 +160,23 @@ def test_deprecated_hf_models(self, model_info):
109160 torch_dtype = "bfloat16" ,
110161 device_map = "cuda:0" ,
111162 )
163+ # version mismatch check in config.py
112164 assert any (
113165 "Stored version is not the same as current default version of the config"
114166 in str (w .message )
115167 for w in caught_warnings
116- ), "Didn't get expected warning message for version mismatch"
168+ ), (
169+ f"Didn't get expected warning message for version mismatch for config { config_name } , model { model_name } "
170+ )
117171
172+ # checkpoint deprecation
118173 assert any (
119- f"Models quantized with version 1 of { config_name } is deprecated"
174+ f"Models quantized with version { version } of { config_name } is deprecated"
120175 in str (w .message )
121176 for w in caught_warnings
122- ), "Didn't get expected warning message for deprecation"
177+ ), (
178+ f"Didn't get expected warning message for deprecation for model { model_name } "
179+ )
123180 assert isinstance (quantized_model .config .quantization_config , TorchAoConfig )
124181 assert (
125182 quantized_model .config .quantization_config .quant_type .version == version
@@ -139,7 +196,8 @@ def test_deprecated_hf_models(self, model_info):
139196 return_tensors = "pt" ,
140197 ).to ("cuda" )
141198 generated_ids = quantized_model .generate (
142- ** inputs , max_new_tokens = 128 , temperature = 0
199+ ** inputs ,
200+ max_new_tokens = 128 ,
143201 )
144202
145203 downloaded_output = hf_hub_download (model_name , filename = "model_output.pt" )
@@ -153,6 +211,23 @@ def test_deprecated_hf_models(self, model_info):
153211 generated_ids , skip_special_tokens = True , clean_up_tokenization_spaces = False
154212 )
155213
214+ # make sure we throw warning for config deprecation
215+ with warnings .catch_warnings (record = True ) as caught_warnings :
216+ _ = AutoModelForCausalLM .from_pretrained (
217+ _HIGH_PRECISION_MODEL ,
218+ torch_dtype = "bfloat16" ,
219+ device_map = "cuda:0" ,
220+ quantization_config = quantized_model .config .quantization_config ,
221+ )
222+ # config version deprecation in quant_api.py
223+ assert any (
224+ f"Config Deprecation: version { version } of { config_name } is deprecated and will no longer be supported in a future release"
225+ in str (w .message )
226+ for w in caught_warnings
227+ ), (
228+ f"Didn't get expected warning message for version deprecation for config { config_name } , model { model_name } "
229+ )
230+
156231
157232common_utils .instantiate_parametrized_tests (TestLoadAndRunCheckpoint )
158233
0 commit comments