Skip to content

Commit fac72a2

Browse files
committed
Alternative implementation using ggml_flash_attn (~16% slower than falcon40b-norepeat in wall timings on CPU)
1 parent d5295b4 commit fac72a2

File tree

1 file changed

+2
-21
lines changed

1 file changed

+2
-21
lines changed

examples/falcon/main.cpp

+2-21
Original file line numberDiff line numberDiff line change
@@ -536,27 +536,8 @@ bool falcon_eval(
536536
il * n_ctx * ggml_element_size(model.memory_k) * n_head_kv * head_dim),
537537
0, 2, 1, 3);
538538

539-
// K * Q
540-
541-
K = ggml_cont(ctx0, ggml_repeat2(ctx0, K, repeat_dummy));
542-
543539
struct ggml_tensor * Q = ggml_permute(ctx0, Qcur, 0, 2, 1, 3);
544-
struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
545-
546-
// KQ_scaled = KQ / sqrt(n_embd/n_head)
547-
struct ggml_tensor * KQ_scaled =
548-
ggml_scale_inplace(ctx0,
549-
KQ,
550-
ggml_new_f32(ctx0, 1.0f/sqrt(float(head_dim)))
551-
);
552540

553-
// KQ_masked = mask_past(KQ_scaled)
554-
struct ggml_tensor * KQ_masked = ggml_diag_mask_inf_inplace(ctx0, KQ_scaled, n_past);
555-
556-
// KQ = soft_max(KQ_masked)
557-
struct ggml_tensor * KQ_soft_max = ggml_soft_max_inplace(ctx0, KQ_masked);
558-
559-
// V_trans = Vmem.view(n_embd/n_head, n_head, n_past + N).permute(1, 2, 0, 3).contiguous()
560541
struct ggml_tensor* V = ggml_permute(
561542
ctx0,
562543
ggml_view_3d(
@@ -568,10 +549,10 @@ bool falcon_eval(
568549
il * n_ctx * ggml_element_size(model.memory_v) * n_head_kv * head_dim),
569550
0, 2, 1, 3);
570551

552+
K = ggml_cont(ctx0, ggml_repeat2(ctx0, K, repeat_dummy));
571553
V = ggml_cont(ctx0, ggml_transpose(ctx0, ggml_repeat2(ctx0, V, repeat_dummy)));
572554

573-
// KQV = transpose(V) * KQ_soft_max
574-
struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max);
555+
struct ggml_tensor * KQV = ggml_flash_attn(ctx0, Q, K, V, true);
575556

576557
// KQV_merged = KQV.permute(0, 2, 1, 3)
577558
struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);

0 commit comments

Comments
 (0)