Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add semi-structured sparse + dynamic int8 subclasses #36

Merged
merged 33 commits into from
Apr 26, 2024

Conversation

jcaip
Copy link
Contributor

@jcaip jcaip commented Feb 6, 2024

This PR adds in int8 dynamic quantization + semi-structured sparsity support into torchao.

This is implemented by extending the existing quantization subclasses to use sparse ops.
Ideally we would be able to compose subclasses, and call to_sparse_semi_structured from inside the quantization subclass, but ATM nested subclass tracing does not work with torch.compile and for stuff like fusing scales into the sparse multiply you would probably want to implement it like this anyways.

In particular, this PR adds in two new subclasses:

Int8DynamicallyQuantizedSemiStructuredSparseLinearWeight
Int8DynamicallyQuantized24CusparseltLinearFuseMulWeight

For the cuSPARSELt subclasse, I can extend Int8DynamicallyQuantizedWeightBase by storing the compressed representation in W_int_repr.

FuseMulWeight will fuse one of the multiplies for the dequant into the cuSPARSELt matmul op. However cuSPARSELt expects this in a float32 format, so this eats into our previous speedups since we're now passing this as a bfloat16 tensor.

However for the general subclass, I need to extend QuantizeWeightBase, because I need to pass two tensors (packed and meta) for the CUTLASS sparse mm op. This relies on to_sparse_semi_structured to decide between CUTLASS and cuSPARSELt, which is the right choice for UI but makes benchmarking between them kind of difficult, since it's a class var that decides which backend gets used. Maybe we should add a flag to to_sparse_semi_structured because you might mix between cutlass and cusparselt.

I've also added a benchmarking script for SAM. I don't know how we plan on handling dependencies in torchao, but let me know if there's a better place for that.

Screenshot 2024-04-01 at 7 05 34 PM

On batch size 32, I see a 1.16x speedup over bfloat16 torch.compile baseline, from 21.96 -> 25.54 img/s.

Other benchmarks (BS=16)

Screenshot 2024-04-01 at 4 18 47 PM Screenshot 2024-04-01 at 7 04 47 PM

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Feb 6, 2024
@HDCharles
Copy link
Contributor

can you move the benchmark_sam and other .py files to one of the other directories? Maybe make a torch/benchmarks dir?

@jcaip jcaip changed the title Add 24 sparse + dynamic int8 composed subclasses [wip] Add 24 sparse + dynamic int8 subclasses Apr 1, 2024
)

int_data = w_int_repr.contiguous()
int_data = torch._cslt_compress(int_data)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it currently possible to replace this with

from torch.sparse import to_sparse_semi_structured
int_data = to_sparse_semi_structured(int_data)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Leaving this one here b/c it's the cuSPARSELt fuse mul special one, but I have changed the subclass to be backend agnostic (Int8DynamicallyQuantizedSemiStructuredSparseLinearWeight)

side note - do you care much about naming convention? This name is so long I kind of want to change it to something simpler like QuantizedSemiSparseLinearWeight

@jcaip jcaip changed the title [wip] Add 24 sparse + dynamic int8 subclasses Add semi-structured sparse + dynamic int8 subclasses Apr 22, 2024
)


class Int8DynamicallyQuantizedSemiStructuredSparseLinearWeight(QuantizedLinearWeightBase):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is sparsity also implemented with tensor subclass? I thought we should be able to compose them in some way?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't have nested subclassing support currently for tracing, so we can't compose them currently :( hence why we're landing in prototype.

I can tag you in the issue i'll make to raise this for core.

Copy link
Contributor

@cpuhrsch cpuhrsch left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Approving for prototype. Thanks for sending this :D

@jcaip jcaip force-pushed the jcaip/quant+sparse_subclasses branch from afc9d3d to 4e0c8b3 Compare April 24, 2024 09:57
@msaroufim
Copy link
Member

@jcaip just saw the CI failure issue is our CI GPU is too old so updating it now to A10G which should work for your code so after I merge this #176 make sure to rebase to main

@jcaip jcaip merged commit 739e62d into main Apr 26, 2024
13 checks passed
dbyoung18 pushed a commit to dbyoung18/ao that referenced this pull request Jul 31, 2024
This PR adds in int8 dynamic quantization + semi-structured sparsity support into torchao.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants