From 0117572ea8f98124ecd524732f9207973e5f4642 Mon Sep 17 00:00:00 2001 From: Jerry Zhang <jerryzh168@gmail.com> Date: Tue, 1 Apr 2025 10:54:06 -0700 Subject: [PATCH 1/3] Register codebook quant ops Summary: Register the codebook quant / dequant ops as custom ops so they can be recongnized after export Test Plan: python test/prototype/test_codebook_quant.py -k test_export Reviewers: Subscribers: Tasks: Tags: --- test/prototype/test_codebook_quant.py | 16 +++++++++++++++- .../quantization/codebook/codebook_ops.py | 19 +++++++++++++++---- .../codebook/codebook_quantized_tensor.py | 3 +++ torchao/utils.py | 8 ++++---- 4 files changed, 37 insertions(+), 9 deletions(-) diff --git a/test/prototype/test_codebook_quant.py b/test/prototype/test_codebook_quant.py index 0614fc9632..5a71539f29 100644 --- a/test/prototype/test_codebook_quant.py +++ b/test/prototype/test_codebook_quant.py @@ -20,7 +20,7 @@ class TestCodebookQuantization(unittest.TestCase): def setUp(self): torch.manual_seed(123) self.input = torch.randn(100, 256, dtype=torch.float32) - self.block_size = (1, 1) + self.block_size = (2, 2) self.scale_block_size = 64 self.code_dtype = torch.uint8 self.chunk_size = 1024 @@ -74,6 +74,20 @@ def test_quantize_api(self): quantize_(m, codebook_weight_only()) assert type(m[0].weight) == CodebookQuantizedTensor + def test_export(self): + m = torch.nn.Sequential(torch.nn.Linear(128, 64)).to( + dtype=torch.bfloat16, device="cuda" + ) + quantize_(m, codebook_weight_only()) + # quantize_(m, int4_weight_only(group_size=16)) + example_inputs = (torch.randn(1, 128, dtype=torch.bfloat16, device="cuda"),) + print("m:", m) + # torchao.utils.unwrap_tensor_subclass(m) + m = torch.export.export_for_training(m, example_inputs).module() + print("m:", m) + targets = [n.target for n in m.graph.nodes] + self.assertTrue(torch.ops.quant.quantize_codebook.default in targets) + if __name__ == "__main__": unittest.main() diff --git a/torchao/prototype/quantization/codebook/codebook_ops.py b/torchao/prototype/quantization/codebook/codebook_ops.py index a7db8774ea..ff51194b1b 100644 --- a/torchao/prototype/quantization/codebook/codebook_ops.py +++ b/torchao/prototype/quantization/codebook/codebook_ops.py @@ -11,8 +11,13 @@ _DTYPE_TO_QVALUE_BOUNDS, _SUB_BYTE_UINT_BOUNDS, ) +from torchao.utils import _register_custom_op +quant_lib = torch.library.Library("quant", "FRAGMENT") +register_custom_op = _register_custom_op(quant_lib) + +@register_custom_op def quantize_codebook( input: torch.Tensor, codebook: torch.Tensor, @@ -25,7 +30,8 @@ def quantize_codebook( Args: input (torch.Tensor): Input tensor to quantize, shape (d1, d2, ..., dN). - codebook (torch.Tensor): Codebook tensor for quantization, shape (k, b1, b2, ..., bN) where b_i are block sizes. + codebook (torch.Tensor): Codebook tensor for quantization, shape (k, b1, b2, ..., bN) where b_i are block sizes and k is the codebook_size, e.g. for uint4 (4 bit), codebook size is 2**4 + one corresponding dequantized vector of (b1, b2, .., bN) dimension for each of uint4 integer value of 0 to 15 scales (torch.Tensor): Scales, shape (d1, d2, ..., dN // scale_block_size, 1). chunk_size (int): Number of elements to process per chunk to control memory usage. code_dtype (torch.dtype): dtype for the codes. @@ -95,9 +101,11 @@ def quantize_codebook( return codes.to(code_dtype) +@register_custom_op def dequantize_codebook( codes: torch.Tensor, codebook: torch.Tensor, + input_dtype: torch.dtype, scales: torch.Tensor, output_dtype: torch.dtype = torch.float32, ) -> torch.Tensor: @@ -105,10 +113,12 @@ def dequantize_codebook( Reconstructs the original tensor from codes and the codebook. Args: - codes (torch.Tensor): Indices of codebook entries for each block, - shape (d1//b1, d2//b2, ..., dN//bN). + codes (torch.Tensor): torch.int32 dtype, indices of codebook entries for each block, + shape (d1//b1, d2//b2, ..., dN//bN). codebook (torch.Tensor): Codebook tensor used for quantization, shape (k, b1, b2, ..., bN) where b_i are block sizes. + input_dtype (torch.dtype): Input dtype for `codes`, used for downstream pattern matching + and not enforced in `codes`. can be sub byte dtype like torch.uint4 scales (torch.Tensor): Scales, shape (d1, d2, ..., dN // scale_block_size, 1). output_dtype (torch.dtype): dtype for the output tensor. @@ -142,7 +152,7 @@ def dequantize_codebook( dequant = dequant.view( *new_shape ) # (d1, d2, ..., num_scale_blocks, scale_block_size) - dequant.mul_(scales) + dequant = dequant * scales dequant = dequant.view(*original_shape) @@ -172,6 +182,7 @@ def choose_qparams_codebook( Returns: torch.Tensor: The codebook tensor, shape (codebook_size, *block_size). """ + breakpoint() if code_dtype == torch.int32: codebook_size = 2**16 else: diff --git a/torchao/prototype/quantization/codebook/codebook_quantized_tensor.py b/torchao/prototype/quantization/codebook/codebook_quantized_tensor.py index e16a339e82..82d2be7c47 100644 --- a/torchao/prototype/quantization/codebook/codebook_quantized_tensor.py +++ b/torchao/prototype/quantization/codebook/codebook_quantized_tensor.py @@ -96,12 +96,15 @@ def dequantize(self, output_dtype: Optional[torch.dtype] = None) -> torch.Tensor codes = self.codes.get_plain() else: codes = self.codes + if codes.dtype != torch.int32: # TODO: Investigate and support not casting to torch.int32 for indexing to improve performance codes = codes.to(torch.int32) + return dequantize_codebook( codes, self.codebook, + self.codes.dtype, self.scales, output_dtype=output_dtype, ) diff --git a/torchao/utils.py b/torchao/utils.py index 5577a66637..08850f8f5a 100644 --- a/torchao/utils.py +++ b/torchao/utils.py @@ -210,13 +210,13 @@ def decorator(fn): # expecting fn.__name__ starts with `_` and we want to take the rest # to be the name of the custom op - assert ( - fn.__name__[0] == "_" - ), f"Expecting function name starts with `_`, got {fn.__name__}" assert not any( c in fn.__name__ for c in ".<>" ), f"Expecting op to be defined in normal functions, not lambda or local: {fn.__name__}" - op_name = fn.__name__[1:] + op_name = fn.__name__ + if op_name[0] == "_": + op_name = op_name[1:] + schema = op_name + infer_schema(fn, mutates_args={}) lib.define(schema) lib.impl(op_name, fn, "CompositeImplicitAutograd") From 7600e9bcccf603d8f8c71367ccef66d9aec74e49 Mon Sep 17 00:00:00 2001 From: Jerry Zhang <jerryzh168@gmail.com> Date: Tue, 8 Apr 2025 15:15:06 -0700 Subject: [PATCH 2/3] update --- test/prototype/test_codebook_quant.py | 17 ++++++++--------- .../prototype/quantization/codebook/__init__.py | 4 ++-- .../quantization/codebook/codebook_ops.py | 1 - torchao/testing/utils.py | 1 + 4 files changed, 11 insertions(+), 12 deletions(-) diff --git a/test/prototype/test_codebook_quant.py b/test/prototype/test_codebook_quant.py index 5a71539f29..9fbd9ef7d7 100644 --- a/test/prototype/test_codebook_quant.py +++ b/test/prototype/test_codebook_quant.py @@ -9,18 +9,19 @@ from torchao.prototype.quantization.codebook import ( CodebookQuantizedTensor, + CodebookWeightOnlyConfig, choose_qparams_codebook, - codebook_weight_only, ) from torchao.quantization import quantize_ from torchao.quantization.utils import compute_error +from torchao.testing.utils import skip_if_no_cuda class TestCodebookQuantization(unittest.TestCase): def setUp(self): torch.manual_seed(123) self.input = torch.randn(100, 256, dtype=torch.float32) - self.block_size = (2, 2) + self.block_size = (1, 1) self.scale_block_size = 64 self.code_dtype = torch.uint8 self.chunk_size = 1024 @@ -71,16 +72,14 @@ def test_codebook_quantized_tensor_from_float2(self): def test_quantize_api(self): m = torch.nn.Sequential(torch.nn.Linear(64, 64)) - quantize_(m, codebook_weight_only()) + quantize_(m, CodebookWeightOnlyConfig()) assert type(m[0].weight) == CodebookQuantizedTensor + @skip_if_no_cuda() def test_export(self): - m = torch.nn.Sequential(torch.nn.Linear(128, 64)).to( - dtype=torch.bfloat16, device="cuda" - ) - quantize_(m, codebook_weight_only()) - # quantize_(m, int4_weight_only(group_size=16)) - example_inputs = (torch.randn(1, 128, dtype=torch.bfloat16, device="cuda"),) + m = torch.nn.Sequential(torch.nn.Linear(128, 64)).to(dtype=torch.bfloat16) + quantize_(m, CodebookWeightOnlyConfig()) + example_inputs = (torch.randn(1, 128, dtype=torch.bfloat16),) print("m:", m) # torchao.utils.unwrap_tensor_subclass(m) m = torch.export.export_for_training(m, example_inputs).module() diff --git a/torchao/prototype/quantization/codebook/__init__.py b/torchao/prototype/quantization/codebook/__init__.py index 3fba2beedd..c8ee2a2e34 100644 --- a/torchao/prototype/quantization/codebook/__init__.py +++ b/torchao/prototype/quantization/codebook/__init__.py @@ -3,11 +3,11 @@ dequantize_codebook, quantize_codebook, ) -from .codebook_quantized_tensor import CodebookQuantizedTensor, codebook_weight_only +from .codebook_quantized_tensor import CodebookQuantizedTensor, CodebookWeightOnlyConfig __all__ = [ "CodebookQuantizedTensor", - "codebook_weight_only", + "CodebookWeightOnlyConfig", "quantize_codebook", "dequantize_codebook", "choose_qparams_codebook", diff --git a/torchao/prototype/quantization/codebook/codebook_ops.py b/torchao/prototype/quantization/codebook/codebook_ops.py index ff51194b1b..dd37b517a9 100644 --- a/torchao/prototype/quantization/codebook/codebook_ops.py +++ b/torchao/prototype/quantization/codebook/codebook_ops.py @@ -182,7 +182,6 @@ def choose_qparams_codebook( Returns: torch.Tensor: The codebook tensor, shape (codebook_size, *block_size). """ - breakpoint() if code_dtype == torch.int32: codebook_size = 2**16 else: diff --git a/torchao/testing/utils.py b/torchao/testing/utils.py index da6512468a..b401ea382a 100644 --- a/torchao/testing/utils.py +++ b/torchao/testing/utils.py @@ -96,6 +96,7 @@ def skip_if_no_cuda(): def decorator(test_func): def wrapper(*args, **kwargs): if not torch.cuda.is_available(): + print("no cuda available") raise unittest.SkipTest("No cuda available") return test_func(*args, **kwargs) From ebaa5fc2d7e31b621458e347f064032fd090e7ad Mon Sep 17 00:00:00 2001 From: Jerry Zhang <jerryzh168@gmail.com> Date: Tue, 8 Apr 2025 19:42:06 -0700 Subject: [PATCH 3/3] version guard --- test/prototype/test_codebook_quant.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/test/prototype/test_codebook_quant.py b/test/prototype/test_codebook_quant.py index 9fbd9ef7d7..3317b598ab 100644 --- a/test/prototype/test_codebook_quant.py +++ b/test/prototype/test_codebook_quant.py @@ -15,6 +15,7 @@ from torchao.quantization import quantize_ from torchao.quantization.utils import compute_error from torchao.testing.utils import skip_if_no_cuda +from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 class TestCodebookQuantization(unittest.TestCase): @@ -76,6 +77,7 @@ def test_quantize_api(self): assert type(m[0].weight) == CodebookQuantizedTensor @skip_if_no_cuda() + @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "requires 2.5+.") def test_export(self): m = torch.nn.Sequential(torch.nn.Linear(128, 64)).to(dtype=torch.bfloat16) quantize_(m, CodebookWeightOnlyConfig())