-
Notifications
You must be signed in to change notification settings - Fork 169
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Intx Quantization Tensor Class #468
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/468
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 1505bca with merge base de4a1fb (): This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is a good starting point, next I think we could think more on kernels. current path we hit is dequantize()
and then apply F.linear path I think. Also we can refactor IntxTensor
to work with native uint1
to uint7
dtype as well
Not an expert on QAT, but does this mean we would have to enable support for autograd |
torchao/prototype/intx/bitpacking.py
Outdated
2 bit shard: [0b00100111, 00010001] | ||
4 bit shard: [0b00000000, 0b01101001, 0b10010111, 0b00100101] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
so this seems to be interleaved packing, that is packing two elements far away together, will packing the adjacent elements together be more efficient because of data locality? I guess this might be covered by setting different pack_dim
values, but still want to see if we want to explicitly test this
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yea it is interleaved. I think its something worth exploring when making optimized kernels. The issue is if you pack adjacent elements then you have to perform interleaved shifting and bit-wise or. Not sure which is faster without using memory/compute profiler.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great! Mostly just minor comments and suggestions on testing
torchao/prototype/intx/Intx.py
Outdated
shards = [shard.to(torch.uint8) for shard in shards] | ||
self.shard = shards | ||
for i, atrib in enumerate(self.bits_to_shard[bit_size]): | ||
setattr(self, atrib, shards[i]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should we assert len(shards) == len(bits_to_shard[self.bit_size])
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Users can only pass in the data to be quantized and the element_size. The pack function (called by the constructor) will automatically create the correct amount of shards based on the element size. I don't really think it would ever break
|
||
aten = torch.ops.aten | ||
|
||
class UintxTensor(torch.Tensor): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: UIntx feels better I think, same for the filename
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was thinking about that, but I felt it looks too much like user interface UI
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh OK, I was mainly concerned about naming convention, but Uintx seems fine for naming convention as well (assuming it means uintx datatype and capitalize u
) although it looks a bit weird.
setattr(self, attrib, shards[i]) | ||
|
||
self.packed_shape = packed_shape | ||
self.bit_size = bit_size |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
also I feel it would be helpful to add an assert for accepted bit_size
, also mayby rename this to bit_width
? feels more commen
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What would the assert be checking?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it should one of 1 to 7 right?
merging now since the remaining are minor nits |
PR fulfilling #439
benchmark results:
Performance with dtypes that aren't multiples of 2 is significantly worse, but that was to be expected without custom kernels.