@@ -536,27 +536,8 @@ bool falcon_eval(
536
536
il * n_ctx * ggml_element_size (model.memory_k ) * n_head_kv * head_dim),
537
537
0 , 2 , 1 , 3 );
538
538
539
- // K * Q
540
-
541
- K = ggml_cont (ctx0, ggml_repeat2 (ctx0, K, repeat_dummy));
542
-
543
539
struct ggml_tensor * Q = ggml_permute (ctx0, Qcur, 0 , 2 , 1 , 3 );
544
- struct ggml_tensor * KQ = ggml_mul_mat (ctx0, K, Q);
545
-
546
- // KQ_scaled = KQ / sqrt(n_embd/n_head)
547
- struct ggml_tensor * KQ_scaled =
548
- ggml_scale_inplace (ctx0,
549
- KQ,
550
- ggml_new_f32 (ctx0, 1 .0f /sqrt (float (head_dim)))
551
- );
552
540
553
- // KQ_masked = mask_past(KQ_scaled)
554
- struct ggml_tensor * KQ_masked = ggml_diag_mask_inf_inplace (ctx0, KQ_scaled, n_past);
555
-
556
- // KQ = soft_max(KQ_masked)
557
- struct ggml_tensor * KQ_soft_max = ggml_soft_max_inplace (ctx0, KQ_masked);
558
-
559
- // V_trans = Vmem.view(n_embd/n_head, n_head, n_past + N).permute(1, 2, 0, 3).contiguous()
560
541
struct ggml_tensor * V = ggml_permute (
561
542
ctx0,
562
543
ggml_view_3d (
@@ -568,10 +549,10 @@ bool falcon_eval(
568
549
il * n_ctx * ggml_element_size (model.memory_v ) * n_head_kv * head_dim),
569
550
0 , 2 , 1 , 3 );
570
551
552
+ K = ggml_cont (ctx0, ggml_repeat2 (ctx0, K, repeat_dummy));
571
553
V = ggml_cont (ctx0, ggml_transpose (ctx0, ggml_repeat2 (ctx0, V, repeat_dummy)));
572
554
573
- // KQV = transpose(V) * KQ_soft_max
574
- struct ggml_tensor * KQV = ggml_mul_mat (ctx0, V, KQ_soft_max);
555
+ struct ggml_tensor * KQV = ggml_flash_attn (ctx0, Q, K, V, true );
575
556
576
557
// KQV_merged = KQV.permute(0, 2, 1, 3)
577
558
struct ggml_tensor * KQV_merged = ggml_permute (ctx0, KQV, 0 , 2 , 1 , 3 );
0 commit comments