Skip to content

Commit daa4228

Browse files
authored
llama : DeepSeek V2/V3 MLA implementation (ggml-org#12801)
* Merged using squash to remove all noise commit messages * Force flash attention off for `LLM_ARCH_DEEPSEEK2` - embedding too large * Removed 3 conts (2x RoPE and 1x RMS-norm) * Changed to use `<cmath>` instead of `<math.h>` * Reverted removal of the 3 conts * Used `reshape` in `llm_graph_context::build_attn_mha()` * Use `k_pe = ggml_reshape` * Removed the 3 conts again * Removed the 3D views of `wk_b` and `wv_b`, and just save and 3D in GGUF * Removed MQA optimisation from `build_attn_mha()` as no gains now * Simplified `is_mla` branch in `llm_build_deepseek2()` * Removed `build_attn_mla` and added `nullptr` to all `build_atnn` calls * Fixed call to `build_attn` in `llm_build_t5_enc`
1 parent eccc7a1 commit daa4228

13 files changed

+288
-164
lines changed

convert_hf_to_gguf.py

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4422,6 +4422,10 @@ def set_vocab(self):
44224422
self._set_vocab_gpt2()
44234423

44244424
def set_gguf_parameters(self):
4425+
4426+
# note: deepseek2 using MLA converts into MQA (ie: GQA with 1 group)
4427+
self.hparams["num_key_value_heads"] = 1
4428+
44254429
super().set_gguf_parameters()
44264430
hparams = self.hparams
44274431

@@ -4430,8 +4434,13 @@ def set_gguf_parameters(self):
44304434
if "q_lora_rank" in hparams and hparams["q_lora_rank"] is not None:
44314435
self.gguf_writer.add_q_lora_rank(hparams["q_lora_rank"])
44324436
self.gguf_writer.add_kv_lora_rank(hparams["kv_lora_rank"])
4433-
self.gguf_writer.add_key_length(hparams["qk_nope_head_dim"] + hparams["qk_rope_head_dim"])
4434-
self.gguf_writer.add_value_length(hparams["v_head_dim"])
4437+
4438+
# note: deepseek2 using MLA converts into MQA with larger heads, then decompresses to MHA
4439+
self.gguf_writer.add_key_length(hparams["kv_lora_rank"] + hparams["qk_rope_head_dim"])
4440+
self.gguf_writer.add_value_length(hparams["kv_lora_rank"])
4441+
self.gguf_writer.add_key_length_mla(hparams["qk_nope_head_dim"] + hparams["qk_rope_head_dim"])
4442+
self.gguf_writer.add_value_length_mla(hparams["v_head_dim"])
4443+
44354444
self.gguf_writer.add_expert_feed_forward_length(hparams["moe_intermediate_size"])
44364445
self.gguf_writer.add_expert_count(hparams["n_routed_experts"])
44374446
self.gguf_writer.add_expert_shared_count(hparams["n_shared_experts"])
@@ -4500,6 +4509,26 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
45004509
else:
45014510
return []
45024511

4512+
# note: MLA with the absorption optimization, needs these two split and k_b_proj transposed
4513+
if name.endswith("kv_b_proj.weight"):
4514+
name_kb = name.replace("kv_b_proj", "k_b_proj")
4515+
name_vb = name.replace("kv_b_proj", "v_b_proj")
4516+
4517+
n_head_kv = self.hparams["num_key_value_heads"]
4518+
v_head_dim = self.hparams["v_head_dim"]
4519+
qk_nope_head_dim = self.hparams["qk_nope_head_dim"]
4520+
4521+
assert data_torch.shape[0] == n_head_kv * (v_head_dim + qk_nope_head_dim)
4522+
4523+
kv_b = data_torch.view(n_head_kv, v_head_dim + qk_nope_head_dim, data_torch.shape[-1])
4524+
k_b, v_b = torch.split(kv_b, [qk_nope_head_dim, v_head_dim], dim=1)
4525+
k_b = k_b.transpose(1, 2)
4526+
4527+
return [
4528+
(self.map_tensor_name(name_kb), k_b),
4529+
(self.map_tensor_name(name_vb), v_b)
4530+
]
4531+
45034532
return [(self.map_tensor_name(name), data_torch)]
45044533

45054534
def prepare_tensors(self):

gguf-py/gguf/constants.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,8 @@ class Attention:
139139
REL_BUCKETS_COUNT = "{arch}.attention.relative_buckets_count"
140140
SLIDING_WINDOW = "{arch}.attention.sliding_window"
141141
SCALE = "{arch}.attention.scale"
142+
KEY_LENGTH_MLA = "{arch}.attention.key_length_mla"
143+
VALUE_LENGTH_MLA = "{arch}.attention.value_length_mla"
142144

143145
class Rope:
144146
DIMENSION_COUNT = "{arch}.rope.dimension_count"
@@ -382,6 +384,8 @@ class MODEL_TENSOR(IntEnum):
382384
ATTN_Q_B = auto()
383385
ATTN_KV_A_MQA = auto()
384386
ATTN_KV_B = auto()
387+
ATTN_K_B = auto()
388+
ATTN_V_B = auto()
385389
ATTN_Q_A_NORM = auto()
386390
ATTN_KV_A_NORM = auto()
387391
FFN_SUB_NORM = auto()
@@ -590,6 +594,8 @@ class MODEL_TENSOR(IntEnum):
590594
MODEL_TENSOR.ATTN_Q_B: "blk.{bid}.attn_q_b",
591595
MODEL_TENSOR.ATTN_KV_A_MQA: "blk.{bid}.attn_kv_a_mqa",
592596
MODEL_TENSOR.ATTN_KV_B: "blk.{bid}.attn_kv_b",
597+
MODEL_TENSOR.ATTN_K_B: "blk.{bid}.attn_k_b",
598+
MODEL_TENSOR.ATTN_V_B: "blk.{bid}.attn_v_b",
593599
MODEL_TENSOR.ATTN_Q_A_NORM: "blk.{bid}.attn_q_a_norm",
594600
MODEL_TENSOR.ATTN_KV_A_NORM: "blk.{bid}.attn_kv_a_norm",
595601
MODEL_TENSOR.ATTN_SUB_NORM: "blk.{bid}.attn_sub_norm",
@@ -1517,6 +1523,8 @@ class MODEL_TENSOR(IntEnum):
15171523
MODEL_TENSOR.ATTN_Q_B,
15181524
MODEL_TENSOR.ATTN_KV_A_MQA,
15191525
MODEL_TENSOR.ATTN_KV_B,
1526+
MODEL_TENSOR.ATTN_K_B,
1527+
MODEL_TENSOR.ATTN_V_B,
15201528
MODEL_TENSOR.ATTN_Q_A_NORM,
15211529
MODEL_TENSOR.ATTN_KV_A_NORM,
15221530
MODEL_TENSOR.ATTN_OUT,

gguf-py/gguf/gguf_writer.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -689,6 +689,12 @@ def add_key_length(self, length: int) -> None:
689689
def add_value_length(self, length: int) -> None:
690690
self.add_uint32(Keys.Attention.VALUE_LENGTH.format(arch=self.arch), length)
691691

692+
def add_key_length_mla(self, length: int) -> None:
693+
self.add_uint32(Keys.Attention.KEY_LENGTH_MLA.format(arch=self.arch), length)
694+
695+
def add_value_length_mla(self, length: int) -> None:
696+
self.add_uint32(Keys.Attention.VALUE_LENGTH_MLA.format(arch=self.arch), length)
697+
692698
def add_max_alibi_bias(self, bias: float) -> None:
693699
self.add_float32(Keys.Attention.MAX_ALIBI_BIAS.format(arch=self.arch), bias)
694700

gguf-py/gguf/tensor_mapping.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -677,6 +677,14 @@ class TensorNameMap:
677677
"model.layers.{bid}.self_attn.kv_b_proj", # deepseek2
678678
),
679679

680+
MODEL_TENSOR.ATTN_K_B: (
681+
"model.layers.{bid}.self_attn.k_b_proj", # deepseek2
682+
),
683+
684+
MODEL_TENSOR.ATTN_V_B: (
685+
"model.layers.{bid}.self_attn.v_b_proj", # deepseek2
686+
),
687+
680688
MODEL_TENSOR.ATTN_Q_A_NORM: (
681689
"model.layers.{bid}.self_attn.q_a_layernorm", # deepseek2
682690
),

src/llama-arch.cpp

Lines changed: 6 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,8 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
140140
{ LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT, "%s.attention.relative_buckets_count" },
141141
{ LLM_KV_ATTENTION_SLIDING_WINDOW, "%s.attention.sliding_window" },
142142
{ LLM_KV_ATTENTION_SCALE, "%s.attention.scale" },
143+
{ LLM_KV_ATTENTION_KEY_LENGTH_MLA, "%s.attention.key_length_mla" },
144+
{ LLM_KV_ATTENTION_VALUE_LENGTH_MLA, "%s.attention.value_length_mla" },
143145

144146
{ LLM_KV_ROPE_DIMENSION_COUNT, "%s.rope.dimension_count" },
145147
{ LLM_KV_ROPE_DIMENSION_SECTIONS, "%s.rope.dimension_sections" },
@@ -1103,6 +1105,8 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
11031105
{ LLM_TENSOR_ATTN_Q_B, "blk.%d.attn_q_b" },
11041106
{ LLM_TENSOR_ATTN_KV_A_MQA, "blk.%d.attn_kv_a_mqa" },
11051107
{ LLM_TENSOR_ATTN_KV_B, "blk.%d.attn_kv_b" },
1108+
{ LLM_TENSOR_ATTN_K_B, "blk.%d.attn_k_b" },
1109+
{ LLM_TENSOR_ATTN_V_B, "blk.%d.attn_v_b" },
11061110
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
11071111
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
11081112
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
@@ -1563,23 +1567,8 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
15631567
{LLM_TENSOR_ATTN_Q_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
15641568
{LLM_TENSOR_ATTN_KV_A_MQA, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
15651569
{LLM_TENSOR_ATTN_KV_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
1566-
{LLM_TENSOR_DEC_ATTN_Q, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
1567-
{LLM_TENSOR_DEC_ATTN_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
1568-
{LLM_TENSOR_ATTN_Q, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
1569-
{LLM_TENSOR_ATTN_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
1570-
{LLM_TENSOR_ATTN_V, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
1571-
{LLM_TENSOR_ATTN_QKV, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
1572-
{LLM_TENSOR_ATTN_OUT, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
1573-
{LLM_TENSOR_FFN_GATE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
1574-
{LLM_TENSOR_FFN_DOWN, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
1575-
{LLM_TENSOR_FFN_UP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
1576-
{LLM_TENSOR_FFN_DOWN_SHEXP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
1577-
{LLM_TENSOR_FFN_GATE_SHEXP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
1578-
{LLM_TENSOR_FFN_UP_SHEXP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
1579-
{LLM_TENSOR_ATTN_Q_A, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
1580-
{LLM_TENSOR_ATTN_Q_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
1581-
{LLM_TENSOR_ATTN_KV_A_MQA, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
1582-
{LLM_TENSOR_ATTN_KV_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
1570+
{LLM_TENSOR_ATTN_K_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
1571+
{LLM_TENSOR_ATTN_V_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
15831572
{LLM_TENSOR_DEC_ATTN_Q, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
15841573
{LLM_TENSOR_DEC_ATTN_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
15851574
{LLM_TENSOR_DEC_ATTN_V, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},

src/llama-arch.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,8 @@ enum llm_kv {
144144
LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT,
145145
LLM_KV_ATTENTION_SLIDING_WINDOW,
146146
LLM_KV_ATTENTION_SCALE,
147+
LLM_KV_ATTENTION_KEY_LENGTH_MLA,
148+
LLM_KV_ATTENTION_VALUE_LENGTH_MLA,
147149

148150
LLM_KV_ROPE_DIMENSION_COUNT,
149151
LLM_KV_ROPE_DIMENSION_SECTIONS,
@@ -306,6 +308,8 @@ enum llm_tensor {
306308
LLM_TENSOR_ATTN_Q_B,
307309
LLM_TENSOR_ATTN_KV_A_MQA,
308310
LLM_TENSOR_ATTN_KV_B,
311+
LLM_TENSOR_ATTN_K_B,
312+
LLM_TENSOR_ATTN_V_B,
309313
LLM_TENSOR_ATTN_Q_A_NORM,
310314
LLM_TENSOR_ATTN_KV_A_NORM,
311315
LLM_TENSOR_ATTN_SUB_NORM,

src/llama-context.cpp

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#include <cstring>
1111
#include <stdexcept>
1212
#include <cinttypes>
13+
#include <cmath>
1314

1415
//
1516
// llama_context
@@ -473,7 +474,6 @@ ggml_tensor * llama_context::build_rope_shift(
473474
const auto & n_ctx_orig = cparams.n_ctx_orig_yarn;
474475

475476
const auto & yarn_ext_factor = cparams.yarn_ext_factor;
476-
const auto & yarn_attn_factor = cparams.yarn_attn_factor;
477477
const auto & yarn_beta_fast = cparams.yarn_beta_fast;
478478
const auto & yarn_beta_slow = cparams.yarn_beta_slow;
479479

@@ -482,6 +482,10 @@ ggml_tensor * llama_context::build_rope_shift(
482482
const auto & n_rot = hparams.n_rot;
483483
const auto & rope_type = hparams.rope_type;
484484

485+
// See llm_build_deepseek2() for why attn_factor has to be scaled for YaRN RoPE to work correctly.
486+
// See https://github.com/ggerganov/llama.cpp/discussions/7416 for detailed explanation.
487+
const float yarn_attn_factor_scaled = model.arch == LLM_ARCH_DEEPSEEK2 ? 1.0f / (1.0f + 0.1f * logf(1.0f / freq_scale)) : cparams.yarn_attn_factor;
488+
485489
ggml_tensor * tmp;
486490

487491
if (ggml_is_quantized(cur->type)) {
@@ -500,14 +504,14 @@ ggml_tensor * llama_context::build_rope_shift(
500504

501505
tmp = ggml_rope_ext_inplace(ctx0, tmp,
502506
shift, factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
503-
yarn_ext_factor, yarn_attn_factor, yarn_beta_fast, yarn_beta_slow);
507+
yarn_ext_factor, yarn_attn_factor_scaled, yarn_beta_fast, yarn_beta_slow);
504508

505509
tmp = ggml_cpy(ctx0, tmp, cur);
506510
} else {
507511
// we rotate only the first n_rot dimensions
508512
tmp = ggml_rope_ext_inplace(ctx0, cur,
509513
shift, factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
510-
yarn_ext_factor, yarn_attn_factor, yarn_beta_fast, yarn_beta_slow);
514+
yarn_ext_factor, yarn_attn_factor_scaled, yarn_beta_fast, yarn_beta_slow);
511515
}
512516

513517
return tmp;
@@ -2274,6 +2278,11 @@ llama_context * llama_init_from_model(
22742278
params.flash_attn = false;
22752279
}
22762280

2281+
if (params.flash_attn && model->arch == LLM_ARCH_DEEPSEEK2) {
2282+
LLAMA_LOG_WARN("%s: flash_attn is not compatible with Deepseek2 - forcing off\n", __func__);
2283+
params.flash_attn = false;
2284+
}
2285+
22772286
if (ggml_is_quantized(params.type_v) && !params.flash_attn) {
22782287
LLAMA_LOG_ERROR("%s: V cache quantization requires flash_attn\n", __func__);
22792288
return nullptr;

src/llama-graph.cpp

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1188,6 +1188,7 @@ ggml_tensor * llm_graph_context::build_attn_mha(
11881188
ggml_tensor * v,
11891189
ggml_tensor * kq_b,
11901190
ggml_tensor * kq_mask,
1191+
ggml_tensor * v_mla,
11911192
bool v_trans,
11921193
float kq_scale) const {
11931194
//const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
@@ -1199,7 +1200,8 @@ ggml_tensor * llm_graph_context::build_attn_mha(
11991200
//const auto & n_embd_head_k = hparams.n_embd_head_k;
12001201
//const auto & n_embd_head_v = hparams.n_embd_head_v;
12011202

1202-
const auto n_embd_head_v = v_trans ? v->ne[1] : v->ne[0];
1203+
// note: for MLA with the absorption optimization, the final embedding size will be changed via v_mla
1204+
const auto n_embd_head_v = v_mla == nullptr ? v_trans ? v->ne[1] : v->ne[0] : v_mla->ne[1];
12031205

12041206
const auto n_tokens = q->ne[1];
12051207
const auto n_head = q->ne[2];
@@ -1267,6 +1269,11 @@ ggml_tensor * llm_graph_context::build_attn_mha(
12671269

12681270
ggml_tensor * kqv = ggml_mul_mat(ctx0, v, kq);
12691271

1272+
// for MLA with the absorption optimization, we need to "decompress" from MQA back to MHA
1273+
if (v_mla) {
1274+
kqv = ggml_mul_mat(ctx0, v_mla, kqv);
1275+
}
1276+
12701277
ggml_tensor * kqv_merged = ggml_permute(ctx0, kqv, 0, 2, 1, 3);
12711278

12721279
cur = ggml_cont_2d(ctx0, kqv_merged, n_embd_head_v*n_head, n_tokens);
@@ -1304,6 +1311,7 @@ ggml_tensor * llm_graph_context::build_attn(
13041311
ggml_tensor * k_cur,
13051312
ggml_tensor * v_cur,
13061313
ggml_tensor * kq_b,
1314+
ggml_tensor * v_mla,
13071315
float kq_scale,
13081316
int il) const {
13091317
GGML_UNUSED(n_tokens);
@@ -1325,7 +1333,7 @@ ggml_tensor * llm_graph_context::build_attn(
13251333
ggml_tensor * v = ggml_permute(ctx0, v_cur, 0, 2, 1, 3);
13261334
//cb(k, "v", il);
13271335

1328-
ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, false, kq_scale);
1336+
ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, false, kq_scale);
13291337

13301338
cb(cur, "kqv_out", il);
13311339

@@ -1379,6 +1387,7 @@ ggml_tensor * llm_graph_context::build_attn(
13791387
ggml_tensor * k_cur,
13801388
ggml_tensor * v_cur,
13811389
ggml_tensor * kq_b,
1390+
ggml_tensor * v_mla,
13821391
float kq_scale,
13831392
int il) const {
13841393
// these nodes are added to the graph together so that they are not reordered
@@ -1464,7 +1473,7 @@ ggml_tensor * llm_graph_context::build_attn(
14641473
ggml_element_size(kv_self->v_l[il])*n_ctx*n_embd_head_v,
14651474
0);
14661475

1467-
ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_trans, kq_scale);
1476+
ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, v_trans, kq_scale);
14681477
cb(cur, "kqv_out", il);
14691478

14701479
if (wo) {
@@ -1504,6 +1513,7 @@ ggml_tensor * llm_graph_context::build_attn(
15041513
ggml_tensor * k_cur,
15051514
ggml_tensor * v_cur,
15061515
ggml_tensor * kq_b,
1516+
ggml_tensor * v_mla,
15071517
float kq_scale,
15081518
int il) const {
15091519
// these nodes are added to the graph together so that they are not reordered
@@ -1523,7 +1533,7 @@ ggml_tensor * llm_graph_context::build_attn(
15231533
ggml_tensor * v = ggml_permute(ctx0, v_cur, 0, 2, 1, 3);
15241534
//cb(k, "v", il);
15251535

1526-
ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, false, kq_scale);
1536+
ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, false, kq_scale);
15271537

15281538
cb(cur, "kqv_out", il);
15291539

@@ -1692,4 +1702,3 @@ void llm_graph_context::build_pooling(
16921702

16931703
ggml_build_forward_expand(gf, cur);
16941704
}
1695-

src/llama-graph.h

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -505,11 +505,12 @@ struct llm_graph_context {
505505

506506
ggml_tensor * build_attn_mha(
507507
ggml_cgraph * gf,
508-
ggml_tensor * q, // [n_embd_head_q, n_tokens, n_head_q]
509-
ggml_tensor * k, // [n_embd_head_k, n_tokens, n_head_k]
510-
ggml_tensor * v, // [n_embd_head_v, n_tokens, n_head_v] (v_trans == false)
508+
ggml_tensor * q, // [n_embd_head_q, n_tokens, n_head_q]
509+
ggml_tensor * k, // [n_embd_head_k, n_tokens, n_head_k]
510+
ggml_tensor * v, // [n_embd_head_v, n_tokens, n_head_v] (v_trans == false)
511511
ggml_tensor * kq_b,
512512
ggml_tensor * kq_mask,
513+
ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
513514
bool v_trans,
514515
float kq_scale) const;
515516

@@ -524,6 +525,7 @@ struct llm_graph_context {
524525
ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
525526
ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
526527
ggml_tensor * kq_b,
528+
ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
527529
float kq_scale,
528530
int il) const;
529531

@@ -538,6 +540,7 @@ struct llm_graph_context {
538540
ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
539541
ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
540542
ggml_tensor * kq_b,
543+
ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
541544
float kq_scale,
542545
int il) const;
543546

@@ -552,6 +555,7 @@ struct llm_graph_context {
552555
ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
553556
ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
554557
ggml_tensor * kq_b,
558+
ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
555559
float kq_scale,
556560
int il) const;
557561

src/llama-hparams.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,10 @@ struct llama_hparams {
4343
uint32_t n_expert_used = 0;
4444
uint32_t n_rel_attn_bkts = 0;
4545

46+
// note: deepseek2 using MLA converts into MQA with larger heads, then decompresses to MHA
47+
uint32_t n_embd_head_k_mla = 0;
48+
uint32_t n_embd_head_v_mla = 0;
49+
4650
// for WavTokenizer
4751
struct llama_hparams_posnet posnet;
4852
struct llama_hparams_convnext convnext;

src/llama-kv-cache.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ bool llama_kv_cache_unified::init(
2727

2828
recurrent = llama_model_is_recurrent(&model);
2929
v_trans = !recurrent && !cparams.flash_attn;
30-
can_shift = !recurrent && model.arch != LLM_ARCH_DEEPSEEK2; // not supported due to MLA
30+
can_shift = !recurrent;
3131

3232
LLAMA_LOG_INFO("%s: kv_size = %d, offload = %d, type_k = '%s', type_v = '%s', n_layer = %d, can_shift = %d\n",
3333
__func__, kv_size, offload, ggml_type_name(type_k), ggml_type_name(type_v), n_layer, can_shift);

0 commit comments

Comments
 (0)