From 93864cda8a024227dec297085f6152662738ea0d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Stanis=C5=82aw=20Szymczyk?= Date: Wed, 22 Jan 2025 15:19:34 +0100 Subject: [PATCH 1/9] llama : experimental DeepSeek2 MLA implementation that caches latent kv representations --- src/llama-kv-cache.cpp | 16 ++++++- src/llama-kv-cache.h | 7 +++ src/llama.cpp | 99 +++++++++++++++++++++++++++++++++++------- 3 files changed, 106 insertions(+), 16 deletions(-) diff --git a/src/llama-kv-cache.cpp b/src/llama-kv-cache.cpp index 90b6c56ed068c..99fd1d8df1c32 100644 --- a/src/llama-kv-cache.cpp +++ b/src/llama-kv-cache.cpp @@ -53,7 +53,7 @@ bool llama_kv_cache_init( auto it = ctx_map.find(buft); if (it == ctx_map.end()) { struct ggml_init_params params = { - /*.mem_size =*/ size_t(2u*n_layer*ggml_tensor_overhead()), + /*.mem_size =*/ size_t(4u*n_layer*ggml_tensor_overhead()), /*.mem_buffer =*/ NULL, /*.no_alloc =*/ true, }; @@ -71,6 +71,10 @@ bool llama_kv_cache_init( cache.k_l.reserve(n_layer); cache.v_l.reserve(n_layer); + // DeepSeek MLA + cache.kr_l.reserve(n_layer); + cache.kv_l.reserve(n_layer); + for (int i = 0; i < n_layer; i++) { const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s(); const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i) + hparams.n_embd_v_s(); @@ -97,6 +101,16 @@ bool llama_kv_cache_init( ggml_format_name(v, "cache_v_l%d", i); cache.k_l.push_back(k); cache.v_l.push_back(v); + + // DeepSeek MLA + const uint32_t n_embd_head_qk_rope = hparams.n_rot; + const uint32_t kv_lora_rank = hparams.n_lora_kv; + ggml_tensor * kr = ggml_new_tensor_1d(ctx, cache.type_kr, n_embd_head_qk_rope*kv_size); + ggml_tensor * kv = ggml_new_tensor_1d(ctx, cache.type_kv, kv_lora_rank*kv_size); + ggml_format_name(kr, "cache_kr_l%d", i); + ggml_format_name(kv, "cache_kv_l%d", i); + cache.kr_l.push_back(kr); + cache.kv_l.push_back(kv); } // allocate tensors and initialize the buffers to avoid NaNs in the padding diff --git a/src/llama-kv-cache.h b/src/llama-kv-cache.h index dca6f3998c645..7f2e1b3e7b144 100644 --- a/src/llama-kv-cache.h +++ b/src/llama-kv-cache.h @@ -49,11 +49,18 @@ struct llama_kv_cache { ggml_type type_k = GGML_TYPE_F16; ggml_type type_v = GGML_TYPE_F16; + ggml_type type_kr = GGML_TYPE_F32; + ggml_type type_kv = GGML_TYPE_F32; + std::vector cells; std::vector k_l; // per layer std::vector v_l; + // DeepSeek MLA + std::vector kr_l; // per layer + std::vector kv_l; + std::vector ctxs; std::vector bufs; diff --git a/src/llama.cpp b/src/llama.cpp index 60728e5bb91ca..99af190e1474b 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -8860,32 +8860,37 @@ struct llm_build_context { LLM_NORM_RMS, cb, il); cb(kv_compressed, "kv_compressed", il); + struct ggml_tensor * kv_cache_view = ggml_view_1d(ctx0, kv_self.kv_l[il], n_tokens*kv_lora_rank, ggml_row_size(kv_self.kv_l[il]->type, kv_lora_rank)*kv_head); + cb(kv_cache_view, "kv_cache_view", il); + + // note: storing c^KV in the KV cache + ggml_build_forward_expand(gf, ggml_cpy(ctx0, kv_compressed, kv_cache_view)); + + struct ggml_tensor * kv_cache = + ggml_view_2d(ctx0, kv_self.kv_l[il], + kv_lora_rank, n_kv, + ggml_row_size(kv_self.kv_l[il]->type, kv_lora_rank), + 0); + cb(kv_cache, "kv_cache", il); + // {kv_lora_rank, n_head * (n_embd_head_qk_nope + n_embd_head_v)} * {kv_lora_rank, n_tokens} -> {n_head * (n_embd_head_qk_nope + n_embd_head_v), n_tokens} - struct ggml_tensor * kv = ggml_mul_mat(ctx0, model.layers[il].wkv_b, kv_compressed); + struct ggml_tensor * kv = ggml_mul_mat(ctx0, model.layers[il].wkv_b, kv_cache); cb(kv, "kv", il); // split into {n_head * n_embd_head_qk_nope, n_tokens} - struct ggml_tensor * k_nope = ggml_view_3d(ctx0, kv, n_embd_head_qk_nope, n_head, n_tokens, + struct ggml_tensor * k_nope = ggml_view_3d(ctx0, kv, n_embd_head_qk_nope, n_head, n_kv, ggml_row_size(kv->type, n_embd_head_qk_nope + hparams.n_embd_head_v), ggml_row_size(kv->type, n_head * (n_embd_head_qk_nope + hparams.n_embd_head_v)), 0); cb(k_nope, "k_nope", il); // and {n_head * n_embd_head_v, n_tokens} - struct ggml_tensor * v_states = ggml_view_3d(ctx0, kv, hparams.n_embd_head_v, n_head, n_tokens, + struct ggml_tensor * v_states = ggml_view_3d(ctx0, kv, hparams.n_embd_head_v, n_head, n_kv, ggml_row_size(kv->type, (n_embd_head_qk_nope + hparams.n_embd_head_v)), ggml_row_size(kv->type, (n_embd_head_qk_nope + hparams.n_embd_head_v)*n_head), ggml_row_size(kv->type, (n_embd_head_qk_nope))); cb(v_states, "v_states", il); - v_states = ggml_cont(ctx0, v_states); - cb(v_states, "v_states", il); - - v_states = ggml_view_2d(ctx0, v_states, hparams.n_embd_head_v * n_head, n_tokens, - ggml_row_size(kv->type, hparams.n_embd_head_v * n_head), - 0); - cb(v_states, "v_states", il); - q_pe = ggml_cont(ctx0, q_pe); // TODO: the CUDA backend does not support non-contiguous RoPE q_pe = ggml_rope_ext( ctx0, q_pe, inp_pos, nullptr, @@ -8903,15 +8908,61 @@ struct llm_build_context { ); cb(k_pe, "k_pe", il); + struct ggml_tensor * kr_cache_view = ggml_view_1d(ctx0, kv_self.kr_l[il], n_tokens*n_embd_head_qk_rope, ggml_row_size(kv_self.kr_l[il]->type, n_embd_head_qk_rope)*kv_head); + cb(kr_cache_view, "kr_cache_view", il); + + // note: storing RoPE-ed version of K^R in the KV cache + ggml_build_forward_expand(gf, ggml_cpy(ctx0, k_pe, kr_cache_view)); + struct ggml_tensor * q_states = ggml_concat(ctx0, q_nope, q_pe, 0); cb(q_states, "q_states", il); - struct ggml_tensor * k_states = ggml_concat(ctx0, k_nope, ggml_repeat(ctx0, k_pe, q_pe), 0); + struct ggml_tensor * kr_cache = + ggml_view_2d(ctx0, kv_self.kr_l[il], + n_embd_head_qk_rope, n_kv, + ggml_row_size(kv_self.kr_l[il]->type, n_embd_head_qk_rope), + 0); + cb(kr_cache, "kr_cache", il); + + // TODO is there a better way? + struct ggml_tensor * kr_rep_shape = ggml_new_tensor_3d(ctx0, kr_cache->type, kr_cache->ne[0], kr_cache->ne[1], n_head); + struct ggml_tensor * kr_rep = ggml_repeat(ctx0, kr_cache, kr_rep_shape); + kr_rep = ggml_permute(ctx0, kr_rep, 0, 2, 1, 3); + struct ggml_tensor * k_states = ggml_concat(ctx0, k_nope, kr_rep, 0); cb(k_states, "k_states", il); - cur = llm_build_kv(ctx0, lctx, kv_self, gf, - model.layers[il].wo, NULL, - k_states, v_states, q_states, KQ_mask, n_tokens, kv_head, n_kv, kq_scale, cb, il); + q_states = ggml_permute(ctx0, q_states, 0, 2, 1, 3); + cb(q_states, "q_states", il); + + k_states = ggml_permute(ctx0, k_states, 0, 2, 1, 3); + cb(k_states, "k_states", il); + + struct ggml_tensor * kq = ggml_mul_mat(ctx0, k_states, q_states); + cb(kq, "kq", il); + + kq = ggml_soft_max_ext(ctx0, kq, KQ_mask, kq_scale, hparams.f_max_alibi_bias); + cb(kq, "kq_soft_max_ext", il); + + v_states = ggml_permute(ctx0, v_states, 1, 2, 0, 3); + cb(v_states, "v_states", il); + + v_states = ggml_cont(ctx0, v_states); + + struct ggml_tensor * kqv = ggml_mul_mat(ctx0, v_states, kq); + cb(kqv, "kqv", il); + + GGML_ASSERT(kv_self.size == n_ctx); + + struct ggml_tensor * kqv_merged = ggml_permute(ctx0, kqv, 0, 2, 1, 3); + cb(kqv_merged, "kqv_merged", il); + + cur = ggml_cont_2d(ctx0, kqv_merged, n_embd_head_v*n_head, n_tokens); + cb(cur, "kqv_merged_cont", il); + + ggml_build_forward_expand(gf, cur); + + cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wo, cur); + cb(cur, "kqv_out", il); } if (il == n_layer - 1) { @@ -12004,6 +12055,24 @@ struct llama_context * llama_new_context_with_model( ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f)); } + { + size_t memory_size_kr = 0; + size_t memory_size_kv = 0; + + for (auto & kr : ctx->kv_self.kr_l) { + memory_size_kr += ggml_nbytes(kr); + } + + for (auto & kv : ctx->kv_self.kv_l) { + memory_size_kv += ggml_nbytes(kv); + } + + LLAMA_LOG_INFO("%s: KV self size = %7.2f MiB, K^R (%s): %7.2f MiB, c^KV (%s): %7.2f MiB\n", __func__, + (float)(memory_size_kr + memory_size_kv) / (1024.0f * 1024.0f), + ggml_type_name(type_k), (float)memory_size_kr / (1024.0f * 1024.0f), + ggml_type_name(type_k), (float)memory_size_kv / (1024.0f * 1024.0f)); + } + // graph outputs buffer { // resized during inference when a batch uses more outputs From de538aa32929a10555097f01cad91639dfbe84ac Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Stanis=C5=82aw=20Szymczyk?= Date: Sat, 25 Jan 2025 18:10:22 +0100 Subject: [PATCH 2/9] llama : optimize DeepSeek MLA implementation --- convert_hf_to_gguf.py | 23 ++++++++++ gguf-py/gguf/constants.py | 6 +++ gguf-py/gguf/tensor_mapping.py | 8 ++++ src/llama-arch.cpp | 6 +++ src/llama-arch.h | 2 + src/llama-kv-cache.cpp | 1 + src/llama-kv-cache.h | 4 +- src/llama-model.cpp | 2 + src/llama-model.h | 2 + src/llama.cpp | 83 ++++++++++++++++++---------------- 10 files changed, 96 insertions(+), 41 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 63b54a9cf6b48..4df55e7b15b93 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -4136,6 +4136,29 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter else: return [] + if name.endswith("kv_b_proj.weight"): + name_kb = name.replace("kv_b_proj", "k_b_proj") + name_vb = name.replace("kv_b_proj", "v_b_proj") + + n_head_kv = self.hparams["num_key_value_heads"] + v_head_dim = self.hparams["v_head_dim"] + qk_nope_head_dim = self.hparams["qk_nope_head_dim"] + + assert data_torch.shape[0] == n_head_kv * (v_head_dim + qk_nope_head_dim) + + kv_b = data_torch.view(n_head_kv, v_head_dim + qk_nope_head_dim, data_torch.shape[-1]) + k_b, v_b = torch.split(kv_b, [qk_nope_head_dim, v_head_dim], dim=1) + k_b = k_b.transpose(1, 2); + k_b = k_b.reshape(n_head_kv * data_torch.shape[-1], qk_nope_head_dim) + v_b = v_b.reshape(n_head_kv * v_head_dim, data_torch.shape[-1]) + + return [ + (self.map_tensor_name(name), data_torch), + (self.map_tensor_name(name_kb), k_b), + (self.map_tensor_name(name_vb), v_b) + ] + + return [(self.map_tensor_name(name), data_torch)] def prepare_tensors(self): diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 8fe84df21ea20..12522928a8c28 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -356,6 +356,8 @@ class MODEL_TENSOR(IntEnum): ATTN_Q_B = auto() ATTN_KV_A_MQA = auto() ATTN_KV_B = auto() + ATTN_K_B = auto() + ATTN_V_B = auto() ATTN_Q_A_NORM = auto() ATTN_KV_A_NORM = auto() FFN_SUB_NORM = auto() @@ -543,6 +545,8 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.ATTN_Q_B: "blk.{bid}.attn_q_b", MODEL_TENSOR.ATTN_KV_A_MQA: "blk.{bid}.attn_kv_a_mqa", MODEL_TENSOR.ATTN_KV_B: "blk.{bid}.attn_kv_b", + MODEL_TENSOR.ATTN_K_B: "blk.{bid}.attn_k_b", + MODEL_TENSOR.ATTN_V_B: "blk.{bid}.attn_v_b", MODEL_TENSOR.ATTN_Q_A_NORM: "blk.{bid}.attn_q_a_norm", MODEL_TENSOR.ATTN_KV_A_NORM: "blk.{bid}.attn_kv_a_norm", MODEL_TENSOR.ATTN_SUB_NORM: "blk.{bid}.attn_sub_norm", @@ -1333,6 +1337,8 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.ATTN_Q_B, MODEL_TENSOR.ATTN_KV_A_MQA, MODEL_TENSOR.ATTN_KV_B, + MODEL_TENSOR.ATTN_K_B, + MODEL_TENSOR.ATTN_V_B, MODEL_TENSOR.ATTN_Q_A_NORM, MODEL_TENSOR.ATTN_KV_A_NORM, MODEL_TENSOR.ATTN_OUT, diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index 617791e240b60..df831ba70594c 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -586,6 +586,14 @@ class TensorNameMap: "model.layers.{bid}.self_attn.kv_b_proj", # deepseek2 ), + MODEL_TENSOR.ATTN_K_B: ( + "model.layers.{bid}.self_attn.k_b_proj", # deepseek2 + ), + + MODEL_TENSOR.ATTN_V_B: ( + "model.layers.{bid}.self_attn.v_b_proj", # deepseek2 + ), + MODEL_TENSOR.ATTN_Q_A_NORM: ( "model.layers.{bid}.self_attn.q_a_layernorm", # deepseek2 ), diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index a7260f495d945..e6daa1bc4b5ce 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -999,6 +999,8 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_ATTN_Q_B, "blk.%d.attn_q_b" }, { LLM_TENSOR_ATTN_KV_A_MQA, "blk.%d.attn_kv_a_mqa" }, { LLM_TENSOR_ATTN_KV_B, "blk.%d.attn_kv_b" }, + { LLM_TENSOR_ATTN_K_B, "blk.%d.attn_k_b" }, + { LLM_TENSOR_ATTN_V_B, "blk.%d.attn_v_b" }, { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, @@ -1330,6 +1332,8 @@ static const std::map LLM_TENSOR_INFOS = { {LLM_TENSOR_ATTN_Q_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_ATTN_KV_A_MQA, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_ATTN_KV_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_ATTN_K_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_ATTN_V_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_DEC_ATTN_Q, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_DEC_ATTN_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_ATTN_Q, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, @@ -1347,6 +1351,8 @@ static const std::map LLM_TENSOR_INFOS = { {LLM_TENSOR_ATTN_Q_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_ATTN_KV_A_MQA, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_ATTN_KV_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_ATTN_K_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_ATTN_V_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_DEC_ATTN_Q, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_DEC_ATTN_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_DEC_ATTN_V, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, diff --git a/src/llama-arch.h b/src/llama-arch.h index 122fdcebe0af6..c6105d59ac1f3 100644 --- a/src/llama-arch.h +++ b/src/llama-arch.h @@ -277,6 +277,8 @@ enum llm_tensor { LLM_TENSOR_ATTN_Q_B, LLM_TENSOR_ATTN_KV_A_MQA, LLM_TENSOR_ATTN_KV_B, + LLM_TENSOR_ATTN_K_B, + LLM_TENSOR_ATTN_V_B, LLM_TENSOR_ATTN_Q_A_NORM, LLM_TENSOR_ATTN_KV_A_NORM, LLM_TENSOR_ATTN_SUB_NORM, diff --git a/src/llama-kv-cache.cpp b/src/llama-kv-cache.cpp index 8a836c784eca5..51e71437c1391 100644 --- a/src/llama-kv-cache.cpp +++ b/src/llama-kv-cache.cpp @@ -105,6 +105,7 @@ bool llama_kv_cache_init( // DeepSeek MLA const uint32_t n_embd_head_qk_rope = hparams.n_rot; const uint32_t kv_lora_rank = hparams.n_lora_kv; + LLAMA_LOG_DEBUG("%s: layer %d: n_embd_head_qk_rope = %d, kv_lora_rank = %d\n", __func__, i, n_embd_head_qk_rope, kv_lora_rank); ggml_tensor * kr = ggml_new_tensor_1d(ctx, cache.type_kr, n_embd_head_qk_rope*kv_size); ggml_tensor * kv = ggml_new_tensor_1d(ctx, cache.type_kv, kv_lora_rank*kv_size); ggml_format_name(kr, "cache_kr_l%d", i); diff --git a/src/llama-kv-cache.h b/src/llama-kv-cache.h index 7f2e1b3e7b144..a87344c849235 100644 --- a/src/llama-kv-cache.h +++ b/src/llama-kv-cache.h @@ -49,8 +49,8 @@ struct llama_kv_cache { ggml_type type_k = GGML_TYPE_F16; ggml_type type_v = GGML_TYPE_F16; - ggml_type type_kr = GGML_TYPE_F32; - ggml_type type_kv = GGML_TYPE_F32; + ggml_type type_kr = GGML_TYPE_F16; + ggml_type type_kv = GGML_TYPE_F16; std::vector cells; diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 031b4c30b75dd..8007e730d04f8 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -2870,6 +2870,8 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.wkv_a_mqa = create_tensor(tn(LLM_TENSOR_ATTN_KV_A_MQA, "weight", i), {n_embd, kv_lora_rank + (n_embd_head_qk_rope)}, 0); layer.wkv_b = create_tensor(tn(LLM_TENSOR_ATTN_KV_B, "weight", i), {kv_lora_rank, n_head * (n_embd_head_qk_nope + n_embd_head_v)}, 0); + layer.wk_b = create_tensor(tn(LLM_TENSOR_ATTN_K_B, "weight", i), {n_embd_head_qk_nope, n_head * kv_lora_rank}, 0); + layer.wv_b = create_tensor(tn(LLM_TENSOR_ATTN_V_B, "weight", i), {kv_lora_rank, n_head * n_embd_head_v}, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_head * ( n_embd_head_v), n_embd}, 0); layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); diff --git a/src/llama-model.h b/src/llama-model.h index a7c30444786fd..1fdbd3721d630 100644 --- a/src/llama-model.h +++ b/src/llama-model.h @@ -161,6 +161,8 @@ struct llama_layer { struct ggml_tensor * wq_b = nullptr; struct ggml_tensor * wkv_a_mqa = nullptr; struct ggml_tensor * wkv_b = nullptr; + struct ggml_tensor * wk_b = nullptr; + struct ggml_tensor * wv_b = nullptr; struct ggml_tensor * wq_cross = nullptr; struct ggml_tensor * wk_cross = nullptr; struct ggml_tensor * wv_cross = nullptr; diff --git a/src/llama.cpp b/src/llama.cpp index 5a9518a8e93e2..cb9fe8c9714f5 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -6483,24 +6483,6 @@ struct llm_build_context { 0); cb(kv_cache, "kv_cache", il); - // {kv_lora_rank, n_head * (n_embd_head_qk_nope + n_embd_head_v)} * {kv_lora_rank, n_tokens} -> {n_head * (n_embd_head_qk_nope + n_embd_head_v), n_tokens} - struct ggml_tensor * kv = ggml_mul_mat(ctx0, model.layers[il].wkv_b, kv_cache); - cb(kv, "kv", il); - - // split into {n_head * n_embd_head_qk_nope, n_tokens} - struct ggml_tensor * k_nope = ggml_view_3d(ctx0, kv, n_embd_head_qk_nope, n_head, n_kv, - ggml_row_size(kv->type, n_embd_head_qk_nope + hparams.n_embd_head_v), - ggml_row_size(kv->type, n_head * (n_embd_head_qk_nope + hparams.n_embd_head_v)), - 0); - cb(k_nope, "k_nope", il); - - // and {n_head * n_embd_head_v, n_tokens} - struct ggml_tensor * v_states = ggml_view_3d(ctx0, kv, hparams.n_embd_head_v, n_head, n_kv, - ggml_row_size(kv->type, (n_embd_head_qk_nope + hparams.n_embd_head_v)), - ggml_row_size(kv->type, (n_embd_head_qk_nope + hparams.n_embd_head_v)*n_head), - ggml_row_size(kv->type, (n_embd_head_qk_nope))); - cb(v_states, "v_states", il); - q_pe = ggml_cont(ctx0, q_pe); // TODO: the CUDA backend does not support non-contiguous RoPE q_pe = ggml_rope_ext( ctx0, q_pe, inp_pos, nullptr, @@ -6524,9 +6506,6 @@ struct llm_build_context { // note: storing RoPE-ed version of K^R in the KV cache ggml_build_forward_expand(gf, ggml_cpy(ctx0, k_pe, kr_cache_view)); - struct ggml_tensor * q_states = ggml_concat(ctx0, q_nope, q_pe, 0); - cb(q_states, "q_states", il); - struct ggml_tensor * kr_cache = ggml_view_2d(ctx0, kv_self.kr_l[il], n_embd_head_qk_rope, n_kv, @@ -6534,36 +6513,62 @@ struct llm_build_context { 0); cb(kr_cache, "kr_cache", il); - // TODO is there a better way? - struct ggml_tensor * kr_rep_shape = ggml_new_tensor_3d(ctx0, kr_cache->type, kr_cache->ne[0], kr_cache->ne[1], n_head); - struct ggml_tensor * kr_rep = ggml_repeat(ctx0, kr_cache, kr_rep_shape); - kr_rep = ggml_permute(ctx0, kr_rep, 0, 2, 1, 3); - struct ggml_tensor * k_states = ggml_concat(ctx0, k_nope, kr_rep, 0); - cb(k_states, "k_states", il); + struct ggml_tensor * wk_b = ggml_view_3d(ctx0, model.layers[il].wk_b, n_embd_head_qk_nope, kv_lora_rank, n_head, ggml_row_size(model.layers[il].wk_b->type, n_embd_head_qk_nope), ggml_row_size(model.layers[il].wk_b->type, kv_lora_rank * n_embd_head_qk_nope), 0); + cb(wk_b, "wk_b", il); - q_states = ggml_permute(ctx0, q_states, 0, 2, 1, 3); - cb(q_states, "q_states", il); + struct ggml_tensor * q_nope_perm = ggml_permute(ctx0, q_nope, 0, 2, 3, 1); + cb(q_nope_perm, "q_nope_perm", il); - k_states = ggml_permute(ctx0, k_states, 0, 2, 1, 3); - cb(k_states, "k_states", il); + struct ggml_tensor * q_nope2 = ggml_mul_mat(ctx0, wk_b, q_nope_perm); + cb(q_nope2, "q_nope2", il); - struct ggml_tensor * kq = ggml_mul_mat(ctx0, k_states, q_states); - cb(kq, "kq", il); + struct ggml_tensor * q_nope2_perm = ggml_permute(ctx0, q_nope2, 0, 1, 3, 2); + cb(q_nope2_perm, "q_nope2_perm", il); + + struct ggml_tensor * kv_cache_perm = ggml_cont(ctx0, ggml_permute(ctx0, kv_cache, 1, 0, 2, 3)); + cb(kv_cache_perm, "kv_cache_perm", il); + + struct ggml_tensor * scores1 = ggml_mul_mat(ctx0, kv_cache, q_nope2_perm); + cb(scores1, "scores1", il); + + struct ggml_tensor * q_pe_perm = ggml_permute(ctx0, q_pe, 0, 3, 2, 1); + cb(q_pe_perm, "q_pe_perm", il); + + struct ggml_tensor * kr_cache_perm = ggml_permute(ctx0, kr_cache, 0, 2, 3, 1); + cb(kr_cache_perm, "kr_cache_perm", il); + + struct ggml_tensor * scores2 = ggml_mul_mat(ctx0, kr_cache, q_pe_perm); + cb(scores2, "scores2", il); + + struct ggml_tensor * scores = ggml_add(ctx0, scores1, scores2); + cb(scores, "scores", il); + + struct ggml_tensor * kq = ggml_permute(ctx0, scores, 0, 3, 1, 2); + + struct ggml_tensor * wv_b = ggml_view_3d(ctx0, model.layers[il].wv_b, kv_lora_rank, n_embd_head_v, n_head, ggml_row_size(model.layers[il].wv_b->type, kv_lora_rank), ggml_row_size(model.layers[il].wv_b->type, kv_lora_rank * n_embd_head_v), 0); + cb(wv_b, "wv_b", il); kq = ggml_soft_max_ext(ctx0, kq, KQ_mask, kq_scale, hparams.f_max_alibi_bias); cb(kq, "kq_soft_max_ext", il); - v_states = ggml_permute(ctx0, v_states, 1, 2, 0, 3); - cb(v_states, "v_states", il); + struct ggml_tensor * kq_perm = ggml_permute(ctx0, kq, 0, 2, 3, 1); + cb(kq_perm, "kq_perm", il); - v_states = ggml_cont(ctx0, v_states); + struct ggml_tensor * kqv1 = ggml_mul_mat(ctx0, kv_cache_perm, kq_perm); + cb(kqv1, "kqv1", il); - struct ggml_tensor * kqv = ggml_mul_mat(ctx0, v_states, kq); - cb(kqv, "kqv", il); + struct ggml_tensor * kqv1_trans = ggml_permute(ctx0, kqv1, 0, 1, 3, 2); + cb(kqv1_trans, "kqv1_trans", il); + + struct ggml_tensor * kqv2 = ggml_mul_mat(ctx0, wv_b, kqv1_trans); + cb(kqv2, "kqv2", il); + + struct ggml_tensor * kqv2_trans = ggml_permute(ctx0, kqv2, 0, 3, 2, 1); + cb(kqv2_trans, "kqv2_trans", il); GGML_ASSERT(kv_self.size == n_ctx); - struct ggml_tensor * kqv_merged = ggml_permute(ctx0, kqv, 0, 2, 1, 3); + struct ggml_tensor * kqv_merged = ggml_permute(ctx0, kqv2_trans, 0, 2, 1, 3); cb(kqv_merged, "kqv_merged", il); cur = ggml_cont_2d(ctx0, kqv_merged, n_embd_head_v*n_head, n_tokens); From ce730637e8fe1b86de7d6e1758f33d716c6c7781 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Stanis=C5=82aw=20Szymczyk?= Date: Sun, 26 Jan 2025 12:50:17 +0100 Subject: [PATCH 3/9] llama : Update tensor names in DeepSeek2 MLA implementation. --- src/llama.cpp | 46 +++++++++++++++++++++------------------------- 1 file changed, 21 insertions(+), 25 deletions(-) diff --git a/src/llama.cpp b/src/llama.cpp index cb9fe8c9714f5..08b27b33add97 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -6525,11 +6525,8 @@ struct llm_build_context { struct ggml_tensor * q_nope2_perm = ggml_permute(ctx0, q_nope2, 0, 1, 3, 2); cb(q_nope2_perm, "q_nope2_perm", il); - struct ggml_tensor * kv_cache_perm = ggml_cont(ctx0, ggml_permute(ctx0, kv_cache, 1, 0, 2, 3)); - cb(kv_cache_perm, "kv_cache_perm", il); - - struct ggml_tensor * scores1 = ggml_mul_mat(ctx0, kv_cache, q_nope2_perm); - cb(scores1, "scores1", il); + struct ggml_tensor * kq_nope = ggml_mul_mat(ctx0, kv_cache, q_nope2_perm); + cb(kq_nope, "kq_nope", il); struct ggml_tensor * q_pe_perm = ggml_permute(ctx0, q_pe, 0, 3, 2, 1); cb(q_pe_perm, "q_pe_perm", il); @@ -6537,13 +6534,14 @@ struct llm_build_context { struct ggml_tensor * kr_cache_perm = ggml_permute(ctx0, kr_cache, 0, 2, 3, 1); cb(kr_cache_perm, "kr_cache_perm", il); - struct ggml_tensor * scores2 = ggml_mul_mat(ctx0, kr_cache, q_pe_perm); - cb(scores2, "scores2", il); + struct ggml_tensor * kq_pe = ggml_mul_mat(ctx0, kr_cache, q_pe_perm); + cb(kq_pe, "kq_pe", il); - struct ggml_tensor * scores = ggml_add(ctx0, scores1, scores2); - cb(scores, "scores", il); + struct ggml_tensor * kq = ggml_add(ctx0, kq_nope, kq_pe); + cb(kq, "kq", il); - struct ggml_tensor * kq = ggml_permute(ctx0, scores, 0, 3, 1, 2); + kq = ggml_permute(ctx0, kq, 0, 3, 1, 2); + cb(kq, "kq_perm", il); struct ggml_tensor * wv_b = ggml_view_3d(ctx0, model.layers[il].wv_b, kv_lora_rank, n_embd_head_v, n_head, ggml_row_size(model.layers[il].wv_b->type, kv_lora_rank), ggml_row_size(model.layers[il].wv_b->type, kv_lora_rank * n_embd_head_v), 0); cb(wv_b, "wv_b", il); @@ -6552,27 +6550,25 @@ struct llm_build_context { cb(kq, "kq_soft_max_ext", il); struct ggml_tensor * kq_perm = ggml_permute(ctx0, kq, 0, 2, 3, 1); - cb(kq_perm, "kq_perm", il); - - struct ggml_tensor * kqv1 = ggml_mul_mat(ctx0, kv_cache_perm, kq_perm); - cb(kqv1, "kqv1", il); + cb(kq_perm, "kq_soft_max_ext_perm", il); - struct ggml_tensor * kqv1_trans = ggml_permute(ctx0, kqv1, 0, 1, 3, 2); - cb(kqv1_trans, "kqv1_trans", il); + struct ggml_tensor * kv_cache_trans = ggml_cont(ctx0, ggml_transpose(ctx0, kv_cache)); + cb(kv_cache_trans, "kv_cache_trans", il); - struct ggml_tensor * kqv2 = ggml_mul_mat(ctx0, wv_b, kqv1_trans); - cb(kqv2, "kqv2", il); + struct ggml_tensor * kqv_compressed = ggml_mul_mat(ctx0, kv_cache_trans, kq_perm); + cb(kqv_compressed, "kqv_compressed", il); - struct ggml_tensor * kqv2_trans = ggml_permute(ctx0, kqv2, 0, 3, 2, 1); - cb(kqv2_trans, "kqv2_trans", il); + kqv_compressed = ggml_permute(ctx0, kqv_compressed, 0, 1, 3, 2); + cb(kqv_compressed, "kqv_compressed_perm", il); - GGML_ASSERT(kv_self.size == n_ctx); + struct ggml_tensor * kqv = ggml_mul_mat(ctx0, wv_b, kqv_compressed); + cb(kqv, "kqv", il); - struct ggml_tensor * kqv_merged = ggml_permute(ctx0, kqv2_trans, 0, 2, 1, 3); - cb(kqv_merged, "kqv_merged", il); + kqv = ggml_permute(ctx0, kqv, 0, 3, 1, 2); + cb(kqv, "kqv_perm", il); - cur = ggml_cont_2d(ctx0, kqv_merged, n_embd_head_v*n_head, n_tokens); - cb(cur, "kqv_merged_cont", il); + cur = ggml_view_2d(ctx0, kqv, n_embd_head_v*n_head, n_tokens, ggml_row_size(kqv->type, n_embd_head_v*n_head), 0); + cb(cur, "kqv_2d", il); ggml_build_forward_expand(gf, cur); From 202f323e66809bb1df192245caddc49471660466 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Stanis=C5=82aw=20Szymczyk?= Date: Sun, 26 Jan 2025 18:29:54 +0100 Subject: [PATCH 4/9] llama : add a second copy of c^KV cache in DeepSeek2 MLA to avoid transposing the cache during inference --- src/llama-kv-cache.cpp | 6 +++++- src/llama-kv-cache.h | 1 + src/llama.cpp | 16 +++++++++++++--- 3 files changed, 19 insertions(+), 4 deletions(-) diff --git a/src/llama-kv-cache.cpp b/src/llama-kv-cache.cpp index 51e71437c1391..57ccbeeae7e26 100644 --- a/src/llama-kv-cache.cpp +++ b/src/llama-kv-cache.cpp @@ -53,7 +53,7 @@ bool llama_kv_cache_init( auto it = ctx_map.find(buft); if (it == ctx_map.end()) { struct ggml_init_params params = { - /*.mem_size =*/ size_t(4u*n_layer*ggml_tensor_overhead()), + /*.mem_size =*/ size_t(5u*n_layer*ggml_tensor_overhead()), /*.mem_buffer =*/ NULL, /*.no_alloc =*/ true, }; @@ -74,6 +74,7 @@ bool llama_kv_cache_init( // DeepSeek MLA cache.kr_l.reserve(n_layer); cache.kv_l.reserve(n_layer); + cache.kvt_l.reserve(n_layer); for (int i = 0; i < n_layer; i++) { const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s(); @@ -108,10 +109,13 @@ bool llama_kv_cache_init( LLAMA_LOG_DEBUG("%s: layer %d: n_embd_head_qk_rope = %d, kv_lora_rank = %d\n", __func__, i, n_embd_head_qk_rope, kv_lora_rank); ggml_tensor * kr = ggml_new_tensor_1d(ctx, cache.type_kr, n_embd_head_qk_rope*kv_size); ggml_tensor * kv = ggml_new_tensor_1d(ctx, cache.type_kv, kv_lora_rank*kv_size); + ggml_tensor * kvt = ggml_new_tensor_1d(ctx, cache.type_kv, kv_lora_rank*kv_size); ggml_format_name(kr, "cache_kr_l%d", i); ggml_format_name(kv, "cache_kv_l%d", i); + ggml_format_name(kvt, "cache_kvt_l%d", i); cache.kr_l.push_back(kr); cache.kv_l.push_back(kv); + cache.kvt_l.push_back(kvt); } // allocate tensors and initialize the buffers to avoid NaNs in the padding diff --git a/src/llama-kv-cache.h b/src/llama-kv-cache.h index a87344c849235..b10540d76442e 100644 --- a/src/llama-kv-cache.h +++ b/src/llama-kv-cache.h @@ -60,6 +60,7 @@ struct llama_kv_cache { // DeepSeek MLA std::vector kr_l; // per layer std::vector kv_l; + std::vector kvt_l; std::vector ctxs; std::vector bufs; diff --git a/src/llama.cpp b/src/llama.cpp index 08b27b33add97..d9fe40102b346 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -6476,6 +6476,12 @@ struct llm_build_context { // note: storing c^KV in the KV cache ggml_build_forward_expand(gf, ggml_cpy(ctx0, kv_compressed, kv_cache_view)); + struct ggml_tensor * kv_cache_trans_view = ggml_view_2d(ctx0, kv_self.kvt_l[il], n_tokens, kv_lora_rank, ggml_row_size(kv_self.kv_l[il]->type, kv_self.size), ggml_row_size(kv_self.kv_l[il]->type, kv_head)); + cb(kv_cache_trans_view, "kv_cache_trans_view", il); + + // note: storing transposed c^KV in the transposed KV cache + ggml_build_forward_expand(gf, ggml_cpy(ctx0, ggml_transpose(ctx0, kv_compressed), kv_cache_trans_view)); + struct ggml_tensor * kv_cache = ggml_view_2d(ctx0, kv_self.kv_l[il], kv_lora_rank, n_kv, @@ -6483,6 +6489,13 @@ struct llm_build_context { 0); cb(kv_cache, "kv_cache", il); + struct ggml_tensor * kv_cache_trans = + ggml_view_2d(ctx0, kv_self.kvt_l[il], + n_kv, kv_lora_rank, + ggml_row_size(kv_self.kv_l[il]->type, kv_self.size), + 0); + cb(kv_cache_trans, "kv_cache_trans", il); + q_pe = ggml_cont(ctx0, q_pe); // TODO: the CUDA backend does not support non-contiguous RoPE q_pe = ggml_rope_ext( ctx0, q_pe, inp_pos, nullptr, @@ -6552,9 +6565,6 @@ struct llm_build_context { struct ggml_tensor * kq_perm = ggml_permute(ctx0, kq, 0, 2, 3, 1); cb(kq_perm, "kq_soft_max_ext_perm", il); - struct ggml_tensor * kv_cache_trans = ggml_cont(ctx0, ggml_transpose(ctx0, kv_cache)); - cb(kv_cache_trans, "kv_cache_trans", il); - struct ggml_tensor * kqv_compressed = ggml_mul_mat(ctx0, kv_cache_trans, kq_perm); cb(kqv_compressed, "kqv_compressed", il); From 93c5937249c313bf825d020f4a5213e32c94737c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Stanis=C5=82aw=20Szymczyk?= Date: Sun, 26 Jan 2025 22:23:13 +0100 Subject: [PATCH 5/9] llama : modified tensor permutations to multiply larger matrices during inference --- src/llama.cpp | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/src/llama.cpp b/src/llama.cpp index d9fe40102b346..3df9896922254 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -6529,13 +6529,13 @@ struct llm_build_context { struct ggml_tensor * wk_b = ggml_view_3d(ctx0, model.layers[il].wk_b, n_embd_head_qk_nope, kv_lora_rank, n_head, ggml_row_size(model.layers[il].wk_b->type, n_embd_head_qk_nope), ggml_row_size(model.layers[il].wk_b->type, kv_lora_rank * n_embd_head_qk_nope), 0); cb(wk_b, "wk_b", il); - struct ggml_tensor * q_nope_perm = ggml_permute(ctx0, q_nope, 0, 2, 3, 1); + struct ggml_tensor * q_nope_perm = ggml_permute(ctx0, q_nope, 0, 2, 1, 3); cb(q_nope_perm, "q_nope_perm", il); struct ggml_tensor * q_nope2 = ggml_mul_mat(ctx0, wk_b, q_nope_perm); cb(q_nope2, "q_nope2", il); - struct ggml_tensor * q_nope2_perm = ggml_permute(ctx0, q_nope2, 0, 1, 3, 2); + struct ggml_tensor * q_nope2_perm = ggml_permute(ctx0, q_nope2, 0, 2, 1, 3); cb(q_nope2_perm, "q_nope2_perm", il); struct ggml_tensor * kq_nope = ggml_mul_mat(ctx0, kv_cache, q_nope2_perm); @@ -6547,34 +6547,34 @@ struct llm_build_context { struct ggml_tensor * kr_cache_perm = ggml_permute(ctx0, kr_cache, 0, 2, 3, 1); cb(kr_cache_perm, "kr_cache_perm", il); - struct ggml_tensor * kq_pe = ggml_mul_mat(ctx0, kr_cache, q_pe_perm); + struct ggml_tensor * kq_pe = ggml_mul_mat(ctx0, kr_cache, q_pe); cb(kq_pe, "kq_pe", il); struct ggml_tensor * kq = ggml_add(ctx0, kq_nope, kq_pe); cb(kq, "kq", il); - kq = ggml_permute(ctx0, kq, 0, 3, 1, 2); + kq = ggml_cont(ctx0, ggml_permute(ctx0, kq, 0, 2, 1, 3)); cb(kq, "kq_perm", il); - struct ggml_tensor * wv_b = ggml_view_3d(ctx0, model.layers[il].wv_b, kv_lora_rank, n_embd_head_v, n_head, ggml_row_size(model.layers[il].wv_b->type, kv_lora_rank), ggml_row_size(model.layers[il].wv_b->type, kv_lora_rank * n_embd_head_v), 0); - cb(wv_b, "wv_b", il); - kq = ggml_soft_max_ext(ctx0, kq, KQ_mask, kq_scale, hparams.f_max_alibi_bias); cb(kq, "kq_soft_max_ext", il); - struct ggml_tensor * kq_perm = ggml_permute(ctx0, kq, 0, 2, 3, 1); + struct ggml_tensor * kq_perm = ggml_permute(ctx0, kq, 0, 2, 1, 3); cb(kq_perm, "kq_soft_max_ext_perm", il); struct ggml_tensor * kqv_compressed = ggml_mul_mat(ctx0, kv_cache_trans, kq_perm); cb(kqv_compressed, "kqv_compressed", il); - kqv_compressed = ggml_permute(ctx0, kqv_compressed, 0, 1, 3, 2); + kqv_compressed = ggml_permute(ctx0, kqv_compressed, 0, 2, 1, 3); cb(kqv_compressed, "kqv_compressed_perm", il); + struct ggml_tensor * wv_b = ggml_view_3d(ctx0, model.layers[il].wv_b, kv_lora_rank, n_embd_head_v, n_head, ggml_row_size(model.layers[il].wv_b->type, kv_lora_rank), ggml_row_size(model.layers[il].wv_b->type, kv_lora_rank * n_embd_head_v), 0); + cb(wv_b, "wv_b", il); + struct ggml_tensor * kqv = ggml_mul_mat(ctx0, wv_b, kqv_compressed); cb(kqv, "kqv", il); - kqv = ggml_permute(ctx0, kqv, 0, 3, 1, 2); + kqv = ggml_cont(ctx0, ggml_permute(ctx0, kqv, 0, 2, 1, 3)); cb(kqv, "kqv_perm", il); cur = ggml_view_2d(ctx0, kqv, n_embd_head_v*n_head, n_tokens, ggml_row_size(kqv->type, n_embd_head_v*n_head), 0); From 1eee98f01fca721a889defac3d38e9ada7abb617 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Stanis=C5=82aw=20Szymczyk?= Date: Mon, 27 Jan 2025 09:32:25 +0100 Subject: [PATCH 6/9] llama : removed unnecessary code in DeepSeek V2 implementation --- src/llama.cpp | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/llama.cpp b/src/llama.cpp index 3df9896922254..a4c78240b265e 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -6544,9 +6544,6 @@ struct llm_build_context { struct ggml_tensor * q_pe_perm = ggml_permute(ctx0, q_pe, 0, 3, 2, 1); cb(q_pe_perm, "q_pe_perm", il); - struct ggml_tensor * kr_cache_perm = ggml_permute(ctx0, kr_cache, 0, 2, 3, 1); - cb(kr_cache_perm, "kr_cache_perm", il); - struct ggml_tensor * kq_pe = ggml_mul_mat(ctx0, kr_cache, q_pe); cb(kq_pe, "kq_pe", il); From 8ff0991eed65e4041e6e3dfa2e3c98aee7fa2c21 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Stanis=C5=82aw=20Szymczyk?= Date: Tue, 28 Jan 2025 11:02:52 +0100 Subject: [PATCH 7/9] convert : make lint happy --- convert_hf_to_gguf.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 4df55e7b15b93..2be7de5a59bbe 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -4148,7 +4148,7 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter kv_b = data_torch.view(n_head_kv, v_head_dim + qk_nope_head_dim, data_torch.shape[-1]) k_b, v_b = torch.split(kv_b, [qk_nope_head_dim, v_head_dim], dim=1) - k_b = k_b.transpose(1, 2); + k_b = k_b.transpose(1, 2) k_b = k_b.reshape(n_head_kv * data_torch.shape[-1], qk_nope_head_dim) v_b = v_b.reshape(n_head_kv * v_head_dim, data_torch.shape[-1]) @@ -4158,7 +4158,6 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter (self.map_tensor_name(name_vb), v_b) ] - return [(self.map_tensor_name(name), data_torch)] def prepare_tensors(self): From 8a887decd35083c1542534cfadc3a5ee592da964 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Stanis=C5=82aw=20Szymczyk?= Date: Tue, 28 Jan 2025 19:26:54 +0100 Subject: [PATCH 8/9] llama : prompt processing optimizations in DeepSeek V2 --- src/llama.cpp | 38 ++++++++++++++++++++++++++------------ 1 file changed, 26 insertions(+), 12 deletions(-) diff --git a/src/llama.cpp b/src/llama.cpp index a4c78240b265e..5768f9215fea9 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -6403,6 +6403,10 @@ struct llm_build_context { // KQ_mask (mask for 1 head, it will be broadcasted to all heads) struct ggml_tensor * KQ_mask = build_inp_KQ_mask(); + // whether to use n_tokens as the matrix dimension during multiplication or n_head + // n_tokens is higher during prompt processing, this allows to optimize for this case + bool pp_opt = n_tokens > n_head; + for (int il = 0; il < n_layer; ++il) { struct ggml_tensor * inpSA = inpL; @@ -6535,14 +6539,18 @@ struct llm_build_context { struct ggml_tensor * q_nope2 = ggml_mul_mat(ctx0, wk_b, q_nope_perm); cb(q_nope2, "q_nope2", il); - struct ggml_tensor * q_nope2_perm = ggml_permute(ctx0, q_nope2, 0, 2, 1, 3); - cb(q_nope2_perm, "q_nope2_perm", il); + if (!pp_opt) { + q_nope2 = ggml_permute(ctx0, q_nope2, 0, 2, 1, 3); + cb(q_nope2, "q_nope2_perm", il); + } - struct ggml_tensor * kq_nope = ggml_mul_mat(ctx0, kv_cache, q_nope2_perm); + struct ggml_tensor * kq_nope = ggml_mul_mat(ctx0, kv_cache, q_nope2); cb(kq_nope, "kq_nope", il); - struct ggml_tensor * q_pe_perm = ggml_permute(ctx0, q_pe, 0, 3, 2, 1); - cb(q_pe_perm, "q_pe_perm", il); + if (pp_opt) { + q_pe = ggml_permute(ctx0, q_pe, 0, 2, 1, 3); + cb(q_pe, "q_pe_perm", il); + } struct ggml_tensor * kq_pe = ggml_mul_mat(ctx0, kr_cache, q_pe); cb(kq_pe, "kq_pe", il); @@ -6550,20 +6558,26 @@ struct llm_build_context { struct ggml_tensor * kq = ggml_add(ctx0, kq_nope, kq_pe); cb(kq, "kq", il); - kq = ggml_cont(ctx0, ggml_permute(ctx0, kq, 0, 2, 1, 3)); - cb(kq, "kq_perm", il); + if (!pp_opt) { + kq = ggml_cont(ctx0, ggml_permute(ctx0, kq, 0, 2, 1, 3)); + cb(kq, "kq_perm", il); + } kq = ggml_soft_max_ext(ctx0, kq, KQ_mask, kq_scale, hparams.f_max_alibi_bias); cb(kq, "kq_soft_max_ext", il); - struct ggml_tensor * kq_perm = ggml_permute(ctx0, kq, 0, 2, 1, 3); - cb(kq_perm, "kq_soft_max_ext_perm", il); + if (!pp_opt) { + kq = ggml_permute(ctx0, kq, 0, 2, 1, 3); + cb(kq, "kq_soft_max_ext_perm", il); + } - struct ggml_tensor * kqv_compressed = ggml_mul_mat(ctx0, kv_cache_trans, kq_perm); + struct ggml_tensor * kqv_compressed = ggml_mul_mat(ctx0, kv_cache_trans, kq); cb(kqv_compressed, "kqv_compressed", il); - kqv_compressed = ggml_permute(ctx0, kqv_compressed, 0, 2, 1, 3); - cb(kqv_compressed, "kqv_compressed_perm", il); + if (!pp_opt) { + kqv_compressed = ggml_permute(ctx0, kqv_compressed, 0, 2, 1, 3); + cb(kqv_compressed, "kqv_compressed_perm", il); + } struct ggml_tensor * wv_b = ggml_view_3d(ctx0, model.layers[il].wv_b, kv_lora_rank, n_embd_head_v, n_head, ggml_row_size(model.layers[il].wv_b->type, kv_lora_rank), ggml_row_size(model.layers[il].wv_b->type, kv_lora_rank * n_embd_head_v), 0); cb(wv_b, "wv_b", il); From 76543311acc85e1d77575728000f1979faa7591f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Stanis=C5=82aw=20Szymczyk?= Date: Thu, 30 Jan 2025 18:25:36 +0100 Subject: [PATCH 9/9] llama : avoid ggml_cont() is possible in DeepSeek V2 implementation --- src/llama.cpp | 29 ++++++++++++++++++----------- 1 file changed, 18 insertions(+), 11 deletions(-) diff --git a/src/llama.cpp b/src/llama.cpp index 5768f9215fea9..1a3d1d0bda9d2 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -6533,10 +6533,10 @@ struct llm_build_context { struct ggml_tensor * wk_b = ggml_view_3d(ctx0, model.layers[il].wk_b, n_embd_head_qk_nope, kv_lora_rank, n_head, ggml_row_size(model.layers[il].wk_b->type, n_embd_head_qk_nope), ggml_row_size(model.layers[il].wk_b->type, kv_lora_rank * n_embd_head_qk_nope), 0); cb(wk_b, "wk_b", il); - struct ggml_tensor * q_nope_perm = ggml_permute(ctx0, q_nope, 0, 2, 1, 3); - cb(q_nope_perm, "q_nope_perm", il); + q_nope = ggml_permute(ctx0, q_nope, 0, 2, 1, 3); + cb(q_nope, "q_nope_perm", il); - struct ggml_tensor * q_nope2 = ggml_mul_mat(ctx0, wk_b, q_nope_perm); + struct ggml_tensor * q_nope2 = ggml_mul_mat(ctx0, wk_b, q_nope); cb(q_nope2, "q_nope2", il); if (!pp_opt) { @@ -6547,6 +6547,11 @@ struct llm_build_context { struct ggml_tensor * kq_nope = ggml_mul_mat(ctx0, kv_cache, q_nope2); cb(kq_nope, "kq_nope", il); + if (!pp_opt) { + kq_nope = ggml_permute(ctx0, kq_nope, 0, 2, 1, 3); + cb(kq_nope, "kq_nope_perm", il); + } + if (pp_opt) { q_pe = ggml_permute(ctx0, q_pe, 0, 2, 1, 3); cb(q_pe, "q_pe_perm", il); @@ -6555,14 +6560,14 @@ struct llm_build_context { struct ggml_tensor * kq_pe = ggml_mul_mat(ctx0, kr_cache, q_pe); cb(kq_pe, "kq_pe", il); - struct ggml_tensor * kq = ggml_add(ctx0, kq_nope, kq_pe); - cb(kq, "kq", il); - if (!pp_opt) { - kq = ggml_cont(ctx0, ggml_permute(ctx0, kq, 0, 2, 1, 3)); - cb(kq, "kq_perm", il); + kq_pe = ggml_permute(ctx0, kq_pe, 0, 2, 1, 3); + cb(kq_pe, "kq_pe_perm", il); } + struct ggml_tensor * kq = ggml_add(ctx0, kq_nope, kq_pe); + cb(kq, "kq", il); + kq = ggml_soft_max_ext(ctx0, kq, KQ_mask, kq_scale, hparams.f_max_alibi_bias); cb(kq, "kq_soft_max_ext", il); @@ -6575,7 +6580,7 @@ struct llm_build_context { cb(kqv_compressed, "kqv_compressed", il); if (!pp_opt) { - kqv_compressed = ggml_permute(ctx0, kqv_compressed, 0, 2, 1, 3); + kqv_compressed = ggml_permute(ctx0, kqv_compressed, 0, 2, 3, 1); cb(kqv_compressed, "kqv_compressed_perm", il); } @@ -6585,8 +6590,10 @@ struct llm_build_context { struct ggml_tensor * kqv = ggml_mul_mat(ctx0, wv_b, kqv_compressed); cb(kqv, "kqv", il); - kqv = ggml_cont(ctx0, ggml_permute(ctx0, kqv, 0, 2, 1, 3)); - cb(kqv, "kqv_perm", il); + if (pp_opt) { + kqv = ggml_cont(ctx0, ggml_permute(ctx0, kqv, 0, 2, 1, 3)); + cb(kqv, "kqv_perm", il); + } cur = ggml_view_2d(ctx0, kqv, n_embd_head_v*n_head, n_tokens, ggml_row_size(kqv->type, n_embd_head_v*n_head), 0); cb(cur, "kqv_2d", il);