Skip to content

Commit d4cd263

Browse files
committed
ggml : ggml_rope now takes a vector with positions instead of n_past
1 parent 3b4bab6 commit d4cd263

File tree

9 files changed

+270
-131
lines changed

9 files changed

+270
-131
lines changed

examples/baby-llama/baby-llama.cpp

+31-6
Original file line numberDiff line numberDiff line change
@@ -556,6 +556,14 @@ struct ggml_tensor * forward(
556556
struct ggml_tensor * kc = kv_self.k;
557557
struct ggml_tensor * vc = kv_self.v;
558558

559+
struct ggml_tensor * KQ_rope = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
560+
{
561+
int * data = (int *) KQ_rope->data;
562+
for (int i = 0; i < N; ++i) {
563+
data[i] = n_past + i;
564+
}
565+
}
566+
559567
// inpL shape [n_embd,N,1,1]
560568
struct ggml_tensor * inpL = ggml_get_rows(ctx0, model->tok_embeddings, tokens);
561569
for (int il = 0; il < n_layer; ++il) {
@@ -583,8 +591,8 @@ struct ggml_tensor * forward(
583591
// wk shape [n_embd, n_embd, 1, 1]
584592
// Qcur shape [n_embd/n_head, n_head, N, 1]
585593
// Kcur shape [n_embd/n_head, n_head, N, 1]
586-
struct ggml_tensor * Qcur = ggml_rope(ctx0, ggml_reshape_3d(ctx0, ggml_mul_mat(ctx0, model->layers[il].wq, cur), n_embd/n_head, n_head, N), n_past, n_rot, 0, 0);
587-
struct ggml_tensor * Kcur = ggml_rope(ctx0, ggml_reshape_3d(ctx0, ggml_mul_mat(ctx0, model->layers[il].wk, cur), n_embd/n_head, n_head, N), n_past, n_rot, 0, 0);
594+
struct ggml_tensor * Qcur = ggml_rope(ctx0, ggml_reshape_3d(ctx0, ggml_mul_mat(ctx0, model->layers[il].wq, cur), n_embd/n_head, n_head, N), KQ_rope, n_rot, 0, 0);
595+
struct ggml_tensor * Kcur = ggml_rope(ctx0, ggml_reshape_3d(ctx0, ggml_mul_mat(ctx0, model->layers[il].wk, cur), n_embd/n_head, n_head, N), KQ_rope, n_rot, 0, 0);
588596

589597
// store key and value to memory
590598
{
@@ -810,9 +818,18 @@ struct ggml_tensor * forward_batch(
810818
struct ggml_tensor * kc = kv_self.k;
811819
struct ggml_tensor * vc = kv_self.v;
812820

821+
struct ggml_tensor * KQ_rope = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
822+
{
823+
int * data = (int *) KQ_rope->data;
824+
for (int i = 0; i < N; ++i) {
825+
data[i] = n_past + i;
826+
}
827+
}
828+
813829
// inpL shape [n_embd,N*n_batch,1]
814830
struct ggml_tensor * inpL = ggml_get_rows(ctx0, model->tok_embeddings, tokens);
815831
assert_shape_2d(inpL, n_embd, N*n_batch);
832+
816833
for (int il = 0; il < n_layer; ++il) {
817834
struct ggml_tensor * inpSA = inpL;
818835

@@ -840,8 +857,8 @@ struct ggml_tensor * forward_batch(
840857
// wk shape [n_embd, n_embd, 1, 1]
841858
// Qcur shape [n_embd/n_head, n_head, N, n_batch]
842859
// Kcur shape [n_embd/n_head, n_head, N, n_batch]
843-
struct ggml_tensor * Qcur = ggml_rope(ctx0, ggml_reshape_4d(ctx0, ggml_mul_mat(ctx0, model->layers[il].wq, cur), n_embd/n_head, n_head, N, n_batch), n_past, n_rot, 0, 0);
844-
struct ggml_tensor * Kcur = ggml_rope(ctx0, ggml_reshape_4d(ctx0, ggml_mul_mat(ctx0, model->layers[il].wk, cur), n_embd/n_head, n_head, N, n_batch), n_past, n_rot, 0, 0);
860+
struct ggml_tensor * Qcur = ggml_rope(ctx0, ggml_reshape_4d(ctx0, ggml_mul_mat(ctx0, model->layers[il].wq, cur), n_embd/n_head, n_head, N, n_batch), KQ_rope, n_rot, 0, 0);
861+
struct ggml_tensor * Kcur = ggml_rope(ctx0, ggml_reshape_4d(ctx0, ggml_mul_mat(ctx0, model->layers[il].wk, cur), n_embd/n_head, n_head, N, n_batch), KQ_rope, n_rot, 0, 0);
845862
assert_shape_4d(Qcur, n_embd/n_head, n_head, N, n_batch);
846863
assert_shape_4d(Kcur, n_embd/n_head, n_head, N, n_batch);
847864

@@ -1100,6 +1117,14 @@ struct ggml_tensor * forward_lora(
11001117
struct ggml_tensor * kc = kv_self.k;
11011118
struct ggml_tensor * vc = kv_self.v;
11021119

1120+
struct ggml_tensor * KQ_rope = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
1121+
{
1122+
int * data = (int *) KQ_rope->data;
1123+
for (int i = 0; i < N; ++i) {
1124+
data[i] = n_past + i;
1125+
}
1126+
}
1127+
11031128
// inpL shape [n_embd,N,1,1]
11041129
struct ggml_tensor * inpL = ggml_get_rows(ctx0, model->tok_embeddings, tokens);
11051130
for (int il = 0; il < n_layer; ++il) {
@@ -1133,7 +1158,7 @@ struct ggml_tensor * forward_lora(
11331158
model->layers[il].wqb,
11341159
cur)),
11351160
n_embd/n_head, n_head, N),
1136-
n_past, n_rot, 0, 0);
1161+
KQ_rope, n_rot, 0, 0);
11371162
struct ggml_tensor * Kcur = ggml_rope(ctx0,
11381163
ggml_reshape_3d(ctx0,
11391164
ggml_mul_mat(ctx0,
@@ -1142,7 +1167,7 @@ struct ggml_tensor * forward_lora(
11421167
model->layers[il].wkb,
11431168
cur)),
11441169
n_embd/n_head, n_head, N),
1145-
n_past, n_rot, 0, 0);
1170+
KQ_rope, n_rot, 0, 0);
11461171

11471172
// store key and value to memory
11481173
{

examples/train-text-from-scratch/train-text-from-scratch.cpp

+11-3
Original file line numberDiff line numberDiff line change
@@ -679,15 +679,23 @@ struct ggml_tensor * llama_build_train_graphs(
679679
}
680680
};
681681

682+
// KQ_rope - contains the positions
683+
struct ggml_tensor * KQ_rope = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, N);
684+
{
685+
int * data = (int *) KQ_rope->data;
686+
for (int i = 0; i < N; ++i) {
687+
data[i] = n_past + i;
688+
}
689+
}
690+
682691
// rope has so much parameters that we make a custom function for it
683-
auto rope = [ctx, n_rot, n_ctx, rope_freq_base, rope_freq_scale]
692+
auto rope = [ctx, KQ_rope, n_rot, n_ctx, rope_freq_base, rope_freq_scale]
684693
(struct ggml_tensor * t) -> struct ggml_tensor * {
685694
// not capturing these, to silcence warnings
686-
const int n_past = 0;
687695
const int rope_mode = 0;
688696

689697
return ggml_rope_custom(ctx,
690-
t, n_past, n_rot, rope_mode, n_ctx,
698+
t, KQ_rope, n_rot, rope_mode, n_ctx,
691699
rope_freq_base, rope_freq_scale);
692700
};
693701

ggml-metal.m

+26-23
Original file line numberDiff line numberDiff line change
@@ -1210,7 +1210,9 @@ void ggml_metal_graph_compute(
12101210
} break;
12111211
case GGML_OP_ROPE:
12121212
{
1213-
const int n_past = ((int32_t *) dst->op_params)[0];
1213+
GGML_ASSERT(ne10 == ne02);
1214+
1215+
//const int n_past = ((int32_t *) dst->op_params)[0];
12141216
const int n_dims = ((int32_t *) dst->op_params)[1];
12151217
const int mode = ((int32_t *) dst->op_params)[2];
12161218

@@ -1221,28 +1223,29 @@ void ggml_metal_graph_compute(
12211223

12221224
[encoder setComputePipelineState:ctx->pipeline_rope];
12231225
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1224-
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1225-
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
1226-
[encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
1227-
[encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
1228-
[encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5];
1229-
[encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6];
1230-
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7];
1231-
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8];
1232-
[encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9];
1233-
[encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10];
1234-
[encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11];
1235-
[encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12];
1236-
[encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13];
1237-
[encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14];
1238-
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
1239-
[encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
1240-
[encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
1241-
[encoder setBytes:&n_past length:sizeof( int) atIndex:18];
1242-
[encoder setBytes:&n_dims length:sizeof( int) atIndex:19];
1243-
[encoder setBytes:&mode length:sizeof( int) atIndex:20];
1244-
[encoder setBytes:&freq_base length:sizeof(float) atIndex:21];
1245-
[encoder setBytes:&freq_scale length:sizeof(float) atIndex:22];
1226+
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1227+
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1228+
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:3];
1229+
[encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:4];
1230+
[encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:5];
1231+
[encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:6];
1232+
[encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:7];
1233+
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:8];
1234+
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:9];
1235+
[encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:10];
1236+
[encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:11];
1237+
[encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:12];
1238+
[encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:13];
1239+
[encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:14];
1240+
[encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:15];
1241+
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:16];
1242+
[encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:17];
1243+
[encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:18];
1244+
//[encoder setBytes:&n_past length:sizeof( int) atIndex:19];
1245+
[encoder setBytes:&n_dims length:sizeof( int) atIndex:20];
1246+
[encoder setBytes:&mode length:sizeof( int) atIndex:21];
1247+
[encoder setBytes:&freq_base length:sizeof(float) atIndex:22];
1248+
[encoder setBytes:&freq_scale length:sizeof(float) atIndex:23];
12461249

12471250
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, 1, 1)];
12481251
} break;

ggml-metal.metal

+29-26
Original file line numberDiff line numberDiff line change
@@ -854,29 +854,30 @@ kernel void kernel_alibi_f32(
854854
}
855855

856856
kernel void kernel_rope(
857-
device const void * src0,
858-
device float * dst,
859-
constant int64_t & ne00,
860-
constant int64_t & ne01,
861-
constant int64_t & ne02,
862-
constant int64_t & ne03,
863-
constant uint64_t & nb00,
864-
constant uint64_t & nb01,
865-
constant uint64_t & nb02,
866-
constant uint64_t & nb03,
867-
constant int64_t & ne0,
868-
constant int64_t & ne1,
869-
constant int64_t & ne2,
870-
constant int64_t & ne3,
871-
constant uint64_t & nb0,
872-
constant uint64_t & nb1,
873-
constant uint64_t & nb2,
874-
constant uint64_t & nb3,
875-
constant int & n_past,
876-
constant int & n_dims,
877-
constant int & mode,
878-
constant float & freq_base,
879-
constant float & freq_scale,
857+
device const void * src0,
858+
device const int32_t * src1,
859+
device float * dst,
860+
constant int64_t & ne00,
861+
constant int64_t & ne01,
862+
constant int64_t & ne02,
863+
constant int64_t & ne03,
864+
constant uint64_t & nb00,
865+
constant uint64_t & nb01,
866+
constant uint64_t & nb02,
867+
constant uint64_t & nb03,
868+
constant int64_t & ne0,
869+
constant int64_t & ne1,
870+
constant int64_t & ne2,
871+
constant int64_t & ne3,
872+
constant uint64_t & nb0,
873+
constant uint64_t & nb1,
874+
constant uint64_t & nb2,
875+
constant uint64_t & nb3,
876+
constant int & n_past,
877+
constant int & n_dims,
878+
constant int & mode,
879+
constant float & freq_base,
880+
constant float & freq_scale,
880881
uint tiitg[[thread_index_in_threadgroup]],
881882
uint3 tptg[[threads_per_threadgroup]],
882883
uint3 tgpig[[threadgroup_position_in_grid]]) {
@@ -886,7 +887,9 @@ kernel void kernel_rope(
886887

887888
const bool is_neox = mode & 2;
888889

889-
const int64_t p = ((mode & 1) == 0 ? n_past + i2 : i2);
890+
device const int32_t * pos = src1;
891+
892+
const int64_t p = pos[i2];
890893

891894
const float theta_0 = freq_scale * (float)p;
892895
const float inv_ndims = -1.f/n_dims;
@@ -1320,8 +1323,8 @@ kernel void kernel_mul_mat_q3_K_f32(
13201323

13211324
float yl[32];
13221325

1323-
const uint16_t kmask1 = 0x3030;
1324-
const uint16_t kmask2 = 0x0f0f;
1326+
//const uint16_t kmask1 = 0x3030;
1327+
//const uint16_t kmask2 = 0x0f0f;
13251328

13261329
const int tid = tiisg/4;
13271330
const int ix = tiisg%4;

0 commit comments

Comments
 (0)