Skip to content

Commit

Permalink
Support MatMulNBit op for ort 1.17 (#1327)
Browse files Browse the repository at this point in the history
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
mengniwang95 and pre-commit-ci[bot] committed Nov 10, 2023
1 parent 1beb435 commit 67a31ba
Show file tree
Hide file tree
Showing 2 changed files with 117 additions and 38 deletions.
2 changes: 1 addition & 1 deletion neural_compressor/adaptor/onnxrt.py
Original file line number Diff line number Diff line change
Expand Up @@ -1690,7 +1690,7 @@ def _dump_model_op_stats(self, model, tune_cfg):

dtype_set = set()
for node in model.nodes():
if node.op_type == "MatMulFpQ4":
if node.op_type in ["MatMulFpQ4", "MatMulNBits"]:
optype = "MatMul"
else:
optype = node.op_type
Expand Down
153 changes: 116 additions & 37 deletions neural_compressor/adaptor/ox_utils/weight_only.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,23 @@
ort = LazyImport("onnxruntime")
logger = logging.getLogger("neural_compressor")
ONNXRT116_VERSION = Version("1.16.0")
ONNXRT1161_VERSION = Version("1.16.1")


def get_blob_size(group_size, has_zp): # pragma: no cover
"""Get blob_size.
Args:
group_size (int): how many elements share one scale/zp
has_zp (bool): whether zero_point is None
"""
if Version(ort.__version__) > ONNXRT1161_VERSION:
blob_size = group_size // 2
elif has_zp:
blob_size = group_size // 2 + 4 + 1
else:
blob_size = group_size // 2 + 4
return blob_size


def make_matmul_weight_only_node(
Expand All @@ -54,49 +71,102 @@ def make_matmul_weight_only_node(
zero_point (array): zero point
Returns:
matmul_weight_only_node: MatMulFpQ4 node
new_inits: initializers of the MatMulFpQ4 node
matmul_weight_only_node: MatMulFpQ4 or MatMulNBits node
new_inits: initializers of the new node
"""
if zero_point is not None:
blob_size = group_size // 2 + 4 + 1
offset = 5
else:
blob_size = group_size // 2 + 4
offset = 4

blob_size = get_blob_size(group_size, zero_point is not None)
packed = np.zeros((q_weight.shape[0], blob_size), dtype="uint8")
for i in range(q_weight.shape[0]):
bf = struct.pack("f", scale[i])
packed[i][0] = bf[0]
packed[i][1] = bf[1]
packed[i][2] = bf[2]
packed[i][3] = bf[3]
q_weight_name = node.input[1] + "_Q{}G{}".format(str(num_bits), str(group_size))
input_names = [node.input[0], q_weight_name]
new_inits = []
kwargs = {}

if Version(ort.__version__) > ONNXRT1161_VERSION:
op_type = "MatMulNBits"

# pack quantized weight
for i in range(q_weight.shape[0]):
for k in range(0, group_size, 2):
packed[i][k // 2] = q_weight[i][k] | q_weight[i][k + 1] << 4
packed = np.reshape(packed, (-1, k_blocks, blob_size))

# build scale tensor
scale = np.reshape(scale, (-1, k_blocks)).astype("float32")
scale_tensor = onnx.helper.make_tensor(
name=node.input[1] + "_scale", data_type=1, dims=scale.shape, vals=scale.tobytes(), raw=True
)
input_names.append(scale_tensor.name)
new_inits.append(scale_tensor)

# build zero_point tensor
if zero_point is not None:
packed[i][4] = zero_point[i]
if num_bits > 4:
packed_zp = np.reshape(zero_point, (1, -1)).astype("uint8")
else:
packed_zp = np.full((zero_point.shape[0] + 1) // 2, 136, dtype="uint8")
for i in range(zero_point.shape[0] // k_blocks):
for j in range(k_blocks):
idx = i * k_blocks + j
zp = zero_point[idx]
packed_zp[idx // 2] = (
((packed_zp[idx // 2] & 0x0F) | (zp << 4))
if (idx & 1)
else ((packed_zp[idx // 2] & 0xF0) | zp)
)

zp_tensor = onnx.helper.make_tensor(
name=node.input[1] + "_zp", data_type=2, dims=packed_zp.shape, vals=packed_zp.tobytes(), raw=True
)
input_names.append(zp_tensor.name)
new_inits.append(zp_tensor)

# set kwargs
kwargs["K"] = weight_shape[0]
kwargs["N"] = weight_shape[1]
kwargs["bits"] = num_bits
kwargs["block_size"] = group_size

else:
offset = 5 if zero_point is not None else 4
op_type = "MatMulFpQ4"

# pack quantized weight
for i in range(q_weight.shape[0]):
bf = struct.pack("f", scale[i])
packed[i][0] = bf[0]
packed[i][1] = bf[1]
packed[i][2] = bf[2]
packed[i][3] = bf[3]

if zero_point is not None:
packed[i][4] = zero_point[i]

packed[i][offset:] = np.bitwise_or(
q_weight[i][: group_size // 2], np.left_shift(q_weight[i][group_size // 2 :], num_bits)
)
packed = packed.reshape(-1)

packed[i][offset:] = np.bitwise_or(
q_weight[i][: group_size // 2], np.left_shift(q_weight[i][group_size // 2 :], num_bits)
# build shape tensor
shape_tensor = onnx.helper.make_tensor(
name=node.input[1] + "_shape", data_type=7, dims=(2,), vals=np.array(weight_shape, dtype="int64")
)
new_inits.append(shape_tensor)
input_names.append(shape_tensor.name)

# set kwargs
kwargs["blk_quant_type"] = 1 if zero_point is not None else 0

packed = packed.reshape(-1)
q_weight_tensor = onnx.helper.make_tensor(
name=node.input[1] + "_Q{}G{}".format(str(num_bits), str(group_size)),
name=q_weight_name,
data_type=2,
dims=packed.shape,
vals=packed.tobytes(),
raw=True,
)
shape_tensor = onnx.helper.make_tensor(
name=node.input[1] + "_shape", data_type=7, dims=(2,), vals=np.array(weight_shape, dtype="int64")
)
input_names = [node.input[0], q_weight_tensor.name, shape_tensor.name]
new_inits = [q_weight_tensor, shape_tensor]
new_inits.append(q_weight_tensor)

kwargs = {}
kwargs["blk_quant_type"] = 1 if zero_point is not None else 0
matmul_weight_only_node = onnx.helper.make_node(
"MatMulFpQ4",
op_type,
inputs=input_names,
outputs=node.output,
name=node.name + "_Q" + str(num_bits) if node.name else "_Q" + str(num_bits),
Expand Down Expand Up @@ -257,8 +327,11 @@ def rtn_quantize(

weight = pad_tensor(weight, group_size, k_blocks)

if Version(ort.__version__) >= ONNXRT116_VERSION and num_bits == 4 and group_size == 32: # pragma: no cover
# currently MatMulFpQ4 only support 4 bits and 32 group_size
if (Version(ort.__version__) > ONNXRT1161_VERSION and num_bits == 4) or (
Version(ort.__version__) >= ONNXRT116_VERSION and num_bits == 4 and group_size == 32
): # pragma: no cover
# MatMulFpQ4 support 4 bits and 32 group_size with ort 1.16.0 and 1.16.1 versions
# MatMulNBits supports 4 bits and 2^n group_size with ort > 1.16.1
q_weight, scale, zp = quant_tensor(
weight.T, num_bits, group_size, scheme, "uint", ratios.get(node.input[1], 1)
)
Expand Down Expand Up @@ -361,9 +434,11 @@ def apply_awq_scale(model, weight_config, absorb_pairs, output_dicts, num_bits,
weight = weight.T * scales
weight = pad_tensor(weight, group_size, (org_w_shape[0] + group_size - 1) // group_size).T

if (
if (Version(ort.__version__) > ONNXRT1161_VERSION and num_bits == 4) or (
Version(ort.__version__) >= ONNXRT116_VERSION and num_bits == 4 and group_size == 32
): # pragma: no cover
# MatMulFpQ4 support 4 bits and 32 group_size with ort 1.16.0 and 1.16.1 versions
# MatMulNBits supports 4 bits and 2^n group_size with ort > 1.16.1
q_weight = qdq_tensor(weight, num_bits, group_size, scheme, "uint") / np.expand_dims(
scales, axis=-1
)
Expand Down Expand Up @@ -504,10 +579,11 @@ def apply_awq_clip(model, weight_config, absorb_pairs, output_dicts, num_bits, g
for i_s in range(10):
ratio = 1 - i_s / 100
weight = copy.deepcopy(org_weight)
if (
if (Version(ort.__version__) > ONNXRT1161_VERSION and num_bits == 4) or (
Version(ort.__version__) >= ONNXRT116_VERSION and num_bits == 4 and group_size == 32
): # pragma: no cover
# currently MatMulFpQ4 only support 4 bits and 32 group_size
# MatMulFpQ4 support 4 bits and 32 group_size with ort 1.16.0 and 1.16.1 versions
# MatMulNBits supports 4 bits and 2^n group_size with ort > 1.16.1
weight = qdq_tensor(weight, num_bits, group_size, scheme, "uint", ratios.get(node.input[1], 1))
else:
weight = qdq_tensor(weight, num_bits, group_size, scheme, "int", ratios.get(node.input[1], 1))
Expand Down Expand Up @@ -571,9 +647,9 @@ def prepare_inputs(model, n_samples, dataloader):

if isinstance(data[0], dict):
inputs.append(dict([(name, to_numpy(inp_data)) for name, inp_data in data[0].items()]))
elif isinstance(data[0], np.ndarray):
elif isinstance(data[0], np.ndarray): # pragma: no cover
inputs.append(dict([(name, inp) for name, inp in zip(inputs_names, [data[0]])]))
else:
else: # pragma: no cover
inputs.append(dict([(name, to_numpy(inp)) for name, inp in zip(inputs_names, data[0])]))
return inputs, so

Expand Down Expand Up @@ -982,8 +1058,11 @@ def gptq_quantize(

weight_tensor = model.get_initializer(node.input[1])
init_share_num = model.get_initializer_share_num(node.input[1])
if Version(ort.__version__) >= ONNXRT116_VERSION and num_bits == 4 and group_size == 32: # pragma: no cover
# currently MatMulFpQ4 only support 4 bits and 32 group_size
if (Version(ort.__version__) > ONNXRT1161_VERSION and num_bits == 4) or (
Version(ort.__version__) >= ONNXRT116_VERSION and num_bits == 4 and group_size == 32
): # pragma: no cover
# MatMulFpQ4 support 4 bits and 32 group_size with ort 1.16.0 and 1.16.1 versions
# MatMulNBits supports 4 bits and 2^n group_size with ort > 1.16.1
org_shape = weight.shape
k_blocks = (org_shape[0] + group_size - 1) // group_size
q_weight = pad_tensor(q_weight, group_size, k_blocks)
Expand Down

0 comments on commit 67a31ba

Please sign in to comment.