Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

cuda : improve text-generation and batched decoding performance #3776

Merged
merged 7 commits into from
Oct 27, 2023

Conversation

ggerganov
Copy link
Owner

@ggerganov ggerganov commented Oct 25, 2023

ref #3479
ref #3771

Description

This PR should improve significantly the text-generation, prompt processing and batched decoding speed for all models for NVIDIA cards with tensor cores (i.e. VOLTA, AMPERE, etc).

Prompt processing

By default llama.cpp uses MMQ=1 which means that the matrix-matrix multiplications for quantized models are performed with custom kernel for integer multiplications. Recently (#3412), we found out that for large batch dimension (which is the case when processing prompts), MMQ=0 offers significant performance boost by first dequantizing src0 to F16 and performing the GEMM using cublas. This PR essentially enables the same optimization for MMQ=1 by not using the custom kernel for batch size > 32.

Batched decoding

In this mode, the batch size is larger than 1, but typically small (for example not more than 32). In #3545 we found out that the currently used constants MMQ_X, MMQ_Y and NWARPS are not optimal for small batch sizes. Probably they have been optimized for prompt processing. However, since we now fallback to cuBLAS for prompt processing, the constants can be adjusted for small batch sizes.

Text-generation

So far, for the KV cache related ops (KQ and KQV) we have been using custom matrix-vector kernels. For small sequence lengths (~128) and no prompt, these kernels are quite efficient. However, as the KV cache grows with the sequence length it is more efficient to use the tensor cores via cuBLAS GEMM. This PR applies this change to achieve TG improvements for all models when the context is big

In summary, we now have the following strategy for matrix multiplications:

  • batch size == 1:
    • non-attention ops && quantized src0: use custom matrix-vector kernel
    • otherwise: use cuBLAS GEMM
  • batch size <= 32: use custom matrix-matrix kernel
  • batch size > 32: use cuBLAS GEMM

Results

RTX 3090

LLAMA_CUBLAS=1 make -j batched batched-bench && ./batched-bench ./models/codellama-7b/ggml-model-q4_0.gguf 8704 1 99 1 512 128 1,2,3,4,5,6,7,8,16,32,64

master

main: n_kv_max = 8704, is_pp_shared = 1, n_gpu_layers = 99, mmq = 1

PP TG B N_KV T_PP s S_PP t/s T_TG s S_TG t/s T s S t/s
512 128 1 640 0.248 2063.97 1.098 116.60 1.346 475.54
512 128 2 768 0.234 2191.11 7.435 34.43 7.669 100.15
512 128 3 896 0.233 2198.40 7.515 51.09 7.748 115.64
512 128 4 1024 0.234 2189.40 7.553 67.78 7.787 131.50
512 128 5 1152 0.236 2168.46 7.635 83.83 7.871 146.37
512 128 6 1280 0.236 2167.85 7.700 99.74 7.936 161.28
512 128 7 1408 0.238 2152.19 7.768 115.34 8.006 175.86
512 128 8 1536 0.242 2115.82 7.832 130.75 8.074 190.25
512 128 16 2560 0.244 2095.92 8.391 244.06 8.636 296.45
512 128 32 4608 0.264 1937.69 8.852 462.71 9.116 505.46
512 128 64 8704 0.245 2085.75 11.646 703.42 11.891 731.95

PR

main: n_kv_max = 8704, is_pp_shared = 1, n_gpu_layers = 99, mmq = 1

PP TG B N_KV T_PP s S_PP t/s T_TG s S_TG t/s T s S t/s
512 128 1 640 0.173 2953.07 1.073 119.32 1.246 513.57
512 128 2 768 0.165 3109.78 1.715 149.23 1.880 408.49
512 128 3 896 0.162 3155.39 1.768 217.16 1.931 464.11
512 128 4 1024 0.164 3123.65 1.803 283.98 1.967 520.63
512 128 5 1152 0.163 3138.81 2.403 266.36 2.566 448.96
512 128 6 1280 0.166 3077.79 2.433 315.70 2.599 492.49
512 128 7 1408 0.166 3082.24 2.515 356.21 2.681 525.08
512 128 8 1536 0.166 3080.13 2.551 401.48 2.717 565.37
512 128 16 2560 0.169 3028.26 4.723 433.60 4.892 523.27
512 128 32 4608 0.167 3059.51 7.973 513.75 8.140 566.09
512 128 64 8704 0.164 3126.79 11.462 714.74 11.625 748.71
LLAMA_CUBLAS=1 make -j batched-bench && ./batched-bench ./models/codellama-13b/ggml-model-q4_k.gguf 4096 1 99 1 512,3200 128,800 1

master

main: n_kv_max = 4096, is_pp_shared = 1, n_gpu_layers = 99, mmq = 1

PP TG B N_KV T_PP s S_PP t/s T_TG s S_TG t/s T s S t/s
512 128 1 640 0.437 1171.55 2.124 60.28 2.561 249.95
512 800 1 1312 0.435 1176.46 14.269 56.07 14.704 89.23
3200 128 1 3328 3.233 989.79 3.422 37.40 6.655 500.07
3200 800 1 4000 3.219 994.22 22.392 35.73 25.611 156.19

PR

PP TG B N_KV T_PP s S_PP t/s T_TG s S_TG t/s T s S t/s
512 128 1 640 0.246 2084.35 2.021 63.32 2.267 282.30
512 800 1 1312 0.247 2074.85 13.436 59.54 13.683 95.88
3200 128 1 3328 1.977 1618.55 2.501 51.18 4.478 743.16
3200 800 1 4000 1.984 1613.30 16.140 49.57 18.123 220.71
make -j && ../scripts/run-all-perf.sh codellama-7b "f16 q8_0 q4_0 q4_k" "-ngl 999 -t 1 -n 128,512 -p 512"

master

model size params backend ngl threads test t/s
llama 7B F16 12.55 GiB 6.74 B CUDA 999 1 pp 512 4493.64 ± 25.65
llama 7B F16 12.55 GiB 6.74 B CUDA 999 1 tg 128 54.67 ± 0.04
llama 7B F16 12.55 GiB 6.74 B CUDA 999 1 tg 512 52.65 ± 0.24
llama 7B Q8_0 6.67 GiB 6.74 B CUDA 999 1 pp 512 2038.44 ± 73.46
llama 7B Q8_0 6.67 GiB 6.74 B CUDA 999 1 tg 128 87.10 ± 0.09
llama 7B Q8_0 6.67 GiB 6.74 B CUDA 999 1 tg 512 82.33 ± 0.25
llama 7B Q4_0 3.56 GiB 6.74 B CUDA 999 1 pp 512 1901.29 ± 154.12
llama 7B Q4_0 3.56 GiB 6.74 B CUDA 999 1 tg 128 133.48 ± 0.17
llama 7B Q4_0 3.56 GiB 6.74 B CUDA 999 1 tg 512 123.13 ± 0.38
llama 7B Q4_K - Medium 3.80 GiB 6.74 B CUDA 999 1 pp 512 1998.98 ± 52.16
llama 7B Q4_K - Medium 3.80 GiB 6.74 B CUDA 999 1 tg 128 110.47 ± 0.18
llama 7B Q4_K - Medium 3.80 GiB 6.74 B CUDA 999 1 tg 512 101.51 ± 0.29

PR

model size params backend ngl threads test t/s
llama 7B F16 12.55 GiB 6.74 B CUDA 999 1 pp 512 4513.66 ± 32.46
llama 7B F16 12.55 GiB 6.74 B CUDA 999 1 tg 128 54.14 ± 0.13
llama 7B F16 12.55 GiB 6.74 B CUDA 999 1 tg 512 53.00 ± 0.26
llama 7B Q8_0 6.67 GiB 6.74 B CUDA 999 1 pp 512 3183.63 ± 16.25
llama 7B Q8_0 6.67 GiB 6.74 B CUDA 999 1 tg 128 85.08 ± 0.15
llama 7B Q8_0 6.67 GiB 6.74 B CUDA 999 1 tg 512 82.86 ± 0.14
llama 7B Q4_0 3.56 GiB 6.74 B CUDA 999 1 pp 512 3250.48 ± 22.76
llama 7B Q4_0 3.56 GiB 6.74 B CUDA 999 1 tg 128 130.02 ± 0.09
llama 7B Q4_0 3.56 GiB 6.74 B CUDA 999 1 tg 512 126.12 ± 0.29
llama 7B Q4_K - Medium 3.80 GiB 6.74 B CUDA 999 1 pp 512 3661.56 ± 53.75
llama 7B Q4_K - Medium 3.80 GiB 6.74 B CUDA 999 1 tg 128 109.18 ± 0.19
llama 7B Q4_K - Medium 3.80 GiB 6.74 B CUDA 999 1 tg 512 104.73 ± 0.08

RTX 4090

LLAMA_CUBLAS=1 make -j batched batched-bench && ./batched-bench ./models/codellama-7b/ggml-model-q4_0.gguf 8704 1 99 1 512 128 1,2,3,4,5,6,7,8,16,32,64

master

main: n_kv_max = 8704, is_pp_shared = 1, n_gpu_layers = 99, mmq = 1

PP TG B N_KV T_PP s S_PP t/s T_TG s S_TG t/s T s S t/s
512 128 1 640 0.142 3615.61 0.915 139.85 1.057 605.55
512 128 2 768 0.107 4779.37 5.170 49.52 5.277 145.55
512 128 3 896 0.126 4050.31 5.201 73.83 5.328 168.17
512 128 4 1024 0.120 4268.98 5.210 98.27 5.330 192.12
512 128 5 1152 0.128 3988.04 5.240 122.15 5.368 214.60
512 128 6 1280 0.126 4062.23 5.242 146.51 5.368 238.45
512 128 7 1408 0.138 3721.15 5.336 167.91 5.474 257.22
512 128 8 1536 0.114 4493.71 5.302 193.12 5.416 283.59
512 128 16 2560 0.115 4435.90 5.558 368.49 5.673 451.24
512 128 32 4608 0.118 4325.20 5.853 699.78 5.972 771.65
512 128 64 8704 0.120 4251.54 7.116 1151.21 7.236 1202.81

PR

main: n_kv_max = 8704, is_pp_shared = 1, n_gpu_layers = 99, mmq = 1

PP TG B N_KV T_PP s S_PP t/s T_TG s S_TG t/s T s S t/s
512 128 1 640 0.109 4697.51 0.943 135.73 1.052 608.36
512 128 2 768 0.089 5738.82 1.389 184.31 1.478 519.57
512 128 3 896 0.099 5188.49 1.410 272.26 1.509 593.72
512 128 4 1024 0.091 5633.43 1.438 355.94 1.529 669.57
512 128 5 1152 0.093 5476.76 1.508 424.27 1.602 719.12
512 128 6 1280 0.086 5968.90 1.520 505.42 1.605 797.35
512 128 7 1408 0.092 5567.34 1.546 579.54 1.638 859.57
512 128 8 1536 0.091 5653.09 1.574 650.69 1.664 922.91
512 128 16 2560 0.084 6129.17 2.196 932.58 2.280 1123.01
512 128 32 4608 0.099 5172.66 3.436 1192.04 3.535 1303.50
512 128 64 8704 0.097 5279.28 6.336 1293.00 6.433 1353.10
LLAMA_CUBLAS=1 make -j batched-bench && ./batched-bench ./models/codellama-13b/ggml-model-q4_k.gguf 4096 1 99 1 512,3200 128,800 1

master

PP TG B N_KV T_PP s S_PP t/s T_TG s S_TG t/s T s S t/s
512 128 1 640 0.226 2269.96 1.607 79.66 1.832 349.29
512 800 1 1312 0.214 2387.60 10.557 75.78 10.771 121.80
3200 128 1 3328 1.640 1950.80 2.294 55.80 3.934 845.87
3200 800 1 4000 1.626 1968.49 14.876 53.78 16.501 242.41

PR

main: n_kv_max = 4096, is_pp_shared = 1, n_gpu_layers = 99, mmq = 1

PP TG B N_KV T_PP s S_PP t/s T_TG s S_TG t/s T s S t/s
512 128 1 640 0.152 3363.09 1.612 79.40 1.764 362.73
512 800 1 1312 0.145 3522.46 10.201 78.42 10.347 126.80
3200 128 1 3328 1.269 2521.20 1.986 64.45 3.255 1022.30
3200 800 1 4000 1.268 2523.72 12.692 63.03 13.960 286.53
make -j && ../scripts/run-all-perf.sh codellama-7b "f16 q8_0 q4_0 q4_k" "-ngl 999 -t 1 -n 128,512 -p 512"

master

model size params backend ngl threads test t/s
llama 7B F16 12.55 GiB 6.74 B CUDA 999 1 pp 512 8972.06 ± 345.72
llama 7B F16 12.55 GiB 6.74 B CUDA 999 1 tg 128 62.27 ± 0.19
llama 7B F16 12.55 GiB 6.74 B CUDA 999 1 tg 512 61.54 ± 0.04
llama 7B Q8_0 6.67 GiB 6.74 B CUDA 999 1 pp 512 5272.91 ± 99.01
llama 7B Q8_0 6.67 GiB 6.74 B CUDA 999 1 tg 128 101.73 ± 0.15
llama 7B Q8_0 6.67 GiB 6.74 B CUDA 999 1 tg 512 99.43 ± 0.05
llama 7B Q4_0 3.56 GiB 6.74 B CUDA 999 1 pp 512 5017.87 ± 132.87
llama 7B Q4_0 3.56 GiB 6.74 B CUDA 999 1 tg 128 158.38 ± 0.13
llama 7B Q4_0 3.56 GiB 6.74 B CUDA 999 1 tg 512 152.41 ± 0.85
llama 7B Q4_K - Medium 3.80 GiB 6.74 B CUDA 999 1 pp 512 4763.11 ± 163.93
llama 7B Q4_K - Medium 3.80 GiB 6.74 B CUDA 999 1 tg 128 147.54 ± 0.30
llama 7B Q4_K - Medium 3.80 GiB 6.74 B CUDA 999 1 tg 512 142.81 ± 0.15

PR

model size params backend ngl threads test t/s
llama 7B F16 12.55 GiB 6.74 B CUDA 999 1 pp 512 9473.31 ± 227.98
llama 7B F16 12.55 GiB 6.74 B CUDA 999 1 tg 128 61.76 ± 0.29
llama 7B F16 12.55 GiB 6.74 B CUDA 999 1 tg 512 61.27 ± 0.01
llama 7B Q8_0 6.67 GiB 6.74 B CUDA 999 1 pp 512 6473.99 ± 4.55
llama 7B Q8_0 6.67 GiB 6.74 B CUDA 999 1 tg 128 100.22 ± 0.14
llama 7B Q8_0 6.67 GiB 6.74 B CUDA 999 1 tg 512 98.65 ± 0.07
llama 7B Q4_0 3.56 GiB 6.74 B CUDA 999 1 pp 512 6693.25 ± 6.18
llama 7B Q4_0 3.56 GiB 6.74 B CUDA 999 1 tg 128 154.33 ± 0.36
llama 7B Q4_0 3.56 GiB 6.74 B CUDA 999 1 tg 512 150.87 ± 0.12
llama 7B Q4_K - Medium 3.80 GiB 6.74 B CUDA 999 1 pp 512 7277.83 ± 4.73
llama 7B Q4_K - Medium 3.80 GiB 6.74 B CUDA 999 1 tg 128 144.40 ± 0.25
llama 7B Q4_K - Medium 3.80 GiB 6.74 B CUDA 999 1 tg 512 141.19 ± 0.07

V100

LLAMA_CUBLAS=1 make -j batched-bench && ./batched-bench ./models/openllama-7b-v2/ggml-model-q4_k.gguf 4096 1 99 1 512,3200 128,800 1

master

main: n_kv_max = 4096, is_pp_shared = 1, n_gpu_layers = 99, mmq = 1

PP TG B N_KV T_PP s S_PP t/s T_TG s S_TG t/s T s S t/s
512 128 1 640 0.353 1451.95 1.388 92.23 1.740 367.73
512 800 1 1312 0.351 1457.49 9.431 84.83 9.782 134.12
3200 128 1 3328 2.648 1208.35 2.329 54.97 4.977 668.71
3200 800 1 4000 2.653 1206.38 15.285 52.34 17.937 223.00

PR

main: n_kv_max = 4096, is_pp_shared = 1, n_gpu_layers = 99, mmq = 1

PP TG B N_KV T_PP s S_PP t/s T_TG s S_TG t/s T s S t/s
512 128 1 640 0.211 2421.42 1.376 92.99 1.588 403.05
512 800 1 1312 0.212 2419.59 9.145 87.48 9.357 140.22
3200 128 1 3328 1.767 1811.34 2.026 63.17 3.793 877.43
3200 800 1 4000 1.765 1812.75 13.127 60.94 14.892 268.60

A100 80GB

LLAMA_CUBLAS=1 make -j batched batched-bench && ./batched-bench ./models/codellama-7b/ggml-model-q4_0.gguf 8704 1 99 1 512 128 1,2,3,4,5,6,7,8,16,32,64

master

main: n_kv_max = 8704, is_pp_shared = 1, n_gpu_layers = 99, mmq = 1

PP TG B N_KV T_PP s S_PP t/s T_TG s S_TG t/s T s S t/s
512 128 1 640 0.234 2185.63 1.031 124.18 1.265 505.93
512 128 2 768 0.217 2363.62 10.829 23.64 11.045 69.53
512 128 3 896 0.217 2360.95 10.870 35.33 11.087 80.81
512 128 4 1024 0.217 2361.65 10.904 46.95 11.121 92.08
512 128 5 1152 0.217 2359.50 10.967 58.36 11.184 103.01
512 128 6 1280 0.217 2360.19 10.993 69.86 11.210 114.19
512 128 7 1408 0.217 2360.06 11.044 81.13 11.261 125.04
512 128 8 1536 0.217 2360.97 11.081 92.41 11.297 135.96
512 128 16 2560 0.217 2355.54 11.536 177.53 11.754 217.81
512 128 32 4608 0.218 2346.86 11.580 353.72 11.798 390.58
512 128 64 8704 0.220 2331.26 13.576 603.41 13.796 630.92

PR

main: n_kv_max = 8704, is_pp_shared = 1, n_gpu_layers = 99, mmq = 1

PP TG B N_KV T_PP s S_PP t/s T_TG s S_TG t/s T s S t/s
512 128 1 640 0.115 4469.78 1.006 127.18 1.121 570.93
512 128 2 768 0.097 5258.29 1.853 138.13 1.951 393.72
512 128 3 896 0.097 5298.23 1.893 202.84 1.990 450.31
512 128 4 1024 0.098 5247.73 1.931 265.14 2.029 504.78
512 128 5 1152 0.097 5277.15 2.147 298.05 2.244 513.31
512 128 6 1280 0.097 5276.45 2.176 352.95 2.273 563.13
512 128 7 1408 0.097 5289.42 2.228 402.20 2.325 605.71
512 128 8 1536 0.097 5282.33 2.274 450.38 2.371 647.94
512 128 16 2560 0.097 5259.21 3.386 604.89 3.483 734.98
512 128 32 4608 0.098 5212.26 5.141 796.67 5.240 879.46
512 128 64 8704 0.100 5132.47 8.464 967.89 8.564 1016.40
LLAMA_CUBLAS=1 make -j batched-bench && ./batched-bench ./models/codellama-13b/ggml-model-q4_k.gguf 4096 1 99 1 512,3200 128,800 1

master

main: n_kv_max = 4096, is_pp_shared = 1, n_gpu_layers = 99, mmq = 1

PP TG B N_KV T_PP s S_PP t/s T_TG s S_TG t/s T s S t/s
512 128 1 640 0.462 1108.58 1.879 68.13 2.341 273.42
512 800 1 1312 0.444 1153.87 12.526 63.87 12.970 101.16
3200 128 1 3328 3.100 1032.42 2.914 43.92 6.014 553.38
3200 800 1 4000 3.101 1032.07 19.025 42.05 22.126 180.78

PR

main: n_kv_max = 4096, is_pp_shared = 1, n_gpu_layers = 99, mmq = 1

PP TG B N_KV T_PP s S_PP t/s T_TG s S_TG t/s T s S t/s
512 128 1 640 0.163 3150.34 1.842 69.50 2.004 319.31
512 800 1 1312 0.145 3532.96 12.018 66.57 12.163 107.87
3200 128 1 3328 1.239 2582.01 2.445 52.36 3.684 903.34
3200 800 1 4000 1.241 2579.48 15.755 50.78 16.995 235.36
make -j && ../scripts/run-all-perf.sh codellama-7b "f16 q8_0 q4_0 q4_k" "-ngl 999 -t 1 -n 128,512 -p 512"
model size params backend ngl threads test t/s
llama 7B F16 12.55 GiB 6.74 B CUDA 999 1 pp 512 7883.90 ± 3.58
llama 7B F16 12.55 GiB 6.74 B CUDA 999 1 tg 128 77.14 ± 0.05
llama 7B F16 12.55 GiB 6.74 B CUDA 999 1 tg 512 74.74 ± 0.03
llama 7B Q8_0 6.67 GiB 6.74 B CUDA 999 1 pp 512 2316.34 ± 1.83
llama 7B Q8_0 6.67 GiB 6.74 B CUDA 999 1 tg 128 98.67 ± 0.04
llama 7B Q8_0 6.67 GiB 6.74 B CUDA 999 1 tg 512 94.81 ± 0.04
llama 7B Q4_0 3.56 GiB 6.74 B CUDA 999 1 pp 512 2389.03 ± 1.81
llama 7B Q4_0 3.56 GiB 6.74 B CUDA 999 1 tg 128 146.72 ± 0.07
llama 7B Q4_0 3.56 GiB 6.74 B CUDA 999 1 tg 512 138.33 ± 0.14
llama 7B Q4_K - Medium 3.80 GiB 6.74 B CUDA 999 1 pp 512 2034.51 ± 4.29
llama 7B Q4_K - Medium 3.80 GiB 6.74 B CUDA 999 1 tg 128 115.35 ± 0.06
llama 7B Q4_K - Medium 3.80 GiB 6.74 B CUDA 999 1 tg 512 110.23 ± 0.12

PR

model size params backend ngl threads test t/s
llama 7B F16 12.55 GiB 6.74 B CUDA 999 1 pp 512 7912.18 ± 32.14
llama 7B F16 12.55 GiB 6.74 B CUDA 999 1 tg 128 75.84 ± 0.05
llama 7B F16 12.55 GiB 6.74 B CUDA 999 1 tg 512 74.30 ± 0.02
llama 7B Q8_0 6.67 GiB 6.74 B CUDA 999 1 pp 512 5356.38 ± 15.04
llama 7B Q8_0 6.67 GiB 6.74 B CUDA 999 1 tg 128 96.57 ± 0.02
llama 7B Q8_0 6.67 GiB 6.74 B CUDA 999 1 tg 512 94.09 ± 0.14
llama 7B Q4_0 3.56 GiB 6.74 B CUDA 999 1 pp 512 5412.66 ± 8.32
llama 7B Q4_0 3.56 GiB 6.74 B CUDA 999 1 tg 128 141.65 ± 0.06
llama 7B Q4_0 3.56 GiB 6.74 B CUDA 999 1 tg 512 136.63 ± 0.62
llama 7B Q4_K - Medium 3.80 GiB 6.74 B CUDA 999 1 pp 512 6044.39 ± 24.43
llama 7B Q4_K - Medium 3.80 GiB 6.74 B CUDA 999 1 tg 128 112.27 ± 0.14
llama 7B Q4_K - Medium 3.80 GiB 6.74 B CUDA 999 1 tg 512 109.27 ± 0.02

TODO

  • Perform ppl tests to make sure I didn't break something
  • Run tests on other cards
  • Run tests on other models
  • Try to add full MMQ fallback
  • Fix BACKEND_SPLIT support for src0
  • Tune compile-time constants for other CUDA / AMD architectures

@slaren
Copy link
Collaborator

slaren commented Oct 25, 2023

cuBLAS should still be optional since it increases memory usage significantly, and reduces the number of layers that can be offloaded. What is the reason for adding a new function instead of just using the same ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_mul_mat_cublas, false); as in other cases?

@JohannesGaessler
Copy link
Collaborator

JohannesGaessler commented Oct 25, 2023

If this PR does what I think it does I very much do not agree with it. Half the motivation behind mmq was to reduce the VRAM usage by avoiding the allocation of a temporary buffer for the dequantized weight matrix. Even on relatively recent hardware where the speed is now comparatively worse this can be worthwhile if VRAM is tight.

Also as far as I can tell there is no check for CUDA architectures without tensor cores where mmq should be universally faster or for AMD GPUs (where I don't know which one is faster).

@ggerganov
Copy link
Owner Author

What is the reason for adding a new function

Likely no reason, but I wanted to eliminate the convoluted logic of ggml_cuda_op_mul_mat in the process - it was just easier to work this way. I've added a TODO to remove the duplicated call

Also as far as I can tell there is no check for CUDA architectures without tensor cores where mmq should be universally faster or for AMD GPUs (where I don't know which one is faster).

I will try to add an elegant way to accommodate old video cards and fallback to MMQ (should be easy), but for sure we are not going to hold back performance on modern hardware where there is plenty of VRAM just to be able to offload a few more layers. So this change in one way or another is going to make it to master unless there is some bug of course.

@JohannesGaessler
Copy link
Collaborator

for sure we are not going to hold back performance on modern hardware where there is plenty of VRAM just to be able to offload a few more layers. So this change in one way or another is going to make it to master unless there is some bug of course.

Let me be frank: I currently do not have the time to work on llama.cpp but I would consider this unacceptable. Even on modern hardware VRAM is a considerable bottleneck, especially on consumer products. If and when I return to working on llama.cpp this would be enough reason for me to fork the project.

@ggerganov
Copy link
Owner Author

What do you propose?

With this change:

We gain:

  • 1.5x PP speed
  • 5x Batched decoding speed needed for real-world applications and speculative decoding

We lose:

  • ~100 MB VRAM for quantum 7B models
  • ~200 MB VRAM for quantum 13B models

@JohannesGaessler
Copy link
Collaborator

Firstly I propose keeping some version of the current mmq behavior, regardless of whatever is the default. Secondly the mmq functions are templates. So I suggest compiling additional variants that are optimized for small batch sizes and switching between the two variants at runtime based on batch size.

@slaren
Copy link
Collaborator

slaren commented Oct 25, 2023

  • ~100 MB VRAM for quantum 7B models
  • ~200 MB VRAM for quantum 13B models

In practice, it is probably worse than this. In theory, we should only need enough memory to store a copy of the largest tensor in F16. When offloading the output tensor this requires at least 260MB for 7B and 327MB for 13B, excluding the activations. However, the CUDA memory pool is far from optimal and will waste a lot of memory in some cases. We could improve on this somewhat.

Ultimately, the best solution would be to store activations as F16 and implementing our own MMQ kernels using tensor cores. This would remove the need for these buffers entirely.

@JohannesGaessler
Copy link
Collaborator

Ultimately, the best solution would be to store activations as F16 and implementing our own MMQ kernels using tensor cores. This would remove the need for these buffers entirely.

I agree that long-term this would be the best solution. However, using tensor cores on their own does not seem to be sufficient. Several weeks ago I did a prototype mmq implementation using tensor cores and there was no speedup because the tensor cores were dramatically underutilized. In fact, even without the use of tensor cores the ALU using the current mmq utilization is underutilized. The problem is most likely related to mmq not utilizing asynchronous data loading (available on Ampere or newer). For an mmq implementation that outperforms FP16 cuBLAS for quantized models I think the following changes would be necessary:

  1. Implement an mmq kernel that loads the data asynchronously.
  2. Modify the mmq kernels to make the tiles thinner in direction of src0 and src1 dimension 0 through another template parameter. The tradeoff would be potentially worse data coalescing but reduced shared memory usage. With thinner tiles you should be able to increase the MMQ_X and MMQ_Y parameters which in turn should increase the arithmetic intensity. Also there would need to be special consideration for k-quants since currently the tiles are always at least as big as a single k-quant block.
  3. Utilize tensor cores since now I/O is less of a bottleneck.

Currently it looks like I will be busy until the end of December. Afterwards I wanted to start working on this unless someone else does something similar in the meantime.

@ggerganov
Copy link
Owner Author

Ultimately, the best solution would be to store activations as F16 and implementing our own MMQ kernels using tensor cores.

I agree that this is the right direction and we will do it eventually. It's just that this change is so easy (if you remove the duplicated code it's just a few lines) that I don't see any reason to not get the benefits from it.

When offloading the output tensor

I just realized the output tensor currently does not go through the new cublas GEMM branch because I use all_on_device and the dst is not on device. Will see how it affects the performance and depending on the result, we can leave it as it is.

@JohannesGaessler When I wrote earlier that I will try to accommodate older cards and fallback to MMQ, I had something similar in mind to what you propose.

@ggerganov ggerganov changed the title cuda : improve batched decoding performance for quantum models cuda : improve text-generation and batched decoding performance for quantum models Oct 26, 2023
@ggerganov
Copy link
Owner Author

Using cublasGemmBatchedEx for single-batch (i.e. text-generation mode), slightly degrades the TG speed for short sequence length, but significantly improves the performance for long sequence length (i.e. when the context is big). So I've enabled this as well when tensor cores are available.

@oobabooga
Copy link
Contributor

oobabooga commented Oct 26, 2023

I built a custom version of llama-cpp-python using this PR branch and re-did the tests in this post for llama-2-13b-Q4_K_M.gguf. These are the results using my RTX 3090:

llama.cpp Prompt Processing Time (3200 tokens, in seconds) Evaluation Time (800 tokens, in seconds) Tokens/Second VRAM (MB)
llama-cpp-python 0.2.11 3.73 25.95 30.83 8985
ggerganov PR branch 2.19 16.44 48.66 8983

This is a +58% increase in the tokens/second during evaluation, and a +70% increase in the tokens/second during processing, at no VRAM cost (I measured VRAM at load time with context length = 1 so it doesn't apply to the discussion here).

@ggerganov ggerganov changed the title cuda : improve text-generation and batched decoding performance for quantum models cuda : improve text-generation and batched decoding performance Oct 26, 2023
@JohannesGaessler
Copy link
Collaborator

JohannesGaessler commented Oct 26, 2023

Sorry, but are you sure you tested that correctly? This PR only affects batch sizes >1 so there should be no difference whatsoever when generating one token at a time.

Also, if you are measuring VRAM by what is printed to console that is not reflective of the actual VRAM usage. The print only includes the VRAM allocated initially and does not account for the additional VRAM potentially needed later when it may be possible to dequantize a weight matrix to either FP16 or FP32 (i.e. when mmq is not used). Instead something like nvidia-smi should be used to monitor the VRAM usage.

@Dampfinchen
Copy link

Dampfinchen commented Oct 26, 2023

I will give my input here. Mind you, this is before I got the time to test this PR:

With MMQ I am able to run 7B Q4K_S models with GQA like Mistral fully offloaded on my RTX 2060 with 6 GB VRAM and at 4K context. And I still have some VRAM left to spare for other applications.

With cuBLAS on master, it will slow down fast as it overflows into system ram.

So for me, the VRAM savings MMQ in master provides are essential. I hope MMQ will continue to stay VRAM efficient.

I do hope tensor cores can be used for it eventually and be super VRAM efficient at the same time. cuBLAS is noticeably faster (when the VRAM is not spilling over)

@oobabooga
Copy link
Contributor

oobabooga commented Oct 26, 2023

I did one more llama-cpp-python build with the master branch instead of the PR branch, to better separate the changes in this PR from other changes in llama.cpp over the past few weeks (the llama.cpp in llama-cpp-python 0.2.11 is a bit outdated). These are the results:

llama.cpp Prompt Processing Time (3200 tokens, in seconds) Evaluation Time (800 tokens, in seconds) Tokens/Second VRAM (MB, n_ctx = 4096) VRAM (MB, n_ctx = 4096, after generating 800 tokens with 3200 context)
llama-cpp-python 0.2.11 3.73 25.95 30.83 12215 ?
master branch 3.34 22.96 34.84 12189 12967
PR branch 2.19 16.44 48.66 12187 13095

So the bulk of the speed gains are due to this PR.

I don't see any increase in VRAM, at least not at load time, so I am not sure what I am missing. Nevermind, I found an increase of 128MB in the scenario that I added to the table above.

@JohannesGaessler all these VRAM values come from nvidia-smi.

@JohannesGaessler
Copy link
Collaborator

at load time

The VRAM increase will only happen once a weight matrix is actually dequantized, i.e. when the model is being evaluated with a large enough batch size (32 I think it was with this PR).

Also I forgot: unless someone else has worked on this the VRAM allocated for temporary buffers is not freed until the process exits which may be relevant for measurements.

@Ph0rk0z
Copy link

Ph0rk0z commented Oct 26, 2023

There are tons of inference backends with super fast ampere+ support. A big draw of llama.cpp is the wide HW compatibility, especially low end. It's the last one left with decent pascal speeds. Stuff like falcon for under $800 of GPUs is IMO more worth it than a few extra t/s on already well supported platforms.

@Tostino
Copy link

Tostino commented Oct 27, 2023

@Ph0rk0z you have one use case, but to seriously use this software in a business context, good batched performance is necessary for quite a few use cases. I do hope they maintain good compatibility though.

@Dampfinchen
Copy link

Dampfinchen commented Oct 27, 2023

I was testing this PR and I didn't notice a difference in terms of VRAM usage compared to MMQ in master with a 3600 token prompt. While being indeed a lot faster. (batch size was set to the default, so 512)

I'm very happy with it. From my point of view, this PR is safe to merge. Great job!

@JohannesGaessler
Copy link
Collaborator

JohannesGaessler commented Oct 27, 2023

To clarify my perspective: if the increase in VRAM usage is ~1% as previously suggested I am completely fine with this PR and do not think a compilation option for mmq only is necessary. However, to my understanding mmq is currently still used for the output tensor which is by far the largest and therefore requires the most VRAM. So prior to merging I would like there to be a final VRAM measurement.

Also the multi GPU performance should be checked. Currently with mmq the hidden state is converted to q8_1 prior to being distributed from the main GPU to the other GPUs. This significantly reduces latency and bandwidth and therefore improves performance. So cuBLAS may still be slower in multi GPU settings even with the presence of tensor cores. Although for batched decoding in a setting with one server and multiple clients a different parallelization scheme where the GPUs run sequentially rather than in parallel would be much more efficient anyways (it is extremely unlikely that I will implement this because it is not a use case that I care about).

@Ph0rk0z
Copy link

Ph0rk0z commented Oct 27, 2023

@Ph0rk0z you have one use case, but to seriously use this software in a business context, good batched performance is necessary for quite a few use cases. I do hope they maintain good compatibility though.

It hobbles a lot of accessible hardware that people invested money into. I'm not the only one. There are no cheaper 24G cards available. Enjoy running a lot of tiny ineffectual models really really fast I guess. The v100/A100 folks will be using vllm and TGI as they currently do. You could say stay with the old version if the format didn't change so often but that hasn't been the case. So much for good Ml being accessible rather than big business oriented.

@ggerganov ggerganov merged commit 2f9ec7e into master Oct 27, 2023
33 checks passed
@Ph0rk0z
Copy link

Ph0rk0z commented Oct 27, 2023

Nvlink isn't supported on 4090. It will use faux nvlink via PCIE though.

@JohannesGaessler
Copy link
Collaborator

Not sure - it was x2 RTX 4090 system that I spinned for a bit. Will try to do more detailed test later.

RTX 4090s to my knowledge do not have NVLink support.

But atm I'm more interested about why our text-generation speed is degraded when using multiple GPUs.

It's due to the overhead when transferring data between GPUs. Also note that when the batch size increases the computation time per token decreases but the data transfer time per token does not nearly as much. So as the batch size increases the generation becomes increasingly bottlenecked by the interconnect speed/latency. Also comparatively faster GPUs are more bottlenecked, especially when using small models.

There is no easy fix but one thing you could do is convert src1 to FP16 prior to distributing it from the main GPU similar to what I did for mmq in #3110 . Of course, if the entire model evaluation were to be done in FP16 no conversion of the hidden state would be necessary at all. Conceivably you could also convert the result calculated on a single GPU to q8_1 (or compress it in some other way) prior to writing it back to the main GPU. I would assume that the loss in quality would be negligible. You could also try changing the tile size for distributing and writing back the data from the main GPU.

@ggerganov
Copy link
Owner Author

ggerganov commented Oct 27, 2023

How much speedup could I expect from NVLink?

I posted some more results in #3814 for 2x RTX A6000 with PCIe AFAICT.
If I'm not missing something, there is an easy change as proposed to enable tensor cores with having some extra VRAM used on the main device. It does have some positive effect for multi-batch decoding.

Also, it seems TG at 7B F16 now actually benefits slightly from x2 cards compared to one. Even with just this PR (#3776). It's small, but at least it is not slower now as it was before. For quantized model there is not much difference - 2x GPUs are still slower than 1

@Ph0rk0z
Copy link

Ph0rk0z commented Oct 27, 2023

Well for this.. with my Nvlink-ed 3090s I went from 18.86 tokens/s to 17.5t/s so this pr is not a net benefit. That's for 70b. Setting the flag returns the old performance. I have yet to test on long contexts (this is only 22 tokens) though. Didn't test what happens with the P40s or mixed archs. I find when using CPP vs exllama, cpp beats it at first but then it falls short once you get up to 2-3k.

As for nvlink, it gives a gain from .5t/s to 5t/s depending on implementation. Once peer access was enabled my t/s went up and I still see people post lower speeds. For under $100 it's worth it, for more probably not. If you train, it will have larger gains there.

mattgauf added a commit to mattgauf/llama.cpp that referenced this pull request Oct 27, 2023
* master: (350 commits)
  speculative : ensure draft and target model vocab matches (ggerganov#3812)
  llama : correctly report GGUFv3 format (ggerganov#3818)
  simple : fix batch handling (ggerganov#3803)
  cuda : improve text-generation and batched decoding performance (ggerganov#3776)
  server : do not release slot on image input (ggerganov#3798)
  batched-bench : print params at start
  log : disable pid in log filenames
  server : add parameter -tb N, --threads-batch N (ggerganov#3584) (ggerganov#3768)
  server : do not block system prompt update (ggerganov#3767)
  sync : ggml (conv ops + cuda MSVC fixes) (ggerganov#3765)
  cmake : add missed dependencies (ggerganov#3763)
  cuda : add batched cuBLAS GEMM for faster attention (ggerganov#3749)
  Add more tokenizer tests (ggerganov#3742)
  metal : handle ggml_scale for n%4 != 0 (close ggerganov#3754)
  Revert "make : add optional CUDA_NATIVE_ARCH (ggerganov#2482)"
  issues : separate bug and enhancement template + no default title (ggerganov#3748)
  Update special token handling in conversion scripts for gpt2 derived tokenizers (ggerganov#3746)
  llama : remove token functions with `context` args in favor of `model` (ggerganov#3720)
  Fix baichuan convert script not detecing model (ggerganov#3739)
  make : add optional CUDA_NATIVE_ARCH (ggerganov#2482)
  ...
@LostRuins
Copy link
Collaborator

I got around to testing this PR on my 6GB RTX2060, and it's a mixed bag.

For models that I'm able to fully offload, it is indeed an improvement in speed - for 7B Q4_K_S I am able to get a decent boost in prompt processing speed. However, when testing a 13B Q4_K_M model, I must now offload 1 layer less than before - and speeds are slightly slower.

Philosophically, I do agree with the idea that llama.cpp should cater more towards hobbyist hardware like my crappy 6GB VRAM card. It fills an excellent niche for the home user with old/inferior cards, since all the alternatives are unfeasible. Supporting modern hardware is good, yes, but vllm & TGI already cater to high end cards, llama.cpp should play to it's unique strengths.

@ggerganov
Copy link
Owner Author

@LostRuins

The change from this PR combined with the idea from #3457 will most certainly improve the performance with 1 layer less compared to what was on master + #3457. The reason is that now the tensor cores are utilized for the F16 KV cache ops which leads to significant performance gain as seen in the oobabooga test, specifically for large contexts. However, with the current implementation, the KV ops reach the GPU only when you have offloaded all layers. With #3457 you will start seeing benefits even for fewer layers.

So long term, I'm pretty sure that users with old/inferior cards will be very happy by this change.

@Dampfinchen
Copy link

I got around to testing this PR on my 6GB RTX2060, and it's a mixed bag.

For models that I'm able to fully offload, it is indeed an improvement in speed - for 7B Q4_K_S I am able to get a decent boost in prompt processing speed. However, when testing a 13B Q4_K_M model, I must now offload 1 layer less than before - and speeds are slightly slower.

Philosophically, I do agree with the idea that llama.cpp should cater more towards hobbyist hardware like my crappy 6GB VRAM card. It fills an excellent niche for the home user with old/inferior cards, since all the alternatives are unfeasible. Supporting modern hardware is good, yes, but vllm & TGI already cater to high end cards, llama.cpp should play to it's unique strengths.

Don't use Q4_K_M, use K_S instead. The perplexity is almost the same and the speed is noticeably better.

@cebtenzzre
Copy link
Collaborator

Token generation results on my Tesla P40:

GPU Model Test t/s b1429 t/s b1430 (PR) Speedup
P40 7b q4_0 tg128 60.54 35.43 .59
P40 13b q4_k_s tg128 31.38 20.12 .64

@ggerganov How can this change be adapted to not cripple cards that don't have tensor cores?

@ggerganov
Copy link
Owner Author

Building with GGML_CUDA_FORCE_MMQ flag should restore the old performance for these cards

@Dampfinchen
Copy link

Dampfinchen commented Oct 31, 2023

So I'm not sure this PR is at fault, but given how much has changed it likely is the culprit.

Previously, I was just testing fully offloaded 7b models on my RTX 2060. But running 13B with 25 layers offloaded, generation speed is absolutely atrocious, while prompt processing is performing as expected.


llama_print_timings:      sample time =      51.63 ms /   180 runs   (    0.29 ms per token,  3486.48 tokens per second)
llama_print_timings: prompt eval time =   18092.37 ms /  1849 tokens (    9.78 ms per token,   102.20 tokens per second)
llama_print_timings:        eval time =  245749.79 ms /   179 runs   ( 1372.90 ms per token,     0.73 tokens per second)
llama_print_timings:       total time =  263988.36 ms

Normally I would get a generation speed of around 250 ms per token, not 1300. There is also no RAM swapping involved (I've disabled this behavior with the new driver and there's enough VRAM left anyway).

@Dampfinchen
Copy link

Dampfinchen commented Nov 1, 2023

So I'm not sure this PR is at fault, but given how much has changed it likely is the culprit.

Previously, I was just testing fully offloaded 7b models on my RTX 2060. But running 13B with 25 layers offloaded, generation speed is absolutely atrocious, while prompt processing is performing as expected.


llama_print_timings:      sample time =      51.63 ms /   180 runs   (    0.29 ms per token,  3486.48 tokens per second)
llama_print_timings: prompt eval time =   18092.37 ms /  1849 tokens (    9.78 ms per token,   102.20 tokens per second)
llama_print_timings:        eval time =  245749.79 ms /   179 runs   ( 1372.90 ms per token,     0.73 tokens per second)
llama_print_timings:       total time =  263988.36 ms

Normally I would get a generation speed of around 250 ms per token, not 1300. There is also no RAM swapping involved (I've disabled this behavior with the new driver and there's enough VRAM left anyway).

@slaren @ggerganov Can you please test if you can reproduce the significant slowdown using partial offloading? Generation speed using 25 layers on a 13B model is 5 x slower in builds after this PR compared to ones before this PR. And my GPU has tensor cores. This is related to #3860

All testing in this PR was done using full GPU offloading (ngl=999), so it might be possible this slipped under the radar. Please always test partial offloading as well.

@slaren
Copy link
Collaborator

slaren commented Nov 1, 2023

@Dampfinchen what model are you using, quant, etc?

@Dampfinchen
Copy link

@Dampfinchen what model are you using, quant, etc?

This one. https://huggingface.co/TheBloke/Mythical-Destroyer-V2-L2-13B-GGUF/tree/main Q4_K_S.

@Dampfinchen
Copy link

Dampfinchen commented Nov 2, 2023

@slaren Since LostRuin's hardware, which is the same as mine, is unaffected, I suspect there might be an incompatibility with the latest driver 546.01 and partial offloading at play.

While previously I mentioned I didn't had that issue in earlier builds, do notice I meant koboldcpp with that which is built on llama.cpp. However, after testing various llama.cpp builds, old and new I can confirm the issue is not due to tensor core support or batched CUDA processing nor cublas in general as it happens with FORCE_MMQ as well.

Could you perhaps test the latest driver 546.01 in combination with partial offloading? Even if I'm using partial offloading with a 7B model using 28 layers, it's much, much slower than expected.

However, full GPU offloading performs as expected.

@slaren
Copy link
Collaborator

slaren commented Nov 2, 2023

I am already using the driver 546.01, but I can't reproduce this under WSL. If you are building it yourself, are you sure that you are building with AVX? It is not done by default anymore. Otherwise, can you bisect exactly the commit that caused your issue?

@Dampfinchen
Copy link

Dampfinchen commented Nov 2, 2023

I am already using the driver 546.01, but I can't reproduce this under WSL. If you are building it yourself, are you sure that you are building with AVX? It is not done by default anymore. Otherwise, can you bisect exactly the commit that caused your issue?

@slaren You are right, I did not compile it with AVX2 support. I did not notice it was compiling without it anymore, I just ticked cublas and nothing else.

I've now compiled with AVX2 and speed is exactly as expected again. Thank you!

olexiyb pushed a commit to Sanctum-AI/llama.cpp that referenced this pull request Nov 23, 2023
…ganov#3776)

* cuda : prints wip

* cuda : new cublas gemm branch for multi-batch quantized src0

* cuda : add F32 sgemm branch

* cuda : fine-tune >= VOLTA params + use MMQ only for small batches

* cuda : remove duplicated cuBLAS GEMM code

* cuda : add CUDA_USE_TENSOR_CORES and GGML_CUDA_FORCE_MMQ macros

* build : add compile option to force use of MMQ kernels
cebtenzzre added a commit to cebtenzzre/llama.cpp that referenced this pull request Nov 27, 2023
…ce (ggerganov#3776)"

This commit introduces a performance regression on my Tesla P40.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

9 participants