-
Notifications
You must be signed in to change notification settings - Fork 1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[QST] Slower than cublas kernel with same tile size on A100 #430
Comments
What is your compiler version? |
It's "release 11.2, V11.2.152" |
11.2 is not enough to get the best perf. Below is the minimum version of CUDA for different type of kernel. CUDA 11.3: GEMM You will see significant performance improvement. Enjoy! |
I recompiled the library with cuda11.6, but got the same result. |
If you open CMakeCache.txt Does |
I cleaned the build directory and then rebuild. Now the cutlass library is really compiled with cuda11.6, and the performance of cutlass kernel gets better, although it's still a little bit slower than cublas kernel. The current comparison of GEMM(M=3072, N=2048, K=768) on A100: |
Its known that cutlass kernel can be slightly slower than cublas. See the perf chart in the README. To compare the performance apple to apple, we need to lock the frequency. Here are the steps needed on 400W A100
BTW, the way to reset clock is
cutlass also has a profiler to measure the performance, so that you don't need to use nsight or any other things. |
With frequency locked, cutlass gets the same performance as cublas. Cool! Seems like cublas runs at the max frequency if not locked, but cutlass doesn't. |
May I know the reason why you would benchmark the performance with the freq locked, or in an other word, why locking the freq can give us an "apple-to-apple" perf comparison? Thank you! @hwu36 |
Performance is linear with the frequency. The frequency is also changing dynamically. To compare the performance fairly, we want every run to use the same frequency. |
Thanks for the prompt response! Yeah I am fully with you that the freq changes at runtime dynamically. If that is the case, the perf benchmark with no freq/power locks may also be somehow of interests to some users :) |
No, this is not true. |
Oh. Gosh. |
Mostly just the noise. |
Thanks for the comment! |
Oh just have time to update. Days ago I got the results: when correctly linking to an updated nvcc and with the optimal cutlass param setup, CUTLASS obtains fairly similar performance to cuBLAS, even without the freq locked. What an amazing project. Thank you. |
I compared cutlass with cublas on a GEMM with M=3072, N=2048, K=768 on A100
It turned out cutlass kernel is more than 10% slower than cublas kernel, even with the same tile size.
Cublas picks kernel ampere_fp16_s16816gemm_fp16_128x128_ldg8_f2f_stages_32x5_nn, which takes 62.4 us.
While the kernel cutlass_tensorop_f16_s16816gemm_f16_128x128_32x5_nn_align8 from cublas library takes 69.2 us.
So my question is what makes cutlass kernel slower than the cublas kernel with the same tile size?
Thanks!
The text was updated successfully, but these errors were encountered: