Skip to content

Commit

Permalink
int8wo Embedding Quant (#1167)
Browse files Browse the repository at this point in the history
Summary: Added int8 embedding quant to torchAO, speeds up inference on
our llama benchmark from 107.8 -> 108.5 tok/s on A100

expected api is

quantize_(model, int8_weight_only(group_size=64), filter_fn=lambda x,
*args: isinstance(x, torch.nn.Embedding))

Test Plan:

python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --quantization embed-int8wo --compile
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile
python test_integration.py -k
"test_weight_only_groupwise_embedding_quant"

Reviewers:

Subscribers:

Tasks:

Tags:
  • Loading branch information
HDCharles authored Oct 25, 2024
1 parent b76a5e1 commit e85c1a3
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 1 deletion.
14 changes: 14 additions & 0 deletions test/integration/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -956,6 +956,20 @@ def test_weight_only_groupwise_quant(self):
sqnr = compute_error(y_ref, y_wo)
self.assertGreater(sqnr, 45.0)

def test_weight_only_groupwise_embedding_quant(self):
group_size = 64
m = nn.Embedding(4096, 128)
input = torch.randint(0, 4096, (1, 6))

quantize_(m, int8_weight_only(group_size=group_size), filter_fn=lambda x, *args: isinstance(x, nn.Embedding))
y_q = m(input)
y_ref = m.weight.dequantize()[input]

sqnr = compute_error(y_ref, y_q)

self.assertGreater(sqnr, 45.0)


@parameterized.expand(COMMON_DEVICE_DTYPE)
@torch.no_grad()
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
Expand Down
4 changes: 3 additions & 1 deletion torchao/_models/llama/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,8 @@ def main(
quantize_(model, int4_weight_only(layout=MarlinSparseLayout()))
if "fp6" in quantization:
quantize_(model, fpx_weight_only(3, 2))
if "embed-int8wo" in quantization:
quantize_(model, int8_weight_only(group_size=64), filter_fn=lambda x, *args: isinstance(x, torch.nn.Embedding))
if quantization.startswith("awq"):
from torchao._models._eval import TransformerEvalWrapper
from torchao.utils import TORCH_VERSION_AT_LEAST_2_3
Expand Down Expand Up @@ -463,7 +465,7 @@ def callback(x):
parser.add_argument('-q', '--quantization', type=str,
help=(
'Which quantization techniques to apply: int8dq, int8wo, fp6, int4wo-<groupsize>, int4wo-<groupsize>-hqq, autoquant, '
+'autoquant-int4, autoquant-float8, uintx-<nbits>-<groupsize>, uintx-<nbits>-<groupsize>-hqq, sparse-marlin, spinquant'
+'autoquant-int4, autoquant-float8, uintx-<nbits>-<groupsize>, uintx-<nbits>-<groupsize>-hqq, sparse-marlin, spinquant, embed-int8wo'
)
)
parser.add_argument("--calibration_limit", type=int, default=10, help="Number of calibration examples")
Expand Down
26 changes: 26 additions & 0 deletions torchao/dtypes/affine_quantized_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1840,6 +1840,32 @@ def _(func, types, args, kwargs):
weight_tensor = weight_tensor.dequantize()
return torch.nn.functional.linear(input_tensor, weight_tensor, bias)

@implements(torch.nn.functional.embedding)
def _(func, types, args, kwargs):
# new_arg1 = args[1].dequantize()
# return torch.nn.embedding(args[0], new_arg1, *args[2:], **kwargs)
assert isinstance(args[1].tensor_impl, PlainAQTTensorImpl), f"embedding only works with PlainAQTTensorImpl but got {type(args[1].tensor_impl)}"
assert kwargs["padding_idx"] is None and kwargs["max_norm"] is None and not kwargs["scale_grad_by_freq"] and not kwargs["sparse"] and kwargs["norm_type"]==2.0
idx = args[0]
int_data, scale, zero_point = args[1].tensor_impl.get_plain()

sliced_data, sliced_scale, sliced_zero_point = int_data[idx], scale[idx], zero_point[idx]
# Block size is expecting 2 dimensions [1, group size] but
# batchsize or other dims gets added to sliced_data, sliced_scale and sliced_zero_point so
# we need to increase block size to correct dim
new_blocks = idx.dim()-1
return dequantize_affine(
sliced_data,
new_blocks*[1]+list(args[1].block_size),
sliced_scale,
sliced_zero_point,
sliced_data.dtype,
args[1].quant_min,
args[1].quant_max,
args[1].zero_point_domain,
output_dtype=sliced_scale.dtype,
)

@implements(aten.addmm.default)
def _(func, types, args, kwargs):
input_tensor, weight_tensor, bias = (
Expand Down

0 comments on commit e85c1a3

Please sign in to comment.