-
Notifications
You must be signed in to change notification settings - Fork 13.3k
CUDA: larger SRAM reads for tile FA, AMD FP16 dot #15927
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
JohannesGaessler
merged 2 commits into
ggml-org:master
from
JohannesGaessler:cuda-fa-tile-mem-pattern-4
Sep 11, 2025
Merged
CUDA: larger SRAM reads for tile FA, AMD FP16 dot #15927
JohannesGaessler
merged 2 commits into
ggml-org:master
from
JohannesGaessler:cuda-fa-tile-mem-pattern-4
Sep 11, 2025
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
1 task
4ff6731
to
fe4eb4f
Compare
slaren
approved these changes
Sep 11, 2025
} else if constexpr (nbytes == 16) { | ||
*(int4 *) dst = *(const int4 *) src; | ||
} else { | ||
static_assert(nbytes == 0 && nbytes == -1, "bad nbytes"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wouldn't this work?
Suggested change
static_assert(nbytes == 0 && nbytes == -1, "bad nbytes"); | |
static_assert(false, "bad nbytes"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tried this first, it failed during the host pass.
LunNova
added a commit
to LunNova/nixpkgs
that referenced
this pull request
Sep 16, 2025
Includes fix for v_dot2_f32_f16 being used on ISAs without that instruction. ggml-org/llama.cpp#15927
13 tasks
Nexesenex
added a commit
to Nexesenex/croco.cpp
that referenced
this pull request
Sep 30, 2025
)" This reverts commit 75a3a6c. d Update cudart64_12.dll Revert "Cudart 12.9" This reverts commit f79c687. Revert "Allow compile exe, pdf features off" This reverts commit 5e1c154. Update fattn.cu Update set-rows.cu batches Revert "try fix fattn again, porting some older code. the cc detection is not working well, so its hacky" This reverts commit 7b04191. Update ggml-cuda.cu Update fattn.cu Update fattn.cu Update fattn.cu Add option to disable MMA support on Turing Author : pt13762104 GGML_CUDA_NO_PEER_COPY to try to fix a crash on Gemma 3 Deactivate SWA when Fast Forwarding, commented Wrench Fix for the SWA I borked Clean-up quantkv algo comment warp sizes for now in IQ_K MMQ Kernels KV 24 -> KV 31 Add a readme. ngxson's commented hack Try some hack for gpt-oss Update llama-vocab.cpp Bump Windows max open files from 512 to 2048 Author : Thireus CLI - Specify GGML_TYPE to quantize for the main tensors. (#91) To complement the token_embd.weight and output.weight : attn_v.weight attn_k.weight. attn_q_weight attn_output.weight attn_qkv.weight ffn_gate ffn_down ffn_up EsoCroK naming v1.99430_b6645-6_Q6-IO2346_RMv1.17.99m Disable I2_K cpu quantization. To allow compilation. MMQ code adaptation Update mmq.cuh MMQ Initial code for IQ2,3,4,5,6_K IQ_K quants first gen (4, 5, 6) Some logs back Batches Croco Bench. Double the anti-abuse limits Allow compile exe, pdf features off Revert "Allow compile exe, pdf features off" This reverts commit 5e2451f129f0bca326f74aae24df475c0410cdbf. Update koboldcpp.py Revert "Allow compile exe, pdf features off" This reverts commit 2a7e9e004e8578a05fb67967d09cf36263867b9b. Revert "Allow compile exe, pdf features off" This reverts commit b4fd7809a4f77ff18bd415fcfb2d5f435e3b63a3. quantization tweaks iq3_ks quantization tweaks Minor iq3_k tweak q2_K tweaks q3_K tweaks q4_K tweaks q5_K tweaks GGUF v14 attempt of second fix. loosen gguf restrictions. Quantization improvements #295 and #302, GGML part only Improved IQ2_XS quantization #312 Improved IQ1_M quantization #327 ggml_row_size accounting fix for GGUF v14 Credits : @ikawrakow Fighting with cmake #279 Drop the GGML count limitation limit Old markings Customize KCPP.py Croco additional chat adapters andtemplates Reinstate "skip barrier of noop" Allow q8_0 KV cache for head size 256 #330 Up FA KV modes 256 candidates (1024 with Grammar) Adapt q6_0 MMQ to llama.cpp mainline Q6_0 MMQ Kernel attempt MMQ for Q6_0 authored by Ikawrakow Add Q6_0 MMQ to template generator authored by Ikawrakow Q6_0 KVQ for KCPP/Croco -> KV22 For release. fix a few lazy-cuts and hiccups left during the merge of IQ4_NL. dequantize for q6_0 and related cpy Enable q6_0 for flash attention As with IQ4_NL, just for head size of 128 for now. Without GGML_CUDA_FA_ALL_QUANTS set, only Q6_0 + Q5_0 and Q8_0 + Q6_0 are included. With this the VRAM poor have better options for selecting the best possible (as allowed by VRAM, model size, context length) quantized KV-cache. PR by Ikawrakow on ik_llama.cpp Adding Q6_0 (#77) Rev 20240807 * Adding q6_0 - basics + AVX2/Zen4 working * Adding q6_0: CUDA dequantize works, but not mmvq * Adding q6_0: CUDA mmvq works * Adding q6_0: CUDA cpy, so Q6_0 can be used for KV-cache * Add q6_0 to CPU flash attention Disappointing result: for LlaMA-3.2-1B, q6_0 K- and V-cache gives about the same PPL as q8_0 K-cache and q4_0 V-cache, while needing the exact same RAM. I.e., what was the point? * q6_0: slightly better kv-cache result Better than q8_0+q4_0, but not as good as q8_0+iq4_nl * q6_0: works on ARM_NEON * q6_0: dequantize works on Metal, but not vector dot product * q6_0: it now works on Metal Outperforms q5_0 by a significant margin. E.g. | model | size | params | backend | ngl | threads | test | t/s | | ------------------------------ | ---------: | ---------: | ---------- | --: | ------: | ------------: | ---------------: | | llama 8B Q6_0 | 6.08 GiB | 8.03 B | Metal | 100 | 4 | tg128 | 44.02 ± 0.08 | | llama 8B Q5_0 | 5.21 GiB | 8.03 B | Metal | 100 | 4 | tg128 | 40.13 ± 0.12 | | llama 8B Q6_0 | 6.08 GiB | 8.03 B | Metal | 100 | 4 | pp512 | 500.55 ± 0.32 | | llama 8B Q5_0 | 5.21 GiB | 8.03 B | Metal | 100 | 4 | pp512 | 448.02 ± 0.27 | * q6_0: can now be used for kv-cache on Metal -> skipped. --------- Adaptation to mainline by me! IQ4_NL KVQ for KCPP/Croco missing templates instances for KVQ IQ4_NL Update fattn.cu for KVQ IQ4_NL Update fattn-vec-f16.cuh for KVQ IQ4_NL Update fattn-vec-f32.cuh for KVQ IQ4_NL CML and Makefile FOR IQ4_NL KV_IQ4_NL uncommenting VEC16 cases KV_IQ4_NL uncommenting VEC32 cases Enable IQ4_NL for V-cache in token generation Add IQ4_NL + IQ4_NL to FA This is a better alternative than Q4_0 + Q4_0 for the VRAM poor. Comment unwanted add-in in makefile iq4_nl: faster quantization (#76) CUDA: faster float -> iq4_nl conversion (#73) * iqk_mul_mat: better iq4_nl implementation on Zen4/AVX2 PP-512 performance for LLaMA-3.1-8B goes to 162.6 t/s up from 133.2 t/s. Default Blas Batch Size = 128 Quant KV and Draft QKV, 24 modes With customizable QKV for the draft as well. And reduced Blas Batch Size for the draft model. Default Draft Amount = 4 Bench context size Max contextsize and steps Croco CML SCHED_MAX_COPIES = 1 And Croco usual additions to the CMakeList Cudart 12.9 Revert "CUDA: faster tile FA (Pascal/AMD), headsize 256 (ggml-org#15769)" This reverts commit 79bc429. Revert "HIP: use v_dot2_f32_f16 instruction for FA (ggml-org#15884)" This reverts commit 17bc5a8. Revert "CUDA: larger SRAM reads for tile FA, AMD FP16 dot (ggml-org#15927)" This reverts commit 0e6ff00. Revert "CUDA: fix FA occupancy, optimize tile kernel (ggml-org#15982)" This reverts commit c959b67. Revert "CUDA: fix compilation on CC 6.0 (ggml-org#16091)" This reverts commit 368560a. Co-Authored-By: Kawrakow <iwankawrakow@gmail.com> Co-Authored-By: Iwan Kawrakow <iwan.kawrakow@gmail.com>
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Labels
ggml
changes relating to the ggml tensor library for machine learning
Nvidia GPU
Issues specific to Nvidia GPUs
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
See https://github.com/iacopPBK/llama.cpp-gfx906 . AMD GPUs support reads of up to 16 bytes from SRAM. This PR extends the tile FlashAttention CUDA kernel with support for reads of 8 or 16 bytes. The FP32 -> FP16 type conversion is also done prior to writing the data to SRAM to reduce I/O further.
I also checked the AMD ISA documentation for
v_dot2_f32_f16
support and adjusted the code paths accordingly; it seems to be available everywhere except for RDNA 1.Performance changes