Skip to content

Commit 9516764

Browse files
authored
Adds Q/DQ layout support for embedding quantization with IntxWeightOnlyConfig (#1972)
* up * up * up * up * up * up * up * up
1 parent 245e158 commit 9516764

File tree

7 files changed

+337
-52
lines changed

7 files changed

+337
-52
lines changed

torchao/dtypes/affine_quantized_tensor_ops.py

+9
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,12 @@
6464
_linear_int8_act_int8_weight_check,
6565
_linear_int8_act_int8_weight_impl,
6666
)
67+
from torchao.dtypes.uintx.q_dq_layout import (
68+
_embedding_check as _embedding_q_dq_check,
69+
)
70+
from torchao.dtypes.uintx.q_dq_layout import (
71+
_embedding_impl as _embedding_q_dq_impl,
72+
)
6773
from torchao.dtypes.uintx.q_dq_layout import (
6874
_linear_check as _linear_q_dq_check,
6975
)
@@ -263,6 +269,9 @@ def _(func, types, args, kwargs):
263269

264270
@implements(torch.nn.functional.embedding)
265271
def _(func, types, args, kwargs):
272+
if _embedding_q_dq_check(args, kwargs):
273+
return _embedding_q_dq_impl(args, kwargs)
274+
266275
# new_arg1 = args[1].dequantize()
267276
# return torch.nn.embedding(args[0], new_arg1, *args[2:], **kwargs)
268277
assert isinstance(

torchao/dtypes/uintx/q_dq_layout.py

+13
Original file line numberDiff line numberDiff line change
@@ -50,3 +50,16 @@ def _linear_impl(input_tensor, weight_tensor, bias):
5050
if isinstance(weight_tensor, AffineQuantizedTensor):
5151
weight_tensor = weight_tensor.dequantize()
5252
return torch.nn.functional.linear(input_tensor, weight_tensor, bias)
53+
54+
55+
def _embedding_check(args, kwargs):
56+
_, weight_tensor = args
57+
layout = weight_tensor.tensor_impl.get_layout()
58+
return isinstance(layout, QDQLayout)
59+
60+
61+
def _embedding_impl(args, kwargs):
62+
input_tensor, weight_tensor = args
63+
if isinstance(weight_tensor, AffineQuantizedTensor):
64+
weight_tensor = weight_tensor.dequantize()
65+
return torch.nn.functional.embedding(input_tensor, weight_tensor, **kwargs)

torchao/experimental/quant_api.py

+43-44
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
quantize_per_channel_group,
1616
)
1717

18-
from torchao.quantization.granularity import PerGroup, PerRow
18+
from torchao.quantization.granularity import Granularity, PerAxis, PerGroup, PerRow
1919
from torchao.utils import TORCH_VERSION_AT_LEAST_2_6
2020

2121
logger = logging.getLogger(__name__)
@@ -366,32 +366,44 @@ def __init__(
366366
):
367367
super().__init__()
368368
self.bit_width = bit_width
369-
self.pack_weights_op = getattr(
370-
torch.ops.torchao, f"_pack_embedding_{bit_width}bit"
371-
)
372-
self.embedding_op = getattr(torch.ops.torchao, f"_embedding_{bit_width}bit")
373369

374370
def quantize_and_pack_weights(self, weights, group_size, has_weight_zeros):
375371
assert has_weight_zeros, "has_weight_zeros must be True for QuantizedEmbedding"
376372
num_embeddings, embedding_dim = weights.shape
377-
if group_size == -1:
378-
group_size = embedding_dim
379-
self.group_size = group_size
380373

381-
weight_qvals, weight_scales, weight_zeros = _quantize(
382-
weights, self.group_size, self.bit_width, has_weight_zeros=True
374+
embedding = torch.nn.Embedding(num_embeddings, embedding_dim)
375+
embedding.weight = weights
376+
quantize_(
377+
embedding,
378+
IntxWeightOnlyConfig(
379+
weight_dtype=getattr(torch, f"int{self.bit_width}"),
380+
granularity=PerGroup(group_size) if group_size > 0 else PerAxis(0),
381+
zero_point_domain=ZeroPointDomain.INT
382+
if has_weight_zeros
383+
else ZeroPointDomain.NONE,
384+
mapping_type=MappingType.ASYMMETRIC,
385+
),
386+
lambda m, fqn: isinstance(m, torch.nn.Embedding),
387+
)
388+
weight_qvals, weight_scales, weight_zeros = (
389+
embedding.weight.tensor_impl.get_plain()
383390
)
391+
weight_scales = weight_scales.reshape(num_embeddings, -1)
392+
weight_zeros = weight_zeros.reshape(num_embeddings, -1).to(torch.int8)
384393
self.register_buffer(
385-
"packed_weight_qvals", self.pack_weights_op(weight_qvals.to(torch.int8))
394+
"packed_weight_qvals",
395+
getattr(torch.ops.torchao, f"_pack_embedding_{self.bit_width}bit")(
396+
weight_qvals.to(torch.int8)
397+
),
386398
)
387399
self.num_embeddings = num_embeddings
388400
self.embedding_dim = embedding_dim
389401
self.register_buffer("weight_scales", weight_scales)
390-
self.register_buffer("weight_zeros", weight_zeros.to(torch.int8))
402+
self.register_buffer("weight_zeros", weight_zeros)
391403

392404
def forward(self, x):
393405
shape = x.shape
394-
return self.embedding_op(
406+
return getattr(torch.ops.torchao, f"_embedding_{self.bit_width}bit")(
395407
self.packed_weight_qvals,
396408
self.num_embeddings,
397409
self.embedding_dim,
@@ -410,38 +422,23 @@ def __init__(
410422
self.bit_width = bit_width
411423

412424
def quantize_and_pack_weights(self, weights, group_size, has_weight_zeros):
413-
assert (
414-
has_weight_zeros
415-
), "has_weight_zeros must be True for QuantizedEmbeddingFallback"
416-
num_embeddings, embedding_dim = weights.shape
417-
if group_size == -1:
418-
group_size = embedding_dim
419-
self.group_size = group_size
420-
421-
weight_qvals, weight_scales, weight_zeros = _quantize(
422-
weights, self.group_size, self.bit_width, has_weight_zeros=True
425+
self.embedding = torch.nn.Embedding(*weights.shape)
426+
self.embedding.weight = weights
427+
quantize_(
428+
self.embedding,
429+
IntxWeightOnlyConfig(
430+
weight_dtype=getattr(torch, f"int{self.bit_width}"),
431+
granularity=PerGroup(group_size) if group_size > 0 else PerAxis(0),
432+
zero_point_domain=ZeroPointDomain.INT
433+
if has_weight_zeros
434+
else ZeroPointDomain.NONE,
435+
mapping_type=MappingType.ASYMMETRIC,
436+
),
437+
lambda m, fqn: isinstance(m, torch.nn.Embedding),
423438
)
424-
self.weight_qvals = weight_qvals.to(torch.int32)
425-
self.weight_scales = weight_scales
426-
self.weight_zeros = weight_zeros.to(torch.int32)
427439

428440
def forward(self, x):
429-
shape = x.shape
430-
res = []
431-
for i in x:
432-
res.append(
433-
dequantize_per_channel_group(
434-
w_int8=self.weight_qvals[i, :].reshape(1, -1),
435-
scales=self.weight_scales[i, :].reshape(1, -1),
436-
zero_points=self.weight_zeros[i, :].reshape(1, -1),
437-
quant_min=None, # TODO: why is this an arg for this function
438-
quant_max=None, # TODO: why is this an arg for this function
439-
dtype=None, # TODO: why is this an arg for this function
440-
group_size=self.group_size,
441-
output_dtype=torch.float32,
442-
).reshape(-1)
443-
)
444-
return torch.stack(res).reshape(*shape, -1)
441+
return self.embedding(x)
445442

446443

447444
class QuantizedSharedEmbedding(nn.Module):
@@ -586,15 +583,16 @@ class EmbeddingQuantizer:
586583
def __init__(
587584
self,
588585
weight_dtype: torch.dtype = torch.int4,
589-
granularity: Union[PerRow, PerGroup] = PerRow(),
586+
granularity: Granularity = PerAxis(0),
590587
has_weight_zeros: bool = True,
591588
use_fallback: bool = False,
592589
):
593590
bit_width = _dtype_to_bit_width(weight_dtype)
594591

595592
if isinstance(granularity, PerGroup):
596593
group_size = granularity.group_size
597-
elif isinstance(granularity, PerRow):
594+
elif isinstance(granularity, PerAxis):
595+
assert granularity.axis == 0
598596
group_size = -1
599597
else:
600598
raise ValueError(f"Unsupported granularity: {granularity}")
@@ -630,6 +628,7 @@ def quantize(self, model: nn.Module) -> nn.Module:
630628
to_linear_activation_quantized,
631629
)
632630
from torchao.quantization.quant_api import (
631+
IntxWeightOnlyConfig,
633632
MappingType,
634633
ZeroPointDomain,
635634
to_affine_quantized_intx,

torchao/experimental/quant_passes.py

+98
Original file line numberDiff line numberDiff line change
@@ -215,3 +215,101 @@ def replace_q_dq_patterns_with_quantized_linear_ops_pass(
215215

216216
# Re-export
217217
return torch.export.export(gm, *ep.example_inputs)
218+
219+
220+
def _get_q_dq_embedding_patterns_replacements_and_filters(
221+
weight_bit_width,
222+
):
223+
w_quant_min = -(1 << (weight_bit_width - 1))
224+
w_quant_max = (1 << (weight_bit_width - 1)) - 1
225+
w_target_dtype = torch.int8
226+
227+
def pattern(
228+
indices,
229+
w_int_data,
230+
w_block_size,
231+
w_scale,
232+
w_zero_point,
233+
):
234+
dq_w = torch.ops.quant.dequantize_affine.default(
235+
w_int_data,
236+
w_block_size,
237+
w_scale,
238+
w_zero_point,
239+
w_target_dtype,
240+
w_quant_min,
241+
w_quant_max,
242+
)
243+
return torch.ops.aten.embedding.default(dq_w, indices)
244+
245+
def replacement(
246+
indices,
247+
w_int_data,
248+
w_block_size,
249+
w_scale,
250+
w_zero_point,
251+
):
252+
num_embeddings, embedding_dim = w_int_data.size()
253+
packed_weight_qvals = getattr(
254+
torch.ops.torchao, f"_pack_embedding_{weight_bit_width}bit"
255+
)(w_int_data)
256+
out_shape = indices.shape + (embedding_dim,)
257+
group_size = w_block_size[-1]
258+
n_groups = embedding_dim // group_size
259+
w_scale = w_scale.reshape(-1, n_groups)
260+
w_zero_point = w_zero_point.reshape(-1, n_groups)
261+
return getattr(torch.ops.torchao, f"_embedding_{weight_bit_width}bit")(
262+
packed_weight_qvals,
263+
num_embeddings,
264+
embedding_dim,
265+
w_scale,
266+
w_zero_point,
267+
indices.reshape(-1),
268+
).reshape(out_shape)
269+
270+
def match_filter(match, x, y):
271+
def get_val(name):
272+
node = [n for n in match.nodes_map if n.name == name][0]
273+
return match.nodes_map[node]
274+
275+
# We only want w_block_size with shape [1, group_size]
276+
w_block_size = get_val("w_block_size")
277+
if len(w_block_size) != 2 or w_block_size[0] != 1:
278+
return False
279+
280+
return True
281+
282+
return pattern, replacement, match_filter
283+
284+
285+
def replace_q_dq_patterns_with_quantized_embedding_ops_pass(
286+
ep: torch.export.ExportedProgram,
287+
) -> torch.export.ExportedProgram:
288+
"""
289+
This replaces Q/DQ patterns with torchao quantized embedding ops.
290+
It is intended for converting Q/DQ nodes exported with QDQLayout to using
291+
the lowbit quantized embedding ops.
292+
"""
293+
# TODO: figure out how to do this with dynamic_shapes (not saved on EP for easy re-export)
294+
# See https://fb.workplace.com/groups/1028545332188949/permalink/1185289956514485/
295+
assert (
296+
len(ep.range_constraints) == 0
297+
), "ExportedProgram with range constraints are not supported"
298+
299+
# ep.module() unlifts the weight inputs, which we need for constant folding
300+
gm = ep.module()
301+
for weight_bit_width in range(1, 9):
302+
pattern, replacement, match_filter = (
303+
_get_q_dq_embedding_patterns_replacements_and_filters(
304+
weight_bit_width,
305+
)
306+
)
307+
subgraph_rewriter.replace_pattern_with_filters(
308+
gm, pattern, replacement, match_filters=[match_filter]
309+
)
310+
311+
# Constant fold evaluates and removes the packing ops
312+
constant_fold(gm)
313+
314+
# Re-export
315+
return torch.export.export(gm, *ep.example_inputs)

torchao/experimental/tests/test_embedding_xbit_quantizer.py

+5-6
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
Int8DynamicActivationIntxWeightConfig,
2020
SharedEmbeddingQuantizer,
2121
)
22-
from torchao.quantization.granularity import PerGroup, PerRow
22+
from torchao.quantization.granularity import PerAxis, PerGroup, PerRow
2323
from torchao.quantization.quant_api import quantize_
2424

2525

@@ -68,7 +68,7 @@ def test_accuracy(self):
6868

6969
def test_export_compile_aoti(self):
7070
weight_dtype = torch.int4
71-
granularity = PerRow()
71+
granularity = PerAxis(0)
7272
embedding_dim = 4096
7373
num_embeddings = 131
7474
model = torch.nn.Sequential(
@@ -113,7 +113,6 @@ def test_export_compile_aoti(self):
113113

114114
def test_shared_embedding(self):
115115
weight_dtype = torch.int4
116-
granularity = PerRow()
117116
has_weight_zeros = True
118117
embedding_dim = 4096
119118
num_embeddings = 131
@@ -134,14 +133,14 @@ def test_shared_embedding(self):
134133
quantized_model_reference = copy.deepcopy(model)
135134
EmbeddingQuantizer(
136135
weight_dtype=weight_dtype,
137-
granularity=granularity,
136+
granularity=PerAxis(0),
138137
has_weight_zeros=has_weight_zeros,
139138
).quantize(quantized_model_reference)
140139
quantize_(
141140
quantized_model_reference,
142141
Int8DynamicActivationIntxWeightConfig(
143142
weight_dtype=weight_dtype,
144-
granularity=granularity,
143+
granularity=PerRow(),
145144
has_weight_zeros=has_weight_zeros,
146145
round_weight_scale_to_bf16=False,
147146
layout=PackedLinearInt8DynamicActivationIntxWeightLayout(
@@ -155,7 +154,7 @@ def test_shared_embedding(self):
155154
quantized_model = copy.deepcopy(model)
156155
SharedEmbeddingQuantizer(
157156
weight_dtype=weight_dtype,
158-
granularity=granularity,
157+
granularity=PerRow(),
159158
has_weight_zeros=has_weight_zeros,
160159
).quantize(quantized_model)
161160

0 commit comments

Comments
 (0)