Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

batched : add bench tool #3545

Merged
merged 7 commits into from
Oct 11, 2023
Merged

batched : add bench tool #3545

merged 7 commits into from
Oct 11, 2023

Conversation

ggerganov
Copy link
Owner

@ggerganov ggerganov commented Oct 8, 2023

Inspired by this blog post, implemented a tool to generate similar stats for llama.cpp

  • PP - prompt tokens per batch
  • TG - generated tokens per batch
  • B - number of batches
  • N_KV - required KV cache size
  • T_PP - prompt processing time (i.e. time to first token)
  • S_PP - prompt processing speed ((B*PP)/T_PP or PP/T_PP)
  • T_TG - time to generate all batches
  • S_TG - text generation speed ((B*TG)/T_TG)
  • T - total time
  • S - total speed (i.e. all tokens / total time)

There are 2 modes of operation:

  • prompt not shared - each batch has a separate prompt of size PP (i.e. N_KV = B*(PP + TG))
  • prompt is shared - there is a common prompt of size PP used by all batches (i.e. N_KV = PP + B*TG)
LLaMA 7B, F16, N_KV_MAX = 16384 (8GB), M2 Ultra, prompt not shared
./bin/batched-bench ../models/llama-7b/ggml-model-f16.gguf 0 99
PP TG B N_KV T_PP s S_PP t/s T_TG s S_TG t/s T s S t/s
128 128 1 256 0.108 1186.64 3.079 41.57 3.187 80.32
128 128 2 512 0.198 1295.19 5.029 50.90 5.227 97.95
128 128 4 1024 0.373 1373.96 6.878 74.44 7.251 141.23
128 128 8 2048 0.751 1363.27 7.344 139.43 8.095 252.99
128 128 16 4096 1.570 1304.68 8.455 242.23 10.024 408.60
128 128 32 8192 3.408 1201.73 8.801 465.40 12.209 670.96
128 256 1 384 0.107 1196.70 6.329 40.45 6.436 59.67
128 256 2 768 0.194 1317.45 10.239 50.00 10.433 73.61
128 256 4 1536 0.366 1399.03 13.960 73.35 14.326 107.22
128 256 8 3072 0.751 1363.92 15.110 135.54 15.861 193.69
128 256 16 6144 1.569 1304.93 18.073 226.64 19.642 312.80
128 256 32 12288 3.409 1201.35 19.223 426.15 22.633 542.93
128 512 1 640 0.107 1200.71 12.784 40.05 12.891 49.65
128 512 2 1280 0.194 1317.47 20.789 49.26 20.984 61.00
128 512 4 2560 0.366 1400.18 28.471 71.93 28.837 88.78
128 512 8 5120 0.751 1363.92 31.929 128.28 32.680 156.67
128 512 16 10240 1.570 1304.49 41.024 199.69 42.594 240.41
256 128 1 384 0.195 1315.57 3.172 40.36 3.366 114.07
256 128 2 768 0.366 1399.75 5.170 49.51 5.536 138.73
256 128 4 1536 0.751 1363.24 7.054 72.59 7.805 196.81
256 128 8 3072 1.570 1304.84 7.751 132.11 9.321 329.59
256 128 16 6144 3.408 1201.89 9.618 212.92 13.026 471.65
256 128 32 12288 8.059 1016.56 10.422 393.01 18.481 664.91
256 256 1 512 0.195 1312.30 6.367 40.21 6.562 78.02
256 256 2 1024 0.365 1402.23 10.408 49.19 10.773 95.05
256 256 4 2048 0.750 1364.93 14.236 71.93 14.986 136.66
256 256 8 4096 1.569 1305.32 15.946 128.43 17.515 233.86
256 256 16 8192 3.410 1201.18 20.458 200.21 23.868 343.22
256 256 32 16384 8.064 1015.85 22.504 364.02 30.569 535.97
256 512 1 768 0.195 1313.84 12.747 40.17 12.942 59.34
256 512 2 1536 0.365 1401.96 21.045 48.66 21.410 71.74
256 512 4 3072 0.751 1362.82 29.032 70.54 29.784 103.14
256 512 8 6144 1.572 1302.94 33.604 121.89 35.176 174.66
256 512 16 12288 3.410 1201.02 46.006 178.06 49.416 248.66
512 128 1 640 0.366 1398.68 3.231 39.61 3.597 177.92
512 128 2 1280 0.751 1363.34 5.287 48.42 6.038 211.99
512 128 4 2560 1.570 1304.28 7.330 69.85 8.901 287.62
512 128 8 5120 3.409 1201.46 8.616 118.85 12.025 425.78
512 128 16 10240 8.058 1016.58 12.104 169.20 20.163 507.87
512 256 1 768 0.366 1399.96 6.530 39.21 6.895 111.38
512 256 2 1536 0.751 1363.10 10.707 47.82 11.458 134.05
512 256 4 3072 1.569 1305.34 14.808 69.15 16.377 187.58
512 256 8 6144 3.408 1201.79 17.650 116.03 21.058 291.76
512 256 16 12288 8.064 1015.92 25.537 160.40 33.600 365.71
512 512 1 1024 0.366 1399.14 13.204 38.78 13.570 75.46
512 512 2 2048 0.751 1363.61 21.671 47.25 22.421 91.34
512 512 4 4096 1.570 1304.64 30.222 67.77 31.792 128.84
512 512 8 8192 3.405 1203.05 36.993 110.72 40.398 202.78
512 512 16 16384 8.062 1016.14 56.692 144.50 64.754 253.02
1024 128 1 1152 0.752 1362.53 3.370 37.99 4.121 279.53
1024 128 2 2304 1.570 1304.82 5.613 45.61 7.183 320.77
1024 128 4 4608 3.409 1201.63 7.933 64.54 11.342 406.28
1024 128 8 9216 8.063 1015.94 10.336 99.08 18.399 500.89
1024 256 1 1280 0.751 1363.25 6.805 37.62 7.557 169.39
1024 256 2 2560 1.570 1304.79 11.307 45.28 12.877 198.80
1024 256 4 5120 3.408 1201.91 16.052 63.79 19.460 263.11
1024 256 8 10240 8.062 1016.17 21.167 96.75 29.229 350.34
1024 512 1 1536 0.751 1364.36 13.753 37.23 14.504 105.90
1024 512 2 3072 1.569 1305.08 22.975 44.57 24.544 125.16
1024 512 4 6144 3.409 1201.42 32.709 62.61 36.118 170.11
1024 512 8 12288 8.064 1015.89 44.231 92.60 52.295 234.97
2048 128 1 2176 1.570 1304.68 3.672 34.85 5.242 415.10
2048 128 2 4352 3.409 1201.67 6.247 40.98 9.656 450.71
2048 128 4 8704 8.063 1016.05 9.129 56.09 17.191 506.30
2048 256 1 2304 1.569 1305.56 7.393 34.63 8.962 257.09
2048 256 2 4608 3.410 1201.19 12.605 40.62 16.015 287.73
2048 256 4 9216 8.064 1015.81 18.422 55.59 26.486 347.95
2048 512 1 2560 1.569 1305.22 14.957 34.23 16.526 154.91
2048 512 2 5120 3.409 1201.61 25.513 40.14 28.922 177.03
2048 512 4 10240 8.063 1015.94 37.486 54.63 45.550 224.81
3584 128 1 3712 2.922 1226.65 4.085 31.33 7.007 529.77
3584 128 2 7424 6.761 1060.22 7.235 35.38 13.996 530.43
3584 128 4 14848 18.082 792.81 11.162 45.87 29.244 507.73
LLaMA 7B, F16, N_KV_MAX = 16384 (8GB), M2 Ultra, prompt is shared
./bin/batched-bench ../models/llama-7b/ggml-model-f16.gguf 1 99
PP TG B N_KV T_PP s S_PP t/s T_TG s S_TG t/s T s S t/s
128 128 1 256 0.108 1181.98 3.088 41.46 3.196 80.10
128 128 2 384 0.106 1204.21 5.008 51.12 5.115 75.08
128 128 4 640 0.107 1200.32 6.759 75.75 6.866 93.22
128 128 8 1152 0.107 1201.12 6.949 147.36 7.056 163.27
128 128 16 2176 0.106 1204.66 7.391 277.10 7.497 290.25
128 128 32 4224 0.107 1200.86 7.367 556.03 7.473 565.22
128 256 1 384 0.107 1200.66 6.226 41.11 6.333 60.63
128 256 2 640 0.107 1199.51 10.056 50.92 10.162 62.98
128 256 4 1152 0.107 1201.43 13.672 74.90 13.779 83.61
128 256 8 2176 0.106 1204.99 14.332 142.90 14.438 150.71
128 256 16 4224 0.107 1201.31 15.940 256.96 16.046 263.24
128 256 32 8320 0.107 1198.13 16.178 506.38 16.284 510.92
256 128 1 384 0.197 1300.92 3.113 41.12 3.310 116.02
256 128 2 512 0.195 1315.41 5.027 50.92 5.222 98.05
256 128 4 768 0.194 1319.74 6.799 75.31 6.993 109.82
256 128 8 1280 0.194 1317.29 7.011 146.05 7.206 177.64
256 128 16 2304 0.195 1315.78 7.471 274.12 7.666 300.56
256 128 32 4352 0.195 1314.05 7.423 551.80 7.618 571.29
256 256 1 512 0.194 1316.38 6.255 40.93 6.450 79.39
256 256 2 768 0.194 1318.78 10.115 50.62 10.309 74.50
256 256 4 1280 0.194 1319.14 13.757 74.43 13.951 91.75
256 256 8 2304 0.194 1319.89 14.445 141.78 14.639 157.39
256 256 16 4352 0.195 1314.28 16.142 253.74 16.337 266.39
256 256 32 8448 0.195 1311.93 16.309 502.29 16.504 511.87
512 128 1 640 0.371 1380.58 3.219 39.77 3.590 178.29
512 128 2 768 0.365 1401.37 5.135 49.85 5.501 139.62
512 128 4 1024 0.366 1398.62 6.868 74.55 7.234 141.55
512 128 8 1536 0.366 1400.81 7.109 144.04 7.475 205.49
512 128 16 2560 0.365 1401.21 7.607 269.21 7.973 321.09
512 128 32 4608 0.366 1399.47 7.536 543.50 7.902 583.13
512 256 1 768 0.366 1398.01 6.470 39.57 6.836 112.34
512 256 2 1024 0.366 1400.50 10.351 49.46 10.717 95.55
512 256 4 1536 0.365 1402.32 13.920 73.57 14.285 107.53
512 256 8 2560 0.365 1401.19 14.666 139.64 15.031 170.31
512 256 16 4608 0.365 1401.81 16.410 249.61 16.775 274.70
512 256 32 8704 0.366 1397.87 16.473 497.30 16.839 516.89
1024 128 1 1152 0.752 1361.28 3.366 38.03 4.118 279.75
1024 128 2 1280 0.750 1365.23 5.290 48.40 6.040 211.93
1024 128 4 1536 0.751 1364.32 7.035 72.78 7.786 197.29
1024 128 8 2048 0.751 1362.96 7.333 139.63 8.085 253.32
1024 128 16 3072 0.751 1362.91 7.895 259.41 8.646 355.30
1024 128 32 5120 0.753 1359.82 7.703 531.71 8.456 605.45
1024 256 1 1280 0.751 1363.21 6.650 38.50 7.401 172.94
1024 256 2 1536 0.750 1364.62 10.545 48.55 11.296 135.98
1024 256 4 2048 0.751 1364.22 14.213 72.05 14.963 136.87
1024 256 8 3072 0.751 1364.29 15.062 135.97 15.813 194.27
1024 256 16 5120 0.751 1362.66 16.941 241.77 17.693 289.38
1024 256 32 9216 0.753 1360.55 16.865 485.75 17.617 523.12
2048 128 1 2176 1.570 1304.40 3.644 35.13 5.214 417.33
2048 128 2 2304 1.570 1304.49 5.616 45.59 7.186 320.64
2048 128 4 2560 1.570 1304.68 7.324 69.91 8.893 287.86
2048 128 8 3072 1.569 1304.93 7.742 132.27 9.311 329.92
2048 128 16 4096 1.571 1303.90 8.462 242.01 10.033 408.25
2048 128 32 6144 1.570 1304.21 8.053 508.62 9.623 638.44
2048 256 1 2304 1.572 1303.18 7.353 34.82 8.924 258.17
2048 256 2 2560 1.569 1305.67 11.295 45.33 12.863 199.01
2048 256 4 3072 1.570 1304.31 14.793 69.22 16.363 187.74
2048 256 8 4096 1.568 1305.74 15.917 128.67 17.485 234.26
2048 256 16 6144 1.570 1304.66 18.095 226.36 19.665 312.44
2048 256 32 10240 1.572 1302.62 17.650 464.14 19.222 532.73
3584 128 1 3712 2.923 1226.22 4.076 31.40 6.999 530.38
3584 128 2 3840 2.920 1227.51 6.083 42.08 9.003 426.52
3584 128 4 4096 2.923 1226.16 7.773 65.87 10.696 382.94
3584 128 8 4608 2.921 1226.93 8.393 122.01 11.314 407.28
3584 128 16 5632 2.924 1225.76 9.346 219.14 12.269 459.03
3584 128 32 7680 2.925 1225.14 8.615 475.48 11.540 665.52
3584 256 1 3840 2.924 1225.52 8.212 31.17 11.137 344.80
3584 256 2 4096 2.921 1226.88 12.253 41.79 15.174 269.94
3584 256 4 4608 2.922 1226.65 15.715 65.16 18.636 247.26
3584 256 8 5632 2.922 1226.63 17.215 118.96 20.137 279.68
3584 256 16 7680 2.924 1225.91 19.900 205.83 22.824 336.49
3584 256 32 11776 2.928 1224.16 18.860 434.36 21.787 540.49
7680 128 1 7808 7.396 1038.35 5.255 24.36 12.652 617.16
7680 128 2 7936 7.395 1038.52 7.364 34.76 14.759 537.69
7680 128 4 8192 7.395 1038.61 8.974 57.06 16.368 500.49
7680 128 8 8704 7.399 1037.93 10.118 101.21 17.517 496.88
7680 128 16 9728 7.396 1038.43 11.843 172.92 19.239 505.63
7680 128 32 11776 7.399 1038.03 10.314 397.15 17.712 664.86
LLaMA 7B, Q8_0, N_KV_MAX = 16384 (8GB), M2 Ultra, prompt not shared
./bin/batched-bench ../models/llama-7b/ggml-model-q8_0.gguf 0 99
PP TG B N_KV T_PP s S_PP t/s T_TG s S_TG t/s T s S t/s
128 128 1 256 0.122 1048.05 1.935 66.14 2.058 124.42
128 128 2 512 0.222 1151.63 2.628 97.42 2.850 179.64
128 128 4 1024 0.417 1229.06 4.122 124.21 4.539 225.62
128 128 8 2048 0.845 1211.86 7.901 129.60 8.746 234.16
128 128 16 4096 1.757 1165.75 8.977 228.14 10.734 381.61
128 128 32 8192 3.779 1083.88 9.233 443.64 13.012 629.59
128 256 1 384 0.121 1053.98 3.902 65.60 4.024 95.43
128 256 2 768 0.220 1161.39 5.338 95.91 5.559 138.16
128 256 4 1536 0.412 1241.96 8.422 121.58 8.834 173.86
128 256 8 3072 0.843 1214.59 16.208 126.36 17.051 180.16
128 256 16 6144 1.756 1166.48 19.106 214.38 20.862 294.51
128 256 32 12288 3.778 1084.23 20.149 406.58 23.926 513.58
256 128 1 384 0.221 1158.98 1.966 65.09 2.187 175.56
256 128 2 768 0.411 1245.02 2.706 94.59 3.118 246.34
256 128 4 1536 0.845 1211.79 4.288 119.41 5.133 299.25
256 128 8 3072 1.757 1165.76 8.252 124.09 10.009 306.92
256 128 16 6144 3.777 1084.58 10.141 201.95 13.918 441.45
256 128 32 12288 8.789 932.05 10.907 375.53 19.696 623.87
256 256 1 512 0.221 1159.34 3.965 64.56 4.186 122.31
256 256 2 1024 0.410 1247.64 5.487 93.32 5.897 173.65
256 256 4 2048 0.841 1218.23 8.707 117.60 9.548 214.50
256 256 8 4096 1.751 1169.72 16.797 121.93 18.548 220.83
256 256 16 8192 3.770 1086.45 21.527 190.27 25.297 323.83
256 256 32 16384 8.792 931.76 23.416 349.85 32.208 508.70
512 128 1 640 0.412 1242.45 2.031 63.01 2.444 261.92
512 128 2 1280 0.841 1217.54 2.854 89.70 3.695 346.41
512 128 4 2560 1.748 1171.42 4.572 111.98 6.321 405.02
512 128 8 5120 3.769 1086.82 9.092 112.62 12.861 398.10
512 128 16 10240 8.791 931.86 12.679 161.53 21.470 476.95
512 256 1 768 0.412 1241.44 4.101 62.42 4.513 170.16
512 256 2 1536 0.844 1213.23 5.777 88.63 6.621 232.00
512 256 4 3072 1.760 1163.84 9.281 110.34 11.040 278.25
512 256 8 6144 3.779 1083.89 18.681 109.63 22.460 273.56
512 256 16 12288 8.794 931.53 26.663 153.62 35.457 346.56
1024 128 1 1152 0.845 1212.53 2.176 58.82 3.021 381.38
1024 128 2 2304 1.756 1166.32 3.179 80.54 4.935 466.91
1024 128 4 4608 3.780 1083.64 5.181 98.82 8.961 514.24
1024 128 8 9216 8.794 931.60 10.867 94.23 19.661 468.75
1024 256 1 1280 0.845 1212.34 4.381 58.43 5.226 244.95
1024 256 2 2560 1.757 1165.53 6.445 79.45 8.202 312.13
1024 256 4 5120 3.777 1084.35 10.531 97.24 14.308 357.83
1024 256 8 10240 8.795 931.44 22.240 92.09 31.035 329.95
2048 128 1 2176 1.750 1170.19 2.471 51.80 4.221 515.51
2048 128 2 4352 3.769 1086.73 3.836 66.74 7.605 572.27
2048 128 4 8704 8.793 931.66 6.398 80.03 15.191 572.99
2048 256 1 2304 1.757 1165.65 4.998 51.22 6.755 341.10
2048 256 2 4608 3.780 1083.61 7.762 65.96 11.542 399.24
2048 256 4 9216 8.783 932.67 12.961 79.01 21.744 423.83
3584 128 1 3712 3.246 1104.05 2.919 43.85 6.165 602.11
3584 128 2 7424 7.402 968.35 4.820 53.11 12.222 607.41
3584 128 4 14848 19.331 741.60 8.382 61.08 27.714 535.76
3584 256 1 3840 3.248 1103.34 5.871 43.60 9.120 421.07
3584 256 2 7680 7.400 968.68 9.752 52.50 17.152 447.77
3584 256 4 15360 19.322 741.94 16.915 60.54 36.237 423.87
7680 128 1 7808 8.080 950.54 4.073 31.43 12.152 642.51
7680 128 2 15616 21.488 714.81 7.513 34.08 29.001 538.46
7680 256 1 7936 8.092 949.07 8.208 31.19 16.300 486.86
7680 256 2 15872 21.496 714.57 15.091 33.93 36.587 433.82
# LLaMA 7B F16, V100, no prompt sharing
./bin/batched-bench /mnt/llama.cpp/models/open-llama/7B-v2/ggml-model-f16.gguf 4800 0 100 0 50 100 1,2,3,4,5,6,7,8,16,32,64
PP TG B N_KV T_PP s S_PP t/s T_TG s S_TG t/s T s S t/s
50 100 1 150 0.078 638.86 1.916 52.18 1.995 75.20
50 100 2 300 0.084 1193.52 4.121 48.53 4.205 71.34
50 100 3 450 0.094 1602.63 4.345 69.04 4.439 101.37
50 100 4 600 0.109 1842.13 4.426 90.38 4.534 132.33
50 100 5 750 0.115 2181.50 4.753 105.20 4.868 154.08
50 100 6 900 0.141 2122.00 4.983 120.42 5.124 175.64
50 100 7 1050 0.152 2301.19 5.217 134.18 5.369 195.57
50 100 8 1200 0.178 2250.71 5.340 149.81 5.518 217.48
50 100 16 2400 0.355 2252.53 6.471 247.25 6.826 351.58
50 100 32 4800 0.763 2097.53 9.137 350.22 9.900 484.85

@ggerganov ggerganov added the need feedback Testing and feedback with results are needed label Oct 9, 2023
@ggerganov
Copy link
Owner Author

ggerganov commented Oct 11, 2023

Played a bit with the MMQ kernel parameters on V100 (ref #3479)

./bin/batched-bench /mnt/llama.cpp/models/open-llama/7B-v2/ggml-model-q4_0.gguf 4800 0 100 1 50 100 1,2,3,4,5,6,7,8,16,32,64

The default settings give the following performance:

PP TG B N_KV T_PP s S_PP t/s T_TG s S_TG t/s T s S t/s
50 100 1 150 0.128 389.59 0.894 111.91 1.022 146.78
50 100 2 300 0.139 719.99 10.978 18.22 11.117 26.99
50 100 3 450 0.184 813.15 11.233 26.71 11.417 39.41
50 100 4 600 0.198 1011.63 11.358 35.22 11.556 51.92
50 100 5 750 0.206 1215.62 11.675 42.83 11.881 63.13
50 100 6 900 0.219 1370.03 11.853 50.62 12.072 74.55
50 100 7 1050 0.289 1209.67 12.051 58.09 12.340 85.09
50 100 8 1200 0.320 1249.17 12.222 65.45 12.543 95.67
50 100 16 2400 0.647 1236.23 13.319 120.13 13.966 171.85
50 100 32 4800 1.433 1116.65 15.521 206.18 16.954 283.13

Applying the following patch results in > x3 faster TG speed, but slower PP speed for Q4_0 at low-batches:

diff --git a/ggml-cuda.cu b/ggml-cuda.cu
index 654d363..32eee8b 100644
--- a/ggml-cuda.cu
+++ b/ggml-cuda.cu
@@ -3552,9 +3552,9 @@ static __device__ __forceinline__ void mul_mat_q(
 #define  MMQ_X_Q4_0_RDNA1  64
 #define  MMQ_Y_Q4_0_RDNA1  64
 #define NWARPS_Q4_0_RDNA1  8
-#define  MMQ_X_Q4_0_AMPERE 64
-#define  MMQ_Y_Q4_0_AMPERE 128
-#define NWARPS_Q4_0_AMPERE 4
+#define  MMQ_X_Q4_0_AMPERE 8
+#define  MMQ_Y_Q4_0_AMPERE 32
+#define NWARPS_Q4_0_AMPERE 8
 #define  MMQ_X_Q4_0_PASCAL 64
 #define  MMQ_Y_Q4_0_PASCAL 64
 #define NWARPS_Q4_0_PASCAL 8
PP TG B N_KV T_PP s S_PP t/s T_TG s S_TG t/s T s S t/s
50 100 1 150 0.103 485.62 0.883 113.28 0.986 152.17
50 100 2 300 0.158 632.24 3.439 58.16 3.597 83.40
50 100 3 450 0.218 688.58 3.666 81.83 3.884 115.85
50 100 4 600 0.270 741.20 3.778 105.86 4.048 148.21
50 100 5 750 0.332 751.93 4.099 121.98 4.432 169.24
50 100 6 900 0.394 761.96 4.255 141.00 4.649 193.58
50 100 7 1050 0.456 767.10 4.491 155.86 4.947 212.23
50 100 8 1200 0.517 773.18 4.662 171.60 5.179 231.69
50 100 16 2400 1.044 766.05 6.621 241.65 7.666 313.09
50 100 32 4800 2.146 745.71 10.419 307.15 12.564 382.04

It looks like just by tuning the MMQ constants, we can achieve significant boosts for batched decoding with quantized models, using the existing kernels. Ideally, this should work optimally without modifications, but it is not obvious how to achieve this in generic way. Similar observations are valid for Mac (#3524)

cc @slaren @JohannesGaessler to bring your attention

@JohannesGaessler
Copy link
Collaborator

For the mmq defines the "x" dimension is the dimension that varies with batch size. So the optimal way to tune the kernels would be to determine the optimal values for the y tile size and the number of warps for each x tile size. In practice it would be sufficient to determine only the optimal values for the x tile size being a power of 2. Then at runtime you could just select the kernel with the optimal parameters for a given batch size. The downside of this (apart from the effort needed to determine the optimal values, possibly multiple times) is that the compile time and binary size will increase with each additional batch size considered. So it may make sense to add a compile option that only compiles a single tile size to speed up the compilation for development.

@ggerganov
Copy link
Owner Author

ggerganov commented Oct 11, 2023

Yup, it's quite tricky. The optimal values likely depend not only on the batch size but also on the other dimension. So for different model sizes, we would need different sets of optimal parameters. At least this is what my investigation for the Metal kernels showed.

The way I think about this is that we have 2 sets of kernels:

  • first is optimized for memory-bandwidth bound single-batch multiplications (matrix-vector)
  • second is optimized for compute bound large-batch multiplications (matrix-matrix)

Increasing the batch size, we are transitioning from the first kernel being optimal to the second kernel being optimal. On top of this, the second kernel needs adjustments depending on the compute size.

In the Metal implementation, we have a more general matrix-vector kernel that can be applied for batches > 1.
I can always tune manually the transition from the first kernel to the second for a given model as a function of the batch size so that the performance scales with the batches without having the performance dip at n_batch == 2 as we do for CUDA (i.e. at some "break-even" batch size, the 2 kernels perform equally well). Probably if we extended the CUDA matrix-vector in a similar way, we could do the same thing.

However, I wish we had a way to not have to do any kind of manual adjustments and always get the best performance for any batch size and model size. But it is not clear to me atm how to do this

@JohannesGaessler
Copy link
Collaborator

The optimal values likely depend not only on the batch size but also on the other dimension. So for different model sizes, we would need different sets of optimal parameters. At least this is what my investigation for the Metal kernels showed.

I don't know what you did for the metal kernels but for CUDA the optimal values for the y tile size and the number of warps should not depend on model size. As long as the weight matrix has a number of rows divisible by the y tile size the only difference should be the size of the CUDA grid, i.e. how many tiles need to be worked on. At most you should get a small effect if the grid size is too small and leaves some of the streaming multiprocessors idle but I very much do not expect this to make much of a difference if at all.

However, I wish we had a way to not have to do any kind of manual adjustments and always get the best performance for any batch size and model size. But it is not clear to me atm how to do this

The CUDA tile sizes need to be known at compile time for good performance so it is fundamentally impossible to somehow readjust the kernels at runtime. At most I think we could write a script that repeatedly compiles llama.cpp with different tile sizes and benchmarks the performance.

@ggerganov
Copy link
Owner Author

ggerganov commented Oct 11, 2023

What I meant is that break-even point where the 2 kernels become equally performant depends on the model size.

To illustrate, try to put any kind of tile sizes to the matrix-matrix kernel and it's TG speed at batch == 2 will always be slower than the TG speed of the single-batch matrix-vector kernel at batch == 1. I.e. below a certain batch size, the mv kernel will dominate over the mm kernel. And this break-even batch size I think depends on the model size.

Anyway, still thinking about this and sharing a few thoughts - I'm not 100% about those conclusions yet. Might try the script idea at some point and provide a way to pass configuration to the backends to use a certain set of kernels based on the config.

@ggerganov ggerganov merged commit 8c70a5f into master Oct 11, 2023
joelkuiper added a commit to vortext/llama.cpp that referenced this pull request Oct 12, 2023
…example

* 'master' of github.com:ggerganov/llama.cpp: (34 commits)
  examples: support LLaVA v1.5 (multimodal model) (ggerganov#3436)
  docs : fix typo GOMP_CPU_AFFINITY (ggerganov#3597)
  cmake : fix add_compile_options on macOS
  typo : it is `--n-gpu-layers` not `--gpu-layers` (ggerganov#3592)
  ci : check if there is enough VRAM (ggerganov#3596)
  server : add completion mode (no chat) (ggerganov#3582)
  prompts : add mnemonics.txt
  server : fix kv cache management (ggerganov#3588)
  main : fix session loading bug (ggerganov#3400)
  server : add parameter -tb N, --threads-batch N (ggerganov#3584)
  common : fix mirostat state when using multiple sequences (ggerganov#3543)
  batched : add bench tool (ggerganov#3545)
  examples : add batched.swift + improve CI for swift (ggerganov#3562)
  Add MPT model to supported models in README.md (ggerganov#3574)
  Minor improvements in GPT2 tokenizer (ggerganov#3567)
  readme : add bloom (ggerganov#3570)
  llm : add bloom models (ggerganov#3553)
  swift : improvements and fixes (ggerganov#3564)
  llm : add MPT support (ggerganov#3417)
  infill. : fix tokenization (ggerganov#3508)
  ...
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
need feedback Testing and feedback with results are needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants