diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 337fb5cb0df36..7fd134aab8eef 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -1091,22 +1091,25 @@ ggml_tensor * llm_graph_context::build_attn_mha( ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32); if (v_mla) { -#if 0 - // v_mla can be applied as a matrix-vector multiplication with broadcasting across dimension 3 == n_tokens. - // However, the code is optimized for dimensions 0 and 1 being large, so this is ineffient. - cur = ggml_reshape_4d(ctx0, cur, v_mla->ne[0], 1, n_head, n_tokens); - cur = ggml_mul_mat(ctx0, v_mla, cur); -#else - // It's preferable to do the calculation as a matrix-matrix multiplication with n_tokens in dimension 1. - // The permutations are noops and only change how the tensor data is interpreted. - cur = ggml_permute(ctx0, cur, 0, 2, 1, 3); - cur = ggml_mul_mat(ctx0, v_mla, cur); - cur = ggml_permute(ctx0, cur, 0, 2, 1, 3); - cur = ggml_cont(ctx0, cur); // Needed because ggml_reshape_2d expects contiguous inputs. -#endif + // To "decompress" from MQA back to MHA, v_mla can be either be applied as: + // 1. A matrix-vector multiplication with broadcasting across dimension 3 == n_tokens. + // - The code is optimized for dimensions 0 and 1 being large, so this is ineffient. + // 2. A matrix-matrix multiplication with n_tokens in dimension 1. + // - The added cost of the cont means that (1) is still more effeicent for small batches. + if (n_tokens < 32) { + cur = ggml_reshape_4d(ctx0, cur, v_mla->ne[0], 1, n_head, n_tokens); + cur = ggml_mul_mat(ctx0, v_mla, cur); + cur = ggml_reshape_2d(ctx0, cur, cur->ne[0]*n_head, n_tokens); + } else { + cur = ggml_permute(ctx0, cur, 0, 2, 1, 3); + cur = ggml_mul_mat(ctx0, v_mla, cur); + cur = ggml_permute(ctx0, cur, 0, 2, 1, 3); + cur = ggml_cont_2d(ctx0, cur, cur->ne[0]*n_head, n_tokens); + } + } else { + cur = ggml_reshape_2d(ctx0, cur, cur->ne[0]*n_head, n_tokens); } - cur = ggml_reshape_2d(ctx0, cur, cur->ne[0]*n_head, n_tokens); } else { ggml_tensor * kq = ggml_mul_mat(ctx0, k, q);