Skip to content

Commit 0f7fa57

Browse files
committed
Register codebook quant ops
Summary: Register the codebook quant / dequant ops as custom ops so they can be recongnized after export Test Plan: python test/prototype/test_codebook_quant.py -k test_export Reviewers: Subscribers: Tasks: Tags:
1 parent f38c272 commit 0f7fa57

File tree

4 files changed

+33
-7
lines changed

4 files changed

+33
-7
lines changed

test/prototype/test_codebook_quant.py

+14
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,20 @@ def test_quantize_api(self):
6969
quantize_(m, codebook_weight_only())
7070
assert type(m[0].weight) == CodebookQuantizedTensor
7171

72+
def test_export(self):
73+
m = torch.nn.Sequential(torch.nn.Linear(128, 64)).to(
74+
dtype=torch.bfloat16, device="cuda"
75+
)
76+
quantize_(m, codebook_weight_only())
77+
# quantize_(m, int4_weight_only(group_size=16))
78+
example_inputs = (torch.randn(1, 128, dtype=torch.bfloat16, device="cuda"),)
79+
print("m:", m)
80+
# torchao.utils.unwrap_tensor_subclass(m)
81+
m = torch.export.export_for_training(m, example_inputs).module()
82+
print("m:", m)
83+
targets = [n.target for n in m.graph.nodes]
84+
self.assertTrue(torch.ops.quant.quantize_codebook.default in targets)
85+
7286

7387
if __name__ == "__main__":
7488
unittest.main()

torchao/prototype/quantization/codebook/codebook_ops.py

+12-3
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,13 @@
66
_DTYPE_TO_QVALUE_BOUNDS,
77
_SUB_BYTE_UINT_BOUNDS,
88
)
9+
from torchao.utils import _register_custom_op
910

11+
quant_lib = torch.library.Library("quant", "FRAGMENT")
12+
register_custom_op = _register_custom_op(quant_lib)
1013

14+
15+
@register_custom_op
1116
def quantize_codebook(
1217
input: torch.Tensor,
1318
codebook: torch.Tensor,
@@ -90,20 +95,24 @@ def quantize_codebook(
9095
return codes.to(code_dtype)
9196

9297

98+
@register_custom_op
9399
def dequantize_codebook(
94100
codes: torch.Tensor,
95101
codebook: torch.Tensor,
102+
input_dtype: torch.dtype,
96103
scales: torch.Tensor,
97104
output_dtype: torch.dtype = torch.float32,
98105
) -> torch.Tensor:
99106
"""
100107
Reconstructs the original tensor from codes and the codebook.
101108
102109
Args:
103-
codes (torch.Tensor): Indices of codebook entries for each block,
104-
shape (d1//b1, d2//b2, ..., dN//bN).
110+
codes (torch.Tensor): torch.int32 dtype, indices of codebook entries for each block,
111+
shape (d1//b1, d2//b2, ..., dN//bN).
105112
codebook (torch.Tensor): Codebook tensor used for quantization,
106113
shape (k, b1, b2, ..., bN) where b_i are block sizes.
114+
input_dtype (torch.dtype): Input dtype for `codes`, used for downstream pattern matching
115+
and not enforced in `codes`. can be sub byte dtype like torch.uint4
107116
scales (torch.Tensor): Scales, shape (d1, d2, ..., dN // scale_block_size, 1).
108117
output_dtype (torch.dtype): dtype for the output tensor.
109118
@@ -137,7 +146,7 @@ def dequantize_codebook(
137146
dequant = dequant.view(
138147
*new_shape
139148
) # (d1, d2, ..., num_scale_blocks, scale_block_size)
140-
dequant.mul_(scales)
149+
dequant = dequant * scales
141150

142151
dequant = dequant.view(*original_shape)
143152

torchao/prototype/quantization/codebook/codebook_quantized_tensor.py

+3
Original file line numberDiff line numberDiff line change
@@ -91,12 +91,15 @@ def dequantize(self, output_dtype: Optional[torch.dtype] = None) -> torch.Tensor
9191
codes = self.codes.get_plain()
9292
else:
9393
codes = self.codes
94+
9495
if codes.dtype != torch.int32:
9596
# TODO: Investigate and support not casting to torch.int32 for indexing to improve performance
9697
codes = codes.to(torch.int32)
98+
9799
return dequantize_codebook(
98100
codes,
99101
self.codebook,
102+
self.codes.dtype,
100103
self.scales,
101104
output_dtype=output_dtype,
102105
)

torchao/utils.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -205,13 +205,13 @@ def decorator(fn):
205205

206206
# expecting fn.__name__ starts with `_` and we want to take the rest
207207
# to be the name of the custom op
208-
assert (
209-
fn.__name__[0] == "_"
210-
), f"Expecting function name starts with `_`, got {fn.__name__}"
211208
assert not any(
212209
c in fn.__name__ for c in ".<>"
213210
), f"Expecting op to be defined in normal functions, not lambda or local: {fn.__name__}"
214-
op_name = fn.__name__[1:]
211+
op_name = fn.__name__
212+
if op_name[0] == "_":
213+
op_name = op_name[1:]
214+
215215
schema = op_name + infer_schema(fn, mutates_args={})
216216
lib.define(schema)
217217
lib.impl(op_name, fn, "CompositeImplicitAutograd")

0 commit comments

Comments
 (0)