Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

llama : ggml-backend integration #4766

Merged
merged 39 commits into from
Jan 12, 2024
Merged

llama : ggml-backend integration #4766

merged 39 commits into from
Jan 12, 2024

Conversation

slaren
Copy link
Collaborator

@slaren slaren commented Jan 4, 2024

This PR completes the integration of llama.cpp with ggml-backend. The main change is that partial offloading is handled through ggml_backend_sched, and it is supported with all the backends. This also implements multi-GPU support by offloading different layers to different GPUs, instead of splitting matrices by rows. While this does not achieve the same level of parallelism as row-level splitting, it requires less synchronization and it may be more efficient in some cases.

There is still some work to do, but layer-level multi-GPU with CUDA should already work. For me with WSL, this is significantly faster than the current row-level splitting, so I am opening this now in case anyone wants to give it a try. -ts can be used as usual to configure the fraction of layers to offload to each GPU.

TODO

  • Improve graph splitting logic in ggml_backend_sched
  • Fix --no-kv-offload
  • CUDA split buffers
  • OpenCL
  • LoRA
  • Buffer type names
  • Session saving

Fixes #4055

@ggerganov ggerganov added high priority Very important issue need feedback Testing and feedback with results are needed labels Jan 4, 2024
Copy link
Collaborator

@JohannesGaessler JohannesGaessler left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let me know if you want me to look at something specific.

ggml-cuda.cu Outdated Show resolved Hide resolved
ggml-cuda.cu Outdated Show resolved Hide resolved
ggml-cuda.cu Outdated Show resolved Hide resolved
llama.cpp Show resolved Hide resolved
ggml-cuda.cu Outdated Show resolved Hide resolved
llama.cpp Show resolved Hide resolved
llama.cpp Show resolved Hide resolved
llama.cpp Outdated Show resolved Hide resolved
llama.cpp Outdated Show resolved Hide resolved
@JohannesGaessler
Copy link
Collaborator

On my machine with 3x P40 LLaMA 2 70b q6_K crashes:

    ~/Pro/llama.cpp    sl/backend-sched wip *4 ?45  ./perplexity --n-gpu-layers 99 --model models/opt/${model_name}-${quantization}.gguf --file wikitext-2-raw/wiki.test.raw --mlock --chunks 10 
main: build = 1760 (af8a3742)
main: built with cc (GCC) 13.2.1 20230801 for x86_64-pc-linux-gnu
main: seed  = 1704376616
ggml_init_cublas: GGML_CUDA_FORCE_MMQ:   no
ggml_init_cublas: CUDA_USE_TENSOR_CORES: yes
ggml_init_cublas: found 3 CUDA devices:
  Device 0: Tesla P40, compute capability 6.1, VMM: yes
  Device 1: Tesla P40, compute capability 6.1, VMM: yes
  Device 2: Tesla P40, compute capability 6.1, VMM: yes
llama_model_loader: loaded meta data with 22 key-value pairs and 723 tensors from models/opt/llama_2-70b-q6_k.gguf (version GGUF V3 (latest))
llama_model_loader: Dumping metadata keys/values. Note: KV overrides do not apply in this output.
llama_model_loader: - kv   0:                       general.architecture str              = llama
llama_model_loader: - kv   1:                               general.name str              = LLaMA v2
llama_model_loader: - kv   2:                       llama.context_length u32              = 4096
llama_model_loader: - kv   3:                     llama.embedding_length u32              = 8192
llama_model_loader: - kv   4:                          llama.block_count u32              = 80
llama_model_loader: - kv   5:                  llama.feed_forward_length u32              = 28672
llama_model_loader: - kv   6:                 llama.rope.dimension_count u32              = 128
llama_model_loader: - kv   7:                 llama.attention.head_count u32              = 64
llama_model_loader: - kv   8:              llama.attention.head_count_kv u32              = 8
llama_model_loader: - kv   9:     llama.attention.layer_norm_rms_epsilon f32              = 0.000010
llama_model_loader: - kv  10:                          general.file_type u32              = 18
llama_model_loader: - kv  11:                       tokenizer.ggml.model str              = llama
llama_model_loader: - kv  12:                      tokenizer.ggml.tokens arr[str,32000]   = ["<unk>", "<s>", "</s>", "<0x00>", "<...
llama_model_loader: - kv  13:                      tokenizer.ggml.scores arr[f32,32000]   = [0.000000, 0.000000, 0.000000, 0.0000...
llama_model_loader: - kv  14:                  tokenizer.ggml.token_type arr[i32,32000]   = [2, 3, 3, 6, 6, 6, 6, 6, 6, 6, 6, 6, ...
llama_model_loader: - kv  15:                      tokenizer.ggml.merges arr[str,61249]   = ["▁ t", "e r", "i n", "▁ a", "e n...
llama_model_loader: - kv  16:                tokenizer.ggml.bos_token_id u32              = 1
llama_model_loader: - kv  17:                tokenizer.ggml.eos_token_id u32              = 2
llama_model_loader: - kv  18:            tokenizer.ggml.unknown_token_id u32              = 0
llama_model_loader: - kv  19:               tokenizer.ggml.add_bos_token bool             = true
llama_model_loader: - kv  20:               tokenizer.ggml.add_eos_token bool             = false
llama_model_loader: - kv  21:               general.quantization_version u32              = 2
llama_model_loader: - type  f32:  161 tensors
llama_model_loader: - type q6_K:  562 tensors
llm_load_vocab: special tokens definition check successful ( 259/32000 ).
llm_load_print_meta: format           = GGUF V3 (latest)
llm_load_print_meta: arch             = llama
llm_load_print_meta: vocab type       = SPM
llm_load_print_meta: n_vocab          = 32000
llm_load_print_meta: n_merges         = 0
llm_load_print_meta: n_ctx_train      = 4096
llm_load_print_meta: n_embd           = 8192
llm_load_print_meta: n_head           = 64
llm_load_print_meta: n_head_kv        = 8
llm_load_print_meta: n_layer          = 80
llm_load_print_meta: n_rot            = 128
llm_load_print_meta: n_embd_head_k    = 128
llm_load_print_meta: n_embd_head_v    = 128
llm_load_print_meta: n_gqa            = 8
llm_load_print_meta: n_embd_k_gqa     = 1024
llm_load_print_meta: n_embd_v_gqa     = 1024
llm_load_print_meta: f_norm_eps       = 0.0e+00
llm_load_print_meta: f_norm_rms_eps   = 1.0e-05
llm_load_print_meta: f_clamp_kqv      = 0.0e+00
llm_load_print_meta: f_max_alibi_bias = 0.0e+00
llm_load_print_meta: n_ff             = 28672
llm_load_print_meta: n_expert         = 0
llm_load_print_meta: n_expert_used    = 0
llm_load_print_meta: rope scaling     = linear
llm_load_print_meta: freq_base_train  = 10000.0
llm_load_print_meta: freq_scale_train = 1
llm_load_print_meta: n_yarn_orig_ctx  = 4096
llm_load_print_meta: rope_finetuned   = unknown
llm_load_print_meta: model type       = 70B
llm_load_print_meta: model ftype      = Q6_K
llm_load_print_meta: model params     = 68.98 B
llm_load_print_meta: model size       = 52.70 GiB (6.56 BPW) 
llm_load_print_meta: general.name     = LLaMA v2
llm_load_print_meta: BOS token        = 1 '<s>'
llm_load_print_meta: EOS token        = 2 '</s>'
llm_load_print_meta: UNK token        = 0 '<unk>'
llm_load_print_meta: LF token         = 13 '<0x0A>'
split[0] = 0.33
split[1] = 0.67
split[2] = 1.00
layer 0 -> gpu 0
layer 1 -> gpu 0
layer 2 -> gpu 0
layer 3 -> gpu 0
layer 4 -> gpu 0
layer 5 -> gpu 0
layer 6 -> gpu 0
layer 7 -> gpu 0
layer 8 -> gpu 0
layer 9 -> gpu 0
layer 10 -> gpu 0
layer 11 -> gpu 0
layer 12 -> gpu 0
layer 13 -> gpu 0
layer 14 -> gpu 0
layer 15 -> gpu 0
layer 16 -> gpu 0
layer 17 -> gpu 0
layer 18 -> gpu 0
layer 19 -> gpu 0
layer 20 -> gpu 0
layer 21 -> gpu 0
layer 22 -> gpu 0
layer 23 -> gpu 0
layer 24 -> gpu 0
layer 25 -> gpu 0
layer 26 -> gpu 0
layer 27 -> gpu 1
layer 28 -> gpu 1
layer 29 -> gpu 1
layer 30 -> gpu 1
layer 31 -> gpu 1
layer 32 -> gpu 1
layer 33 -> gpu 1
layer 34 -> gpu 1
layer 35 -> gpu 1
layer 36 -> gpu 1
layer 37 -> gpu 1
layer 38 -> gpu 1
layer 39 -> gpu 1
layer 40 -> gpu 1
layer 41 -> gpu 1
layer 42 -> gpu 1
layer 43 -> gpu 1
layer 44 -> gpu 1
layer 45 -> gpu 1
layer 46 -> gpu 1
layer 47 -> gpu 1
layer 48 -> gpu 1
layer 49 -> gpu 1
layer 50 -> gpu 1
layer 51 -> gpu 1
layer 52 -> gpu 1
layer 53 -> gpu 1
layer 54 -> gpu 2
layer 55 -> gpu 2
layer 56 -> gpu 2
layer 57 -> gpu 2
layer 58 -> gpu 2
layer 59 -> gpu 2
layer 60 -> gpu 2
layer 61 -> gpu 2
layer 62 -> gpu 2
layer 63 -> gpu 2
layer 64 -> gpu 2
layer 65 -> gpu 2
layer 66 -> gpu 2
layer 67 -> gpu 2
layer 68 -> gpu 2
layer 69 -> gpu 2
layer 70 -> gpu 2
layer 71 -> gpu 2
layer 72 -> gpu 2
layer 73 -> gpu 2
layer 74 -> gpu 2
layer 75 -> gpu 2
layer 76 -> gpu 2
layer 77 -> gpu 2
layer 78 -> gpu 2
layer 79 -> gpu 2
output -> gpu 2
llm_load_tensors: ggml ctx size       =    1.10 MiB
llm_load_tensors: offloading 80 repeating layers to GPU
llm_load_tensors: offloading non-repeating layers to GPU
llm_load_tensors: offloaded 81/81 layers to GPU
llm_load_tensors:        ??? buffer size =  205.08 MiB
llm_load_tensors:        ??? buffer size = 18074.81 MiB
llm_load_tensors:        ??? buffer size = 18074.81 MiB
llm_load_tensors:        ??? buffer size = 17610.48 MiB
....................................................................................................
warning: munmap failed: Invalid argument
llama_new_context_with_model: n_ctx      = 512
llama_new_context_with_model: freq_base  = 10000.0
llama_new_context_with_model: freq_scale = 1
llama_kv_cache_init: KV        ??? buffer size: 54.00 MiB
llama_kv_cache_init: KV        ??? buffer size: 54.00 MiB
llama_kv_cache_init: KV        ??? buffer size: 52.00 MiB
llama_new_context_with_model: KV self size  =  160.00 MiB, K (f16):   80.00 MiB, V (f16):   80.00 MiB
GGML_ASSERT: ggml-backend.c:1011: cur_split < GGML_MAX_SPLITS
ptrace: Operation not permitted.
No stack.
The program is not being run.
zsh: IOT instruction (core dumped)  ./perplexity --n-gpu-layers 99 --model  --file wikitext-2-raw/wiki.test.raw

LLaMA 2 7b q4_0 and LLaMA 2 13b q6_K work correctly.

@JohannesGaessler
Copy link
Collaborator

I did some performance testing on my Linux system with 3x P40:

GPU Model Test t/s master t/s PR Speedup
3x P40, PCIe 3.0 x16/x8/x8 7b q4_0 pp512 879 630 0.72
3x P40, PCIe 3.0 x16/x8/x8 7b q4_0 tg128 55.16 47.91 0.87
3x P40, PCIe 3.0 x8/x4/x4 7b q4_0 pp512 712 629 0.88
3x P40, PCIe 3.0 x8/x4/x4 7b q4_0 tg128 53.91 46.46 0.86
2x P40, PCIe 3.0 x8/x8 7b q4_0 pp512 781 676 0.87
2x P40, PCIe 3.0 x8/x8 7b q4_0 tg128 50.20 49.92 0.99

The P40s are connected to the motherboard via PCIe 3.0 x16/x8/x8. Choosing the x16 GPU as the main GPU lets you use the full PCIe bandwidth. Choosing any of the other GPUs effectively reduces the interconnect speed to x8/x4/x4. For dual GPU use the interconnect speed is effectively limited to x8/x8. I was not able to test LLaMA 2 70b or Mixtral because they currently do not work with this PR.

On my system the new scheme for distributing weights is slower. One of the reasons is probably that P40s are comparatively slow so the interconnect speed doesn't bottleneck them as much as e.g. 3x RTX 4090. Synchronization overhead may also depend on OS. Finally, because the P40s use MMQ they benefit from the performance optimizations in this PR: #3110 . For FP16 cuBLAS similar optimizations could be done by converting the hidden state to FP16 on the main device and caching the dequantized weight matrix (otherwise it is dequantized multiple times if tiling is enabled).

@slaren
Copy link
Collaborator Author

slaren commented Jan 4, 2024

Thanks for testing. There is an issue in the graph splitting logic that is causing some operations of each layer to be run on a different GPU, and that's why it fails with 70B, it creates too many splits. GGML_MAX_SPLITS is 256, while it should only need 4 or 5 splits with 3 GPUs. So there is still a lot of room for improvement there, the performance should improve a lot after fixing that. For me in WSL, with 3080+3090 7B q4_0 I get ~2200 t/s pp512, 70 t/s tg128, about 4 times faster pp and 7 times faster tg than master with row-level splitting.

@slaren
Copy link
Collaborator Author

slaren commented Jan 4, 2024

The graph split generation should be much better now. With 3080+3090Ti I get 3600 t/s pp 512, and 105 t/s tg 128. Compared to a single GPU, I get 4020 t/s pp, 121 t/s tg with only the 3090Ti, and 3320 t/s pp, 110 t/s tg with only the 3080.

@ggerganov
Copy link
Owner

ggerganov commented Jan 5, 2024

Did some tests on 4x RTX 4090 - significant improvements are observed

master

ggml_init_cublas: GGML_CUDA_FORCE_MMQ: no
ggml_init_cublas: CUDA_USE_TENSOR_CORES: yes
ggml_init_cublas: found 4 CUDA devices:
Device 0: NVIDIA GeForce RTX 4090, compute capability 8.9, VMM: yes
Device 1: NVIDIA GeForce RTX 4090, compute capability 8.9, VMM: yes
Device 2: NVIDIA GeForce RTX 4090, compute capability 8.9, VMM: yes
Device 3: NVIDIA GeForce RTX 4090, compute capability 8.9, VMM: yes

model size params backend ngl test t/s
llama 34B F16 62.85 GiB 33.74 B CUDA 99 pp 512 250.84 ± 0.71
llama 34B F16 62.85 GiB 33.74 B CUDA 99 tg 128 15.66 ± 0.11
llama 34B Q8_0 33.39 GiB 33.74 B CUDA 99 pp 512 252.90 ± 1.06
llama 34B Q8_0 33.39 GiB 33.74 B CUDA 99 tg 128 22.71 ± 0.08
llama 34B Q4_0 17.74 GiB 33.74 B CUDA 99 pp 512 253.03 ± 1.32
llama 34B Q4_0 17.74 GiB 33.74 B CUDA 99 tg 128 25.50 ± 0.05

build: b3a7c20 (1768)

PR

(bumped GGML_MAX_BACKENDS to 8)

ggml_init_cublas: GGML_CUDA_FORCE_MMQ: no
ggml_init_cublas: CUDA_USE_TENSOR_CORES: yes
ggml_init_cublas: found 4 CUDA devices:
Device 0: NVIDIA GeForce RTX 4090, compute capability 8.9, VMM: yes
Device 1: NVIDIA GeForce RTX 4090, compute capability 8.9, VMM: yes
Device 2: NVIDIA GeForce RTX 4090, compute capability 8.9, VMM: yes
Device 3: NVIDIA GeForce RTX 4090, compute capability 8.9, VMM: yes

model size params backend ngl test t/s
llama 34B F16 62.85 GiB 33.74 B CUDA 99 pp 512 2415.13 ± 133.18
llama 34B F16 62.85 GiB 33.74 B CUDA 99 tg 128 12.81 ± 0.01
llama 34B Q8_0 33.39 GiB 33.74 B CUDA 99 pp 512 1518.36 ± 13.90
llama 34B Q8_0 33.39 GiB 33.74 B CUDA 99 tg 128 23.68 ± 0.01
llama 34B Q4_0 17.74 GiB 33.74 B CUDA 99 pp 512 1572.18 ± 4.97
llama 34B Q4_0 17.74 GiB 33.74 B CUDA 99 tg 128 40.79 ± 0.08

build: d4fca23 (1763)

Any idea why F16 TG is slower with this PR?

Will keep the pod active for the day - can do some more tests if needed

P.S. Here are some additional results using the batched-bench tool:

LLAMA_CUBLAS=1 make -j batched-bench && ./batched-bench models/codellama-34b/ggml-model-f16.gguf 18432 0 99 1 100 128 1,2,3,4,5,6,7,8,16,32,64

master

llm_load_tensors: ggml ctx size       =    0.17 MiB
llm_load_tensors: using CUDA for GPU acceleration
llm_load_tensors: system memory used  =  500.17 MiB
llm_load_tensors: VRAM used           = 63863.03 MiB
llm_load_tensors: offloading 48 repeating layers to GPU
llm_load_tensors: offloading non-repeating layers to GPU
llm_load_tensors: offloaded 49/49 layers to GPU
....................................................................................................
llama_new_context_with_model: n_ctx      = 18432
llama_new_context_with_model: freq_base  = 1000000.0
llama_new_context_with_model: freq_scale = 1
llama_kv_cache_init: VRAM kv self = 3456.00 MB
llama_new_context_with_model: KV self size  = 3456.00 MiB, K (f16): 1728.00 MiB, V (f16): 1728.00 MiB
llama_build_graph: non-view tensors processed: 1012/1012
llama_new_context_with_model: compute buffer total size = 2391.19 MiB
llama_new_context_with_model: VRAM scratch buffer: 2388.00 MiB
llama_new_context_with_model: total VRAM used: 69707.04 MiB (model: 63863.03 MiB, context: 5844.00 MiB)

main: n_kv_max = 18432, is_pp_shared = 0, n_gpu_layers = 99, mmq = 1, n_threads = 64, n_threads_batch = 64

|    PP |     TG |    B |   N_KV |   T_PP s | S_PP t/s |   T_TG s | S_TG t/s |      T s |    S t/s |
|-------|--------|------|--------|----------|----------|----------|----------|----------|----------|
|   100 |    128 |    1 |    228 |    0.395 |   253.18 |    6.744 |    18.98 |    7.139 |    31.94 |
|   100 |    128 |    2 |    456 |    0.598 |   334.30 |   16.628 |    15.40 |   17.226 |    26.47 |
|   100 |    128 |    3 |    684 |    0.894 |   335.73 |   18.027 |    21.30 |   18.921 |    36.15 |
|   100 |    128 |    4 |    912 |    1.231 |   324.94 |   19.466 |    26.30 |   20.697 |    44.06 |
|   100 |    128 |    5 |   1140 |    1.444 |   346.25 |   20.630 |    31.02 |   22.074 |    51.65 |
|   100 |    128 |    6 |   1368 |    1.955 |   306.93 |   22.029 |    34.86 |   23.984 |    57.04 |
|   100 |    128 |    7 |   1596 |    2.183 |   320.72 |   23.209 |    38.61 |   25.391 |    62.86 |
|   100 |    128 |    8 |   1824 |    2.344 |   341.36 |   24.595 |    41.63 |   26.939 |    67.71 |
|   100 |    128 |   16 |   3648 |    4.755 |   336.51 |   35.122 |    58.31 |   39.876 |    91.48 |
|   100 |    128 |   32 |   7296 |    9.611 |   332.95 |   54.635 |    74.97 |   64.246 |   113.56 |
|   100 |    128 |   64 |  14592 |   19.879 |   321.96 |  101.774 |    80.49 |  121.653 |   119.95 |

PR

llm_load_tensors: ggml ctx size =    0.83 MiB
llm_load_tensors: offloading 48 repeating layers to GPU
llm_load_tensors: offloading non-repeating layers to GPU
llm_load_tensors: offloaded 49/49 layers to GPU
llm_load_tensors:        CPU buffer size = 64364.67 MiB
llm_load_tensors:      CUDA0 buffer size = 17160.81 MiB
llm_load_tensors:      CUDA1 buffer size = 15840.75 MiB
llm_load_tensors:      CUDA2 buffer size = 15840.75 MiB
llm_load_tensors:      CUDA3 buffer size = 15020.72 MiB
....................................................................................................
llama_new_context_with_model: n_ctx      = 18432
llama_new_context_with_model: freq_base  = 1000000.0
llama_new_context_with_model: freq_scale = 1
llama_kv_cache_init:      CUDA0 KV buffer size =  936.00 MiB
llama_kv_cache_init:      CUDA1 KV buffer size =  864.00 MiB
llama_kv_cache_init:      CUDA2 KV buffer size =  864.00 MiB
llama_kv_cache_init:      CUDA3 KV buffer size =  792.00 MiB
llama_new_context_with_model: KV self size  = 3456.00 MiB, K (f16): 1728.00 MiB, V (f16): 1728.00 MiB
llama_new_context_with_model: graph splits (measure): 10
llama_new_context_with_model:      CUDA0 compute buffer size = 2388.00 MiB
llama_new_context_with_model:      CUDA1 compute buffer size = 2388.00 MiB
llama_new_context_with_model:      CUDA2 compute buffer size = 2388.00 MiB
llama_new_context_with_model:      CUDA3 compute buffer size = 2388.00 MiB
llama_new_context_with_model:        CPU compute buffer size =   52.00 MiB

main: n_kv_max = 18432, is_pp_shared = 0, n_gpu_layers = 99, mmq = 1, n_threads = 64, n_threads_batch = 64

|    PP |     TG |    B |   N_KV |   T_PP s | S_PP t/s |   T_TG s | S_TG t/s |      T s |    S t/s |
|-------|--------|------|--------|----------|----------|----------|----------|----------|----------|
|   100 |    128 |    1 |    228 |    0.125 |   799.06 |   10.012 |    12.78 |   10.137 |    22.49 |
|   100 |    128 |    2 |    456 |    0.144 |  1392.12 |   10.290 |    24.88 |   10.433 |    43.71 |
|   100 |    128 |    3 |    684 |    0.182 |  1645.47 |   10.395 |    36.94 |   10.577 |    64.67 |
|   100 |    128 |    4 |    912 |    0.203 |  1974.31 |   10.488 |    48.82 |   10.691 |    85.30 |
|   100 |    128 |    5 |   1140 |    0.226 |  2214.94 |   10.670 |    59.98 |   10.896 |   104.62 |
|   100 |    128 |    6 |   1368 |    0.339 |  1770.26 |   10.725 |    71.61 |   11.064 |   123.64 |
|   100 |    128 |    7 |   1596 |    0.359 |  1951.42 |   10.881 |    82.35 |   11.240 |   142.00 |
|   100 |    128 |    8 |   1824 |    0.402 |  1987.88 |   10.988 |    93.19 |   11.391 |   160.13 |
|   100 |    128 |   16 |   3648 |    0.837 |  1910.98 |   12.507 |   163.75 |   13.344 |   273.37 |
|   100 |    128 |   32 |   7296 |    1.799 |  1778.85 |   13.207 |   310.14 |   15.006 |   486.21 |
|   100 |    128 |   64 |  14592 |    4.336 |  1475.96 |   20.349 |   402.57 |   24.685 |   591.12 |

@JohannesGaessler
Copy link
Collaborator

Any idea why F16 TG is slower with this PR?

Token generation is I/O bound. So FP16 is going to be comparatively slow (which reduces the interconnect bottleneck). At the same time the hidden state is very small and thus can be transferred quickly between devices. So I think that this is just a scenario where the implementation on master has favorable conditions.

@slaren
Copy link
Collaborator Author

slaren commented Jan 5, 2024

Any idea why F16 TG is slower with this PR?

I think there are some cases where the row splitting implementation can achieve higher tg performance than with a single GPU, so maybe this is one of them, but I am just guessing. I am already aware of some inefficiencies in the current implementation, for instance the way the data is copied between GPUs is very inefficient, but there is still a lot of work to do before this can be merged, so I think it is better to leave that for another PR.

@slaren
Copy link
Collaborator Author

slaren commented Jan 5, 2024

I have no idea what needs to be done to fix the swift build, the errors don't make sense to me.

@ggerganov
Copy link
Owner

The swift builds are failing because #4691 makes the swift examples pull the ggml repo instead of using the sources from llama.cpp. It's a dependency problem that I did not consider at the time

The builds will be resolved after upstreaming the changes from this PR to the ggml repo. In the meantime I will try to figure out a better solution

@cmp-nct
Copy link
Contributor

cmp-nct commented Jan 6, 2024

Nice job, glad to see! (And thanks, as this will close #4055 which I held dear :)

For parallelism in prompt processing there is possibly another optimization (https://www.deepspeed.ai/tutorials/pipeline/)

@slaren
Copy link
Collaborator Author

slaren commented Jan 6, 2024

Does that work for inference? The article seems to be about training, by interleaving forward and backward passes, but I don't see how to apply that during inference.

@cmp-nct
Copy link
Contributor

cmp-nct commented Jan 6, 2024

Does that work for inference? The article seems to be about training, by interleaving forward and backward passes, but I don't see how to apply that during inference.

I'll still have to wrap my head around it and it's 7:00 AM. maybe I'm mistaken but if we process a batch of tokens through GPU1 (layers 1-10), then pass those to GPU2 (11-20), wouldn't we be able to process the next on GPU1 while GPU2 is still working ?

@JohannesGaessler
Copy link
Collaborator

Does that work for inference? The article seems to be about training, by interleaving forward and backward passes, but I don't see how to apply that during inference.

What you can do is split a batch of e.g. 512 tokens into 4 batches of 128 tokens. If you then have 4 GPUs you can have the GPU with the first few layers work on the first batch first. Once the first batch has passed through the layers of the first GPU, the first GPU can start working on the second batch while the second GPU can start working on the first batch. In this particular example you could get up to 57% GPU utilization, so more than 2x more than without pipelining. For very large batches you could get up to ~100% GPU utilization.

The problem with this approach is that the code in its current form very much does not like small batches in terms of performance:

cublas_vs_mmq_q8_0

So even if you do pipelining and get higher GPU utilization the overall performance could be worse. I have a prototype for int8 matrix matrix multiplication using tensor cores; I'll make a related PR later today (did not yet investigate performance in detail but could be better for small batches).

@ikawrakow
Copy link
Contributor

@slaren @ggerganov

This PR changes perplexity in non-negligible ways. I came across this while looking into #4900. I checked out the version just before #4872 was merged (49662cb), quantized Mixtral-instruct-8x7B with Q3_K_L, and computed perplexity for 200 chunks (not a full calculation as it is quite slow with my 16 GB GPU where I can offload only 22 layers to the GPU). I then checked out current master, prepared a fix that makes the new Q3_K_S identical to Q3_K_L pre-#4872 (see PR #4906), and ran perplexity. After 200 chunks I had PPL = 4.8084 versus PPL = 4.7936 on 49662cb. Assuming that the difference is due to my fix not working as intended, I wasted some time trying to sort out why what it seems a very simple fix wasn't working. But at the end I saw that the model produced for Q3_K_S by #4906 is identical to the model produced for Q3_K_L by 49662cb, so I started going back in history and saw that this PR causes the change.

This is just one anecdotal example, but 0.015 change in PPL is a bit too large for my taste. Did you run a more detailed examination of PPL changes due to this PR?

@ggerganov
Copy link
Owner

ggerganov commented Jan 13, 2024

I'm looking into this and a difference in the PPL values with partial offloading is expected from this PR, but I cannot say yet if the large delta that you observe is within that expectation.

The reason that the results change after this PR is that for the non-offloaded layers, the RoPE is now computed on the CPU, while before this PR, it was computed on the GPU - i.e. the Q and K data for each non-offloaded layer was being copied to the GPU where it was roped with the data from the inp_pos tensor and then the result was moved back to the CPU. It was like this, because the inp_pos tensor was forced to by in GPU memory for all layers:

llama.cpp/llama.cpp

Lines 6174 to 6178 in 584d674

{ "inp_pos", OFFLOAD_FUNC_FRC }, // this is often used for KQ ops (e.g. rope)
{ "KQ_mask", OFFLOAD_FUNC_FRC },
{ "K_shift", OFFLOAD_FUNC_FRC },

The same analysis is also valid for the softmax in the attention, because the KQ_mask tensor was also forced to offload, while now it is on the CPU for non-offloaded layers.

Bottomline is that for non-offloaded layers, RoPE and softmax are now running on the CPU, while before they ran on the GPU, so some change in PPL should be observed.

I did the following experiment - I checked out the commit before this PR and applied this patch:

diff --git a/llama.cpp b/llama.cpp
index ce413f60..d22669c2 100644
--- a/llama.cpp
+++ b/llama.cpp
@@ -6172,8 +6172,8 @@ static const std::unordered_map<const char *, llm_offload_func_e> k_offload_map
   //{ "inp_embd",                   OFFLOAD_FUNC_NR  }, // TODO: missing K-quants get_rows kernel
     { "pos_embd",                   OFFLOAD_FUNC_NR  },
 
-    { "inp_pos",                    OFFLOAD_FUNC_FRC }, // this is often used for KQ ops (e.g. rope)
-    { "KQ_mask",                    OFFLOAD_FUNC_FRC },
+    { "inp_pos",                    OFFLOAD_FUNC_NR }, // this is often used for KQ ops (e.g. rope)
+    { "KQ_mask",                    OFFLOAD_FUNC_NR },
     { "K_shift",                    OFFLOAD_FUNC_FRC },
 
     { "K_shifted",                  OFFLOAD_FUNC     },

This will make the inp_pos and KQ_mask stay on the CPU when partial offloading is used and so RoPE and softmax will be done on the CPU for non-offloaded layers and on the GPU for offloaded. This should match the computation with latest master. I did a few tests with partially offloaded LLaMA and the PPL values now match.

However, they still don't match for Mixtral - I guess there is some additional difference that I do not see, likely due to the extra operations involved in the FFN.

I also checked that the PPL with full offload is the same for CUDA and Metal before and after this PR.

I'll keep looking as I am not yet 100% sure that there isn't some other issue

Edit: for Mixtral with -ngl 8 we now have 3 extra tensors offloaded compared to before this PR:

node #1246 (       MUL):  ffn_moe_weighted-23 (  32K) [CUDA0         ]: CUDA0#ffn_moe_down-2 (  32K) [CUDA0         ] CUDA0#ffn_moe_weight (   0K) [CUDA0         ]
node #1247 (       ADD):       ffn_moe_out-23 (  32K) [CUDA0         ]: CUDA0#ffn_moe_weight (  32K) [CUDA0         ]  ffn_moe_weighted-23 (  32K) [CUDA0         ]
node #1248 (       ADD):             l_out-23 (  32K) [CUDA0         ]:       ffn_moe_out-23 (  32K) [CUDA0         ]     CUDA0#ffn_inp-23 (  32K) [CUDA0         ]

These were on the CPU before, so I suspect this should explain the difference in the PPL values. Will try to confirm this

@mononoSaya
Copy link

partially load with opencl dont work llama-b1843-bin-win-clblast-x64.zip
B1842 normal operation
1842
B1843 dont work
1843

@Green-Sky
Copy link
Collaborator

Green-Sky commented Jan 15, 2024

This pr broke file prompt ingestion in the main example.
image
It will spew out mostly escape sequences.
Tested with llama2 7b and goat-storytelling (based on llama2 70b).

Details
$ result/bin/llama -m models/llama-2-7b.Q4_K_M.gguf -c 0 -n -1 --repeat_penalty 1 --color -t 8 -b 16 --ignore-eos --top-p 1.0 --temp 1.0 -ngl 18 -nkvo -f prompts/chat-with-bob.txt
Log start
main: build = 0 (unknown)
main: built with gcc (GCC) 11.4.0 for x86_64-unknown-linux-gnu
main: seed  = 1705314722
ggml_init_cublas: GGML_CUDA_FORCE_MMQ:   no
ggml_init_cublas: CUDA_USE_TENSOR_CORES: yes
ggml_init_cublas: found 1 CUDA devices:
  Device 0: NVIDIA GeForce RTX 2070, compute capability 7.5, VMM: yes
llama_model_loader: loaded meta data with 19 key-value pairs and 291 tensors from models/llama-2-7b.Q4_K_M.gguf (version GGUF V2)
llama_model_loader: Dumping metadata keys/values. Note: KV overrides do not apply in this output.
llama_model_loader: - kv   0:                       general.architecture str              = llama
llama_model_loader: - kv   1:                               general.name str              = LLaMA v2
llama_model_loader: - kv   2:                       llama.context_length u32              = 4096
llama_model_loader: - kv   3:                     llama.embedding_length u32              = 4096
llama_model_loader: - kv   4:                          llama.block_count u32              = 32
llama_model_loader: - kv   5:                  llama.feed_forward_length u32              = 11008
llama_model_loader: - kv   6:                 llama.rope.dimension_count u32              = 128
llama_model_loader: - kv   7:                 llama.attention.head_count u32              = 32
llama_model_loader: - kv   8:              llama.attention.head_count_kv u32              = 32
llama_model_loader: - kv   9:     llama.attention.layer_norm_rms_epsilon f32              = 0.000010
llama_model_loader: - kv  10:                          general.file_type u32              = 15
llama_model_loader: - kv  11:                       tokenizer.ggml.model str              = llama
llama_model_loader: - kv  12:                      tokenizer.ggml.tokens arr[str,32000]   = ["<unk>", "<s>", "</s>", "<0x00>", "<...
llama_model_loader: - kv  13:                      tokenizer.ggml.scores arr[f32,32000]   = [0.000000, 0.000000, 0.000000, 0.0000...
llama_model_loader: - kv  14:                  tokenizer.ggml.token_type arr[i32,32000]   = [2, 3, 3, 6, 6, 6, 6, 6, 6, 6, 6, 6, ...
llama_model_loader: - kv  15:                tokenizer.ggml.bos_token_id u32              = 1
llama_model_loader: - kv  16:                tokenizer.ggml.eos_token_id u32              = 2
llama_model_loader: - kv  17:            tokenizer.ggml.unknown_token_id u32              = 0
llama_model_loader: - kv  18:               general.quantization_version u32              = 2
llama_model_loader: - type  f32:   65 tensors
llama_model_loader: - type q4_K:  193 tensors
llama_model_loader: - type q6_K:   33 tensors
llm_load_vocab: special tokens definition check successful ( 259/32000 ).
llm_load_print_meta: format           = GGUF V2
llm_load_print_meta: arch             = llama
llm_load_print_meta: vocab type       = SPM
llm_load_print_meta: n_vocab          = 32000
llm_load_print_meta: n_merges         = 0
llm_load_print_meta: n_ctx_train      = 4096
llm_load_print_meta: n_embd           = 4096
llm_load_print_meta: n_head           = 32
llm_load_print_meta: n_head_kv        = 32
llm_load_print_meta: n_layer          = 32
llm_load_print_meta: n_rot            = 128
llm_load_print_meta: n_embd_head_k    = 128
llm_load_print_meta: n_embd_head_v    = 128
llm_load_print_meta: n_gqa            = 1
llm_load_print_meta: n_embd_k_gqa     = 4096
llm_load_print_meta: n_embd_v_gqa     = 4096
llm_load_print_meta: f_norm_eps       = 0.0e+00
llm_load_print_meta: f_norm_rms_eps   = 1.0e-05
llm_load_print_meta: f_clamp_kqv      = 0.0e+00
llm_load_print_meta: f_max_alibi_bias = 0.0e+00
llm_load_print_meta: n_ff             = 11008
llm_load_print_meta: n_expert         = 0
llm_load_print_meta: n_expert_used    = 0
llm_load_print_meta: rope scaling     = linear
llm_load_print_meta: freq_base_train  = 10000.0
llm_load_print_meta: freq_scale_train = 1
llm_load_print_meta: n_yarn_orig_ctx  = 4096
llm_load_print_meta: rope_finetuned   = unknown
llm_load_print_meta: model type       = 7B
llm_load_print_meta: model ftype      = Q4_K - Medium
llm_load_print_meta: model params     = 6.74 B
llm_load_print_meta: model size       = 3.80 GiB (4.84 BPW)
llm_load_print_meta: general.name     = LLaMA v2
llm_load_print_meta: BOS token        = 1 '<s>'
llm_load_print_meta: EOS token        = 2 '</s>'
llm_load_print_meta: UNK token        = 0 '<unk>'
llm_load_print_meta: LF token         = 13 '<0x0A>'
llm_load_tensors: ggml ctx size =    0.22 MiB
llm_load_tensors: offloading 18 repeating layers to GPU
llm_load_tensors: offloaded 18/33 layers to GPU
llm_load_tensors:        CPU buffer size =  3891.24 MiB
llm_load_tensors:      CUDA0 buffer size =  2091.59 MiB
..................................................................................................
llama_new_context_with_model: n_ctx      = 4096
llama_new_context_with_model: freq_base  = 10000.0
llama_new_context_with_model: freq_scale = 1
llama_kv_cache_init:  CUDA_Host KV buffer size =  2048.00 MiB
llama_new_context_with_model: KV self size  = 2048.00 MiB, K (f16): 1024.00 MiB, V (f16): 1024.00 MiB
llama_new_context_with_model: graph splits (measure): 39
llama_new_context_with_model:      CUDA0 compute buffer size =    73.00 MiB
llama_new_context_with_model:  CUDA_Host compute buffer size =     9.00 MiB

system_info: n_threads = 8 / 24 | AVX = 1 | AVX_VNNI = 0 | AVX2 = 1 | AVX512 = 0 | AVX512_VBMI = 0 | AVX512_VNNI = 0 | FMA = 1 | NEON = 0 | ARM_FMA = 0 | F16C = 1 | FP16_VA = 0 | WASM_SIMD = 0 | BLAS = 1 | SSE3 = 1 | SSSE3 = 1 | VSX = 0 |
sampling:
	repeat_last_n = 64, repeat_penalty = 1.000, frequency_penalty = 0.000, presence_penalty = 0.000
	top_k = 40, tfs_z = 1.000, top_p = 1.000, min_p = 0.050, typical_p = 1.000, temp = 1.000
	mirostat = 0, mirostat_lr = 0.100, mirostat_ent = 5.000
sampling order:
CFG -> Penalties -> top_k -> tfs_z -> typical_p -> top_p -> min_p -> temp
generate: n_ctx = 4096, n_batch = 16, n_predict = -1, n_keep = 0


 Transcript of a dialog, where the User interacts with an Assistant named Bob. Bob is helpful, kind, honest, good at writing, and never fails to answer the User's requests immediately and with precision.

User: Hello, Bob.
Bob: Hello. How may I help you today?
User: Please tell me the largest city in Europe.
Bob: Sure. The largest city in Europe is Moscow, the capital of Russia.
$ser:$#$ %	"
 $

	 ▅%

	%


	
        	

#"	


!%
 !!!#  #
	



%%
  	▅!



	
        %▅%$"%#"#
"%
%▅$	
         ▅#		!


▅"
  %

"% %$""

$$▅$#"
      #
 !      ▅	#▅"
$




 %


	$      #

#▅
!	!
         $"
           $"!
             !


▅"	

        %


!	$	▅#%▅

Other models like tinyllama and stablelm-3b-4e1t will spew out single letters repeatedly.

@ggerganov
Copy link
Owner

There is likely an issue with the -nkvo argument as discussed earlier. Check if it works as expected when you remove it

@Green-Sky
Copy link
Collaborator

There is likely an issue with the -nkvo argument as discussed earlier. Check if it works as expected when you remove it

you where right, it does indeed work with kv offloading.

it is still funny to watch it decent into madness and mostly print escape sequences/rawbytes at the "end".
image

@ikawrakow
Copy link
Contributor

@ggerganov Thanks for the detailed explanation of differences introduced by this PR. Somehow I find it unlikely that running RoPE and softmax on the CPU instead of the GPU can cause such large differences in PPL. My latest example is Mixtral-8x7B quantized with Q5_0 using PR #4969. I get PPL = 4.1574 on master, and PPL = 4.1308 on 584d674 with the exact same quantized model. To me this looks like a bug.

@Ph0rk0z
Copy link

Ph0rk0z commented Jan 28, 2024

All this work and afterwards 2x3090 went from 18.9 t/s to 15 t/s and 2xP40 went from 8.8 to like 8.5, regardless of which layer split option is used.

@JohannesGaessler
Copy link
Collaborator

Are you on Windows? I also noticed a single GPU performance regression for my AMD GPU but we were not able to pin down what the cause is. Possibly some API call that is fast on Linux with CUDA is slow for Windows and HIP.

@Ph0rk0z
Copy link

Ph0rk0z commented Jan 28, 2024

I am on linux. I was worried about this so I am building an older version to double check my speeds.

Around dec 27th, which is when my backup is from, I was still getting at least 16.3 t/s so there have been regressions along the way. I probably wasn't paying attention because I was using mixtral.

@sorasoras
Copy link

Are you on Windows? I also noticed a single GPU performance regression for my AMD GPU but we were not able to pin down what the cause is. Possibly some API call that is fast on Linux with CUDA is slow for Windows and HIP.

There are performance regression on rdna2 on single GPU,but I haven't notice that on a rdna3 GPU like 7900xtx on Windows at least.

@jukofyork
Copy link
Contributor

Just to add my experience: using 2x RTX A6000 with an Nvlink bridge I'm only getting around 60% of the tokens/s using the new default layer splitting vs row splitting.

I'm guessing without the Nvlink it would be about the same.

@JohannesGaessler
Copy link
Collaborator

Around dec 27th, which is when my backup is from, I was still getting at least 16.3 t/s so there have been regressions along the way. I probably wasn't paying attention because I was using mixtral.

Can you do a git bisect to check which commit is to blame?

@Ph0rk0z
Copy link

Ph0rk0z commented Jan 29, 2024

I have to investigate. I have snapshots from nov 29, dec 27 and fresh. Going back to earlier ones isn't helping at the moment, I also haven't rebooted in about 3 months.

@Ph0rk0z
Copy link

Ph0rk0z commented Jan 30, 2024

Something weird is going on because I no longer hit 18.x back to november backup. Then again.. my outdoor temperatures are around 30f and my CPUs top out at 14c with the fans pushed to low. You'd think it works better colder but I guess not. I still have my old benchmarks saved on the same models so it's not in my head. By the current measurement I went from 14.4 to 15.4t/s. I suppose it will remain a mystery for a while longer.

jordankanter pushed a commit to jordankanter/llama.cpp that referenced this pull request Feb 3, 2024
* llama : ggml-backend integration

* ggml-backend : add names to buffers

* fix unmap after loading

* batched-bench : add tensor_split param

* llama : check for null tensor_split

* ggml-backend : increase GGML_MAX_BACKENDS

* improve graph splitting, partial fix for --no-kv-offload

* cuda : add ggml-backend split buffer support

* cuda : do not create buffer types for devices that don't exist (fixes usage without CUDA devices available)

* ggml : fix null backend dereference (ggerganov#4807)

* ggml : fix null backend dereference

* ggml : also check ggml_backend_is_cpu

* test-backend-ops : check buffer allocation failures

* llama : add cparam (split_mode) and command line argument (--split-mode, -sm) to configure the split mode (none, layer or row)

* ggml : fix mul_mat_id work size

* llama : rewrite session kv load/set without graphs

* minor

* llama : only initialize used backends, free backends on context free

* llama : abort ctx if cuda backend init fails

* llama : rewrite lora with ggml-backend and compute on CPU

ggml-ci

* llama : only map to a backend buffer the region of the file mapping containing the tensors used in the buffer

* opencl : add ggml-backend buffer type

* cuda : only use batched_cublas with batched mat muls (fixes fp16 tg perf)

* llama : on Metal, by default offload the full model

ggml-ci

* metal : page align the data ptr (ggerganov#4854)

* Apply suggestions from code review

Co-authored-by: Johannes Gäßler <johannesg@5d6.de>

* cuda : fix split buffer free

* address review comments

* llama-bench : add split-mode parameter

* fix whitespace

* opencl : fix double initialization

* server : add --split-mode parameter

* use async copy and compute to improve multi-gpu performance

ggml-ci

* use async memcpys to copy the graph outputs to the CPU

* fix opencl

* use a host buffer for the cpu compute buffer for faster copies to the gpu

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
@Ph0rk0z
Copy link

Ph0rk0z commented Feb 5, 2024

The offending commit is: #4606

@slaren slaren deleted the sl/backend-sched branch March 21, 2024 12:44
hodlen pushed a commit to hodlen/llama.cpp that referenced this pull request Apr 1, 2024
* llama : ggml-backend integration

* ggml-backend : add names to buffers

* fix unmap after loading

* batched-bench : add tensor_split param

* llama : check for null tensor_split

* ggml-backend : increase GGML_MAX_BACKENDS

* improve graph splitting, partial fix for --no-kv-offload

* cuda : add ggml-backend split buffer support

* cuda : do not create buffer types for devices that don't exist (fixes usage without CUDA devices available)

* ggml : fix null backend dereference (ggerganov#4807)

* ggml : fix null backend dereference

* ggml : also check ggml_backend_is_cpu

* test-backend-ops : check buffer allocation failures

* llama : add cparam (split_mode) and command line argument (--split-mode, -sm) to configure the split mode (none, layer or row)

* ggml : fix mul_mat_id work size

* llama : rewrite session kv load/set without graphs

* minor

* llama : only initialize used backends, free backends on context free

* llama : abort ctx if cuda backend init fails

* llama : rewrite lora with ggml-backend and compute on CPU

ggml-ci

* llama : only map to a backend buffer the region of the file mapping containing the tensors used in the buffer

* opencl : add ggml-backend buffer type

* cuda : only use batched_cublas with batched mat muls (fixes fp16 tg perf)

* llama : on Metal, by default offload the full model

ggml-ci

* metal : page align the data ptr (ggerganov#4854)

* Apply suggestions from code review

Co-authored-by: Johannes Gäßler <johannesg@5d6.de>

* cuda : fix split buffer free

* address review comments

* llama-bench : add split-mode parameter

* fix whitespace

* opencl : fix double initialization

* server : add --split-mode parameter

* use async copy and compute to improve multi-gpu performance

ggml-ci

* use async memcpys to copy the graph outputs to the CPU

* fix opencl

* use a host buffer for the cpu compute buffer for faster copies to the gpu

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
hodlen added a commit to hodlen/llama.cpp that referenced this pull request Apr 3, 2024
readme : update hot topics

common : add `--version` option to show build info in CLI (#4433)

build : detect host compiler and cuda compiler separately (#4414)

sync : ggml (SD ops, tests, kernels) (#4444)

* sync : ggml (SD ops, tests, kernels)

ggml-ci

* cuda : restore im2col

ggml-ci

* metal : fix accuracy of dequantization kernels

ggml-ci

* cuda : restore correct im2col

ggml-ci

* metal : try to fix moe test by reducing expert size

ggml-ci

* cuda : fix bin bcast when src1 and dst have different types

ggml-ci

---------

Co-authored-by: slaren <slarengh@gmail.com>

server : fix handling of characters that span multiple tokens when streaming (#4446)

readme : update supported model list (#4457)

convert : support loading vocab from fast tokenizer config (#3633)

* Add HFVocab into convert.py

* Update convert.py

* Update convert.py

* add bytes_to_unicode function

* change add_meta_vocab fucntion

* remove debug code

* remove byte_encoder

* Add newline between classes

* Check tokenizer.json when tokenizer.model is not exist.

* Move transformers dependency to local code

* Add error context with 'raise from'

* Add fast tokenizer option to BpeVocab

* Update convert.py

* Add VocabLoader and remove *Vocab class

* Add transformers dependency

* remove added tokens and check newline token to decide spm or bpe

* Update convert.py

* Add special token type

* Update convert.py

* Update convert.py

* Update convert.py

* Fix typo in convert.py

* Fix when params.n_vocab < tokenizer vocab size

* update vocab class

* change funtion name

* Remove unused variable/functions, add types to class variable and methods, delete blank liens

* fix flake8 warnings

* code style cleanup

* make mypy happy

* change exception

---------

Co-authored-by: Jared Van Bortel <jared@nomic.ai>

ggml : fix OpenCL broadcast requirement for ggml_mul (close #4453)

ggml : add ggml_row_size() (fixes llama out of space) (#4461)

* Fixes "Not enough space in the context's memory pool" encountered on certain models, which seems to be caused by some imprecision related to the automatic casting of floating point values

* do not cast to size_t, instead just use doubles

* ggml : add ggml_row_size(), deprecate ggml_type_sizef()

* ggml : fix row size compute to avoid overflows

* tests : fix sizey -> sizez

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

py : add protobuf dependency (#4466)

ggml : remove n_dims from ggml_tensor (#4469)

ggml-ci

ggml : use ggml_row_size where possible (#4472)

* ggml : use ggml_row_size where possible

ggml-ci

* ggml : move ggml_nbytes_split to ggml-cuda.cu

ggml : group mul_mat_id rows by matrix (cpu only) (#4480)

* ggml : group mul_mat_id rows by matrix (cpu only)

* remove mmid parameters from mm forward

* store row groups in wdata and calculate only once in GGML_TASK_INIT

ggml-ci

server : add optional API Key Authentication example (#4441)

* Add API key authentication for enhanced server-client security

* server : to snake_case

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

llama : sanity checks for access to logits (#4274)

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

lora : add support for non-llama models (#3333)

* lora : add support for non-llama models

ggml-ci

* avoid leaking ggml_context on failure
cleanup

ggml-ci

* lora : allow 1d tensors

* lora : include embd and output layers in size calculation

* fix style

Link to cublas dynamically on Windows even with LLAMA_STATIC (#4506)

server : allow requests larger than 8K (#4500)

server : fix possible ambiguity in content type charset (#4501)

server : fix grammar being ignored (#4494)

Fix bug in identifying the grammar.

server : disable llm logs if SERVER_VERBOSE is off (#3792)

finetune : keep allocs alive until all allocations are done (#4486)

build : Check the ROCm installation location (#4485)

* build : Check the ROCm installation location

* more generic approach

* fixup! It was returning the path instead of the command output

* fixup! Trailing whitespace

gguf-py : fail fast on nonsensical special token IDs (#4489)

llama.swiftui : add bench functionality (#4483)

* llama.swiftui : add bench button

* llama.swiftui : initial bench functionality

* force to use n_gpu_layers on simulator

* add download buttons & expose llamaState.loadModel

* update project.pbxproj

* comment #Preview & fix editorconfig check

* gitignore : xcode stuff

* llama.swiftui : UX improvements

* llama.swiftui : avoid data copy via "downloadTask"

* llama.swiftui : remove model from project

* llama : remove "mostly" from model infos

* llama.swiftui : improve bench

---------

Co-authored-by: jhen <developer@jhen.me>

readme : update hot topics

decode : fix logits_valid for legacy API (#4516)

llama : fix try_override for bool_value which always return true (#4519)

llama : add phi-2 + fix NeoX rope + ggml_mul_mat_set_prec (#4490)

* phi2 implementation

* fix breaking change

* phi-2 : various fixes

* phi-2 : use layer norm eps

* py : whitespaces

* llama : fix meta KV override bug

* convert : phi don't add BOS token

* convert : revert "added_tokens_decoder" change

* phi-2 : scale Q instead of KQ for better precision

* ggml : fix NeoX rope to rotate just first n_dims

* cuda : less diff in the rope_neox kernel

* ggml : add ggml_mul_mat_set_prec

ggml-ci

* Update ggml-cuda.cu

Co-authored-by: slaren <slarengh@gmail.com>

* Update ggml-cuda.cu

Co-authored-by: slaren <slarengh@gmail.com>

* cuda : ggml_cuda_op_mul_mat_cublas support F32 precision

* cuda : remove oboslete comment

---------

Co-authored-by: Ebey Abraham <ebeyabraham@microsoft.com>
Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
Co-authored-by: slaren <slarengh@gmail.com>

llama.swiftui : add more models

llama.swiftui : add tinyllama 1.1B F16

ggml-cuda: Fix HIP build (#4528)

regression of #4490
Adds defines for two new datatypes
cublasComputeType_t, cudaDataType_t.

Currently using deprecated hipblasDatatype_t since newer ones very recent.

ggml : fixed check for _MSC_VER (#4535)

Co-authored-by: Eric Sommerlade <ersomme@microsoft.com>

CUDA: Faster Mixtral prompt processing (#4538)

* CUDA: make MoE tensors contiguous for batch size>1

* Update ggml-cuda.cu

Co-authored-by: slaren <slarengh@gmail.com>

---------

Co-authored-by: slaren <slarengh@gmail.com>

Fix access violation in ggml_cuda_free_data if tensor->extra is NULL (#4554)

llama : disable per-tensor info prints on model load (#4562)

cuda : replace asserts in wrong architecture checks with __trap (#4556)

* cuda : replace asserts in wrong architecture checks with __trap

* make bad_arch noreturn, remove returns

cuda : better error message for ggml_get_rows (#4561)

* Update ggml-cuda.cu

* Update ggml-cuda.cu

* Update ggml-cuda.cu

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

py : open merges file as 'utf-8' (#4566)

Otherwise, on Windows converting bling-phi-2-v0 (<https://huggingface.co/llmware/bling-phi-2-v0>) via convert-hf-to-gguf.py will fail with the following error:

```
Traceback (most recent call last):
  File "C:\Users\User\git\gguf\convert-hf-to-gguf.py", line 1061, in <module>
    model_instance.set_vocab()
  File "C:\Users\User\git\gguf\convert-hf-to-gguf.py", line 52, in set_vocab
    self._set_vocab_gpt2()
  File "C:\Users\User\git\gguf\convert-hf-to-gguf.py", line 264, in _set_vocab_gpt2
    special_vocab = gguf.SpecialVocab(dir_model, load_merges=True)
  File "C:\Users\User\git\gguf\gguf\vocab.py", line 33, in __init__
    self._load(Path(path))
  File "C:\Users\User\git\gguf\gguf\vocab.py", line 81, in _load
    self._try_load_merges_txt(path)
  File "C:\Users\User\git\gguf\gguf\vocab.py", line 95, in _try_load_merges_txt
    for line in fp:
  File "C:\Users\User\miniconda3\envs\gguf\lib\encodings\cp1252.py", line 23, in decode
    return codecs.charmap_decode(input,self.errors,decoding_table)[0]
UnicodeDecodeError: 'charmap' codec can't decode byte 0x81 in position 1415: character maps to <undefined>
```

readme : update coding guidelines

CUDA: mul_mat_id always on GPU for batches >= 32 (#4553)

common : remove incorrect --model-draft default (#4568)

ggml-cuda: Fix HIP build by adding define for __trap (#4569)

Regression of 139882392258671ffe5acdfcadc0bc08572d6eef
HIP doesn't have trap, only abort

cuda : ROCm AMD Unified Memory Architecture (UMA) handling (#4449)

* AMD ROCm: handle UMA memory VRAM expansions

This resolves #2797 by allowing ROCm AMD GPU users with a UMA to
dynamically expand the VRAM allocated to the GPU.

Without this, AMD ROCm users with shared CPU/GPU memory usually are
stuck with the BIOS-set (or fixed) framebuffer VRAM, making it
impossible to load more than 1-2 layers.

Note that the model is duplicated in RAM because it's loaded once for
the CPU and then copied into a second set of allocations that are
managed by the HIP UMA system. We can fix this later.

* clarify build process for ROCm on linux with cmake

* avoid using deprecated ROCm hipMallocHost

* keep simplifying the change required for UMA

* cmake: enable UMA-compatible allocation when LLAMA_HIP_UMA=ON

metal : fix `ggml_metal_log` vargs (#4373)

llama : allow getting n_batch from llama_context in c api (#4540)

* allowed getting n_batch from llama_context in c api

* changed to use `uint32_t` instead of `int`

* changed to use `uint32_t` instead of `int` in `llama_n_ctx`

* Update llama.h

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

llama : initial ggml-backend integration (#4520)

* llama : initial ggml-backend integration

* add ggml-metal

* cuda backend can be used though ggml-backend with LLAMA_GGML_BACKEND_CUDA_TEST
access all tensor data with ggml_backend_tensor_get/set

* add ggml_backend_buffer_clear
zero-init KV cache buffer

* add ggml_backend_buffer_is_hos, used to avoid copies if possible when accesing tensor data

* disable gpu backends with ngl 0

* more accurate mlock

* unmap offloaded part of the model

* use posix_fadvise64(.., POSIX_FADV_SEQUENTIAL) to improve performance with mmap

* update quantize and lora

* update session copy/set to use ggml-backend

ggml-ci

* use posix_fadvise instead of posix_fadvise64

* ggml_backend_alloc_ctx_tensors_from_buft : remove old print

* llama_mmap::align_offset : use pointers instead of references for out parameters

* restore progress_callback behavior

* move final progress_callback call to load_all_data

* cuda : fix fprintf format string (minor)

* do not offload scales

* llama_mmap : avoid unmapping the same fragments again in the destructor

* remove unnecessary unmap

* metal : add default log function that prints to stderr, cleanup code

ggml-ci

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

ci : add `jlumbroso/free-disk-space` to docker workflow (#4150)

* [github][workflows][docker]: removes hardcoded `ggerganov` from `ghcr` repo

* [github][workflows][docker]: adds `jlumbroso/free-disk-space`

gguf : simplify example dependencies

gguf-py : fix broken link

ggml : change ggml_scale to take a float instead of tensor (#4573)

* ggml : change ggml_scale to take a float instead of tensor

* ggml : fix CPU implementation

* tests : fix test-grad0

ggml-ci

llama : add ability to cancel model loading (#4462)

* llama : Add ability to cancel model load

Updated llama_progress_callback so that if it returns false, the model
loading is aborted.

* llama : Add test for model load cancellation

* Fix bool return in llama_model_load, remove std::ignore use

* Update llama.cpp

Co-authored-by: Jared Van Bortel <cebtenzzre@gmail.com>

* Fail test if model file is missing

* Revert "Fail test if model file is missing"

This reverts commit 32ebd525bf7e5a87ee8a3dbaab3d92ce79fbf23d.

* Add test-model-load-cancel to Makefile

* Revert "Revert "Fail test if model file is missing""

This reverts commit 2796953257ee5383fa7c8fe8fa8fc888c048fb0b.

* Simplify .gitignore for tests, clang-tidy fixes

* Label all ctest tests

* ci : ctest uses -L main

* Attempt at writing ctest_with_model

* ci : get ci/run.sh working with test-model-load-cancel

* ci : restrict .github/workflows/build.yml ctest to -L main

* update requirements.txt

* Disable test-model-load-cancel in make

* Remove venv before creation

* Restructure requirements.txt

Top-level now imports the specific additional requirements for each
python file. Using `pip install -r requirements.txt` will fail if
versions become mismatched in the per-file requirements.

* Make per-python-script requirements work alone

This doesn't break the main requirements.txt.

* Add comment

* Add convert-persimmon-to-gguf.py to new requirements.txt scheme

* Add check-requirements.sh script and GitHub workflow

* Remove shellcheck installation step from workflow

* Add nocleanup special arg

* Fix merge

see: https://github.com/ggerganov/llama.cpp/pull/4462#discussion_r1434593573

* reset to upstream/master

* Redo changes for cancelling model load

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
Co-authored-by: Jared Van Bortel <cebtenzzre@gmail.com>

ggml : extend `enum ggml_log_level` with `GGML_LOG_LEVEL_DEBUG` (#4579)

readme : add zig bindings (#4581)

ci : tag docker image with build number (#4584)

make : add LLAMA_HIP_UMA option (#4587)

NB: LLAMA_HIP_UMA=1 (or any value) adds MK_CPPFLAG -DGGML_HIP_UMA

ggml : add comment about backward GGML_OP_DIAG_MASK_INF (#4203)

llama : fix platforms without mmap (#4578)

* llama : fix platforms without mmap

* win32 : limit prefetch size to the file size

* fix win32 error clobber, unnecessary std::string in std::runtime_error

Fix CudaMemcpy direction (#4599)

cuda : fix jetson compile error (#4560)

* fix old jetson compile error

* Update Makefile

* update jetson detect and cuda version detect

* update cuda marco define

* update makefile and cuda,fix some issue

* Update README.md

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

* Update Makefile

* Update README.md

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

sync : ggml (fix im2col) (#4591)

* cuda : fix im2col_f32_f16 (ggml/#658)

ggml-ci

* ggml-alloc : fix ggml_tallocr_is_own

---------

Co-authored-by: leejet <leejet714@gmail.com>

lookup : add prompt lookup decoding example (#4484)

* initial commit, going through initializations

* main loop finished, starting to debug

* BUG: generates gibberish/repeating tokens after a while

* kv_cache management

* Added colors to distinguish drafted tokens (--color). Updated README

* lookup : fix token positions in the draft batch

* lookup : use n_draft from CLI params

* lookup : final touches

---------

Co-authored-by: Leon Ericsson <leon.ericsson@icloud.com>
Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

CUDA: fixed row rounding for 0 tensor splits (#4594)

grammar : check the full vocab only if necessary (opt) (#4306)

* Check the full vocab for grammar only if necessary

* Fix missing logit restoration step (?)

Does this matter, actually?

* Fix whitespace / formatting

* Adjust comment

* Didn't mean to push test gbnf

* Split sampling into the helper function (?)

And also revert the changes made to the header

* common : fix final newline

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

server : allow to specify custom prompt for penalty calculation (#3727)

ci(docker): fix tags in "Build and push docker image (tagged)" (#4603)

fallback to CPU buffer if host buffer alloc fails (#4610)

cuda : improve cuda pool efficiency using virtual memory (#4606)

* cuda : improve cuda pool efficiency using virtual memory

* fix mixtral

* fix cmake build

* check for vmm support, disable for hip

ggml-ci

* fix hip build

* clarify granularity

* move all caps to g_device_caps

* refactor error checking

* add cuda_pool_alloc, refactor most pool allocations

ggml-ci

* fix hip build

* CUBLAS_TF32_TENSOR_OP_MATH is not a macro

* more hip crap

* llama : fix msvc warnings

* ggml : fix msvc warnings

* minor

* minor

* cuda : fallback to CPU on host buffer alloc fail

* Update ggml-cuda.cu

Co-authored-by: Johannes Gäßler <johannesg@5d6.de>

* Update ggml-cuda.cu

Co-authored-by: Johannes Gäßler <johannesg@5d6.de>

* ensure allocations are always aligned

* act_size -> actual_size

---------

Co-authored-by: Johannes Gäßler <johannesg@5d6.de>

llama : add PLaMo model (#3557)

* add plamo mock

* add tensor loading

* plamo convert

* update norm

* able to compile

* fix norm_rms_eps hparam

* runnable

* use inp_pos

* seems ok

* update kqv code

* remove develop code

* update README

* shuffle attn_q.weight and attn_output.weight for broadcasting

* remove plamo_llm_build_kqv and use llm_build_kqv

* fix style

* update

* llama : remove obsolete KQ_scale

* plamo : fix tensor names for correct GPU offload

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

simplify bug issue template (#4623)

Adding Emeltal reference to UI list (#4629)

Fix new CUDA10 compilation errors (#4635)

Update comment for AdamW implementation reference. (#4604)

Co-authored-by: Will Findley <findley@gmail.com>

cuda : fix vmm pool with multi GPU (#4620)

* cuda : fix vmm pool with multi GPU

* hip

* use recommended granularity instead of minimum

* better error checking

* fix mixtral

* use cudaMemcpy3DPeerAsync

* use cuda_pool_alloc in ggml_cuda_op_mul_mat

* consolidate error checking in ggml_cuda_set_device

* remove unnecessary inlines

ggml-ci

* style fixes

* only use vmm for the main device

* fix scratch buffer size, re-enable vmm pool for all devices

* remove unnecessary check id != g_main_device

Add byte token type when tokenizer.model is not exists (#4641)

* Add byte token type to hf format

* remove unused variable

ggml : fix dot product for ARM (#4630)

ggml-ci

scripts : add sync-ggml-am.sh

finetune : fix output formatting in print_params (#4653)

This commit fixes the output formatting in the print_params function
which currently looks like this:
```console
print_params: n_vocab:   32000
print_params: n_ctx:     128
print_params: n_embd:    4096
print_params: n_ff:      11008
print_params: n_head:    32
print_params: n_head_kv: 32
print_params: n_layer:   32
print_params: norm_rms_eps          : 0.000010
print_params: rope_freq_base        : 10000.000000
print_params: rope_freq_scale       : 1.000000
```
With this comit the output will look like this:
```console
print_params: n_vocab               : 32000
print_params: n_ctx                 : 128
print_params: n_embd                : 4096
print_params: n_ff                  : 11008
print_params: n_head                : 32
print_params: n_head_kv             : 32
print_params: n_layer               : 32
print_params: norm_rms_eps          : 0.000010
print_params: rope_freq_base        : 10000.000000
print_params: rope_freq_scale       : 1.000000
```

Signed-off-by: Daniel Bevenius <daniel.bevenius@gmail.com>

llama : add AWQ for llama, llama2, mpt, and mistral models (#4593)

* update: awq support llama-7b model

* update: change order

* update: benchmark results for llama2-7b

* update: mistral 7b v1 benchmark

* update: support 4 models

* fix: Readme

* update: ready for PR

* update: readme

* fix: readme

* update: change order import

* black

* format code

* update: work for bot mpt and awqmpt

* update: readme

* Rename to llm_build_ffn_mpt_awq

* Formatted other files

* Fixed params count

* fix: remove code

* update: more detail for mpt

* fix: readme

* fix: readme

* update: change folder architecture

* fix: common.cpp

* fix: readme

* fix: remove ggml_repeat

* update: cicd

* update: cicd

* uppdate: remove use_awq arg

* update: readme

* llama : adapt plamo to new ffn

ggml-ci

---------

Co-authored-by: Trần Đức Nam <v.namtd12@vinai.io>
Co-authored-by: Le Hoang Anh <v.anhlh33@vinai.io>
Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

gpt2 : Add gpt2 architecture integration (#4555)

Fix OpenAI server sampling w.r.t. temp and seed (#4668)

The default values for tfs_z and typical_p were being set to zero, which
caused the token candidates array to get shrunk down to one element thus
preventing any sampling. Note this only applies to OpenAI API compatible
HTTP server requests.

The solution is to use the default values that OpenAI documents, as well
as ensuring we use the llama.cpp defaults for the rest. I've tested this
change still ensures deterministic output by default. If a "temperature"
greater than 0 is explicitly passed, then output is unique each time. If
"seed" is specified in addition to "temperature" then the output becomes
deterministic once more.

See mozilla-Ocho/llamafile#117
See mozilla-Ocho/llamafile@9e4bf29

scripts : do not sync commits from this repo

ggml : fix some mul mat cases + add tests for src1 F16 (ggml/669)

* fixed mul-mat error for old GPUs

* style fixes

* add mul mat src1 f16 test cases, fix more cases

ggml-ci

---------

Co-authored-by: bssrdf <bssrdf@gmail.com>
Co-authored-by: slaren <slarengh@gmail.com>

sync : ggml

ci : build with CLBlast + ggml-opencl use GGML_API (whisper/1576)

* Build with CLBlast

* Declare GGML_API

After rebasing, examples/talk-llama failed:

"D:\a\whisper.cpp\whisper.cpp\build\ALL_BUILD.vcxproj" (build target) (1) ->
"D:\a\whisper.cpp\whisper.cpp\build\examples\talk-llama\talk-llama.vcxproj" (default target) (14) ->
(Link target) ->
  llama.obj : error LNK2019: unresolved external symbol ggml_cl_free_data referenced in function "public: __cdecl llama_model::~llama_model(void)" (??1llama_model@@QEAA@XZ) [D:\a\whisper.cpp\whisper.cpp\build\examples\talk-llama\talk-llama.vcxproj]
  llama.obj : error LNK2019: unresolved external symbol ggml_cl_transform_tensor referenced in function "public: void __cdecl llama_model_loader::load_all_data(struct ggml_context *,void (__cdecl*)(float,void *),void *,struct llama_mlock *)" (?load_all_data@llama_model_loader@@QEAAXPEAUggml_context@@P6AXMPEAX@Z1PEAUllama_mlock@@@Z) [D:\a\whisper.cpp\whisper.cpp\build\examples\talk-llama\talk-llama.vcxproj]
  D:\a\whisper.cpp\whisper.cpp\build\bin\Release\talk-llama.exe : fatal error LNK1120: 2 unresolved externals [D:\a\whisper.cpp\whisper.cpp\build\examples\talk-llama\talk-llama.vcxproj]

scripts : print list of sync commits

llama.swiftui : fix infinite loop, ouput timings, buff UI (#4674)

* fix infinite loop

* slight UI simplification, clearer UX

* clearer UI text, add timings to completion log

main-cmake-pkg : fix build issue (#4665)

* Fix main-cmake-pkg compilation

* Use glob to load common files

* cmake : fix trailing whitespace

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

server : allow to generate multimodal embeddings (#4681)

server : fix OpenAI server sampling w.r.t. penalty. (#4675)

server : replace sleep with condition variables (#4673)

The server currently schedules tasks using a sleep(5ms) busy loop. This
adds unnecessary latency since most sleep implementations do a round up
to the system scheduling quantum (usually 10ms). Other libc sleep impls
spin for smaller time intervals which results in the server's busy loop
consuming all available cpu. Having the explicit notify() / wait() code
also helps aid in the readability of the server code.

See mozilla-Ocho/llamafile@711344b

llava-cli : refactor to use sampling library (#4669)

This change makes it possible to use flags like `--grammar` when using
the `llava-cli` program. The rest is just code cleanup deleting a long
standing TODO comment.

This change also ensures that logging information is emitted to stderr
which helps the `llava-cli` command be more friendly to shell scripts.

See Mozilla-Ocho/llamafile@1cd334f

cmake : fix ld warning duplicate libraries libllama.a (#4671)

* fix "ld: warning: ignoring duplicate libraries: '../libllama.a'"

* fix warning in example.

flake.nix : rewrite (#4605)

* flake.lock: update to hotfix CUDA::cuda_driver

Required to support https://github.com/ggerganov/llama.cpp/pull/4606

* flake.nix: rewrite

1. Split into separate files per output.

2. Added overlays, so that this flake can be integrated into others.
   The names in the overlay are `llama-cpp`, `llama-cpp-opencl`,
   `llama-cpp-cuda`, and `llama-cpp-rocm` so that they fit into the
   broader set of Nix packages from [nixpkgs](https://github.com/nixos/nixpkgs).

3. Use [callPackage](https://summer.nixos.org/blog/callpackage-a-tool-for-the-lazy/)
   rather than `with pkgs;` so that there's dependency injection rather
   than dependency lookup.

4. Add a description and meta information for each package.
   The description includes a bit about what's trying to accelerate each one.

5. Use specific CUDA packages instead of cudatoolkit on the advice of SomeoneSerge.

6. Format with `serokell/nixfmt` for a consistent style.

7. Update `flake.lock` with the latest goods.

* flake.nix: use finalPackage instead of passing it manually

* nix: unclutter darwin support

* nix: pass most darwin frameworks unconditionally

...for simplicity

* *.nix: nixfmt

nix shell github:piegamesde/nixfmt/rfc101-style --command \
    nixfmt flake.nix .devops/nix/*.nix

* flake.nix: add maintainers

* nix: move meta down to follow Nixpkgs style more closely

* nix: add missing meta attributes

nix: clarify the interpretation of meta.maintainers

nix: clarify the meaning of "broken" and "badPlatforms"

nix: passthru: expose the use* flags for inspection

E.g.:

```
❯ nix eval .#cuda.useCuda
true
```

* flake.nix: avoid re-evaluating nixpkgs too many times

* flake.nix: use flake-parts

* nix: migrate to pname+version

* flake.nix: overlay: expose both the namespace and the default attribute

* ci: add the (Nix) flakestry workflow

* nix: cmakeFlags: explicit OFF bools

* nix: cuda: reduce runtime closure

* nix: fewer rebuilds

* nix: respect config.cudaCapabilities

* nix: add the impure driver's location to the DT_RUNPATHs

* nix: clean sources more thoroughly

...this way outPaths change less frequently,
and so there are fewer rebuilds

* nix: explicit mpi support

* nix: explicit jetson support

* flake.nix: darwin: only expose the default

---------

Co-authored-by: Someone Serge <sergei.kozlukov@aalto.fi>

python : add check-requirements.sh and GitHub workflow (#4585)

* python: add check-requirements.sh and GitHub workflow

This script and workflow forces package versions to remain compatible
across all convert*.py scripts, while allowing secondary convert scripts
to import dependencies not wanted in convert.py.

* Move requirements into ./requirements

* Fail on "==" being used for package requirements (but can be suppressed)

* Enforce "compatible release" syntax instead of ==

* Update workflow

* Add upper version bound for transformers and protobuf

* improve check-requirements.sh

* small syntax change

* don't remove venvs if nocleanup is passed

* See if this fixes docker workflow

* Move check-requirements.sh into ./scripts/

---------

Co-authored-by: Jared Van Bortel <jared@nomic.ai>

cuda: fix vmm oom issue on NVIDIA AGX Orin (#4687)

Signed-off-by: hydai <hydai@secondstate.io>

clip : enable gpu backend (#4205)

* clip: enable CUDA backend

* add missing kernels

* add enough padding for alignment

* remove ggml_repeat of clip.cpp

* add metal backend

* llava : fixes

- avoid ggml_repeat
- use GGML_USE_ instead of CLIP_USE_ macros
- remove unused vars

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

clip : use ggml_backend_buffer_is_host (#4205)

CUDA: fix tensor core logic for Pascal and HIP (#4682)

ggml : add ggml_cpu_has_avx_vnni() (#4589)

* feat: add avx_vnni based on intel documents

* ggml: add avx vnni based on intel document

* llama: add avx vnni information display

* docs: add more details about using oneMKL and oneAPI for intel processors

* docs: add more details about using oneMKL and oneAPI for intel processors

* docs: add more details about using oneMKL and oneAPI for intel processors

* docs: add more details about using oneMKL and oneAPI for intel processors

* docs: add more details about using oneMKL and oneAPI for intel processors

* Update ggml.c

Fix indentation upgate

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

CUDA: fixed tensor cores not being used on RDNA3 (#4697)

clip : refactor + bug fixes (#4696)

* clip : refactor + bug fixes

ggml-ci

* server : add log message

ggml : add ggml_vdotq_s32 alias (#4715)

ggml-ci

flake.nix: expose full scope in legacyPackages

flake.nix: rocm not yet supported on aarch64, so hide the output

flake.nix: expose checks

workflows: nix-ci: init; build flake outputs

workflows: nix-ci: add a job for eval

workflows: weekly `nix flake update`

workflows: nix-flakestry: drop tag filters

...and add a job for flakehub.com

workflows: nix-ci: add a qemu job for jetsons

flake.nix: suggest the binary caches

flake.lock: update

to a commit recently cached by nixpkgs-cuda-ci

metal : enable shader debugging (cmake option) (#4705)

* ggml : disable fast-math for Metal (cmake build only)

ggml-ci

* metal : fix Metal API debug warnings

* cmake : add -fno-inline for Metal build (#4545)

* metal : fix API debug warnings

* metal : fix compile warnings

* metal : use uint64_t for strides

* cmake : rename option to LLAMA_METAL_SHADER_DEBUG

* metal : fix mat-vec Q8_0 kernel for BS > 1

* metal : normalize mat-vec kernel signatures

* cmake : respect LLAMA_QKK_64 option

* metal : fix mat-vec Q4_K kernel for QK_K == 64

ggml-ci

finetune: fix typo in README.md (#4733)

Signed-off-by: Daniel Bevenius <daniel.bevenius@gmail.com>

py : re-enable mmap in convert hf (#4732)

* update: awq support llama-7b model

* update: change order

* update: benchmark results for llama2-7b

* update: mistral 7b v1 benchmark

* update: support 4 models

* fix: Readme

* update: ready for PR

* update: readme

* fix: readme

* update: change order import

* black

* format code

* update: work for bot mpt and awqmpt

* update: readme

* Rename to llm_build_ffn_mpt_awq

* Formatted other files

* Fixed params count

* fix: remove code

* update: more detail for mpt

* fix: readme

* fix: readme

* update: change folder architecture

* fix: common.cpp

* fix: readme

* fix: remove ggml_repeat

* update: cicd

* update: cicd

* uppdate: remove use_awq arg

* update: readme

* llama : adapt plamo to new ffn

ggml-ci

* fix: update torch version

---------

Co-authored-by: Trần Đức Nam <v.namtd12@vinai.io>
Co-authored-by: Le Hoang Anh <v.anhlh33@vinai.io>
Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

server : add --override-kv parameter (#4710)

* Changes to server to allow metadata override

* documentation

* flake.nix: expose full scope in legacyPackages

* flake.nix: rocm not yet supported on aarch64, so hide the output

* flake.nix: expose checks

* workflows: nix-ci: init; build flake outputs

* workflows: nix-ci: add a job for eval

* workflows: weekly `nix flake update`

* workflows: nix-flakestry: drop tag filters

...and add a job for flakehub.com

* workflows: nix-ci: add a qemu job for jetsons

* flake.nix: suggest the binary caches

* flake.lock: update

to a commit recently cached by nixpkgs-cuda-ci

---------

Co-authored-by: John <john@jLap.lan>
Co-authored-by: Someone Serge <sergei.kozlukov@aalto.fi>

editorconfig : fix whitespace and indentation #4710

llama : differentiate the KV dims in the attention (#4657)

* Add n_key_dim and n_value_dim

Some models use values that are not derived from `n_embd`.
Also remove `n_embd_head` and `n_embd_gqa` because it is not clear
which "head" is referred to (key or value).

Fix issue #4648.

* Fix `llm_build_kqv` to use `n_value_gqa`

* Rebase

* Rename variables

* Fix llm_build_kqv to be more generic wrt n_embd_head_k

* Update default values for n_embd_head_k and n_embd_head_v

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

* Fix llm_load_tensors: the asserts were not backcompat

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

llama : replace all API facing `int`'s with `int32_t` (#4577)

* replaced all API facing `int`'s with `int32_t`

* formatting and missed `int` in `llama_token_to_piece`

llama : llama_model_desc print number of experts

server : add token counts to html footer (#4738)

* server: add token counts to stats

* server: generate hpp

---------

Co-authored-by: phiharri <ph@got-root.co.uk>

metal : optimize ggml_mul_mat_id (faster Mixtral PP) (#4725)

* ggml : disable fast-math for Metal (cmake build only)

ggml-ci

* metal : fix Metal API debug warnings

* cmake : add -fno-inline for Metal build (#4545)

* metal : fix API debug warnings

* metal : fix compile warnings

* metal : use uint64_t for strides

* cmake : rename option to LLAMA_METAL_SHADER_DEBUG

* metal : fix mat-vec Q8_0 kernel for BS > 1

* metal : normalize mat-vec kernel signatures

* cmake : respect LLAMA_QKK_64 option

* metal : fix mat-vec Q4_K kernel for QK_K == 64

* metal : optimizing ggml_mul_mat_id (wip)

* metal : minor fix

* metal : opt mul_mm_id

server : throw an error when `slot unavailable` (#4741)

ggml : extend ggml_get_rows, ggml_repeat, ggml_concat (ggml/639)

* add more int ops

* ggml_compute_forward_dup_bytes

* add tests

* PR comments

* tests : minor indentations

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

scripts : fix sync order + metal sed

metal : add kernel_get_rows_i32

ggml-ci

sync : ggml

ggml-ci

cuda : mark I16 and I32 ops as unsupported

ggml-ci

cuda : simplify expression

Co-authored-by: slaren <slarengh@gmail.com>

swift : update Package.swift to use ggml as dependency (#4691)

* updates the package.swift to use ggml as dependency

* changes the ggml package url src to ggerganov

train : fix typo in overlapping-samples help msg (#4758)

This commit fixes a typo in the help message for the
--overlapping-samples option.

Signed-off-by: Daniel Bevenius <daniel.bevenius@gmail.com>

llama.swiftui : fix build of ggml.metallib (#4754)

* metal: fix metal backend init failure in swiftui

* metal: build ggml.metallib instead of copy src

* llama.swift : remove debug flags from metallib build

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

ggml : include stdlib.h before intrin.h (#4736)

server : fix options in README.md (#4765)

* fix examples/server/README.md

* minor : fix whitespace

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

llama.swiftui : support loading custom model from file picker (#4767)

* swiftui: support load model from file picker

* swiftui: remove trailing whitespace

Print backend name on test-backend-ops failure (#4751)

server : send token probs for "stream == false" (#4714)

finetune : remove unused includes (#4756)

This commit removes unused includes from finetune.cpp.

Signed-off-by: Daniel Bevenius <daniel.bevenius@gmail.com>

examples : add few-shot translation example (#4783)

ggml : do not sched_yield when calling BLAS (#4761)

* ggml : do not sched_yield when calling BLAS

ggml-ci

* ggml : fix do_yield logic

ggml-ci

* ggml : simplify do_yield logic

ggml-ci

ggml : add error handling to graph_compute (whisper/1714)

ggml : fix q2_k bpw in comments (ggml/680)

metal : switch back to default.metallib (ggml/681)

ggml-ci

flake.nix : fix typo (#4700)

betwen -> between

cmake : check for openblas64 (#4134)

openblas v0.3.22 64-bit pkg-config file is named openblas64.pc
https://github.com/OpenMathLib/OpenBLAS/issues/3790

examples : improve base-translate.sh script (#4783)

llama.swiftui : use correct pointer for llama_token_eos (#4797)

server : fix n_predict check (#4798)

ggml : use __builtin_amdgcn_sudot4 in __dp4a for gfx11 (#4787)

llama.swiftui : add visionOS target (#4805)

llama : print tensor meta for debugging

llama.swiftui : use llama.cpp as SPM package (#4804)

llama : remove redundant GQA check (#4796)

llama : remove unused vars (#4796)

CUDA: fixed redundant value dequantization (#4809)

llama-bench : add no-kv-offload parameter (#4812)

readme : add lgrammel/modelfusion JS/TS client for llama.cpp (#4814)

examples : add passkey test (#3856)

* examples : add passkey test

* passkey : better prints

* passkey : select pass key pos from CLI

* passkey : simplify n_past logic

* make : add passkey target

* passkey : add "self-extend"-like context extension (#4810)

* llama : "self-extend"-like context extension

* passkey : add comment

* passkey : add readme

main : add self-extend support (#4815)

* examples : add passkey test

* passkey : better prints

* passkey : select pass key pos from CLI

* passkey : simplify n_past logic

* llama : "self-extend"-like context extension

* passkey : add comment

* main : add Self-Extend support

* llama : add comment about llama_kv_cache_seq_div

llama.swiftui : update readme

swift : exclude ggml-metal.metal from the package (#4822)

SOTA 2-bit quants (#4773)

* iq2_xxs: basics

* iq2_xxs: scalar and AVX2 dot products

Needed to change Q8_K to have quants in the -127...127 range,
else the IQ2_XXS AVX implementation becomes very awkward.
The alternative would have been to use Q8_0 instead. Perhaps
I'll change later, for now this is what we have.

* iq2_xxs: ARM_NEON dot product

Somehow strangely slow (112 ms/token).

* iq2_xxs: WIP Metal

Dequantize works, something is still wrong with the
dot product.

* iq2_xxs: Metal dot product now works

We have
PP-512 = 475 t/s
TG-128 = 47.3 t/s

Not the greatest performance, but not complete garbage either.

* iq2_xxs: slighty faster dot product

TG-128 is now 48.4 t/s

* iq2_xxs: slighty faster dot product

TG-128 is now 50.9 t/s

* iq2_xxs: even faster Metal dot product

TG-128 is now 54.1 t/s.

Strangely enough, putting the signs lookup table
into shared memory has a bigger impact than the
grid values being in shared memory.

* iq2_xxs: dequantize CUDA kernel - fix conflict with master

* iq2_xxs: quantized CUDA dot product (MMVQ)

We get TG-128 = 153.1 t/s

* iq2_xxs: slightly faster CUDA dot product

TG-128 is now at 155.1 t/s.

* iq2_xxs: add to llama ftype enum

* iq2_xxs: fix MoE on Metal

* Fix missing MMQ ops when on hipBLAS

I had put the ggml_supports_mmq call at the wrong place.

* Fix bug in qequantize_row_iq2_xxs

The 0.25f factor was missing.
Great detective work by @ggerganov!

* Fixing tests

* PR suggestion

---------

Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>

readme : add link to SOTA models

common : fix the short form of `--grp-attn-w`, not `-gat` (#4825)

See https://github.com/ggerganov/llama.cpp/blob/master/common/common.cpp#L230C53-L230C57

CUDA: faster softmax via shared memory + fp16 math (#4742)

ggml : fix vld1q_s8_x4 32-bit compat (#4828)

* ggml : fix vld1q_s8_x4 32-bit compat

ggml-ci

* ggml : fix 32-bit ARM compat (cont)

ggml-ci

server : add api-key flag to documentation (#4832)

Document the api-key flag added to server in https://github.com/ggerganov/llama.cpp/pull/4441

server : update readme about token probs (#4777)

* updated server readme to reflect the gg/server-token-probs-4088 commit

added explanation for the API's completion result which now includes `completion_probabilities`. Also added a JSON schema that shows the type/structure of `completion_probabilities`.

* simplified the `completion_probabilities` JSON schema

It's now easier to understand what the structure of `completion_probabilities` looks like.

* minor : fix trailing whitespace

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

scripts : script to get Paul Graham essays in txt format (#4838)

readme : add 3rd party collama reference to UI list (#4840)

Add a VSCode extension for llama.cpp reference to UI list

scripts : improve get-pg.sh (#4838)

metal : improve dequantize precision to match CPU (#4836)

ggml-ci

llava-cli : don't crash if --image flag is invalid (#4835)

This change fixes an issue where supplying `--image missing-file` would
result in a segfault due to a null pointer being dereferenced. This can
result in distracting info being printed if robust crash analysis tools
are being used.

convert.py : fix vanilla LLaMA model conversion (#4818)

* Update Imports and Add Notes for Future Reference

- Updated import statements in `convert.py`.
- Added import for `AutoTokenizer` from `transformers` module.
- Added conditional import for `gguf` from the local directory.
- Added comments and notes for future reference.

Additional Notes:

- Noted removal of a redundant `TypeAlias` import.
- Noted the removal of a `gguf` debug statement.
- Commented on the presence of `ARCH` and `NDArray` definitions.
- Commented on cleaning up and refactoring data type definitions.

* Refine Model Hyperparameters and Params Class

- Updated type annotations to use `Optional` for clarity.
- Improved method names and attribute consistency.
- Removed unnecessary variables for better code readability.

Additional Notes:

- Highlighted the use of `Optional` for clearer intent.
- Ensured backward and forward compatibility.

* Restore BpeVocab and SentencePieceVocab classes

- Restored the BpeVocab class for handling BPE tokenization.
- Restored the SentencePieceVocab class for SentencePiece tokenization.

These classes are essential for maintaining the original behavior of the codebase.

* refactor: Standardize vocabulary handling with HfVocab

- Replaced VocabLoader with HfVocab, aligning vocabulary handling across classes.
- Updated initialization of HfVocab with local_files_only=True for AutoTokenizer.
- Introduced optional parameter fname_added_tokens for flexible added token management.
- Streamlined added token handling for clarity and conciseness.
- Maintained special tokens and IDs, enhancing token management.
- Simplified token processing methods for improved readability.
- Added a placeholder for score computation with a default value of -1000.0.
- Optimized newline token check for efficiency.
- Updated __repr__ function for clarity in representation.
- Adjusted type alias Vocab to include BpeVocab, SentencePieceVocab, and HfVocab.
- Removed redundant code related to special token handling, reverse vocabulary mapping, and vocabulary file detection.

This refactoring promotes a standardized and modular approach to vocabulary management, facilitating future integration with a VocabFactory and improving code maintainability and scalability.

* refactor: Enhance readability, functionality, and code quality

- Improved code formatting and readability for better maintainability.
- Refactored LazyUnpickler's CLASSES dictionary for clarity.
- Added print statements and warnings in check_vocab_size for user feedback.
- Removed find_vocab_file_path, as it's superseded by VocabFactory.
- Preparatory changes for upcoming classes: OutputFile and VocabFactory.
- Overall focus on code quality, error handling, and consistency.

These changes reflect a continuous effort to refine the codebase, ensuring it meets best practices and prepares for future enhancements, such as the VocabFactory.

* refactor: Update OutputFile class for enhanced model vocabulary management

- Restructured the constructor for improved readability.
- Updated `add_meta_arch` method for flexible model name determination.
- Introduced `handle_tokenizer_model` for mapping vocab types to supported tokenizer models.
- Streamlined vocabulary extraction with `extract_vocabulary_from_model`.
- Simplified vocabulary metadata addition using `add_meta_vocab`.
- Refactored `add_tensor_info` for clarity and consistency.
- Improved error handling for better user feedback.

These changes signify the development of a versatile and comprehensive `OutputFile` class, enabling efficient management of model conversion output, metadata, vocabulary, and tensor information.

* feat: Introduce VocabFactory for flexible vocabulary management in model conversion

- The VocabFactory class is added to facilitate modular vocabulary handling.
- The constructor initializes a directory path and detects vocabulary-related files.
- The _select_file method provides file paths based on vocabulary type (e.g., BPE, SentencePiece).
- _create_special_vocab generates special vocabularies, accommodating different types.
- The load_vocab method loads vocabularies, handling BPE, SentencePiece, and Hugging Face Fast Tokenizer.
- Error handling and logging enhance debugging and user feedback.
- The modular and flexible design simplifies vocabulary management and supports future extensions.

The VocabFactory class enhances code modularity and maintainability, allowing versatile vocabulary handling in the model conversion process.

* refactor: Improve code organization, argument parsing, and user interface

- Renamed 'default_outfile' to 'default_output_file' for clarity.
- Refactored argument parser setup into 'get_argument_parser' function.
- Introduced descriptive comments for each argument in the parser.
- Added '--vocab-type' argument with choices ["spm", "bpe", "hfft"] for vocabulary processing.
- Improved flag naming consistency: '--outfile' to '--out-file' and '--bigendian' to '--big-endian'.
- Enhanced error handling to prevent overwriting input data in 'default_output_file'.
- Made 'argv' in 'main' an optional parameter for flexibility.
- Introduced dynamic import for 'awq.apply_awq' based on 'args.awq_path' for conditional dependency.

These changes enhance code clarity, organization, and the user interface of the script, aligning it with Python best practices and improving maintainability.

* refactor: Further refine functionality, improve user interaction, and streamline vocabulary handling

- Renamed command-line arguments for clarity and consistency.
- Improved path resolution and import adjustments for robustness.
- Thoughtfully handled 'awq-path' and conditional logic for the weighted model.
- Enhanced model and vocabulary loading with the 'VocabFactory' class for structured and adaptable loading.
- Strengthened error handling and user feedback for a more user-friendly experience.
- Structured output file handling with clear conditions and defaults.
- Streamlined and organized the 'main' function for better logic flow.
- Passed 'sys.argv[1:]' to 'main' for adaptability and testability.

These changes solidify the script's functionality, making it more robust, user-friendly, and adaptable. The use of the 'VocabFactory' class is a notable enhancement in efficient vocabulary handling, reflecting a thoughtful and iterative approach to script development.

* chore: Apply ruff formatting to convert.py

Signed-off-by: teleprint-me <77757836+teleprint-me@users.noreply.github.com>

* Revert to commit 0614c33

* chore: Apply flake8 formatting rules

Signed-off-by: teleprint-me <77757836+teleprint-me@users.noreply.github.com>

* refactor: Revise `check_vocab_size` for Enhanced Clarity and Correctness

- Resolved an unreachable branch issue by reorganizing the conditional structure.
- Moved the special case check for `params.n_vocab == -1` to the top for immediate assertion.
- Flattened the conditional logic for improved clarity and predictability of the function's behavior.

These changes enhance the readability and functional correctness of the `check_vocab_size` function without altering its intended functionality.

* py : fix outfile and outtype

* py : suggest hint for missing vocab size

---------

Signed-off-by: teleprint-me <77757836+teleprint-me@users.noreply.github.com>
Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

Python script to compare commits with llama-bench (#4844)

clip : support more quantization types (#4846)

Uses ggml functions instead of hardcoded names and adds support to quantize into the modern Q-K variants.
This is just the bare minimum to get k-types working - a more refined choice of types would be needed to get best quality on low quantizations.

I ran a few tests, it doesn't break anything I could notice and a Q6_K ViT works almost as well as Q8_0 but 3 times the inference speed.

llama : recognize 1B phi models (#4847)

This update categorizes models with 24 layers as MODEL_1B, ensuring compatibility with different Phi model variants without impacting existing Phi-2 model functionality.

llama : add additional suffixes for model params (#4834)

* llm_load_print_meta: Add additional suffixs for model params

* Update llama.cpp model param log

remove unneeded comments and convert from > to >=

server : add a `/health` endpoint (#4860)

* added /health endpoint to the server

* added comments on the additional /health endpoint

* Better handling of server state

When the model is being loaded, the server state is `LOADING_MODEL`. If model-loading fails, the server state becomes `ERROR`, otherwise it becomes `READY`. The `/health` endpoint provides more granular messages now according to the server_state value.

* initialized server_state

* fixed a typo

* starting http server before initializing the model

* Update server.cpp

* Update server.cpp

* fixes

* fixes

* fixes

* made ServerState atomic and turned two-line spaces into one-line

server : fix build + rename enums (#4870)

server : update readme to document the new `/health` endpoint (#4866)

* added /health endpoint to the server

* added comments on the additional /health endpoint

* Better handling of server state

When the model is being loaded, the server state is `LOADING_MODEL`. If model-loading fails, the server state becomes `ERROR`, otherwise it becomes `READY`. The `/health` endpoint provides more granular messages now according to the server_state value.

* initialized server_state

* fixed a typo

* starting http server before initializing the model

* Update server.cpp

* Update server.cpp

* fixes

* fixes

* fixes

* made ServerState atomic and turned two-line spaces into one-line

* updated `server` readme to document the `/health` endpoint too

fix : cuda order of synchronization when setting a buffer (ggml/679)

* fix : cuda order of synchronization when setting a buffer

* also sync before memcpy

---------

Co-authored-by: slaren <slarengh@gmail.com>

Fix execlp call (ggml/689)

NULL can be an integer constant expression with the value zero, in this case the behavior would be undefined because of an incorrect type being passed to the variable arguments.

ggml : change GGML_MAX_NAME at compile time (ggml/682)

* change GGML_MAX_NAME to 128

* allow controlling the value of GGML_MAX_NAME through external macro definitions

metal : wrap each operation in debug group (ggml/690)

ggml : remove ggml_cpy_inplace and ggml_cont_inplace (ggml/693)

metal : fix deprecation warning (ggml/690)

sync : ggml

metal : put encoder debug group behind a define (#4873)

server : fix typo in model name (#4876)

main : print total token count and tokens consumed so far (#4874)

* Token count changes

* Add show token count

* Updating before PR

* Two requested changes

* Move param def posn

ci: nix-flake-update: new token with pr permissions (#4879)

* ci: nix-flake-update: new token with pr permissions

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

server : add `LOG_INFO` when model is successfully loaded (#4881)

* added /health endpoint to the server

* added comments on the additional /health endpoint

* Better handling of server state

When the model is being loaded, the server state is `LOADING_MODEL`. If model-loading fails, the server state becomes `ERROR`, otherwise it becomes `READY`. The `/health` endpoint provides more granular messages now according to the server_state value.

* initialized server_state

* fixed a typo

* starting http server before initializing the model

* Update server.cpp

* Update server.cpp

* fixes

* fixes

* fixes

* made ServerState atomic and turned two-line spaces into one-line

* updated `server` readme to document the `/health` endpoint too

* used LOG_INFO after successful model loading

server : support for multiple api keys (#4864)

* server: added support for multiple api keys, added loading api keys from file

* minor: fix whitespace

* added file error handling to --api-key-file, changed code to better
reflect current style

* server: update README.md for --api-key-file

---------

Co-authored-by: Michael Coppola <info@michaeljcoppola.com>

server : implement credentialed CORS (#4514)

* Implement credentialed CORS according to MDN

* Fix syntax error

* Move validate_api_key up so it is defined before its first usage

swift : pin ggml commit + remove ggml.h from spm-headers (#4878)

ggml-ci

ggml : SOTA 2-bit quants (add IQ2_XS) (#4856)

* iq2_xs: basics

* iq2_xs: this should have been in the basics

* iq2_xs: CUDA and scalar CPU works

* iq2_xs: WIP Metal

* iq2_xs: Metal now works

* iq2_xs: working, but dog slow, ARM_NEON dot product

* iq2_xs: better ARM_NEON dot product

We are now at 19.5 t/s for TG-128 and 61 t/s for PP-512 when
running on the CPU.

* iq2_xs: AVX2 dot product - 19.5 t/s

* iq2_xs: faster AVX2 dit product

21.4 t/s for TG-128, 59.2 t/s for PP-512.
The latter is 2x compared to the previous version.

* iq2_xs: had forgotten to delete iq2-data.h

* Add llama enum for IQ2_XS

---------

Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>

llama : restore intended k-quants mixes for MoE models (#4872)

* Restore intended k-quants quantization mixes for MoE models

* Update Q2_K_S values in the quantize tool

Still using LLaMA-v1 PPL values in the quant description
today does not make much sense. But let's leave this update
for another PR.

---------

Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

swift : track ggml release branch (#4867)

main : disable token count by default (#4874)

main : better name for variable n_print (#4874)

server : fix infill when prompt is empty (#4833)

Importance Matrix calculation (#4861)

* imatrix: 1st version

* imatrix: WIP

* Cleanup

* Update examples/imatrix/imatrix.cpp

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

---------

Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

llama : fix llm_build_k_shift to use correct n_rot (#4889)

* llama : fix llm_build_k_shift to use correct n_rot

ggml-ci

* llama : always use hparams.n_rot for ggml_rope_custom

ggml-ci

* convert : fix persimmon conversion to write correct n_rot

py : fix lint (#4889)

common : streamline the formatting of help (#4890)

* common : streamline the formatting of help

- Separate alternative parameters by a comma

- Do not indent `--version` differently

* Update common/common.cpp

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

llama : fix typo "imp_embd" -> "inp_embd"

CUDA: fix softmax compile for old CUDA versions (#4862)

gitignore : imatrix

llama.swiftui : update models layout (#4826)

* Updated Models Layout

- Added a models drawer
- Added downloading directly from Hugging Face
- Load custom models from local folder
- Delete models by swiping left

* trimmed trailing white space

* Updated Models Layout

export-lora : use LLAMA_FILE_MAGIC_GGLA (#4894)

This commit replaces the magic number used in export-lora.cpp with
the one defined in llama.h, which is indirectly included via common.h.

Signed-off-by: Daniel Bevenius <daniel.bevenius@gmail.com>

llama : remove redundant assert for StableLM (#4901)

llama : ggml-backend integration (#4766)

* llama : ggml-backend integration

* ggml-backend : add names to buffers

* fix unmap after loading

* batched-bench : add tensor_split param

* llama : check for null tensor_split

* ggml-backend : increase GGML_MAX_BACKENDS

* improve graph splitting, partial fix for --no-kv-offload

* cuda : add ggml-backend split buffer support

* cuda : do not create buffer types for devices that don't exist (fixes usage without CUDA devices available)

* ggml : fix null backend dereference (#4807)

* ggml : fix null backend dereference

* ggml : also check ggml_backend_is_cpu

* test-backend-ops : check buffer allocation failures

* llama : add cparam (split_mode) and command line argument (--split-mode, -sm) to configure the split mode (none, layer or row)

* ggml : fix mul_mat_id work size

* llama : rewrite session kv load/set without graphs

* minor

* llama : only initialize used backends, free backends on context free

* llama : abort ctx if cuda backend init fails

* llama : rewrite lora with ggml-backend and compute on CPU

ggml-ci

* llama : only map to a backend buffer the region of the file mapping containing the tensors used in the buffer

* opencl : add ggml-backend buffer type

* cuda : only use batched_cublas with batched mat muls (fixes fp16 tg perf)

* llama : on Metal, by default offload the full model

ggml-ci

* metal : page align the data ptr (#4854)

* Apply suggestions from code review

Co-authored-by: Johannes Gäßler <johannesg@5d6.de>

* cuda : fix split buffer free

* address review comments

* llama-bench : add split-mode parameter

* fix whitespace

* opencl : fix double initialization

* server : add --split-mode parameter

* use async copy and compute to improve multi-gpu performance

ggml-ci

* use async memcpys to copy the graph outputs to the CPU

* fix opencl

* use a host buffer for the cpu compute buffer for faster copies to the gpu

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
Co-authored-by: Johannes Gäßler <johannesg@5d6.de>

CUDA: faster q8_0 -> f16 dequantization (#4895)

examples : add pydantic models to GBNF grammar generator (#4883)

* Create pydantic-models-to-grammar.py

* Added some comments for usage

* Refactored Grammar Generator

Added example and usage instruction.

* Update pydantic_models_to_grammar.py

* Update pydantic-models-to-grammar-examples.py

* Renamed module and imported it.

* Update pydantic-models-to-grammar.py

* Renamed file and fixed grammar generator issue.

backend_sched : fix assignments

ggml-ci

ggml : fix 32-bit ARM compat for IQ2_XS (whisper/1758)

* ggml : fix 32-bit ARM compat

* ggml : fix fix

* ggml : fix fix fix

sync : ggml

convert : update phi-2 to latest HF repo (#4903)

* convert : update phi-2 to latest HF repo

ggml-ci

* py : try to fix flake stuff

server : fix crash with multimodal models without BOS token (#4904)

server : fix deadlock that occurs in multi-prompt scenarios (#4905)

* * fix deadlock

* * dont ruint all whitespace

compare-llama-bench: tweak output format (#4910)

metal : refactor kernel loading code (#4794)

* metal : detect more GPU families

* metal : refactor kernel loading

* metal : set kernel family requirements

* metal : fix kernel init + fix compile options

* metal : take into account simdgroup reduction support

* metal : print only skipped kernels

* metal : fix check for simdgroup reduction support

* metal : check for Metal 3

* metal : free allocations

* metal : normalize encoder:setComputePipelineStatus calls

ggml-ci

* metal : fix Metal3 family check

ggml-ci

* metal : check for simdgroup matrix mul. feature

ggml-ci

gguf : fix potential infinite for-loop (#4600)

Co-authored-by: Bernhard Gstrein <gstrein@informatik.uni-freiburg.de>

main : add parameter --no-display-prompt (#4541)

* add the parameter : --no-display-prompt , combine with --log-disable it will display only the generated tokens

* remove empty line

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

workflows: unbreak nix-build-aarch64, and split it out (#4915)

The fix should be just the `sudo apt-get update`

llama : minimize size used for state save/load (#4820)

* examples : save-load-state: save only required state

* llama : only reserve n_vocab * n_batch at most for logits

llama_decode asserts that only n_batch tokens are passed each call, and
n_ctx is expected to be bigger than n_batch.

* llama : always reserve n_vocab * n_batch for logits

llama_context de-serialization breaks if the contexts have differing
capacity for logits and llama_decode will at maximum resize to
n_vocab * n_batch.

* llama : only save and restore used logits

for batch sizes of 512 this reduces save state in the best case by
around 62 MB, which can be a lot if planning to save on each message
to allow regenerating messages.

* llama : use ostringstream and istringstream for save and load

* llama : serialize rng into minimum amount of space required

* llama : break session version due to serialization changes

metal : disable log for loaded kernels (#4794)

llama : fix detokenization of non-special added-tokens (#4916)

Co-authored-by: goerch <jhr.walter@t-online.de>

server : fix prompt caching with system prompt (#4914)

metal : remove old API (#4919)

ggml-ci

ggml: cache sin/cos for RoPE (#4908)

sync : ggml

Make Q3_K_S be the same as olf Q3_K_L for Mixtral-8x7B (#4906)

Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>

2-bit quantizations (#4897)

* imatrix: load

* imatrix: WIP

* imatrix: Add Q2_K quantization

* imatrix: also guard against Q2_K_S quantization without importance matrix

* imatrix: guard even more against low-bit quantization misuse

---------

Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>

llama : support WinXP build with MinGW 8.1.0 (#3419)

metal : correctly set SIMD support flags on iOS (#4923)

* Correctly set support_simdgroup_reduction and support_simdgroup_mm on iPhone/iPad

* log a little bit more info on iOS

Fix ffn_down quantization mix for MoE models (#4927)

* Fix ffn_down quantization mix for MoE models

In #4872 I did not consider the part where every third
tensor is quantized with more bits. Fir MoE this leads to tensors
of the same layer being quantized with different number of bits,
which is not considered as a possibility in the inference implementation
(it is assumed all experts use the same quantization).

* Fix the fix

* Review suggestion

---------

Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>

llama : use LLAMA_LOG_ macros for logging

scripts : sync-ggml-am.sh option to skip commits

llama : check LLAMA_TRACE env for extra logging (#4929)

* llama : minor fix indent

* llama : check LLAMA_TRACE env for extra logging

ggml-ci

Add ability to use importance matrix for all k-quants (#4930)

Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>

llama : fix missing quotes (#4937)

CUDA: faster dequantize kernels for Q4_0 and Q4_1 (#4938)

Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>

llama : check for 256 divisibility for IQ2_XS, IQ2_XXS (#4950)

Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>

cuda : fix dequantize kernel names (#4938)

awq-py : fix typo in awq-py/README.md (#4947)

llama : apply classifier-free guidance to logits directly (#4951)

pass cpu-architecture arguments only to host code (C;C++) (#4943)

speculative : threading options (#4959)

* speculative: expose draft threading

* fix usage format

* accept -td and -tbd args

* speculative: revert default behavior when -td is unspecified

* fix trailing whitespace

finetune : use LLAMA_FILE_MAGIC_GGLA (#4961)

This commit replaces the magic number LLAMA_FILE_MAGIC_LORA used in
finetune.cpp with LLAMA_FILE_MAGIC_GGLA defined in llama.h.

Signed-off-by: Daniel Bevenius <daniel.bevenius@gmail.com>

ggml : introduce GGML_CALL function annotation (#4850)

This change makes it possible to build ggml-cuda.cu and ggml-metal.m as
independent dynamic shared objects, that may be conditionally linked at
runtime in a multiplatform binary. It introduces a GGML_CALL annotation
that documents which functions have a cyclic call relationship, between
the application code and GPU modules.

This change does nothing, unless the build defines -DGGML_MULTIPLATFORM
which causes back-references and function pointers to conform to MS ABI
which is supported by NVCC, ROCm, XCode, GCC and Clang across platforms

examples : fix and improv docs for the grammar generator (#4909)

* Create pydantic-models-to-grammar.py

* Added some comments for usage

* Refactored Grammar Generator

Added example and usage instruction.

* Update pydantic_models_to_grammar.py

* Update pydantic-models-to-grammar-examples.py

* Renamed module and imported it.

* Update pydantic-models-to-grammar.py

* Renamed file and fixed grammar generator issue.

* Fixed some issues and bugs of the grammar generator. Imporved Documentation

* Update pydantic_models_to_grammar.py

metal : log `recommendedMaxWorkingSetSize` on iOS 16+ (#4936)

* metal: Log `recommendedMaxWorkingSetSize` on iOS 16+

* Only log on iOS and macOS, ignoring tvOS and other platforms

* Check for Xcode version before using recommendedMaxWorkingSetSize

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

metal : replace loop of dispatch_async with dispatch_apply (#4934)

* Replace loop of dispatch_async with dispatch_apply

* Update ggml-metal.m

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

android : introduce starter project example (#4926)

* Introduce starter project for Android

Based on examples/llama.swiftui.

* Add github workflow

* Set NDK version

* Only build arm64-v8a in CI

* Sync bench code

* Rename CI prop to skip-armeabi-v7a

* Remove unused tests

metal : localized logic in `ggml_metal_graph_compute` (#4924)

* Metal: Localized logic in `ggml_metal_graph_compute`, minor performance improvement

* Whitespace

* Collecting command buffer completions on single t…
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
high priority Very important issue need feedback Testing and feedback with results are needed
Projects
Status: No status
Development

Successfully merging this pull request may close these issues.

Multi GPU CUDA - 8x performance degradation when splitting tensors -> let's split by layer as an option