-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[BUG] w4a8 mixed-input gemm for fine-grained quantization #1332
Comments
@rawnhenry are we missing a static assert somewhere in the collective for valid tile shapes? |
It is a bug. I'll take a look later in the week. |
@jianfei-wangg I have a fix for this internally. Working to tune the perf of it. If you need to run something before I am done, I can provide you with some small changes to make your command work. It may not be the best performance, but you can use it for accuracy evaluations. |
It is very nice to evaluate accuracy before release, please send 'small changes' to my email: jianfei.wangg@outlook.com, thanks. |
@rawnhenry hello, anything update? |
@jianfei-wangg It has been fixed with the release of v3.5 |
|
Hi @jianfei-wangg , sorry to say that I could not reproduce your results. Are you using current main? I cannot run with I tried with |
Hi @Njuapp I'm using cutlass-v3.5 branch. As branch-main only supports w4a8 with coarse-grained group,I opened this issue. As @rawnhenry sayed, fine-grained group w4a8 has been fixed in cutlass-v3.5. You can checkout to v3.5 and change 55th_example like: |
@jianfei-wangg I modifed the 55th example as you listed, but the result shows that hat fused_mixed_input_gemm costs 358.8us, which is less than unfused_dequant (1003 us) + normal FP8 gemm (276 us). It is far from your results, 3813.2us. Did you miss any details? |
@Njuapp Sorry about the weird performance, maybe coming from debug compilation option. I re-compiled the 55th_example and the perf is reasonable on H800 now like below. Thanks. |
This issue has been labeled |
Refer to #1316, I have tried 55th example: 55_hopper_mixed_dtype_gemm.
It works fine for w4a8 groupsize=128, which incudes changes from baseline like:
using MmaType = int8_t;
using ElementC = int32_t;
using ElementAccumulator = int32_t;
using ElementCompute = int32_t;
{ElementAccumulator(options.alpha), ElementAccumulator(options.beta)}
However, when I want to test with groupsize=64 or any lower than 128:
./55_hopper_mixed_dtype_gemm --m=2048 --n=5120 --k=8192 --g=64 --mode=2
it returns a static_assert error like:
Got cutlass error: Invalid status at: 626
I found that groupsize should not be less than TileShapeK (which is 128 default in 55th example).
So I change TileShapeK from
128
constexpr int TileShapeK = 128 * 8 / sizeof_bits<MmaType>::value;
to
64
constexpr int TileShapeK = 64 * 8 / sizeof_bits<MmaType>::value;
and then run
./55_hopper_mixed_dtype_gemm --m=2048 --n=5120 --k=8192 --g=64 --mode=2
It will hang like below:
I wonder TileShapeK can not be less than 128? 16_int8_t is not enough for 16-bytes alignment?
Is there any suggestion for how to implement groupsize<128 mixed-input gemm on hopper?
Thanks for any reply.
The text was updated successfully, but these errors were encountered: