Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Register codebook quant ops #1988

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 17 additions & 2 deletions test/prototype/test_codebook_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,13 @@

from torchao.prototype.quantization.codebook import (
CodebookQuantizedTensor,
CodebookWeightOnlyConfig,
choose_qparams_codebook,
codebook_weight_only,
)
from torchao.quantization import quantize_
from torchao.quantization.utils import compute_error
from torchao.testing.utils import skip_if_no_cuda
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5


class TestCodebookQuantization(unittest.TestCase):
Expand Down Expand Up @@ -71,9 +73,22 @@ def test_codebook_quantized_tensor_from_float2(self):

def test_quantize_api(self):
m = torch.nn.Sequential(torch.nn.Linear(64, 64))
quantize_(m, codebook_weight_only())
quantize_(m, CodebookWeightOnlyConfig())
assert type(m[0].weight) == CodebookQuantizedTensor

@skip_if_no_cuda()
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "requires 2.5+.")
def test_export(self):
m = torch.nn.Sequential(torch.nn.Linear(128, 64)).to(dtype=torch.bfloat16)
quantize_(m, CodebookWeightOnlyConfig())
example_inputs = (torch.randn(1, 128, dtype=torch.bfloat16),)
print("m:", m)
# torchao.utils.unwrap_tensor_subclass(m)
m = torch.export.export_for_training(m, example_inputs).module()
print("m:", m)
targets = [n.target for n in m.graph.nodes]
self.assertTrue(torch.ops.quant.quantize_codebook.default in targets)


if __name__ == "__main__":
unittest.main()
4 changes: 2 additions & 2 deletions torchao/prototype/quantization/codebook/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@
dequantize_codebook,
quantize_codebook,
)
from .codebook_quantized_tensor import CodebookQuantizedTensor, codebook_weight_only
from .codebook_quantized_tensor import CodebookQuantizedTensor, CodebookWeightOnlyConfig

__all__ = [
"CodebookQuantizedTensor",
"codebook_weight_only",
"CodebookWeightOnlyConfig",
"quantize_codebook",
"dequantize_codebook",
"choose_qparams_codebook",
Expand Down
18 changes: 14 additions & 4 deletions torchao/prototype/quantization/codebook/codebook_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,13 @@
_DTYPE_TO_QVALUE_BOUNDS,
_SUB_BYTE_UINT_BOUNDS,
)
from torchao.utils import _register_custom_op

quant_lib = torch.library.Library("quant", "FRAGMENT")
register_custom_op = _register_custom_op(quant_lib)


@register_custom_op
def quantize_codebook(
input: torch.Tensor,
codebook: torch.Tensor,
Expand All @@ -25,7 +30,8 @@ def quantize_codebook(

Args:
input (torch.Tensor): Input tensor to quantize, shape (d1, d2, ..., dN).
codebook (torch.Tensor): Codebook tensor for quantization, shape (k, b1, b2, ..., bN) where b_i are block sizes.
codebook (torch.Tensor): Codebook tensor for quantization, shape (k, b1, b2, ..., bN) where b_i are block sizes and k is the codebook_size, e.g. for uint4 (4 bit), codebook size is 2**4
one corresponding dequantized vector of (b1, b2, .., bN) dimension for each of uint4 integer value of 0 to 15
scales (torch.Tensor): Scales, shape (d1, d2, ..., dN // scale_block_size, 1).
chunk_size (int): Number of elements to process per chunk to control memory usage.
code_dtype (torch.dtype): dtype for the codes.
Expand Down Expand Up @@ -95,20 +101,24 @@ def quantize_codebook(
return codes.to(code_dtype)


@register_custom_op
def dequantize_codebook(
Copy link
Contributor

@metascroy metascroy Apr 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIUC, this does not look like it supports granularity, which we will want.

From what I can tell, k is the idx range, e.g., for 4-bit quantization, k = 16. Each idx=i is mapped to the tensor codebook[i]. So we have 1 codebook/LUT for the tensor that maps indices to tensors.

This seems a bit complicated to me. For CoreML, the default is each idx maps to a scalar (but they also support mapping to a vector). I'm not sure if anyone will need tensor-valued look up values.

But we do want granularity in the sense that we can have one codebook per channel, grouped channel, tensor, etc.

Maybe this is what was originally intended for the block_size (based on https://github.com/pytorch/ao/pull/1299/files/53874a005cb174f764363a7c3a22f653ccf738df#r1870108715), but I understand the code correctly, that's not what got implemented.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the scale_block_size in choose_qparams_codebook or the shape of scales in the dequant op is supposed to allow us to control the granularity, the block_sizearg seems to have a different meaning than the block_size in other ops, so we should probably rename it, may guess is the block_size of tensor values that share the same kmeans cluster value.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh wait, the granularity of codebook is separate, let me take a look again

codes: torch.Tensor,
codebook: torch.Tensor,
input_dtype: torch.dtype,
scales: torch.Tensor,
output_dtype: torch.dtype = torch.float32,
) -> torch.Tensor:
"""
Reconstructs the original tensor from codes and the codebook.

Args:
codes (torch.Tensor): Indices of codebook entries for each block,
shape (d1//b1, d2//b2, ..., dN//bN).
codes (torch.Tensor): torch.int32 dtype, indices of codebook entries for each block,
shape (d1//b1, d2//b2, ..., dN//bN).
codebook (torch.Tensor): Codebook tensor used for quantization,
shape (k, b1, b2, ..., bN) where b_i are block sizes.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: say what k is

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will update docs after I update the code to support block_size

input_dtype (torch.dtype): Input dtype for `codes`, used for downstream pattern matching
and not enforced in `codes`. can be sub byte dtype like torch.uint4
scales (torch.Tensor): Scales, shape (d1, d2, ..., dN // scale_block_size, 1).
output_dtype (torch.dtype): dtype for the output tensor.

Expand Down Expand Up @@ -142,7 +152,7 @@ def dequantize_codebook(
dequant = dequant.view(
*new_shape
) # (d1, d2, ..., num_scale_blocks, scale_block_size)
dequant.mul_(scales)
dequant = dequant * scales

dequant = dequant.view(*original_shape)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,12 +96,15 @@ def dequantize(self, output_dtype: Optional[torch.dtype] = None) -> torch.Tensor
codes = self.codes.get_plain()
else:
codes = self.codes

if codes.dtype != torch.int32:
# TODO: Investigate and support not casting to torch.int32 for indexing to improve performance
codes = codes.to(torch.int32)

return dequantize_codebook(
codes,
self.codebook,
self.codes.dtype,
self.scales,
output_dtype=output_dtype,
)
Expand Down
1 change: 1 addition & 0 deletions torchao/testing/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ def skip_if_no_cuda():
def decorator(test_func):
def wrapper(*args, **kwargs):
if not torch.cuda.is_available():
print("no cuda available")
raise unittest.SkipTest("No cuda available")
return test_func(*args, **kwargs)

Expand Down
8 changes: 4 additions & 4 deletions torchao/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,13 +210,13 @@ def decorator(fn):

# expecting fn.__name__ starts with `_` and we want to take the rest
# to be the name of the custom op
assert (
fn.__name__[0] == "_"
), f"Expecting function name starts with `_`, got {fn.__name__}"
assert not any(
c in fn.__name__ for c in ".<>"
), f"Expecting op to be defined in normal functions, not lambda or local: {fn.__name__}"
op_name = fn.__name__[1:]
op_name = fn.__name__
if op_name[0] == "_":
op_name = op_name[1:]

schema = op_name + infer_schema(fn, mutates_args={})
lib.define(schema)
lib.impl(op_name, fn, "CompositeImplicitAutograd")
Expand Down
Loading