Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

vulkan: Optimize mul_mat_vec p021 and nc shaders #12505

Merged
merged 2 commits into from
Mar 22, 2025

Conversation

jeffbolznv
Copy link
Collaborator

These shaders are used in attention calculations, and when the KV cache grows large they start to dominate the run time. For the nc shader (which is called with large 'k' dimension), use unrolling and vector loads. For the p021 shader (which is called with large 'm' and small 'k' dimensions), take advantage of grouped query attention to reuse loads from the A matrix for the whole group, and reduce the number of workgroups (too much overhead from tiny dispatches).

Using subgroupAdd in the p021 shader also helps, use that conditionally.

I added new directed perf tests based on the multiplies when KV is 16k:

before:
  MUL_MAT(type_a=f16,type_b=f32,m=16416,n=1,k=128,bs=[8,1],nr=[4,1],per=[0,2,1,3],v=0):                 2232 runs -   590.53 us/run - 134.48 MFLOP/run - 227.73 GFLOPS
  MUL_MAT(type_a=f16,type_b=f32,m=128,n=1,k=16416,bs=[8,1],nr=[4,1],per=[0,1,2,3],v=1):                 2976 runs -   358.81 us/run - 134.48 MFLOP/run - 374.80 GFLOPS

after
  MUL_MAT(type_a=f16,type_b=f32,m=16416,n=1,k=128,bs=[8,1],nr=[4,1],per=[0,2,1,3],v=0):                11904 runs -    84.17 us/run - 134.48 MFLOP/run -   1.60 TFLOPS
  MUL_MAT(type_a=f16,type_b=f32,m=128,n=1,k=16416,bs=[8,1],nr=[4,1],per=[0,1,2,3],v=1):                 7440 runs -   138.90 us/run - 134.48 MFLOP/run - 968.19 GFLOPS  
  
cuda:
  MUL_MAT(type_a=f16,type_b=f32,m=16416,n=1,k=128,bs=[8,1],nr=[4,1],per=[0,2,1,3],v=0):                 9672 runs -   105.18 us/run - 134.48 MFLOP/run -   1.28 TFLOPS
  MUL_MAT(type_a=f16,type_b=f32,m=128,n=1,k=16416,bs=[8,1],nr=[4,1],per=[0,1,2,3],v=1):                 8184 runs -   134.20 us/run - 134.48 MFLOP/run -   1.00 TFLOPS

llama-bench results with large KV cache:

before:

llama-bench.exe -m C:\models\meta-llama-3-8b-instruct.Q4_K_M.gguf -p 0 -n 4096,8192,16384 -fa 0 --repetitions 1
ggml_vulkan: Found 1 Vulkan devices:
ggml_vulkan: 0 = NVIDIA GeForce RTX 4070 (NVIDIA) | uma: 0 | fp16: 1 | warp size: 32 | shared memory: 49152 | matrix cores: NV_coopmat2
| model                          |       size |     params | backend    | ngl |          test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | ------------: | -------------------: |
| llama 8B Q4_K - Medium         |   4.58 GiB |     8.03 B | Vulkan     |  99 |        tg4096 |         55.70 ± 0.00 |
| llama 8B Q4_K - Medium         |   4.58 GiB |     8.03 B | Vulkan     |  99 |        tg8192 |         44.18 ± 0.00 |
| llama 8B Q4_K - Medium         |   4.58 GiB |     8.03 B | Vulkan     |  99 |       tg16384 |         30.90 ± 0.00 |

llama-bench.exe -m C:\models\meta-llama-3-8b-instruct.Q4_K_M.gguf -p 0 -n 0 -pg 16384,128 -fa 0 --repetitions 1
ggml_vulkan: Found 1 Vulkan devices:
ggml_vulkan: 0 = NVIDIA GeForce RTX 4070 (NVIDIA) | uma: 0 | fp16: 1 | warp size: 32 | shared memory: 49152 | matrix cores: NV_coopmat2
| model                          |       size |     params | backend    | ngl |          test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | ------------: | -------------------: |
| llama 8B Q4_K - Medium         |   4.58 GiB |     8.03 B | Vulkan     |  99 | pp16384+tg128 |        860.84 ± 0.00 |  
  
after:

llama-bench.exe -m C:\models\meta-llama-3-8b-instruct.Q4_K_M.gguf -p 0 -n 4096,8192,16384 -fa 0 --repetitions 1
ggml_vulkan: Found 1 Vulkan devices:
ggml_vulkan: 0 = NVIDIA GeForce RTX 4070 (NVIDIA) | uma: 0 | fp16: 1 | warp size: 32 | shared memory: 49152 | matrix cores: NV_coopmat2
| model                          |       size |     params | backend    | ngl |          test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | ------------: | -------------------: |
| llama 8B Q4_K - Medium         |   4.58 GiB |     8.03 B | Vulkan     |  99 |        tg4096 |         68.84 ± 0.00 |
| llama 8B Q4_K - Medium         |   4.58 GiB |     8.03 B | Vulkan     |  99 |        tg8192 |         63.41 ± 0.00 |
| llama 8B Q4_K - Medium         |   4.58 GiB |     8.03 B | Vulkan     |  99 |       tg16384 |         55.05 ± 0.00 |

llama-bench.exe -m C:\models\meta-llama-3-8b-instruct.Q4_K_M.gguf -p 0 -n 0 -pg 16384,128 -fa 0 --repetitions 1
ggml_vulkan: Found 1 Vulkan devices:
ggml_vulkan: 0 = NVIDIA GeForce RTX 4070 (NVIDIA) | uma: 0 | fp16: 1 | warp size: 32 | shared memory: 49152 | matrix cores: NV_coopmat2
| model                          |       size |     params | backend    | ngl |          test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | ------------: | -------------------: |
| llama 8B Q4_K - Medium         |   4.58 GiB |     8.03 B | Vulkan     |  99 | pp16384+tg128 |       1028.64 ± 0.00 |

@jeffbolznv jeffbolznv requested a review from 0cc4m March 21, 2025 19:18
@github-actions github-actions bot added testing Everything test related Vulkan Issues specific to the Vulkan backend ggml changes relating to the ggml tensor library for machine learning labels Mar 21, 2025
These shaders are used in attention calculations, and when the KV cache grows
large they start to dominate the run time. For the nc shader (which is called
with large 'k' dimension), use unrolling and vector loads. For the p021 shader
(which is called with large 'm' and small 'k' dimensions), take advantage of
grouped query attention to reuse loads from the A matrix for the whole group,
and reduce the number of workgroups (too much overhead from tiny dispatches).

Using subgroupAdd in the p021 shader also helps, use that conditionally.
@0cc4m
Copy link
Collaborator

0cc4m commented Mar 22, 2025

Wow, great work. That's a good improvement across all my tests.

ggml_vulkan: 0 = NVIDIA GeForce RTX 3090 (NVIDIA) | uma: 0 | fp16: 1 | warp size: 32 | shared memory: 49152 | matrix cores: NV_coopmat2

model size params backend ngl test t/s master t/s PR
llama 8B Q4_K - Small 4.36 GiB 8.03 B Vulkan 99 tg4096 66.32 ± 0.00 82.59 ± 0.00
llama 8B Q4_K - Small 4.36 GiB 8.03 B Vulkan 99 tg8192 51.65 ± 0.00 75.71 ± 0.00
llama 8B Q4_K - Small 4.36 GiB 8.03 B Vulkan 99 tg16384 33.57 ± 0.00 64.76 ± 0.00

ggml_vulkan: 0 = AMD Radeon (TM) Pro VII (RADV VEGA20) (radv) | uma: 0 | fp16: 1 | warp size: 64 | shared memory: 65536 | matrix cores: none

model size params backend ngl test t/s master t/s PR
llama 8B Q4_K - Small 4.36 GiB 8.03 B Vulkan 99 tg4096 43.01 ± 0.00 51.70 ± 0.00
llama 8B Q4_K - Small 4.36 GiB 8.03 B Vulkan 99 tg8192 32.39 ± 0.00 45.45 ± 0.00

ggml_vulkan: 0 = Intel(R) Arc(tm) A770 Graphics (DG2) (Intel open-source Mesa driver) | uma: 0 | fp16: 1 | warp size: 32 | shared memory: 65536 | matrix cores: none

model size params backend ngl test t/s master t/s PR
llama 8B Q4_K - Small 4.36 GiB 8.03 B Vulkan 99 tg4096 16.41 ± 0.00 20.09 ± 0.00
llama 8B Q4_K - Small 4.36 GiB 8.03 B Vulkan 99 tg8192 13.53 ± 0.00 19.47 ± 0.00

@0cc4m 0cc4m merged commit eddfb43 into ggml-org:master Mar 22, 2025
51 of 52 checks passed
@stduhpf
Copy link
Contributor

stduhpf commented Mar 22, 2025

The performance bump is very nice!

However,

> .\build\bin\Release\test-backend-ops.exe | Select-String -pattern "FAIL"
ggml_vulkan: Found 2 Vulkan devices:
ggml_vulkan: 0 = AMD Radeon RX 6800 (AMD proprietary driver) | uma: 0 | fp16: 1 | warp size: 32 | shared memory: 32768 | matrix cores: none
ggml_vulkan: 1 = AMD Radeon RX 5700 XT (AMD proprietary driver) | uma: 0 | fp16: 1 | warp size: 32 | shared memory: 32768 | matrix cores: none

  MUL_MAT(type_a=f16,type_b=f32,m=1056,n=1,k=129,bs=[1,1],nr=[4,1],per=[0,2,1,3],v=0): [MUL_MAT] NMSE = 0.015111144
> 0.000500000 FAIL
  MUL_MAT(type_a=f16,type_b=f32,m=1057,n=1,k=129,bs=[1,1],nr=[4,1],per=[0,2,1,3],v=0): [MUL_MAT] NMSE = 0.000550270
> 0.000500000 FAIL
  Backend Vulkan0: FAIL
  MUL_MAT(type_a=f16,type_b=f32,m=1056,n=1,k=129,bs=[1,1],nr=[4,1],per=[0,2,1,3],v=0): [MUL_MAT] NMSE = 0.009975651
> 0.000500000 FAIL
  MUL_MAT(type_a=f16,type_b=f32,m=1057,n=1,k=129,bs=[1,1],nr=[4,1],per=[0,2,1,3],v=0): [MUL_MAT] NMSE = 0.009067123
> 0.000500000 FAIL
  Backend Vulkan1: FAIL
FAIL
eddfb438502bd5d1014d63a812e9b6d03d326f8c is the first bad commit
commit eddfb438502bd5d1014d63a812e9b6d03d326f8c (HEAD, origin/master, origin/HEAD)
Author: Jeff Bolz <jbolz@nvidia.com>
Date:   Sat Mar 22 03:40:11 2025 -0500

    vulkan: Optimize mul_mat_vec p021 and nc shaders (#12505)

    * tests: add mul_mat perf/functional tests for p021/nc vulkan shaders

    * vulkan: Optimize mul_mat_vec p021 and nc shaders.

    These shaders are used in attention calculations, and when the KV cache grows
    large they start to dominate the run time. For the nc shader (which is called
    with large 'k' dimension), use unrolling and vector loads. For the p021 shader
    (which is called with large 'm' and small 'k' dimensions), take advantage of
    grouped query attention to reuse loads from the A matrix for the whole group,
    and reduce the number of workgroups (too much overhead from tiny dispatches).

    Using subgroupAdd in the p021 shader also helps, use that conditionally.

 ggml/src/ggml-vulkan/ggml-vulkan.cpp               |  36 +++++-
 .../ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp |  66 +++++++++--
 .../vulkan-shaders/mul_mat_vec_p021.comp           | 125 +++++++++++++++++----
 .../vulkan-shaders/vulkan-shaders-gen.cpp          |   5 +-
 tests/test-backend-ops.cpp                         |  31 ++++-
 5 files changed, 219 insertions(+), 44 deletions(-)

Note that the outputs still seem perfectly fine and coherent, so this isn't really a problem.

@0cc4m
Copy link
Collaborator

0cc4m commented Mar 22, 2025

Note that the outputs still seem perfectly fine and coherent, so this isn't really a problem.

Oh no, not the proprietary driver again...

@stduhpf
Copy link
Contributor

stduhpf commented Mar 22, 2025

The tests were finally all passing on the proprietary driver with the commit just before this one 😢

@0cc4m
Copy link
Collaborator

0cc4m commented Mar 22, 2025

The tests were finally all passing on the proprietary driver with the commit just before this one 😢

Can you try disabling subgroup_add? Maybe that is what the driver doesn't like (even though it claims it supports it).

device->subgroup_add = (vk11_props.subgroupSupportedStages & vk::ShaderStageFlagBits::eCompute) &&
                               (vk11_props.subgroupSupportedOperations & vk::SubgroupFeatureFlagBits::eArithmetic);

set this to false

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ggml changes relating to the ggml tensor library for machine learning testing Everything test related Vulkan Issues specific to the Vulkan backend
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants