Skip to content

Commit d4b0f0f

Browse files
authored
modify op_type for set_local in 3.x API (#1773)
Signed-off-by: Cheng, Zixuan <zixuan.cheng@intel.com>
1 parent 1cb844b commit d4b0f0f

File tree

4 files changed

+37
-14
lines changed

4 files changed

+37
-14
lines changed

examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/llm/run_clm_no_trainer.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -340,13 +340,13 @@ def run_fn_for_gptq(model, dataloader_for_calibration, *args):
340340
quant_config = SmoothQuantConfig(alpha=args.alpha, folding=True)
341341

342342
if re.search("gpt", user_model.config.model_type):
343-
quant_config.set_local("add", SmoothQuantConfig(w_dtype="fp32", act_dtype="fp32"))
343+
quant_config.set_local(torch.add, SmoothQuantConfig(w_dtype="fp32", act_dtype="fp32"))
344344
else:
345345
from neural_compressor.torch.quantization import get_default_static_config, StaticQuantConfig
346346

347347
quant_config = get_default_static_config()
348348
if re.search("gpt", user_model.config.model_type):
349-
quant_config.set_local("add", StaticQuantConfig(w_dtype="fp32", act_dtype="fp32"))
349+
quant_config.set_local(torch.add, StaticQuantConfig(w_dtype="fp32", act_dtype="fp32"))
350350

351351
from neural_compressor.torch.algorithms.smooth_quant import move_input_to_device
352352
from tqdm import tqdm
@@ -397,7 +397,7 @@ def run_fn(model):
397397
# print("Int8 model loading does not support WeightOnlyQuant now.")
398398
# pass
399399
# else:
400-
# user_model, _ = get_user_model()
400+
user_model, _ = get_user_model()
401401

402402

403403
if args.accuracy:

neural_compressor/torch/algorithms/static_quant/utility.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -90,9 +90,10 @@ def check_cfg_and_qconfig(user_cfg, cfgs, op_infos_from_cfgs, output_tensor_ids_
9090
for i, op_name in enumerate(op):
9191
for ops, _ in op_infos_from_cfgs.items():
9292
if "fqn" in op_infos_from_cfgs[ops].keys() and op_infos_from_cfgs[ops]["fqn"] == op_name:
93-
ori_op = (tuple(ops), unify_op_type_mapping_ipex[op_infos_from_cfgs[ops]["op_type"]])
94-
tmp_user_cfg[((ori_op[0],), ori_op[1])] = user_cfg[op]
95-
break
93+
if op_infos_from_cfgs[ops]["op_type"] in unify_op_type_mapping_ipex:
94+
ori_op = (tuple(ops), unify_op_type_mapping_ipex[op_infos_from_cfgs[ops]["op_type"]])
95+
tmp_user_cfg[((ori_op[0],), ori_op[1])] = user_cfg[op]
96+
break
9697
user_cfg = tmp_user_cfg
9798
for op_name in user_cfg:
9899
inc_op_cfg = user_cfg[op_name]
@@ -291,15 +292,12 @@ def get_quantizable_ops_recursively(model, example_inputs): # pragma: no cover
291292
map_op_name_to_fqn[(tuple(name), ipex_op_type)] = module_fqn
292293
if "class" in ipex_op_type: # "<class 'torch.nn.modules.activation.ReLU'>"
293294
op_type = ipex_op_type.split("'")[1]
294-
op_name_info.append((module_fqn, eval(op_type)))
295+
op_name_info.append((module_fqn, eval(op_type).__name__))
295296
elif "method" in ipex_op_type: # "<method 'add' of 'torch._C._TensorBase' objects>"
296297
method = ipex_op_type.split("'")[1]
297-
op_type = getattr(
298-
torch._C._TensorBase if ipex_ver.release < Version("2.2") else torch._C.TensorBase, method
299-
)
300-
op_name_info.append((module_fqn, op_type))
301-
else:
302-
op_name_info.append((module_fqn, op_type))
298+
op_name_info.append((module_fqn, method))
299+
elif "Convolution" in ipex_op_type: # "Convolution_Relu"
300+
op_name_info.append((module_fqn, "Conv2d"))
303301
else:
304302
re_flag = False
305303
for pattern, unify_op_type in unify_op_type_mapping_ipex["re"].items():

test/3x/torch/quantization/test_smooth_quant.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,21 @@ def test_smooth_quant_auto(self):
5555
q_model = quantize(fp32_model, quant_config=quant_config, run_fn=run_fn, example_inputs=example_inputs)
5656
assert q_model is not None, "Quantization failed!"
5757

58+
@pytest.mark.skipif(not is_ipex_available(), reason="Requires IPEX")
59+
def test_smooth_quant_fallback(self):
60+
fp32_model = copy.deepcopy(model)
61+
quant_config = get_default_sq_config()
62+
example_inputs = torch.randn([1, 3])
63+
# fallback by op_type
64+
quant_config.set_local(torch.nn.Linear, SmoothQuantConfig(w_dtype="fp32", act_dtype="fp32"))
65+
q_model = quantize(fp32_model, quant_config=quant_config, run_fn=run_fn, example_inputs=example_inputs)
66+
assert q_model is not None, "Quantization failed!"
67+
68+
for op, op_info in q_model.tune_cfg[" "]["q_op_infos"].items():
69+
if op_info["op_type"] == "<class 'torch.nn.modules.linear.Linear'>":
70+
dtype = q_model.tune_cfg[" "]["q_op_infos"][op]["input_tensor_infos"][0]["force_dtype"]
71+
assert dtype == "torch.float32", "Failed to fallback linear op, please check!"
72+
5873
@pytest.mark.skipif(not is_ipex_available(), reason="Requires IPEX")
5974
@pytest.mark.parametrize(
6075
"act_sym, act_algo, alpha, folding, scale_sharing",

test/3x/torch/quantization/test_static_quant.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,19 +63,29 @@ def test_static_quant_fallback(self):
6363
quant_config = get_default_static_config()
6464
example_inputs = self.input
6565
# fallback by op_type
66-
quant_config.set_local(torch.nn.modules.linear.Linear, StaticQuantConfig(w_dtype="fp32", act_dtype="fp32"))
66+
quant_config.set_local(torch.nn.Linear, StaticQuantConfig(w_dtype="fp32", act_dtype="fp32"))
6767
prepared_model = prepare(fp32_model, quant_config=quant_config, example_inputs=example_inputs)
6868
run_fn(prepared_model)
6969
q_model = convert(prepared_model)
7070
assert q_model is not None, "Quantization failed!"
7171

72+
for op, op_info in q_model.tune_cfg[" "]["q_op_infos"].items():
73+
if op_info["op_type"] == "<class 'torch.nn.modules.linear.Linear'>":
74+
dtype = q_model.tune_cfg[" "]["q_op_infos"][op]["input_tensor_infos"][0]["force_dtype"]
75+
assert dtype == "torch.float32", "Failed to fallback linear op, please check!"
76+
7277
# fallback by op_name
7378
quant_config.set_local("fc1", StaticQuantConfig(w_dtype="fp32", act_dtype="fp32"))
7479
prepared_model = prepare(fp32_model, quant_config=quant_config, example_inputs=example_inputs)
7580
run_fn(prepared_model)
7681
q_model = convert(prepared_model)
7782
assert q_model is not None, "Quantization failed!"
7883

84+
for op, op_info in q_model.tune_cfg[" "]["q_op_infos"].items():
85+
if op_info["fqn"] == "fc1":
86+
dtype = q_model.tune_cfg[" "]["q_op_infos"][op]["input_tensor_infos"][0]["force_dtype"]
87+
assert dtype == "torch.float32", "Failed to fallback fc1 layer, please check!"
88+
7989
@pytest.mark.skipif(not is_ipex_available(), reason="Requires IPEX")
8090
@pytest.mark.parametrize(
8191
"act_sym, act_algo",

0 commit comments

Comments
 (0)