Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SLIM][AWQ] AWQ GEMM support #1362

Merged
merged 1 commit into from
Dec 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions python/mlc_chat/compiler/model/llama/llama_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,10 @@ def awq(model_config: LlamaConfig, quantization: Quantization) -> ExternMapping:
f"{attn}.v_proj.{quantize_suffix}",
],
functools.partial(
lambda q, k, v, dtype: np.concatenate([q, k, v], axis=0).astype(dtype),
lambda q, k, v, dtype: np.concatenate(
[q, k, v],
axis=1, # AWQ GEMM would transpose the weight
).astype(dtype),
dtype=mlc_param.dtype,
),
)
Expand All @@ -140,7 +143,10 @@ def awq(model_config: LlamaConfig, quantization: Quantization) -> ExternMapping:
f"{mlp}.up_proj.{quantize_suffix}",
],
functools.partial(
lambda gate, up, dtype: np.concatenate([gate, up], axis=0).astype(dtype),
lambda gate, up, dtype: np.concatenate(
[gate, up],
axis=1, # AWQ GEMM would transpose the weight
).astype(dtype),
dtype=mlc_param.dtype,
),
)
Expand Down
28 changes: 12 additions & 16 deletions python/mlc_chat/compiler/quantization/awq_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from dataclasses import dataclass, field
from typing import Any, Callable, Dict, List, Optional

from tvm import DataType, DataTypeCode, te, tir
from tvm import DataType, DataTypeCode, te, tir, topi
from tvm.relax.frontend import nn
from tvm.runtime import NDArray

Expand Down Expand Up @@ -138,16 +138,21 @@ def _dequantize(
self.num_elem_per_storage,
self.storage_dtype,
self.model_dtype,
out_shape,
[weight.shape[0], weight.shape[1] * self.num_elem_per_storage],
ft_reorder=True,
)
float_zeros = convert_uint_to_float(
zeros,
DataType(self.quantize_dtype).bits,
self.num_elem_per_storage,
self.storage_dtype,
self.model_dtype,
out_shape,
[zeros.shape[0], zeros.shape[1] * self.num_elem_per_storage],
ft_reorder=True,
)
float_weight = topi.transpose(float_weight)
float_zeros = topi.transpose(float_zeros)
scale = topi.transpose(scale)
return te.compute(
shape=[weight.shape[0], weight.shape[1] * self.num_elem_per_storage]
if out_shape is None
Expand Down Expand Up @@ -177,23 +182,14 @@ def __init__( # pylint: disable=too-many-arguments
self.out_dtype = out_dtype
self.config = config
self.qweight = nn.Parameter(
(out_features, tir.ceildiv(in_features, config.num_elem_per_storage)),
config.storage_dtype,
(in_features, out_features // config.num_elem_per_storage), config.storage_dtype
)
self.qzeros = nn.Parameter(
(
out_features,
_calculate_zeros_width(in_features, config.group_size, config.num_elem_per_storage),
),
dtype=config.storage_dtype,
(in_features // config.group_size, out_features // config.num_elem_per_storage),
config.storage_dtype,
)
self.scales = nn.Parameter(
(
out_features,
_calculate_zeros_width(in_features, config.group_size, config.num_elem_per_storage)
* config.num_elem_per_storage,
),
config.model_dtype,
(in_features // config.group_size, out_features), config.model_dtype
)
if bias:
self.bias = nn.Parameter(
Expand Down
4 changes: 2 additions & 2 deletions python/mlc_chat/compiler/quantization/quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,8 @@ def quantize_weight(self, weight: tvm.runtime.NDArray) -> List[tvm.runtime.NDArr
storage_dtype="uint32",
model_dtype="float32",
),
"q4f16_awq": AWQQuantize(
name="q4f16_awq",
"q4f16_autoawq": AWQQuantize(
name="q4f16_autoawq",
kind="awq",
group_size=128,
quantize_dtype="int4",
Expand Down
7 changes: 6 additions & 1 deletion python/mlc_chat/compiler/quantization/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ def convert_uint_to_float( # pylint: disable=too-many-arguments
storage_dtype: str,
model_dtype: str,
out_shape: Optional[List[tir.PrimExpr]] = None,
ft_reorder: Optional[bool] = False,
) -> te.Tensor:
"""Convert a quantized uint weight to an unquantized float weight."""
tir_bin_mask = tir.const((1 << bits) - 1, storage_dtype)
Expand All @@ -21,7 +22,11 @@ def convert_uint_to_float( # pylint: disable=too-many-arguments
fcompute=lambda i, j: tir.bitwise_and(
tir.shift_right(
weight[i, j // num_elem_per_storage],
((j % num_elem_per_storage) * bits).astype(storage_dtype),
(
((j % num_elem_per_storage) % 2 * 4 + (j % num_elem_per_storage) // 2) * bits
if ft_reorder
else (j % num_elem_per_storage) * bits
).astype(storage_dtype),
),
tir_bin_mask,
).astype(model_dtype),
Expand Down