-
Notifications
You must be signed in to change notification settings - Fork 171
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
intx weight only linear quantizer for mps #1192
base: main
Are you sure you want to change the base?
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/1192
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 8e43385 with merge base c546c5c (): This comment was automatically generated by Dr. CI and updates every 15 minutes. |
This pull request was exported from Phabricator. Differential Revision: D65079774 |
Summary: Pull Request resolved: pytorch#1192 Differential Revision: D65079774
edd4a18
to
ed83de7
Compare
This pull request was exported from Phabricator. Differential Revision: D65079774 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Left a few comments
from torchao.experimental.quant_api import IntxWeightOnlyLinearQuantizer | ||
|
||
|
||
def parameterized(test_cases): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
isnt there something equivalent within torchao?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
well, there is something in torch.testing._internal, but then I ran into dependency issues with 'expecttest'.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we use this in other places:
ao/test/integration/test_integration.py
Line 647 in c546c5c
@parameterized.expand(COMMON_DEVICE_DTYPE) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we please use what Jerry is pointing to? Less code is better.
Unresolving the comment
Summary: Pull Request resolved: pytorch#1192 Differential Revision: D65079774
ed83de7
to
7dede6f
Compare
Summary: Pull Request resolved: pytorch#1192 Differential Revision: D65079774
7dede6f
to
df02572
Compare
This pull request was exported from Phabricator. Differential Revision: D65079774 |
1 similar comment
This pull request was exported from Phabricator. Differential Revision: D65079774 |
Summary: Pull Request resolved: pytorch#1192 Differential Revision: D65079774
df02572
to
b1d27e1
Compare
This pull request was exported from Phabricator. Differential Revision: D65079774 |
Summary: Pull Request resolved: pytorch#1192 Differential Revision: D65079774
b1d27e1
to
b93f3ea
Compare
Summary: Pull Request resolved: pytorch#1192 Differential Revision: D65079774
This pull request was exported from Phabricator. Differential Revision: D65079774 |
Summary: Pull Request resolved: pytorch#1192 Differential Revision: D65079774
faa0c6b
to
8e43385
Compare
This pull request was exported from Phabricator. Differential Revision: D65079774 |
1 similar comment
This pull request was exported from Phabricator. Differential Revision: D65079774 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Left some comments
from torchao.experimental.quant_api import IntxWeightOnlyLinearQuantizer | ||
|
||
|
||
def parameterized(test_cases): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we please use what Jerry is pointing to? Less code is better.
Unresolving the comment
if dtype == torch.int8: | ||
qmin = -(1 << (nbit - 1)) | ||
qmax = (1 << (nbit - 1)) - 1 | ||
elif dtype == torch.uint8: | ||
qmin = 0 | ||
qmax = (1 << nbit) - 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Instead of overloading dtype=int8 to convey signed vs. unsigned, can you just do signed=True
?
) | ||
|
||
|
||
def _replace_linear_with_quantized_linear_mps(module: nn.Module, kwargs={}): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you at least add todo to the effect
_replace_linear_with_quantized_linear_mps(child, kwargs) | ||
else: | ||
assert child.bias is None | ||
qlinear = UIntxWeightOnlyQuantizedLinear( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you have groupsize restriction? If so where is that asserted? I would have expected that groupsize will be constructor arg so that constructor can check and throw if the quantized linear supports it or not. I dont exactly like exceptions in this scenario but maybe thats a better choice because you cannot create an instance of quantized linear for invaild group size
), | ||
) | ||
setattr(module, name, qlinear) | ||
getattr(module, name).quantize_and_pack_weights( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
getattr(module, name).quantize_and_pack_weights( | |
qlinear.quantize_and_pack_weights( |
|
||
@parameterized(cases) | ||
def test_export(self, nbit): | ||
model, group_size, k0, n = self._model_setup() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are all tests doing only group size of 32? If so we should test other group sizes as well including those that should result in exception.
Differential Revision: D65079774