Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[QST] Slower than cublas kernel with same tile size on A100 #430

Closed
minminsun opened this issue Mar 15, 2022 · 16 comments
Closed

[QST] Slower than cublas kernel with same tile size on A100 #430

minminsun opened this issue Mar 15, 2022 · 16 comments

Comments

@minminsun
Copy link
Contributor

I compared cutlass with cublas on a GEMM with M=3072, N=2048, K=768 on A100
It turned out cutlass kernel is more than 10% slower than cublas kernel, even with the same tile size.
Cublas picks kernel ampere_fp16_s16816gemm_fp16_128x128_ldg8_f2f_stages_32x5_nn, which takes 62.4 us.
While the kernel cutlass_tensorop_f16_s16816gemm_f16_128x128_32x5_nn_align8 from cublas library takes 69.2 us.

So my question is what makes cutlass kernel slower than the cublas kernel with the same tile size?

Thanks!

@hwu36
Copy link
Collaborator

hwu36 commented Mar 15, 2022

What is your compiler version?

@minminsun
Copy link
Contributor Author

What is your compiler version?

It's "release 11.2, V11.2.152"

@hwu36
Copy link
Collaborator

hwu36 commented Mar 15, 2022

11.2 is not enough to get the best perf. Below is the minimum version of CUDA for different type of kernel.

CUDA 11.3: GEMM
CUDA 11.4: Conv
CUDA 11.5: Sparse
CUDA 11.6: TF32x3

You will see significant performance improvement. Enjoy!

@minminsun
Copy link
Contributor Author

I recompiled the library with cuda11.6, but got the same result.

@hwu36
Copy link
Collaborator

hwu36 commented Mar 15, 2022

If you open CMakeCache.txt

Does CMAKE_CUDA_COMPILER point to 11.6 compiler?

@minminsun
Copy link
Contributor Author

I cleaned the build directory and then rebuild. Now the cutlass library is really compiled with cuda11.6, and the performance of cutlass kernel gets better, although it's still a little bit slower than cublas kernel.

The current comparison of GEMM(M=3072, N=2048, K=768) on A100:
nsight gpu-time: cutlass 64.8us, cublas 62.4us
measured time: cutlass 55.3us, cublas 48.3us

@hwu36
Copy link
Collaborator

hwu36 commented Mar 16, 2022

Its known that cutlass kernel can be slightly slower than cublas. See the perf chart in the README.

To compare the performance apple to apple, we need to lock the frequency. Here are the steps needed on 400W A100

sudo nvidia-smi -i 0 -pm 1                    # persistent mode
sudo nvidia-smi -lgc 1005 -i 0               # lock to 1005 MHz.  The max on A100 is 1410 MHz
sudo nvidia-smi --power-limit=400 -i 0  # lock to 400 W

BTW, the way to reset clock is

sudo nvidia-smi -rgc

cutlass also has a profiler to measure the performance, so that you don't need to use nsight or any other things.

@minminsun
Copy link
Contributor Author

With frequency locked, cutlass gets the same performance as cublas. Cool!

Seems like cublas runs at the max frequency if not locked, but cutlass doesn't.

@mnicely mnicely closed this as completed Mar 22, 2022
@yzhaiustc
Copy link
Contributor

yzhaiustc commented Mar 23, 2022

May I know the reason why you would benchmark the performance with the freq locked, or in an other word, why locking the freq can give us an "apple-to-apple" perf comparison?

Thank you! @hwu36

@hwu36
Copy link
Collaborator

hwu36 commented Mar 23, 2022

Performance is linear with the frequency. The frequency is also changing dynamically. To compare the performance fairly, we want every run to use the same frequency.

@yzhaiustc
Copy link
Contributor

Thanks for the prompt response! Yeah I am fully with you that the freq changes at runtime dynamically.
My understanding being that, the freq typically starts to decrease to save energy at runtime when the computing units become idle due to the mem latency. cuBLAS seems better at taking the advantage of the dynamic voltage/frequency scaling strategy, and it can better keep the computing units busy.

If that is the case, the perf benchmark with no freq/power locks may also be somehow of interests to some users :)

@hwu36
Copy link
Collaborator

hwu36 commented Mar 23, 2022

My understanding being that, the freq typically starts to decrease to save energy at runtime when the computing units become idle due to the mem latency. cuBLAS seems better at taking the advantage of the dynamic voltage/frequency scaling strategy, and it can better keep the computing units busy.

No, this is not true.

@yzhaiustc
Copy link
Contributor

Oh. Gosh.
Just curious, then how to understand [ the gap between cublas/cutlass diminishes under a locked freq ]? Thanks a lot for your time.

@hwu36
Copy link
Collaborator

hwu36 commented Mar 23, 2022

Mostly just the noise.

@yzhaiustc
Copy link
Contributor

Thanks for the comment!

@yzhaiustc
Copy link
Contributor

Mostly just the noise.

Oh just have time to update. Days ago I got the results: when correctly linking to an updated nvcc and with the optimal cutlass param setup, CUTLASS obtains fairly similar performance to cuBLAS, even without the freq locked.

What an amazing project. Thank you.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

4 participants