-
Notifications
You must be signed in to change notification settings - Fork 236
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Register codebook quant ops #1988
Open
jerryzh168
wants to merge
3
commits into
pytorch:main
Choose a base branch
from
jerryzh168:register-codebook
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
3 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -11,8 +11,13 @@ | |
_DTYPE_TO_QVALUE_BOUNDS, | ||
_SUB_BYTE_UINT_BOUNDS, | ||
) | ||
from torchao.utils import _register_custom_op | ||
|
||
quant_lib = torch.library.Library("quant", "FRAGMENT") | ||
register_custom_op = _register_custom_op(quant_lib) | ||
|
||
|
||
@register_custom_op | ||
def quantize_codebook( | ||
input: torch.Tensor, | ||
codebook: torch.Tensor, | ||
|
@@ -25,7 +30,8 @@ def quantize_codebook( | |
|
||
Args: | ||
input (torch.Tensor): Input tensor to quantize, shape (d1, d2, ..., dN). | ||
codebook (torch.Tensor): Codebook tensor for quantization, shape (k, b1, b2, ..., bN) where b_i are block sizes. | ||
codebook (torch.Tensor): Codebook tensor for quantization, shape (k, b1, b2, ..., bN) where b_i are block sizes and k is the codebook_size, e.g. for uint4 (4 bit), codebook size is 2**4 | ||
one corresponding dequantized vector of (b1, b2, .., bN) dimension for each of uint4 integer value of 0 to 15 | ||
scales (torch.Tensor): Scales, shape (d1, d2, ..., dN // scale_block_size, 1). | ||
chunk_size (int): Number of elements to process per chunk to control memory usage. | ||
code_dtype (torch.dtype): dtype for the codes. | ||
|
@@ -95,20 +101,24 @@ def quantize_codebook( | |
return codes.to(code_dtype) | ||
|
||
|
||
@register_custom_op | ||
def dequantize_codebook( | ||
codes: torch.Tensor, | ||
codebook: torch.Tensor, | ||
input_dtype: torch.dtype, | ||
scales: torch.Tensor, | ||
output_dtype: torch.dtype = torch.float32, | ||
) -> torch.Tensor: | ||
""" | ||
Reconstructs the original tensor from codes and the codebook. | ||
|
||
Args: | ||
codes (torch.Tensor): Indices of codebook entries for each block, | ||
shape (d1//b1, d2//b2, ..., dN//bN). | ||
codes (torch.Tensor): torch.int32 dtype, indices of codebook entries for each block, | ||
shape (d1//b1, d2//b2, ..., dN//bN). | ||
codebook (torch.Tensor): Codebook tensor used for quantization, | ||
shape (k, b1, b2, ..., bN) where b_i are block sizes. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: say what k is There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. will update docs after I update the code to support block_size |
||
input_dtype (torch.dtype): Input dtype for `codes`, used for downstream pattern matching | ||
and not enforced in `codes`. can be sub byte dtype like torch.uint4 | ||
scales (torch.Tensor): Scales, shape (d1, d2, ..., dN // scale_block_size, 1). | ||
output_dtype (torch.dtype): dtype for the output tensor. | ||
|
||
|
@@ -142,7 +152,7 @@ def dequantize_codebook( | |
dequant = dequant.view( | ||
*new_shape | ||
) # (d1, d2, ..., num_scale_blocks, scale_block_size) | ||
dequant.mul_(scales) | ||
dequant = dequant * scales | ||
|
||
dequant = dequant.view(*original_shape) | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IIUC, this does not look like it supports granularity, which we will want.
From what I can tell, k is the idx range, e.g., for 4-bit quantization, k = 16. Each idx=i is mapped to the tensor codebook[i]. So we have 1 codebook/LUT for the tensor that maps indices to tensors.
This seems a bit complicated to me. For CoreML, the default is each idx maps to a scalar (but they also support mapping to a vector). I'm not sure if anyone will need tensor-valued look up values.
But we do want granularity in the sense that we can have one codebook per channel, grouped channel, tensor, etc.
Maybe this is what was originally intended for the block_size (based on https://github.com/pytorch/ao/pull/1299/files/53874a005cb174f764363a7c3a22f653ccf738df#r1870108715), but I understand the code correctly, that's not what got implemented.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the
scale_block_size
inchoose_qparams_codebook
or the shape ofscales
in the dequant op is supposed to allow us to control the granularity, theblock_size
arg seems to have a different meaning than theblock_size
in other ops, so we should probably rename it, may guess is theblock_size
of tensor values that share the same kmeans cluster value.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh wait, the granularity of codebook is separate, let me take a look again