Skip to content

Commit

Permalink
Fix ONNXRT WOQ failed with None model_path (#1411)
Browse files Browse the repository at this point in the history
Signed-off-by: yuwenzho <yuwen.zhou@intel.com>
  • Loading branch information
yuwenzho authored Nov 23, 2023
1 parent ab8c9f0 commit 4fcfdf8
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 12 deletions.
24 changes: 12 additions & 12 deletions neural_compressor/adaptor/ox_utils/weight_only.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,7 @@ def rtn_quantize(
model: fake quantized ONNXModel
"""
model = model if isinstance(model, BaseModel) else ONNXModel(model)
base_dir = os.path.dirname(model.model_path) if model.model_path is not None else ""
new_nodes = []
remove_nodes = []
for node in model.nodes():
Expand All @@ -326,7 +327,7 @@ def rtn_quantize(
and weight_config.get(node.name, {}) != "fp32"
):
weight_tensor = model.get_initializer(node.input[1])
weight = numpy_helper.to_array(weight_tensor, base_dir=os.path.dirname(model.model_path)).copy()
weight = numpy_helper.to_array(weight_tensor, base_dir=base_dir).copy()
if len(weight.shape) != 2:
continue

Expand Down Expand Up @@ -406,6 +407,7 @@ def apply_awq_scale(model, weight_config, absorb_pairs, output_dicts, num_bits,
new_added_mul_nodes = []
replace_input = []
updated_nodes = []
base_dir = os.path.dirname(model.model_path) if model.model_path is not None else ""

for parent, nodes in absorb_pairs.items():
if any([node.input[0] not in output_dicts for node in nodes]):
Expand Down Expand Up @@ -439,7 +441,7 @@ def apply_awq_scale(model, weight_config, absorb_pairs, output_dicts, num_bits,
if weight_config.get(node.name, {}) == "fp32":
continue

weight = numpy_helper.to_array(model.get_initializer(node.input[1]), os.path.dirname(model.model_path))
weight = numpy_helper.to_array(model.get_initializer(node.input[1]), base_dir)
if len(weight.shape) != 2:
continue

Expand Down Expand Up @@ -481,7 +483,7 @@ def apply_awq_scale(model, weight_config, absorb_pairs, output_dicts, num_bits,

init_share_num = model.get_initializer_share_num(node.input[1])
weight_tensor = model.get_initializer(node.input[1])
tensor = numpy_helper.to_array(weight_tensor, os.path.dirname(model.model_path))
tensor = numpy_helper.to_array(weight_tensor, base_dir)

tensor = tensor.T * best_scale
tensor = (tensor.T).astype("float32")
Expand All @@ -501,9 +503,7 @@ def apply_awq_scale(model, weight_config, absorb_pairs, output_dicts, num_bits,
model.input_name_to_nodes[nodes[0].input[0]]
) == len(nodes):
for idx in [1, 2]:
tensor = numpy_helper.to_array(
model.get_initializer(parent.input[idx]), os.path.dirname(model.model_path)
)
tensor = numpy_helper.to_array(model.get_initializer(parent.input[idx]), base_dir)
new_tensor = tensor / np.reshape(best_scale, (1, -1))
model.set_initializer(parent.input[idx], new_tensor.astype(tensor.dtype), raw=True)
updated_nodes.append(parent.name)
Expand All @@ -516,7 +516,7 @@ def apply_awq_scale(model, weight_config, absorb_pairs, output_dicts, num_bits,
): # pragma: no cover
for inp in parent.input:
if model.get_initializer(inp) is not None:
tensor = numpy_helper.to_array(model.get_initializer(inp), os.path.dirname(model.model_path))
tensor = numpy_helper.to_array(model.get_initializer(inp), base_dir)
new_tensor = tensor / np.reshape(best_scale, (1, -1))
model.set_initializer(inp, new_tensor.astype(tensor.dtype), raw=True)
updated_nodes.append(parent.name)
Expand All @@ -525,7 +525,7 @@ def apply_awq_scale(model, weight_config, absorb_pairs, output_dicts, num_bits,
elif parent.op_type in ["Conv", "FusedConv"] and len(model.input_name_to_nodes[nodes[0].input[0]]) == len(
nodes
): # pragma: no cover
tensor = numpy_helper.to_array(model.get_initializer(parent.input[2]), os.path.dirname(model.model_path))
tensor = numpy_helper.to_array(model.get_initializer(parent.input[2]), base_dir)
new_tensor = tensor / np.reshape(best_scale, (1, -1))
model.set_initializer(parent.input[2], new_tensor.astype(tensor.dtype), raw=True)
updated_nodes.append(parent.name)
Expand Down Expand Up @@ -563,6 +563,7 @@ def apply_awq_scale(model, weight_config, absorb_pairs, output_dicts, num_bits,

def apply_awq_clip(model, weight_config, absorb_pairs, output_dicts, num_bits, group_size, scheme):
"""Apply clip for weight by checking mse."""
base_dir = os.path.dirname(model.model_path) if model.model_path is not None else ""
ratios = {}
for parent, nodes in absorb_pairs.items():
if any([node.input[0] not in output_dicts for node in nodes]):
Expand All @@ -581,9 +582,7 @@ def apply_awq_clip(model, weight_config, absorb_pairs, output_dicts, num_bits, g
group_size = weight_config[node.name]["group_size"]
scheme = weight_config[node.name]["scheme"]

org_weight = numpy_helper.to_array(
model.get_initializer(node.input[1]), base_dir=os.path.dirname(model.model_path)
)
org_weight = numpy_helper.to_array(model.get_initializer(node.input[1]), base_dir=base_dir)
org_w_shape = org_weight.shape # ic, oc
group_size = group_size if group_size != -1 else org_w_shape[0]
org_out = np.matmul(inp, org_weight) # n_token, oc
Expand Down Expand Up @@ -992,6 +991,7 @@ def gptq_quantize(
model: fake quantized ONNXModel
"""
model = model if isinstance(model, BaseModel) else ONNXModel(model)
base_dir = os.path.dirname(model.model_path) if model.model_path is not None else ""
output_dicts = {}

inputs, so = prepare_inputs(model, n_samples, dataloader)
Expand Down Expand Up @@ -1037,7 +1037,7 @@ def gptq_quantize(
and weight_config.get(node.name, {}).get("algorithm", "GPTQ") == "GPTQ"
):
weight = numpy_helper.to_array(
model.get_initializer(model.get_node(node.name).input[1]), os.path.dirname(model.model_path)
model.get_initializer(model.get_node(node.name).input[1]), base_dir
).copy()
if len(weight.shape) != 2:
continue
Expand Down
35 changes: 35 additions & 0 deletions test/adaptor/onnxrt_adaptor/test_weight_only_adaptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from transformers import AutoTokenizer

from neural_compressor import PostTrainingQuantConfig, quantization
from neural_compressor.adaptor.ox_utils.weight_only import awq_quantize, gptq_quantize, rtn_quantize
from neural_compressor.utils.constant import FP32


Expand Down Expand Up @@ -386,6 +387,40 @@ def fake_eval(model, eval_result_lst):
woq_model = self._test_woq_tune_common(self.llama_model, self.llama_dataloader, partial_fake_eval)
self.assertEqual(self._count_woq_matmul(woq_model), 155)

def test_woq_with_ModelProto_input(self):
from neural_compressor.model.onnx_model import ONNXModel

q4_node_config = {}
template_config_q4 = {"bits": 4, "group_size": 32, "scheme": "sym"}
template_config_fp32 = "fp32"
for node in self.gptj_model.graph.node:
if node.op_type in ["MatMul"]:
if not all([ONNXModel(self.gptj_model).get_initializer(i) is None for i in node.input]):
q4_node_config[node.name] = template_config_q4
else:
q4_node_config[node.name] = template_config_fp32

q_model = rtn_quantize(self.gptj_model, q4_node_config)
for data, _ in self.gptj_dataloader:
q_out = Inference(q_model.model, data)
org_out = Inference(self.gptj_model, data)
for q, org in zip(q_out, org_out):
self.assertTrue((np.abs(q_out[0] - org_out[0]) < 0.5).all())

q_model = gptq_quantize(self.gptj_model, self.gptj_dataloader, q4_node_config)
for data, _ in self.gptj_dataloader:
q_out = Inference(q_model.model, data)
org_out = Inference(self.gptj_model, data)
for q, org in zip(q_out, org_out):
self.assertTrue((np.abs(q_out[0] - org_out[0]) < 0.5).all())

q_model = awq_quantize(self.gptj_model, self.gptj_dataloader, q4_node_config)
for data, _ in self.gptj_dataloader:
q_out = Inference(q_model.model, data)
org_out = Inference(self.gptj_model, data)
for q, org in zip(q_out, org_out):
self.assertTrue((np.abs(q_out[0] - org_out[0]) < 0.5).all())


if __name__ == "__main__":
unittest.main()

0 comments on commit 4fcfdf8

Please sign in to comment.