Skip to content

Commit 0cbc8e6

Browse files
committed
fix name bug
Signed-off-by: xin3he <xin3.he@intel.com>
1 parent 73dc12c commit 0cbc8e6

File tree

1 file changed

+16
-16
lines changed

1 file changed

+16
-16
lines changed

test/3x/torch/test_config.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ def test_quantize_rtn_from_dict_beginner(self):
5757
from neural_compressor.torch import quantize
5858

5959
quant_config = {
60-
"RTN": {
60+
"rtn": {
6161
"weight_dtype": "nf4",
6262
"weight_bits": 4,
6363
"weight_group_size": 32,
@@ -127,7 +127,7 @@ def test_quantize_rtn_from_dict_advance(self):
127127

128128
fp32_model = build_simple_torch_model()
129129
quant_config = {
130-
"RTN": {
130+
"rtn": {
131131
"global": {
132132
"weight_dtype": "nf4",
133133
"weight_bits": 4,
@@ -188,7 +188,7 @@ def test_config_from_dict(self):
188188
from neural_compressor.torch import RTNConfig
189189

190190
quant_config = {
191-
"RTN": {
191+
"rtn": {
192192
"global": {
193193
"weight_dtype": "nf4",
194194
"weight_bits": 4,
@@ -202,7 +202,7 @@ def test_config_from_dict(self):
202202
},
203203
}
204204
}
205-
config = RTNConfig.from_dict(quant_config["RTN"])
205+
config = RTNConfig.from_dict(quant_config["rtn"])
206206
self.assertIsNotNone(config.local_config)
207207

208208
def test_config_to_dict(self):
@@ -219,15 +219,15 @@ def test_same_type_configs_addition(self):
219219
from neural_compressor.torch import RTNConfig
220220

221221
quant_config1 = {
222-
"RTN": {
222+
"rtn": {
223223
"weight_dtype": "nf4",
224224
"weight_bits": 4,
225225
"weight_group_size": 32,
226226
},
227227
}
228-
q_config = RTNConfig.from_dict(quant_config1["RTN"])
228+
q_config = RTNConfig.from_dict(quant_config1["rtn"])
229229
quant_config2 = {
230-
"RTN": {
230+
"rtn": {
231231
"global": {
232232
"weight_bits": 8,
233233
"weight_group_size": 32,
@@ -240,48 +240,48 @@ def test_same_type_configs_addition(self):
240240
},
241241
}
242242
}
243-
q_config2 = RTNConfig.from_dict(quant_config2["RTN"])
243+
q_config2 = RTNConfig.from_dict(quant_config2["rtn"])
244244
q_config3 = q_config + q_config2
245245
q3_dict = q_config3.to_dict()
246-
for op_name, op_config in quant_config2["RTN"]["local"].items():
246+
for op_name, op_config in quant_config2["rtn"]["local"].items():
247247
for attr, val in op_config.items():
248248
self.assertEqual(q3_dict["local"][op_name][attr], val)
249-
self.assertNotEqual(q3_dict["global"]["weight_bits"], quant_config2["RTN"]["global"]["weight_bits"])
249+
self.assertNotEqual(q3_dict["global"]["weight_bits"], quant_config2["rtn"]["global"]["weight_bits"])
250250

251251
def test_diff_types_configs_addition(self):
252252
from neural_compressor.torch import GPTQConfig, RTNConfig
253253

254254
quant_config1 = {
255-
"RTN": {
255+
"rtn": {
256256
"weight_dtype": "nf4",
257257
"weight_bits": 4,
258258
"weight_group_size": 32,
259259
},
260260
}
261-
q_config = RTNConfig.from_dict(quant_config1["RTN"])
261+
q_config = RTNConfig.from_dict(quant_config1["rtn"])
262262
d_config = GPTQConfig(double_quant_bits=4)
263263
combined_config = q_config + d_config
264264
combined_config_d = combined_config.to_dict()
265265
logger.info(combined_config)
266-
self.assertTrue("RTN" in combined_config_d)
266+
self.assertTrue("rtn" in combined_config_d)
267267
self.assertIn("gptq", combined_config_d)
268268

269269
def test_composable_config_addition(self):
270270
from neural_compressor.torch import GPTQConfig, RTNConfig
271271

272272
quant_config1 = {
273-
"RTN": {
273+
"rtn": {
274274
"weight_dtype": "nf4",
275275
"weight_bits": 4,
276276
"weight_group_size": 32,
277277
},
278278
}
279-
q_config = RTNConfig.from_dict(quant_config1["RTN"])
279+
q_config = RTNConfig.from_dict(quant_config1["rtn"])
280280
d_config = GPTQConfig(double_quant_bits=4)
281281
combined_config = q_config + d_config
282282
combined_config_d = combined_config.to_dict()
283283
logger.info(combined_config)
284-
self.assertTrue("RTN" in combined_config_d)
284+
self.assertTrue("rtn" in combined_config_d)
285285
self.assertIn("gptq", combined_config_d)
286286
combined_config2 = combined_config + d_config
287287
combined_config3 = combined_config + combined_config2

0 commit comments

Comments
 (0)