Skip to content

CUDA: add fused rms norm #14800

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
Open

Conversation

am17an
Copy link
Collaborator

@am17an am17an commented Jul 21, 2025

Similar to the Vulkan PR (#14366), perhaps ggml_vk_can_fuse and ggml_cuda_can_fuse can live inside ggml instead of their respective backends since they don't have backend specific code

Decent speedup in PP on my RTX 3090

Model Test t/s c82d48e t/s cuda_fused_rms_norm Speedup
llama 7B Q5_K_M pp1 124.25 123.34 0.99
llama 7B Q5_K_M pp512 4114.28 4610.30 1.12
llama 7B Q5_K_M tg1 113.82 120.39 1.06
llama 7B Q5_K_M tg128 115.84 118.67 1.02

@am17an am17an requested a review from JohannesGaessler July 21, 2025 15:16
@github-actions github-actions bot added Nvidia GPU Issues specific to Nvidia GPUs ggml changes relating to the ggml tensor library for machine learning labels Jul 21, 2025
@JohannesGaessler
Copy link
Collaborator

Sorry, ignore my previous suggestion. I forgot that the kernel modifies the x and dst pointers relatively early. In that case, the point at which this PR modifies the pointer is the correct one after all.

Copy link
Collaborator

@JohannesGaessler JohannesGaessler left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for the confusion, the non-deleted suggestions for changes are how I mean the code to be modified (though my preference is very small).

@exxocism
Copy link

I've encountered this error while trying it with Qwen/Qwen3-235B-A22B-Instruct-2507 🥲

/home/user/Documents/llama.cpp/ggml/src/ggml-cuda/ggml-cuda.cu:80: CUDA error
CUDA error: an illegal memory access was encountered
  current device: 0, in function ggml_backend_cuda_synchronize at /home/user/Documents/llama.cpp/ggml/src/ggml-cuda/ggml-cuda.cu:2582
  cudaStreamSynchronize(cuda_ctx->stream())

@am17an
Copy link
Collaborator Author

am17an commented Jul 22, 2025

@exxocism do you have a stack trace? Also does the problem go away with setting env variable GGML_CUDA_DISABLE_FUSION=1?

@JohannesGaessler
Copy link
Collaborator

Some quick performance numbers:

GPU Model Microbatch size Test t/s master t/s 9d605df Speedup
P40 llama 8B Q4_0 1 pp512 54.39 54.92 1.01
P40 llama 8B Q4_0 2 pp512 109.62 110.63 1.01
P40 llama 8B Q4_0 4 pp512 157.67 158.81 1.01
P40 llama 8B Q4_0 8 pp512 200.13 200.81 1.00
P40 llama 8B Q4_0 16 pp512 467.05 469.90 1.01
P40 llama 8B Q4_0 32 pp512 664.70 669.26 1.01
P40 llama 8B Q4_0 64 pp512 781.94 783.55 1.00
P40 llama 8B Q4_0 128 pp512 905.37 908.52 1.00
P40 llama 8B Q4_0 256 pp512 983.71 991.55 1.01
P40 llama 8B Q4_0 512 pp512 1025.81 1027.67 1.00
RTX 3090 llama 8B Q4_0 1 pp512 157.07 158.58 1.01
RTX 3090 llama 8B Q4_0 2 pp512 279.16 282.83 1.01
RTX 3090 llama 8B Q4_0 4 pp512 468.61 482.90 1.03
RTX 3090 llama 8B Q4_0 8 pp512 569.25 598.82 1.05
RTX 3090 llama 8B Q4_0 16 pp512 1242.52 1277.09 1.03
RTX 3090 llama 8B Q4_0 32 pp512 2057.66 2117.68 1.03
RTX 3090 llama 8B Q4_0 64 pp512 3199.04 3361.83 1.05
RTX 3090 llama 8B Q4_0 128 pp512 4149.02 4433.66 1.07
RTX 3090 llama 8B Q4_0 256 pp512 4895.11 5256.71 1.07
RTX 3090 llama 8B Q4_0 512 pp512 5184.98 5559.70 1.07
RTX 4090 llama 8B Q4_0 1 pp512 189.14 190.01 1.00
RTX 4090 llama 8B Q4_0 2 pp512 333.98 336.39 1.01
RTX 4090 llama 8B Q4_0 4 pp512 652.00 656.24 1.01
RTX 4090 llama 8B Q4_0 8 pp512 1066.79 1095.86 1.03
RTX 4090 llama 8B Q4_0 16 pp512 1827.19 1843.04 1.01
RTX 4090 llama 8B Q4_0 32 pp512 3327.40 3367.65 1.01
RTX 4090 llama 8B Q4_0 64 pp512 5826.00 5899.43 1.01
RTX 4090 llama 8B Q4_0 128 pp512 8750.81 8858.00 1.01
RTX 4090 llama 8B Q4_0 256 pp512 11709.08 11973.87 1.02
RTX 4090 llama 8B Q4_0 512 pp512 13187.04 13471.53 1.02
RX 6800 llama 8B Q4_0 1 pp512 56.92 58.03 1.02
RX 6800 llama 8B Q4_0 2 pp512 102.93 104.52 1.02
RX 6800 llama 8B Q4_0 4 pp512 142.11 143.96 1.01
RX 6800 llama 8B Q4_0 8 pp512 162.26 163.80 1.01
RX 6800 llama 8B Q4_0 16 pp512 233.37 234.84 1.01
RX 6800 llama 8B Q4_0 32 pp512 329.30 331.28 1.01
RX 6800 llama 8B Q4_0 64 pp512 405.96 408.59 1.01
RX 6800 llama 8B Q4_0 128 pp512 488.22 493.10 1.01
RX 6800 llama 8B Q4_0 256 pp512 549.70 554.56 1.01
RX 6800 llama 8B Q4_0 512 pp512 530.23 534.62 1.01

@JohannesGaessler
Copy link
Collaborator

JohannesGaessler commented Jul 22, 2025

I get NaN with llama-perplexity using an RTX 4090 and LLaMA 3 8b q4_0. How did you check the code for correctness?

@JohannesGaessler
Copy link
Collaborator

The code seems to work correctly for -ub 1 but not for -ub 2.

@am17an
Copy link
Collaborator Author

am17an commented Jul 22, 2025

I get NaN with llama-perplexity using an RTX 4090 and LLaMA 3 8b q4_0. How did you check the code for correctness?

I just ran test-backend-ops, let me check what's happening. Probably something with batch size since it's not covered in the test-backend-ops

@JohannesGaessler
Copy link
Collaborator

The PPL value can be fixed with GGML_CUDA_DISABLE_FUSION=1, GGML_CUDA_DISABLE_GRAPHS=1 on its own does not work.

@am17an
Copy link
Collaborator Author

am17an commented Jul 22, 2025

Thanks @JohannesGaessler for quickly figuring out the bug! Could you please try again for your ppl values, I could replicate the issue and it seems to be fixed now. Also cc @exxocism if you are willing to give this another try.

@exxocism
Copy link

@am17an Yes, it works with the env variable GGML_CUDA_DISABLE_FUSION=1.
Here's the stack trace:

#0  0x000076c372f107e3 in __GI___wait4 (pid=202905, stat_loc=0x0, options=0, usage=0x0) at ../sysdeps/unix/sysv/linux/wait4.c:30
30	in ../sysdeps/unix/sysv/linux/wait4.c
#1  0x000076c3738f8de3 in ggml_print_backtrace () from /home/user/Documents/llama.cpp/build/bin/libggml-base.so
#2  0x000076c3738f8f8b in ggml_abort () from /home/user/Documents/llama.cpp/build/bin/libggml-base.so
#3  0x000076c370ccfb67 in ggml_cuda_error(char const*, char const*, char const*, int, char const*) () from /home/user/Documents/llama.cpp/build/bin/libggml-cuda.so
#4  0x000076c370cd1192 in ggml_backend_cuda_synchronize(ggml_backend*) () from /home/user/Documents/llama.cpp/build/bin/libggml-cuda.so
#5  0x000076c373910ae5 in ggml_backend_sched_graph_compute_async () from /home/user/Documents/llama.cpp/build/bin/libggml-base.so
#6  0x000076c37369a1e1 in llama_context::graph_compute(ggml_cgraph*, bool) () from /home/user/Documents/llama.cpp/build/bin/libllama.so
#7  0x000076c37369b84e in llama_context::process_ubatch(llama_ubatch const&, llm_graph_type, llama_memory_context_i*, ggml_status&) () from /home/user/Documents/llama.cpp/build/bin/libllama.so
#8  0x000076c37369f152 in llama_context::decode(llama_batch const&) () from /home/user/Documents/llama.cpp/build/bin/libllama.so
#9  0x000076c3736a01df in llama_decode () from /home/user/Documents/llama.cpp/build/bin/libllama.so
#10 0x00005f8e613e9e1e in server_context::update_slots() ()
#11 0x00005f8e613b013b in server_queue::start_loop() ()
#12 0x00005f8e61377896 in main ()
[Inferior 1 (process 202644) detached]
/home/user/models/2_qwen3/2507_non_thinking.sh: line 44: 202644 Aborted                 (core dumped)  

@exxocism
Copy link

@am17an Thanks! it works with new commit. 🎉

@JohannesGaessler
Copy link
Collaborator

In my testing it now also works with the new commit. I think the problem with test_rms_norm_mul_add in test-backend_ops is that all tensors have the same shape so issues due to broadcasting are not being detected. Can you extend the tests with a boolean argument to conditionally increase the dimensions of one of the input tensors in order to also cover broadcasting?

@am17an am17an requested a review from JohannesGaessler July 22, 2025 11:52
@github-actions github-actions bot added the testing Everything test related label Jul 22, 2025
Comment on lines 2789 to 2790
if (rms_norm == mul->src[1] &&
mul->src[0]->ne[1] != rms_norm->src[1]->ne[1]) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if (rms_norm == mul->src[1] &&
mul->src[0]->ne[1] != rms_norm->src[1]->ne[1]) {
if (rms_norm == mul->src[1] && !ggml_are_same_shape(mul->src[0], rms_norm->src[1]) {

Looking at the broadcasting logic again I think this is what would be correct since the broadcasting could be in any of the dimensions. For the new test cases this doesn't matter because all dimensions are varied simultaneously (I think it's fine to keep it that way). @0cc4m @jeffbolznv since this code seems to have been copied from the Vulkan backend, should it be changed there too?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll take a look.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like the vulkan shader isn't handling repeat in the innermost dimension. I'll make a PR shortly.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should be fixed by #14817.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will merge after #14817 as this PR adds the new test

@am17an
Copy link
Collaborator Author

am17an commented Jul 22, 2025

Actually it looks like the vulkan fusion operation also does not seem to implement the broadcast, the new test cases are failing for them

@am17an am17an force-pushed the cuda_fused_rms_norm branch from 757b81c to ed9f84e Compare July 22, 2025 16:37
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ggml changes relating to the ggml tensor library for machine learning Nvidia GPU Issues specific to Nvidia GPUs testing Everything test related
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants