-
Notifications
You must be signed in to change notification settings - Fork 349
W4A8 based on CUTLASS #880
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/880
Note: Links to docs will display an error until the docs builds have been completed. ❌ 1 New FailureAs of commit 015896b with merge base f7f20e9 ( NEW FAILURE - The following job has failed:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
Hi @alexsamardzic! Thank you for your pull request. We require contributors to sign our Contributor License Agreement, and yours needs attention. You currently have a record in our system, but the CLA is no longer valid, and will need to be resubmitted. ProcessIn order for us to review and merge your suggested changes, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA. Once the CLA is signed, our tooling will perform checks and validations. Afterwards, the pull request will be tagged with If you have received this in error or have any questions, please contact us at cla@meta.com. Thanks! |
The kernel implements W4A8 GEMM, with float16 scaling factors. The zero point support is to be eventually added later, for now several hacks (to be removed) are put in the code, that will force There are several points to discuss: CUTLASS would have to be made a dependency. IMO, the best approach to satisfy the dependency would be to install The group quantization may be a problem. Let's say The sum in the last expression could be efficiently calculated as mixed integer data types GEMM on tensor cores, and the result could be then updated by mulitplying the scale factors in. However, if group size parameter is less than Now, the only approach possible in CUTLASS to do this calculation in integer mixed data types on tensor cores would be to split it into Another related issue is zero point handling. Let's say Only the first expression within parentheses could be calculated on tensor cores as mixed integer data types GEMM, while the sums in the next two expression are best to be pre-calculated in case of weight values, or calculated on the fly during the input quantization. So it seems to me these are also calling for specialized type of quantization. (Note also that if group quantization used, above mentioned complications for All comments/suggestions welcome; in particular I'm pretty much new to quantization specifics so please let me know if I'm missing something obvious. |
I'm on PTO today and tomorrow so will review asap, apologies for the delay |
@alexsamardzic - Can we use the CUTLASS that ships with PyTorch? As in, should we change PyTorch to ship the headers used to build its CUTLASS kernels / does the PyTorch nightly already ship those? I see the test is using group size 128. I think it's ok if we don't necessarily support all group sizes or shapes right away. We have some int4 support via the pattern matched in https://github.com/pytorch/pytorch/blob/main/torch/_inductor/fx_passes/post_grad.py#L345-L403 which dispatches to https://github.com/pytorch/pytorch/blob/dab7d646d55a2b6696d51dee4816a6743ec1ae5a/torch/_inductor/kernel/unpack_mixed_mm.py#L76 - would an extension for int4x2 X int8 of this be interesting here? |
Thanks Mark - it's really just a draft, so not yet ready for review, but it would be useful to discuss points that I mentioned in my comment above. |
This CUTLASS version is also lagging behind. My CUTLASS PR with mixed int4/int8 GEMM is merged after the latest (3.5.1) CUTLASS release, hopefully there will be a new release soon. But in any case, this is a kind of problem that we'll have if we use more CUTLASS from torchao - for lots of time, the torchao build will have to be pointed to a bleeding edge CUTLASS checkout.
It uses group size 128 in order to force weight scale to be a vector, and not a matrix. I tried to explain the issue in my comment above, if group quantization is obligatory here, then it's going to be rather complicated to make this work.
I'm just looking into the quantization code, to see is it possible to do it there - it's not hard to make this change, but CUTLASS in general doesn't support doing things before GEMM (while fusing operations after GEMM calculated is reasonably well supported), so it would be the best if the quantization code actually put the weight values in int4x2 format. |
1bacd02
to
e1a1ff1
Compare
Updated so that there is a new |
torchao/quantization/quant_api.py
Outdated
return _get_linear_subclass_inserter(apply_int8_dynamic_activation_int4_weight_quant, group_size=group_size, mapping_type=mapping_type) | ||
|
||
|
||
def apply_int8_dynamic_activation_int4_weight_quant_cutlass(weight): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can this be represented as a different Layout
for int8 dynamic activation/int4 weight quantization? docs for Packing/Layout can be found in #391 "Layout and Packing" and simplified example in https://github.com/pytorch/ao/blob/main/tutorials/developer_api_guide/my_dtype_tensor_subclass.py
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the pointer! Yes, this will need refinement on this and several other places, as I learn about doing things the "torchao way"; but my main goal initially is to connect the dots, so that some benchmarks could be run, and that we could verify that CUTLASS provides some value here.
f6383ca
to
02f8805
Compare
Made some minor updates, including added support for bfloat16. Micro-benchmarking script
For particular shapes given in the script above, on A100 the micro-benchmark shows around 2x speedup over the case when float16 MM used, and around 1.8x speedup over the case when bfloat16 MM used. (Note that this is for eager mode execution, as compilation to corresponding CUTLASS kernel is not yet supported by PyTorch.) Patch to run torchao/_models/llama/generate.py
With the patch above, I was able to run Llama
and the output is as follows (again, this is run on A100):
while the reference output, for the case when no arguments supplied to
So the tokens/sec is more than 3x slower, but this is not even that bad, considering that batch size is 1 here, and that the CUTLASS code has it hard-coded for a block of threads to handle input tile size that is 128 for the same dimension, so most of the work is wasted. So there is a room for improvement regarding the speed. The text generated is garbage, however. Even for the micro-benchmark above, output values visibly deviate from the values produced when native precision used (but at least they resemble each other). |
575e074
to
956fc80
Compare
Made an update - turns out that actually CUTLASS needs a fix (posted below for now), and then CUTLASS fix
On the other side, I tried with adapting tile sizes processed by block/warp of threads of corresponding CUTLASS kernel, in order to adapt to the fact that batch size is 1 here. Here is an example of such change:
However, tokens/sec is not much improved this way. Thus, the performance of this kernel for Llama model will require more work. Edit: CUTLASS fix posted upstream here. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
will make a second pass for the kernel code
setup.py
Outdated
extension = CUDAExtension if use_cuda else CppExtension | ||
|
||
if not IS_WINDOWS: | ||
import cutlass_library |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
interesting: not too familiar with cutlass packaging but what is cutlass_library
exactly? only reference I found is this https://github.com/NVIDIA/cutlass/tree/main/python/cutlass_library
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's a recent addition to CUTLASS: a Python library that is able to generate C++ code for CUTLASS GEMM templates instantiation (which is nice to have, as these templates have dozen or more arguments, and it's oftentimes hard to get them right). It's used in CUTLASS codegen for TorchInductor, like here. However, recently CUTLASS itself also added a functionality to generate and compile C++ code for GEMM kernels, from a high-level specification in Python - this is part of cutlass
Python package, see here. Both cutlass
and cutlass_library
are available through nvidia-cutlass pip package. It's important to note that this package also contains all of the CUTLASS C++ header files, in order to make it possible to compile the C++ generated kernels.
setup.py
Outdated
cutlass_library_dir = os.path.dirname(cutlass_library.__file__) | ||
cutlass_include_dir = os.path.join(cutlass_library_dir, "source", "include") | ||
# FIXME: remove this once CUTLASS package updated to include int4/int8 MM | ||
cutlass_include_dir = "/data/quansight/scratch/cutlass/include" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
n00b q: what is this exactly? Do you need any help packaging CUTLASS?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I discussed this a bit in my first comment on this PR. In order ao to compile after this PR eventually merged, CUTLASS C++ header files are to be made available. There are at least two ways to do it:
- To make CUTLASS repo a submodule of ao repo, just like PyTorch did it.
- To make above mentioned
nvidia-cutlass
package a dependency of ao.
I'm leaning towards the later, and this is what above code, before "FIXME" is expecting. However, in both of above cases, we'll certainly face an issue of having to depend on stuff that is not yet merged into CUTLASS, but we need it. For example, at this very moment:
- My CUTLASS PR with int4/int8 GEMM support for CUTLASS is merged, but CUTLASS team has not made a release in the meantime, so this functionality is only available in CUTLASS
main
branch, and also above mentionednvidia-cutlass
package doesn't contain it yet. - As mentioned in one of my comments above, while working in this PR, I found an omission in CUTLASS. I created a CUTLASS PR with a fix, but this one is not yet merged, so neither CUTLASS
main
branch nornvidia-cutlass
package contain the fix at the moment, it's only available in my branch. So the only way to proceed with the development of my PR was to create a local copy of this branch - I created it in/data/quansight/scratch/cutlass
directory on my machine; in order to try this PR, the local copy of this branch is to be created, and this last line in the snippet above is to be changed to the local directory.
From my experience with this stuff from PyTorch development based on CUTLASS, this is going to be permanent issue - if we decide to use CUTLASS in ao, the for the most of the time we'll need bleeding edge features. So this is to be discussed further, IMO the best approach would be to build our own nvidia-cutlass
package, from whatever CUTLASS branch we find the most appropriate.
torchao/quantization/quant_api.py
Outdated
"_get_subclass_inserter", | ||
"quantize_", | ||
"int8_dynamic_activation_int4_weight", | ||
"int8_dynamic_activation_int4_weight_cutlass", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do you have some baseline numbers vs int8_dynamic_activation_int4_weight
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Now when I have the dots connected, in the sense that I can run a micro-benchmark, and also Lllama model, using this kernel, I'm working on a more detailed profiling, part of this is also comparing the performance of this kernel with int8_dynamic_activation_int4_weight
kernel. I'll report all my findings here when I'm done with the profiling.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As a quick update here: Using the micro-benchmarking script above, it seems this PR is just 3-5% faster than int8_dynamic_activation_int4_weight
. However, on the Llama generator, it seems about 2x faster, when tokens/sec numbers compared. (Remember that all the caveats from my first comment above still apply, so let's not jump into any conclusions for now.)
test/test_s8s4_linear_cutlass.py
Outdated
@@ -0,0 +1,51 @@ | |||
# FIXME: move this test to the appropriate test file!!! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah maybe make yourself a cutlass folder to park all your work
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes. Again, as mentioned in one of my comments above: At the moment, most of the "FIXME"-s in the PR are as I'm aware that I took shortcuts to make things work. If/when we're happy with the main stuff, I'll revisit all of these, and redo them in the proper "ao-way".
test/test_s8s4_linear_cutlass.py
Outdated
output_ref = model(input) | ||
|
||
modelq = copy.deepcopy(model) | ||
quantize_(modelq, int8_dynamic_activation_int4_weight_cutlass()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe another reference would be the non cutlass variant
torchao/quantization/quant_api.py
Outdated
# then corresponding changes made in | ||
# _linear_int8_act_int4_weight_cutlass_check and for the check in | ||
# the CUTLASS kernel!!! | ||
weight.original_weight_tensor.layout_tensor.int_data = ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe a comment like
# Combine pairs of 4-bit values into single bytes
weight.original_weight_tensor.layout_tensor.int_data = (
# Take odd-indexed columns, keep lower 4 bits, shift left by 4 bits
(weight.original_weight_tensor.layout_tensor.int_data[:, 1::2] & 0xF) << 4
) | (
# Take even-indexed columns, keep lower 4 bits
weight.original_weight_tensor.layout_tensor.int_data[:, 0::2] & 0xF
)
torchao/quantization/quant_api.py
Outdated
return _get_linear_subclass_inserter(apply_int8_dynamic_activation_int4_weight_quant_cutlass) | ||
|
||
|
||
def int4_weight_only(group_size=128, layout_type=TensorCoreTiledLayoutType(inner_k_tiles=8), use_hqq=False): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unrelated comment, what is this use_hqq
? @jerryzh168 do you know?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah this means use hqq
algorithm to choose qparams and quantize the weight, since it is reusing the tinygemm kernel, we just added this as a separate option here
const int n = tensor_b.size(0); | ||
const int k = tensor_a.size(1); | ||
|
||
constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: mind adding a comment for why 128
Also how do you think about padding vs erroring
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The 128 bits here is because of how tensor cores work (so it's not CUTLASS-specific), at least for SM 8.x. It's related to the layout of tiles of matrix operands that single warp of thread is multiplying cooperatively. The best explanation that I found so far is in GTC 2020 talk, by CUTLASS team, around slide 15.
We can consider padding (maybe at the later stage?), I believe it would the best to incorporate padding together with the quantization.
using SmArch = cutlass::arch::Sm80; | ||
using ThreadblockSwizzle = | ||
cutlass::gemm::threadblock::ThreadblockSwizzleStreamK; | ||
constexpr auto NumStages = 4; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cutlass n00b but how do you pick these hyperparams?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These, and others, are the CUTLASS GEMM C++ template arguments. As mentioned above, there is dozen of these to set, but on the other side only small number of combinations of these arguments actually works. The above mentioned cutlass_library
package enumerates some of these working combinations. The CUTLASS itself doesn't include any sort of heuristic for selection of these parameters, for example based on GEMM operand shapes. So I had to hard-code some values, at least for now. The values selected here are based on my previous experimentation with different combinations, and different operand shapes - in the sense that these values should provide acceptable performance for number of cases. But certainly there are cases where these values are not good fit, Lllama inference, having batch size 1, is one such example. So we may want to consider adding some heuristic here, but on the longer term we'd probably prefer to do support some auto-tuning, just like what is possible with Triton kernels.
956fc80
to
bc85146
Compare
(Pushed an update, where the branch is just rebased on the latest main.) I did lots of profiling in the meantime, focusing primarily on running Llama generator (
and the run for this PR was as follows (with the patch mentioned above applied beforehand):
TLDR (note that each of these items could be verified by profiling W8A8DQ alone, without using this PR at all):
As an example for item 1 above, here are the performance results, as printed by
and when moved to the last place in the list:
The generator runs are profiled using
here is the relevant part of the So, for the attention segment of the model, one could see that everything related to running the linear operator takes about 34s in total. Out of this time, 24s are spend in the dynamic quantization, while about 9.4s only are spent on the linear operator itself, and then out of these 9.4s, only 2.4s are spent on the CUTLASS MM kernel execution, while the rest of time get spent on checking to which kernel to dispatch (note that for this run, the check for applicability of the CUTLASS kernel is added last to the list) - these checks are not visible in this snippet, as Here is the As mentioned above, profiling results are verified using
here is a screenshot of the timeline as shown by Here, one could see that loading of model takes about 30s, then there is a short sequence of copying model to GPU and doing weights quantization, and then the rest of the timeline is the inference. The CUTLASS MM kernel, designated as |
@alexsamardzic - Was the model torch.compile'd with mode 'max-autotune'? Also you can use |
eef7b09
to
825df6b
Compare
Moved |
Merge at will @alexsamardzic - I assume the failed CI jobs are unrelated. |
583560b
to
aebc20e
Compare
@jerryzh168 Do you maybe want to give another look to the changes on the Python side? |
aebc20e
to
9f824fe
Compare
Made changes suggested, and rebased on the latest main. |
9f824fe
to
356ab28
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, one minor comment about zero_point
356ab28
to
7fda24e
Compare
Again - made the change suggested, rebased on the latest main. If CI goes green, will merge. |
7fda24e
to
e1b1690
Compare
CUTLASS-based s8s4_linear_cutlass() operator is introduced, performing linear transformation over quantized 8-bit input and quantized 4-bit weight tensors, with corresponding floating point scale tensors attached. A benchmark script, for comparing performance of MM based on this linear operator with MM over 16-bit floating point tensors is supplied in benchmarks/benchmarks/benchmark_s8s4_cutlass.py. The Llama generator script torchao/_models/llama/generate.py is changed, to add "int8adq-int4w-symm" quantization as an option, that will in turn activate s8s4_linear_cutlass() operator. With this type of quantization activated, i.e. if generate.py script run as follows: python generate.py --compile --precision=torch.float16 -q int8adq-int4w-symm the generator achieves around 133 tok/sec on A100, vs. around 93 tok/sec without quantization, i.e. when generate.py script run as follows: python generate.py --compile --precision=torch.float16
e1b1690
to
015896b
Compare
CUTLASS-based s8s4_linear_cutlass() operator is introduced, performing linear transformation over quantized 8-bit input and quantized 4-bit weight tensors, with corresponding floating point scale tensors attached. A benchmark script, for comparing performance of MM based on this linear operator with MM over 16-bit floating point tensors is supplied in benchmarks/benchmarks/benchmark_s8s4_cutlass.py. The Llama generator script torchao/_models/llama/generate.py is changed, to add "int8adq-int4w-symm" quantization as an option, that will in turn activate s8s4_linear_cutlass() operator. With this type of quantization activated, i.e. if generate.py script run as follows: python generate.py --compile --precision=torch.float16 -q int8adq-int4w-symm the generator achieves around 133 tok/sec on A100, vs. around 93 tok/sec without quantization, i.e. when generate.py script run as follows: python generate.py --compile --precision=torch.float16
@alexsamardzic for W4A8, does cutlass also support int4 weight and fp8 activation as well? |
Yes, it does, see https://github.com/NVIDIA/cutlass/blob/main/examples/55_hopper_mixed_dtype_gemm/55_hopper_int4_fp8_gemm.cu. But note that it won't be a simple update for this PR, as this one is for Ampere, but Hopper kernels in CUTLASS are written differently.
|
CUTLASS-based
s8s4_linear_cutlass()
operator is introduced, performing linear transformation over quantized 8-bit input and quantized 4-bit weight tensors, with corresponding floating point scale tensors attached.A benchmark script, for comparing performance of MM based on this linear operator with MM over 16-bit floating point tensors is supplied in
benchmarks/benchmarks/benchmark_s8s4_cutlass.py
.The Llama generator
script torchao/_models/llama/generate.py
is changed, to addint8adq-int4w-symm
quantization as an option, that will in turn activates8s4_linear_cutlass()
operator. With this type of quantization activated, i.e. ifgenerate.py
script run as follows:the generator achieves around 133 tok/sec on A100, vs. around 93 tok/sec without quantization, i.e. when
generate.py
script run as follows: