Skip to content

Commit 3b4bab6

Browse files
committed
llama : replace ggml_diag_mask_inf with ggml_add (custom -inf mask)
1 parent c5df72e commit 3b4bab6

File tree

4 files changed

+129
-32
lines changed

4 files changed

+129
-32
lines changed

ggml-metal.m

+42-8
Original file line numberDiff line numberDiff line change
@@ -736,25 +736,59 @@ void ggml_metal_graph_compute(
736736
GGML_ASSERT(ggml_is_contiguous(src0));
737737
GGML_ASSERT(ggml_is_contiguous(src1));
738738

739-
// utilize float4
740-
GGML_ASSERT(ne00 % 4 == 0);
741-
const int64_t nb = ne00/4;
739+
bool bcast_row = false;
742740

743-
if (ggml_nelements(src1) == ne10) {
741+
int64_t nb = ne00;
742+
743+
if (ggml_nelements(src1) == ne10 && ne00 % 4 == 0) {
744744
// src1 is a row
745745
GGML_ASSERT(ne11 == 1);
746+
747+
nb = ne00 / 4;
746748
[encoder setComputePipelineState:ctx->pipeline_add_row];
749+
750+
bcast_row = true;
747751
} else {
748752
[encoder setComputePipelineState:ctx->pipeline_add];
749753
}
750754
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
751755
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
752756
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
753-
[encoder setBytes:&nb length:sizeof(nb) atIndex:3];
754-
755-
const int64_t n = ggml_nelements(dst)/4;
757+
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
758+
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
759+
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
760+
[encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6];
761+
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7];
762+
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:8];
763+
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:9];
764+
[encoder setBytes:&nb03 length:sizeof(nb03) atIndex:10];
765+
[encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11];
766+
[encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12];
767+
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13];
768+
[encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14];
769+
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15];
770+
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16];
771+
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17];
772+
[encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18];
773+
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:19];
774+
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:20];
775+
[encoder setBytes:&ne2 length:sizeof(ne2) atIndex:21];
776+
[encoder setBytes:&ne3 length:sizeof(ne3) atIndex:22];
777+
[encoder setBytes:&nb0 length:sizeof(nb0) atIndex:23];
778+
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:24];
779+
[encoder setBytes:&nb2 length:sizeof(nb2) atIndex:25];
780+
[encoder setBytes:&nb3 length:sizeof(nb3) atIndex:26];
781+
[encoder setBytes:&nb length:sizeof(nb) atIndex:27];
782+
783+
if (bcast_row) {
784+
const int64_t n = ggml_nelements(dst)/4;
785+
786+
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
787+
} else {
788+
const int nth = MIN(1024, ne0);
756789

757-
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
790+
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
791+
}
758792
} break;
759793
case GGML_OP_MUL:
760794
{

ggml-metal.metal

+53-6
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,59 @@ typedef struct {
2424
int8_t qs[QK8_0]; // quants
2525
} block_q8_0;
2626

27+
// general-purpose kernel for addition of two tensors
28+
// pros: works for non-contiguous tensors, supports broadcast across dims 1, 2 and 3
29+
// cons: not very efficient
2730
kernel void kernel_add(
28-
device const float4 * src0,
29-
device const float4 * src1,
30-
device float4 * dst,
31-
uint tpig[[thread_position_in_grid]]) {
32-
dst[tpig] = src0[tpig] + src1[tpig];
31+
device const char * src0,
32+
device const char * src1,
33+
device char * dst,
34+
constant int64_t & ne00,
35+
constant int64_t & ne01,
36+
constant int64_t & ne02,
37+
constant int64_t & ne03,
38+
constant int64_t & nb00,
39+
constant int64_t & nb01,
40+
constant int64_t & nb02,
41+
constant int64_t & nb03,
42+
constant int64_t & ne10,
43+
constant int64_t & ne11,
44+
constant int64_t & ne12,
45+
constant int64_t & ne13,
46+
constant int64_t & nb10,
47+
constant int64_t & nb11,
48+
constant int64_t & nb12,
49+
constant int64_t & nb13,
50+
constant int64_t & ne0,
51+
constant int64_t & ne1,
52+
constant int64_t & ne2,
53+
constant int64_t & ne3,
54+
constant int64_t & nb0,
55+
constant int64_t & nb1,
56+
constant int64_t & nb2,
57+
constant int64_t & nb3,
58+
uint3 tgpig[[threadgroup_position_in_grid]],
59+
uint3 tpitg[[thread_position_in_threadgroup]],
60+
uint3 ntg[[threads_per_threadgroup]]) {
61+
const int64_t i03 = tgpig.z;
62+
const int64_t i02 = tgpig.y;
63+
const int64_t i01 = tgpig.x;
64+
65+
const int64_t i13 = i03 % ne13;
66+
const int64_t i12 = i02 % ne12;
67+
const int64_t i11 = i01 % ne11;
68+
69+
device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01 + tpitg.x*nb00;
70+
device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11 + tpitg.x*nb10;
71+
device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1 + tpitg.x*nb0;
72+
73+
for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
74+
((device float *)dst_ptr)[0] = ((device float *)src0_ptr)[0] + ((device float *)src1_ptr)[0];
75+
76+
src0_ptr += ntg.x*nb00;
77+
src1_ptr += ntg.x*nb10;
78+
dst_ptr += ntg.x*nb0;
79+
}
3380
}
3481

3582
// assumption: src1 is a row
@@ -38,7 +85,7 @@ kernel void kernel_add_row(
3885
device const float4 * src0,
3986
device const float4 * src1,
4087
device float4 * dst,
41-
constant int64_t & nb,
88+
constant int64_t & nb [[buffer(27)]],
4289
uint tpig[[thread_position_in_grid]]) {
4390
dst[tpig] = src0[tpig] + src1[tpig % nb];
4491
}

ggml.c

-2
Original file line numberDiff line numberDiff line change
@@ -8797,8 +8797,6 @@ static void ggml_compute_forward_add_f32(
87978797
#else
87988798
ggml_vec_add_f32(ne00, dst_ptr, src0_ptr, src1_ptr);
87998799
#endif
8800-
// }
8801-
// }
88028800
}
88038801
} else {
88048802
// src1 is not contiguous

llama.cpp

+34-16
Original file line numberDiff line numberDiff line change
@@ -2404,13 +2404,30 @@ static struct ggml_cgraph * llm_build_llama(
24042404
}
24052405
#endif // GGML_USE_CUBLAS
24062406

2407+
// KQ_scale
24072408
struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1);
24082409
ggml_allocr_alloc(lctx.alloc, KQ_scale);
24092410
if (!ggml_allocr_is_measure(lctx.alloc)) {
24102411
ggml_set_f32(KQ_scale, 1.0f/sqrtf(float(n_embd)/n_head));
24112412
}
24122413
ggml_set_name(KQ_scale, "1/sqrt(n_embd_head)");
24132414

2415+
// KQ_mask
2416+
struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_past + N, N, 1);
2417+
ggml_allocr_alloc(lctx.alloc, KQ_mask);
2418+
if (!ggml_allocr_is_measure(lctx.alloc)) {
2419+
float * data = (float *) KQ_mask->data;
2420+
memset(data, 0, ggml_nbytes(KQ_mask));
2421+
2422+
for (int h = 0; h < 1; ++h) {
2423+
for (int j = 0; j < N; ++j) {
2424+
for (int i = n_past + j + 1; i < n_past + N; ++i) {
2425+
data[h*(n_past + N)*N + j*(n_past + N) + i] = -INFINITY;
2426+
}
2427+
}
2428+
}
2429+
}
2430+
24142431
for (int il = 0; il < n_layer; ++il) {
24152432
ggml_format_name(inpL, "layer_inp_%d", il);
24162433

@@ -2447,11 +2464,11 @@ static struct ggml_cgraph * llm_build_llama(
24472464
offload_func_kq(tmpq);
24482465
ggml_set_name(tmpq, "tmpq");
24492466

2450-
struct ggml_tensor * Kcur = ggml_rope_custom_inplace(ctx0, ggml_reshape_3d(ctx0, tmpk, n_embd_head, n_head_kv, N), n_past, n_embd_head, 0, 0, freq_base, freq_scale);
2467+
struct ggml_tensor * Kcur = ggml_rope_custom(ctx0, ggml_reshape_3d(ctx0, tmpk, n_embd_head, n_head_kv, N), n_past, n_embd_head, 0, 0, freq_base, freq_scale);
24512468
offload_func_kq(Kcur);
24522469
ggml_set_name(Kcur, "Kcur");
24532470

2454-
struct ggml_tensor * Qcur = ggml_rope_custom_inplace(ctx0, ggml_reshape_3d(ctx0, tmpq, n_embd_head, n_head, N), n_past, n_embd_head, 0, 0, freq_base, freq_scale);
2471+
struct ggml_tensor * Qcur = ggml_rope_custom(ctx0, ggml_reshape_3d(ctx0, tmpq, n_embd_head, n_head, N), n_past, n_embd_head, 0, 0, freq_base, freq_scale);
24552472
offload_func_kq(Qcur);
24562473
ggml_set_name(Qcur, "Qcur");
24572474

@@ -2502,17 +2519,18 @@ static struct ggml_cgraph * llm_build_llama(
25022519

25032520
// KQ_scaled = KQ / sqrt(n_embd_head)
25042521
// KQ_scaled shape [n_past + N, N, n_head, 1]
2505-
struct ggml_tensor * KQ_scaled = ggml_scale_inplace(ctx0, KQ, KQ_scale);
2522+
struct ggml_tensor * KQ_scaled = ggml_scale(ctx0, KQ, KQ_scale);
25062523
offload_func_kq(KQ_scaled);
25072524
ggml_set_name(KQ_scaled, "KQ_scaled");
25082525

25092526
// KQ_masked = mask_past(KQ_scaled)
2510-
struct ggml_tensor * KQ_masked = ggml_diag_mask_inf_inplace(ctx0, KQ_scaled, n_past);
2527+
struct ggml_tensor * KQ_masked = ggml_add(ctx0, KQ_scaled, KQ_mask);
2528+
//struct ggml_tensor * KQ_masked = ggml_diag_mask_inf_inplace(ctx0, KQ_scaled, n_past);
25112529
offload_func_kq(KQ_masked);
25122530
ggml_set_name(KQ_masked, "KQ_masked");
25132531

25142532
// KQ = soft_max(KQ_masked)
2515-
struct ggml_tensor * KQ_soft_max = ggml_soft_max_inplace(ctx0, KQ_masked);
2533+
struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ_masked);
25162534
offload_func_v(KQ_soft_max);
25172535
ggml_set_name(KQ_soft_max, "KQ_soft_max");
25182536

@@ -2783,8 +2801,8 @@ static struct ggml_cgraph * llm_build_baichaun(
27832801
struct ggml_tensor * Qcur;
27842802
switch (model.type) {
27852803
case MODEL_7B:
2786-
Kcur = ggml_rope_custom_inplace(ctx0, ggml_reshape_3d(ctx0, tmpk, n_embd_head, n_head_kv, N), n_past, n_embd_head, 0, 0, freq_base, freq_scale);
2787-
Qcur = ggml_rope_custom_inplace(ctx0, ggml_reshape_3d(ctx0, tmpq, n_embd_head, n_head, N), n_past, n_embd_head, 0, 0, freq_base, freq_scale);
2804+
Kcur = ggml_rope_custom(ctx0, ggml_reshape_3d(ctx0, tmpk, n_embd_head, n_head_kv, N), n_past, n_embd_head, 0, 0, freq_base, freq_scale);
2805+
Qcur = ggml_rope_custom(ctx0, ggml_reshape_3d(ctx0, tmpq, n_embd_head, n_head, N), n_past, n_embd_head, 0, 0, freq_base, freq_scale);
27882806
break;
27892807
case MODEL_13B:
27902808
Kcur = ggml_reshape_3d(ctx0, tmpk, n_embd/n_head, n_head, N);
@@ -2847,7 +2865,7 @@ static struct ggml_cgraph * llm_build_baichaun(
28472865

28482866
// KQ_scaled = KQ / sqrt(n_embd_head)
28492867
// KQ_scaled shape [n_past + N, N, n_head, 1]
2850-
struct ggml_tensor * KQ_scaled = ggml_scale_inplace(ctx0, KQ, KQ_scale);
2868+
struct ggml_tensor * KQ_scaled = ggml_scale(ctx0, KQ, KQ_scale);
28512869
offload_func_kq(KQ_scaled);
28522870
ggml_set_name(KQ_scaled, "KQ_scaled");
28532871

@@ -2856,7 +2874,7 @@ static struct ggml_cgraph * llm_build_baichaun(
28562874

28572875
switch (model.type) {
28582876
case MODEL_7B:
2859-
KQ_masked = ggml_diag_mask_inf_inplace(ctx0, KQ_scaled, n_past);
2877+
KQ_masked = ggml_diag_mask_inf(ctx0, KQ_scaled, n_past);
28602878
break;
28612879
case MODEL_13B:
28622880
KQ_scaled_alibi =ggml_alibi(ctx0, KQ_scaled, n_past, n_head, 8);
@@ -2867,13 +2885,13 @@ static struct ggml_cgraph * llm_build_baichaun(
28672885
GGML_ASSERT(false);
28682886
}
28692887
// KQ_masked = mask_past(KQ_scaled)
2870-
// struct ggml_tensor * KQ_masked = ggml_diag_mask_inf_inplace(ctx0, KQ_scaled, n_past);
2888+
// struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctx0, KQ_scaled, n_past);
28712889
// struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctx0, KQ_scaled_alibi, n_past);
28722890
// offload_func_kq(KQ_masked);
28732891
// ggml_set_name(KQ_masked, "KQ_masked");
28742892

28752893
// KQ = soft_max(KQ_masked)
2876-
struct ggml_tensor * KQ_soft_max = ggml_soft_max_inplace(ctx0, KQ_masked);
2894+
struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ_masked);
28772895
offload_func_v(KQ_soft_max);
28782896
ggml_set_name(KQ_soft_max, "KQ_soft_max");
28792897

@@ -3179,9 +3197,9 @@ static struct ggml_cgraph * llm_build_falcon(
31793197
offload_func_v(tmpv);
31803198

31813199
// using mode = 2 for neox mode
3182-
struct ggml_tensor * Qcur = ggml_rope_custom_inplace(ctx0, tmpq, n_past, n_embd_head, 2, 0, freq_base, freq_scale);
3200+
struct ggml_tensor * Qcur = ggml_rope_custom(ctx0, tmpq, n_past, n_embd_head, 2, 0, freq_base, freq_scale);
31833201
offload_func_kq(Qcur);
3184-
struct ggml_tensor * Kcur = ggml_rope_custom_inplace(ctx0, tmpk, n_past, n_embd_head, 2, 0, freq_base, freq_scale);
3202+
struct ggml_tensor * Kcur = ggml_rope_custom(ctx0, tmpk, n_past, n_embd_head, 2, 0, freq_base, freq_scale);
31853203
offload_func_kq(Kcur);
31863204

31873205
{
@@ -3220,15 +3238,15 @@ static struct ggml_cgraph * llm_build_falcon(
32203238
offload_func_kq(KQ);
32213239
ggml_set_name(KQ, "KQ");
32223240

3223-
struct ggml_tensor * KQ_scaled = ggml_scale_inplace(ctx0, KQ, KQ_scale);
3241+
struct ggml_tensor * KQ_scaled = ggml_scale(ctx0, KQ, KQ_scale);
32243242
offload_func_kq(KQ_scaled);
32253243
ggml_set_name(KQ_scaled, "KQ_scaled");
32263244

3227-
struct ggml_tensor * KQ_masked = ggml_diag_mask_inf_inplace(ctx0, KQ_scaled, n_past);
3245+
struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctx0, KQ_scaled, n_past);
32283246
offload_func_kq(KQ_masked);
32293247
ggml_set_name(KQ_masked, "KQ_masked");
32303248

3231-
struct ggml_tensor * KQ_soft_max = ggml_soft_max_inplace(ctx0, KQ_masked);
3249+
struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ_masked);
32323250
offload_func_v(KQ_soft_max);
32333251
ggml_set_name(KQ_soft_max, "KQ_soft_max");
32343252

0 commit comments

Comments
 (0)