Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] w4a8 mixed-input gemm for fine-grained quantization #1332

Closed
jianfei-wangg opened this issue Feb 5, 2024 · 12 comments
Closed

[BUG] w4a8 mixed-input gemm for fine-grained quantization #1332

jianfei-wangg opened this issue Feb 5, 2024 · 12 comments
Labels
bug Something isn't working feature request New feature or request inactive-30d
Milestone

Comments

@jianfei-wangg
Copy link

Refer to #1316, I have tried 55th example: 55_hopper_mixed_dtype_gemm.
It works fine for w4a8 groupsize=128, which incudes changes from baseline like:
using MmaType = int8_t;
using ElementC = int32_t;
using ElementAccumulator = int32_t;
using ElementCompute = int32_t;
{ElementAccumulator(options.alpha), ElementAccumulator(options.beta)}

However, when I want to test with groupsize=64 or any lower than 128:
./55_hopper_mixed_dtype_gemm --m=2048 --n=5120 --k=8192 --g=64 --mode=2
it returns a static_assert error like:
Got cutlass error: Invalid status at: 626

I found that groupsize should not be less than TileShapeK (which is 128 default in 55th example).
So I change TileShapeK from 128
constexpr int TileShapeK = 128 * 8 / sizeof_bits<MmaType>::value;
to 64
constexpr int TileShapeK = 64 * 8 / sizeof_bits<MmaType>::value;
and then run
./55_hopper_mixed_dtype_gemm --m=2048 --n=5120 --k=8192 --g=64 --mode=2
It will hang like below:

image

I wonder TileShapeK can not be less than 128? 16_int8_t is not enough for 16-bytes alignment?
Is there any suggestion for how to implement groupsize<128 mixed-input gemm on hopper?
Thanks for any reply.

@thakkarV
Copy link
Collaborator

thakkarV commented Feb 5, 2024

@rawnhenry are we missing a static assert somewhere in the collective for valid tile shapes?

@rawnhenry
Copy link

It is a bug. I'll take a look later in the week.

@rawnhenry
Copy link

@jianfei-wangg I have a fix for this internally. Working to tune the perf of it.

If you need to run something before I am done, I can provide you with some small changes to make your command work. It may not be the best performance, but you can use it for accuracy evaluations.

@mnicely mnicely modified the milestones: CUTLASS 3.05, CUTLASS 3.5 Feb 15, 2024
@jianfei-wangg
Copy link
Author

@jianfei-wangg I have a fix for this internally. Working to tune the perf of it.

If you need to run something before I am done, I can provide you with some small changes to make your command work. It may not be the best performance, but you can use it for accuracy evaluations.

It is very nice to evaluate accuracy before release, please send 'small changes' to my email: jianfei.wangg@outlook.com, thanks.

@mnicely mnicely added bug Something isn't working and removed ? - Needs Triage labels Feb 22, 2024
@mnicely mnicely changed the title [FEA] w4a8 mixed-input gemm for fine-grained quantization [BUG] w4a8 mixed-input gemm for fine-grained quantization Feb 22, 2024
@jianfei-wangg
Copy link
Author

@rawnhenry hello, anything update?

@rawnhenry
Copy link

@jianfei-wangg It has been fixed with the release of v3.5

@jianfei-wangg
Copy link
Author

@jianfei-wangg It has been fixed with the release of v3.5
Hello,I have tested cutlass-3.5 and it supports w4a8 groupsize=32 case, but the perfomance is weird.
image
As above figure shows, unfused_dequant (924.8us) + normal gemm (243us) cost much less time than fused_mixed_input gemm (3813.2us). The test command is ./55_hopper_mixed_dtype_gemm --m=2048 --n=5120 --k=8192 --g=32 mode=2 . I wonder the performance loss is reasonable or some mistakes in my test?

@Njuapp
Copy link

Njuapp commented Apr 11, 2024

Hi @jianfei-wangg , sorry to say that I could not reproduce your results. Are you using current main? I cannot run with --g=32 or --g=64, and only --g=128 or larger is allowed. If not, could you provide the commit you are using right now?

I tried with --g=128 on H100 PCIe, and it shows that fused_mixed_input_gemm costs 238.7us, which is less than unfused_dequant (1003 us) + normal FP8 gemm (209 us). It is reasonable to see normal FP8 gemm kernel alone is faster, because normal FP8 could be more efficient under such a large problem size which is compute-bound instead of memory-bound. Mixed input gemm incurs dequantization overhead within its kernel, and will only show its priority when problem size is small, such as when m=1 (memory bound case).

@jianfei-wangg
Copy link
Author

Hi @Njuapp I'm using cutlass-v3.5 branch. As branch-main only supports w4a8 with coarse-grained group,I opened this issue. As @rawnhenry sayed, fine-grained group w4a8 has been fixed in cutlass-v3.5. You can checkout to v3.5 and change 55th_example like:
using MmaType = int8_t;
using ElementC = int32_t;
using ElementAccumulator = int32_t;
using ElementCompute = int32_t;
constexpr int TileShapeK = 32 * 8 / sizeof_bits<MmaType>::value;
{ElementAccumulator(options.alpha), ElementAccumulator(options.beta)}
It will support --g=32, but the perf is poor.

@Njuapp
Copy link

Njuapp commented Apr 11, 2024

@jianfei-wangg I modifed the 55th example as you listed, but the result shows that hat fused_mixed_input_gemm costs 358.8us, which is less than unfused_dequant (1003 us) + normal FP8 gemm (276 us). It is far from your results, 3813.2us. Did you miss any details?

@jianfei-wangg
Copy link
Author

@Njuapp Sorry about the weird performance, maybe coming from debug compilation option. I re-compiled the 55th_example and the perf is reasonable on H800 now like below. Thanks.
image

Copy link

This issue has been labeled inactive-30d due to no recent activity in the past 30 days. Please close this issue if no further response or action is needed. Otherwise, please respond with a comment indicating any updates or changes to the original issue and/or confirm this issue still needs to be addressed. This issue will be labeled inactive-90d if there is no activity in the next 60 days.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working feature request New feature or request inactive-30d
Projects
None yet
Development

No branches or pull requests

5 participants