From 0117572ea8f98124ecd524732f9207973e5f4642 Mon Sep 17 00:00:00 2001
From: Jerry Zhang <jerryzh168@gmail.com>
Date: Tue, 1 Apr 2025 10:54:06 -0700
Subject: [PATCH 1/3] Register codebook quant ops

Summary:
Register the codebook quant / dequant ops as custom ops so they can be recongnized after export

Test Plan:
python test/prototype/test_codebook_quant.py -k test_export

Reviewers:

Subscribers:

Tasks:

Tags:
---
 test/prototype/test_codebook_quant.py         | 16 +++++++++++++++-
 .../quantization/codebook/codebook_ops.py     | 19 +++++++++++++++----
 .../codebook/codebook_quantized_tensor.py     |  3 +++
 torchao/utils.py                              |  8 ++++----
 4 files changed, 37 insertions(+), 9 deletions(-)

diff --git a/test/prototype/test_codebook_quant.py b/test/prototype/test_codebook_quant.py
index 0614fc9632..5a71539f29 100644
--- a/test/prototype/test_codebook_quant.py
+++ b/test/prototype/test_codebook_quant.py
@@ -20,7 +20,7 @@ class TestCodebookQuantization(unittest.TestCase):
     def setUp(self):
         torch.manual_seed(123)
         self.input = torch.randn(100, 256, dtype=torch.float32)
-        self.block_size = (1, 1)
+        self.block_size = (2, 2)
         self.scale_block_size = 64
         self.code_dtype = torch.uint8
         self.chunk_size = 1024
@@ -74,6 +74,20 @@ def test_quantize_api(self):
         quantize_(m, codebook_weight_only())
         assert type(m[0].weight) == CodebookQuantizedTensor
 
+    def test_export(self):
+        m = torch.nn.Sequential(torch.nn.Linear(128, 64)).to(
+            dtype=torch.bfloat16, device="cuda"
+        )
+        quantize_(m, codebook_weight_only())
+        # quantize_(m, int4_weight_only(group_size=16))
+        example_inputs = (torch.randn(1, 128, dtype=torch.bfloat16, device="cuda"),)
+        print("m:", m)
+        # torchao.utils.unwrap_tensor_subclass(m)
+        m = torch.export.export_for_training(m, example_inputs).module()
+        print("m:", m)
+        targets = [n.target for n in m.graph.nodes]
+        self.assertTrue(torch.ops.quant.quantize_codebook.default in targets)
+
 
 if __name__ == "__main__":
     unittest.main()
diff --git a/torchao/prototype/quantization/codebook/codebook_ops.py b/torchao/prototype/quantization/codebook/codebook_ops.py
index a7db8774ea..ff51194b1b 100644
--- a/torchao/prototype/quantization/codebook/codebook_ops.py
+++ b/torchao/prototype/quantization/codebook/codebook_ops.py
@@ -11,8 +11,13 @@
     _DTYPE_TO_QVALUE_BOUNDS,
     _SUB_BYTE_UINT_BOUNDS,
 )
+from torchao.utils import _register_custom_op
 
+quant_lib = torch.library.Library("quant", "FRAGMENT")
+register_custom_op = _register_custom_op(quant_lib)
 
+
+@register_custom_op
 def quantize_codebook(
     input: torch.Tensor,
     codebook: torch.Tensor,
@@ -25,7 +30,8 @@ def quantize_codebook(
 
     Args:
         input (torch.Tensor): Input tensor to quantize, shape (d1, d2, ..., dN).
-        codebook (torch.Tensor): Codebook tensor for quantization, shape (k, b1, b2, ..., bN) where b_i are block sizes.
+        codebook (torch.Tensor): Codebook tensor for quantization, shape (k, b1, b2, ..., bN) where b_i are block sizes and k is the codebook_size, e.g. for uint4 (4 bit), codebook size is 2**4
+            one corresponding dequantized vector of (b1, b2, .., bN) dimension for each of uint4 integer value of 0 to 15
         scales (torch.Tensor): Scales, shape (d1, d2, ..., dN // scale_block_size, 1).
         chunk_size (int): Number of elements to process per chunk to control memory usage.
         code_dtype (torch.dtype): dtype for the codes.
@@ -95,9 +101,11 @@ def quantize_codebook(
     return codes.to(code_dtype)
 
 
+@register_custom_op
 def dequantize_codebook(
     codes: torch.Tensor,
     codebook: torch.Tensor,
+    input_dtype: torch.dtype,
     scales: torch.Tensor,
     output_dtype: torch.dtype = torch.float32,
 ) -> torch.Tensor:
@@ -105,10 +113,12 @@ def dequantize_codebook(
     Reconstructs the original tensor from codes and the codebook.
 
     Args:
-        codes (torch.Tensor): Indices of codebook entries for each block,
-                                          shape (d1//b1, d2//b2, ..., dN//bN).
+        codes (torch.Tensor): torch.int32 dtype, indices of codebook entries for each block,
+                              shape (d1//b1, d2//b2, ..., dN//bN).
         codebook (torch.Tensor): Codebook tensor used for quantization,
                                  shape (k, b1, b2, ..., bN) where b_i are block sizes.
+        input_dtype (torch.dtype): Input dtype for `codes`, used for downstream pattern matching
+                             and not enforced in `codes`. can be sub byte dtype like torch.uint4
         scales (torch.Tensor): Scales, shape (d1, d2, ..., dN // scale_block_size, 1).
         output_dtype (torch.dtype): dtype for the output tensor.
 
@@ -142,7 +152,7 @@ def dequantize_codebook(
     dequant = dequant.view(
         *new_shape
     )  # (d1, d2, ..., num_scale_blocks, scale_block_size)
-    dequant.mul_(scales)
+    dequant = dequant * scales
 
     dequant = dequant.view(*original_shape)
 
@@ -172,6 +182,7 @@ def choose_qparams_codebook(
     Returns:
         torch.Tensor: The codebook tensor, shape (codebook_size, *block_size).
     """
+    breakpoint()
     if code_dtype == torch.int32:
         codebook_size = 2**16
     else:
diff --git a/torchao/prototype/quantization/codebook/codebook_quantized_tensor.py b/torchao/prototype/quantization/codebook/codebook_quantized_tensor.py
index e16a339e82..82d2be7c47 100644
--- a/torchao/prototype/quantization/codebook/codebook_quantized_tensor.py
+++ b/torchao/prototype/quantization/codebook/codebook_quantized_tensor.py
@@ -96,12 +96,15 @@ def dequantize(self, output_dtype: Optional[torch.dtype] = None) -> torch.Tensor
             codes = self.codes.get_plain()
         else:
             codes = self.codes
+
         if codes.dtype != torch.int32:
             # TODO: Investigate and support not casting to torch.int32 for indexing to improve performance
             codes = codes.to(torch.int32)
+
         return dequantize_codebook(
             codes,
             self.codebook,
+            self.codes.dtype,
             self.scales,
             output_dtype=output_dtype,
         )
diff --git a/torchao/utils.py b/torchao/utils.py
index 5577a66637..08850f8f5a 100644
--- a/torchao/utils.py
+++ b/torchao/utils.py
@@ -210,13 +210,13 @@ def decorator(fn):
 
             # expecting fn.__name__ starts with `_` and we want to take the rest
             # to be the name of the custom op
-            assert (
-                fn.__name__[0] == "_"
-            ), f"Expecting function name starts with `_`, got {fn.__name__}"
             assert not any(
                 c in fn.__name__ for c in ".<>"
             ), f"Expecting op to be defined in normal functions, not lambda or local: {fn.__name__}"
-            op_name = fn.__name__[1:]
+            op_name = fn.__name__
+            if op_name[0] == "_":
+                op_name = op_name[1:]
+
             schema = op_name + infer_schema(fn, mutates_args={})
             lib.define(schema)
             lib.impl(op_name, fn, "CompositeImplicitAutograd")

From 7600e9bcccf603d8f8c71367ccef66d9aec74e49 Mon Sep 17 00:00:00 2001
From: Jerry Zhang <jerryzh168@gmail.com>
Date: Tue, 8 Apr 2025 15:15:06 -0700
Subject: [PATCH 2/3] update

---
 test/prototype/test_codebook_quant.py           | 17 ++++++++---------
 .../prototype/quantization/codebook/__init__.py |  4 ++--
 .../quantization/codebook/codebook_ops.py       |  1 -
 torchao/testing/utils.py                        |  1 +
 4 files changed, 11 insertions(+), 12 deletions(-)

diff --git a/test/prototype/test_codebook_quant.py b/test/prototype/test_codebook_quant.py
index 5a71539f29..9fbd9ef7d7 100644
--- a/test/prototype/test_codebook_quant.py
+++ b/test/prototype/test_codebook_quant.py
@@ -9,18 +9,19 @@
 
 from torchao.prototype.quantization.codebook import (
     CodebookQuantizedTensor,
+    CodebookWeightOnlyConfig,
     choose_qparams_codebook,
-    codebook_weight_only,
 )
 from torchao.quantization import quantize_
 from torchao.quantization.utils import compute_error
+from torchao.testing.utils import skip_if_no_cuda
 
 
 class TestCodebookQuantization(unittest.TestCase):
     def setUp(self):
         torch.manual_seed(123)
         self.input = torch.randn(100, 256, dtype=torch.float32)
-        self.block_size = (2, 2)
+        self.block_size = (1, 1)
         self.scale_block_size = 64
         self.code_dtype = torch.uint8
         self.chunk_size = 1024
@@ -71,16 +72,14 @@ def test_codebook_quantized_tensor_from_float2(self):
 
     def test_quantize_api(self):
         m = torch.nn.Sequential(torch.nn.Linear(64, 64))
-        quantize_(m, codebook_weight_only())
+        quantize_(m, CodebookWeightOnlyConfig())
         assert type(m[0].weight) == CodebookQuantizedTensor
 
+    @skip_if_no_cuda()
     def test_export(self):
-        m = torch.nn.Sequential(torch.nn.Linear(128, 64)).to(
-            dtype=torch.bfloat16, device="cuda"
-        )
-        quantize_(m, codebook_weight_only())
-        # quantize_(m, int4_weight_only(group_size=16))
-        example_inputs = (torch.randn(1, 128, dtype=torch.bfloat16, device="cuda"),)
+        m = torch.nn.Sequential(torch.nn.Linear(128, 64)).to(dtype=torch.bfloat16)
+        quantize_(m, CodebookWeightOnlyConfig())
+        example_inputs = (torch.randn(1, 128, dtype=torch.bfloat16),)
         print("m:", m)
         # torchao.utils.unwrap_tensor_subclass(m)
         m = torch.export.export_for_training(m, example_inputs).module()
diff --git a/torchao/prototype/quantization/codebook/__init__.py b/torchao/prototype/quantization/codebook/__init__.py
index 3fba2beedd..c8ee2a2e34 100644
--- a/torchao/prototype/quantization/codebook/__init__.py
+++ b/torchao/prototype/quantization/codebook/__init__.py
@@ -3,11 +3,11 @@
     dequantize_codebook,
     quantize_codebook,
 )
-from .codebook_quantized_tensor import CodebookQuantizedTensor, codebook_weight_only
+from .codebook_quantized_tensor import CodebookQuantizedTensor, CodebookWeightOnlyConfig
 
 __all__ = [
     "CodebookQuantizedTensor",
-    "codebook_weight_only",
+    "CodebookWeightOnlyConfig",
     "quantize_codebook",
     "dequantize_codebook",
     "choose_qparams_codebook",
diff --git a/torchao/prototype/quantization/codebook/codebook_ops.py b/torchao/prototype/quantization/codebook/codebook_ops.py
index ff51194b1b..dd37b517a9 100644
--- a/torchao/prototype/quantization/codebook/codebook_ops.py
+++ b/torchao/prototype/quantization/codebook/codebook_ops.py
@@ -182,7 +182,6 @@ def choose_qparams_codebook(
     Returns:
         torch.Tensor: The codebook tensor, shape (codebook_size, *block_size).
     """
-    breakpoint()
     if code_dtype == torch.int32:
         codebook_size = 2**16
     else:
diff --git a/torchao/testing/utils.py b/torchao/testing/utils.py
index da6512468a..b401ea382a 100644
--- a/torchao/testing/utils.py
+++ b/torchao/testing/utils.py
@@ -96,6 +96,7 @@ def skip_if_no_cuda():
     def decorator(test_func):
         def wrapper(*args, **kwargs):
             if not torch.cuda.is_available():
+                print("no cuda available")
                 raise unittest.SkipTest("No cuda available")
             return test_func(*args, **kwargs)
 

From ebaa5fc2d7e31b621458e347f064032fd090e7ad Mon Sep 17 00:00:00 2001
From: Jerry Zhang <jerryzh168@gmail.com>
Date: Tue, 8 Apr 2025 19:42:06 -0700
Subject: [PATCH 3/3] version guard

---
 test/prototype/test_codebook_quant.py | 2 ++
 1 file changed, 2 insertions(+)

diff --git a/test/prototype/test_codebook_quant.py b/test/prototype/test_codebook_quant.py
index 9fbd9ef7d7..3317b598ab 100644
--- a/test/prototype/test_codebook_quant.py
+++ b/test/prototype/test_codebook_quant.py
@@ -15,6 +15,7 @@
 from torchao.quantization import quantize_
 from torchao.quantization.utils import compute_error
 from torchao.testing.utils import skip_if_no_cuda
+from torchao.utils import TORCH_VERSION_AT_LEAST_2_5
 
 
 class TestCodebookQuantization(unittest.TestCase):
@@ -76,6 +77,7 @@ def test_quantize_api(self):
         assert type(m[0].weight) == CodebookQuantizedTensor
 
     @skip_if_no_cuda()
+    @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "requires 2.5+.")
     def test_export(self):
         m = torch.nn.Sequential(torch.nn.Linear(128, 64)).to(dtype=torch.bfloat16)
         quantize_(m, CodebookWeightOnlyConfig())