Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Intx Quantization Tensor Class #468

Merged
merged 37 commits into from
Aug 7, 2024
Merged

Intx Quantization Tensor Class #468

merged 37 commits into from
Aug 7, 2024

Conversation

vayuda
Copy link
Collaborator

@vayuda vayuda commented Jul 2, 2024

PR fulfilling #439

benchmark results:
image

image

Performance with dtypes that aren't multiples of 2 is significantly worse, but that was to be expected without custom kernels.

Copy link

pytorch-bot bot commented Jul 2, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/468

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 1505bca with merge base de4a1fb (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jul 2, 2024
torchao/_models/llama/eval.py Outdated Show resolved Hide resolved
@msaroufim
Copy link
Member

  • @andrewor14 and @Hanxian97 who have been looking into intX for QAT. Do y'all mind giving this a quick review as well?

Copy link
Contributor

@jerryzh168 jerryzh168 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is a good starting point, next I think we could think more on kernels. current path we hit is dequantize() and then apply F.linear path I think. Also we can refactor IntxTensor to work with native uint1 to uint7 dtype as well

@vayuda
Copy link
Collaborator Author

vayuda commented Aug 7, 2024

  • @andrewor14 and @Hanxian97 who have been looking into intX for QAT. Do y'all mind giving this a quick review as well?

Not an expert on QAT, but does this mean we would have to enable support for autograd

Comment on lines 88 to 89
2 bit shard: [0b00100111, 00010001]
4 bit shard: [0b00000000, 0b01101001, 0b10010111, 0b00100101]
Copy link
Contributor

@jerryzh168 jerryzh168 Aug 7, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so this seems to be interleaved packing, that is packing two elements far away together, will packing the adjacent elements together be more efficient because of data locality? I guess this might be covered by setting different pack_dim values, but still want to see if we want to explicitly test this

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea it is interleaved. I think its something worth exploring when making optimized kernels. The issue is if you pack adjacent elements then you have to perform interleaved shifting and bit-wise or. Not sure which is faster without using memory/compute profiler.

Copy link
Contributor

@andrewor14 andrewor14 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great! Mostly just minor comments and suggestions on testing

torchao/prototype/intx/Intx.py Outdated Show resolved Hide resolved
shards = [shard.to(torch.uint8) for shard in shards]
self.shard = shards
for i, atrib in enumerate(self.bits_to_shard[bit_size]):
setattr(self, atrib, shards[i])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we assert len(shards) == len(bits_to_shard[self.bit_size])?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Users can only pass in the data to be quantized and the element_size. The pack function (called by the constructor) will automatically create the correct amount of shards based on the element size. I don't really think it would ever break

torchao/prototype/intx/Intx.py Outdated Show resolved Hide resolved
torchao/prototype/intx/bitpacking.py Outdated Show resolved Hide resolved
torchao/prototype/intx/bitpacking.py Outdated Show resolved Hide resolved
torchao/prototype/intx/bitpacking.py Outdated Show resolved Hide resolved

aten = torch.ops.aten

class UintxTensor(torch.Tensor):
Copy link
Contributor

@jerryzh168 jerryzh168 Aug 7, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: UIntx feels better I think, same for the filename

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was thinking about that, but I felt it looks too much like user interface UI

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh OK, I was mainly concerned about naming convention, but Uintx seems fine for naming convention as well (assuming it means uintx datatype and capitalize u) although it looks a bit weird.

setattr(self, attrib, shards[i])

self.packed_shape = packed_shape
self.bit_size = bit_size
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also I feel it would be helpful to add an assert for accepted bit_size, also mayby rename this to bit_width? feels more commen

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What would the assert be checking?

Copy link
Contributor

@jerryzh168 jerryzh168 Aug 7, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it should one of 1 to 7 right?

@jerryzh168
Copy link
Contributor

merging now since the remaining are minor nits

@jerryzh168 jerryzh168 merged commit 87869f2 into pytorch:main Aug 7, 2024
13 checks passed
@vayuda vayuda deleted the intx branch August 8, 2024 16:42
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants