Skip to content

Commit 57cea73

Browse files
committed
metal : add rope_f16 kernel + optimize cpy kernels
1 parent 1fb033f commit 57cea73

File tree

2 files changed

+57
-16
lines changed

2 files changed

+57
-16
lines changed

ggml-metal.m

+18-10
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,8 @@
100100
GGML_METAL_DECL_KERNEL(mul_mm_q4_K_f32);
101101
GGML_METAL_DECL_KERNEL(mul_mm_q5_K_f32);
102102
GGML_METAL_DECL_KERNEL(mul_mm_q6_K_f32);
103-
GGML_METAL_DECL_KERNEL(rope);
103+
GGML_METAL_DECL_KERNEL(rope_f32);
104+
GGML_METAL_DECL_KERNEL(rope_f16);
104105
GGML_METAL_DECL_KERNEL(alibi_f32);
105106
GGML_METAL_DECL_KERNEL(cpy_f32_f16);
106107
GGML_METAL_DECL_KERNEL(cpy_f32_f32);
@@ -261,7 +262,8 @@ @implementation GGMLMetalClass
261262
GGML_METAL_ADD_KERNEL(mul_mm_q4_K_f32);
262263
GGML_METAL_ADD_KERNEL(mul_mm_q5_K_f32);
263264
GGML_METAL_ADD_KERNEL(mul_mm_q6_K_f32);
264-
GGML_METAL_ADD_KERNEL(rope);
265+
GGML_METAL_ADD_KERNEL(rope_f32);
266+
GGML_METAL_ADD_KERNEL(rope_f16);
265267
GGML_METAL_ADD_KERNEL(alibi_f32);
266268
GGML_METAL_ADD_KERNEL(cpy_f32_f16);
267269
GGML_METAL_ADD_KERNEL(cpy_f32_f32);
@@ -335,7 +337,8 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
335337
GGML_METAL_DEL_KERNEL(mul_mm_q4_K_f32);
336338
GGML_METAL_DEL_KERNEL(mul_mm_q5_K_f32);
337339
GGML_METAL_DEL_KERNEL(mul_mm_q6_K_f32);
338-
GGML_METAL_DEL_KERNEL(rope);
340+
GGML_METAL_DEL_KERNEL(rope_f32);
341+
GGML_METAL_DEL_KERNEL(rope_f16);
339342
GGML_METAL_DEL_KERNEL(alibi_f32);
340343
GGML_METAL_DEL_KERNEL(cpy_f32_f16);
341344
GGML_METAL_DEL_KERNEL(cpy_f32_f32);
@@ -870,7 +873,7 @@ void ggml_metal_graph_compute(
870873
} break;
871874
case GGML_OP_SOFT_MAX:
872875
{
873-
const int nth = 32;
876+
const int nth = MIN(32, ne00);
874877

875878
if (ne00%4 == 0) {
876879
[encoder setComputePipelineState:ctx->pipeline_soft_max_4];
@@ -1134,7 +1137,7 @@ void ggml_metal_graph_compute(
11341137
float eps;
11351138
memcpy(&eps, dst->op_params, sizeof(float));
11361139

1137-
const int nth = 512;
1140+
const int nth = MIN(512, ne00);
11381141

11391142
[encoder setComputePipelineState:ctx->pipeline_rms_norm];
11401143
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
@@ -1153,7 +1156,7 @@ void ggml_metal_graph_compute(
11531156
float eps;
11541157
memcpy(&eps, dst->op_params, sizeof(float));
11551158

1156-
const int nth = 256;
1159+
const int nth = MIN(256, ne00);
11571160

11581161
[encoder setComputePipelineState:ctx->pipeline_norm];
11591162
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
@@ -1212,7 +1215,7 @@ void ggml_metal_graph_compute(
12121215
{
12131216
GGML_ASSERT(ne10 == ne02);
12141217

1215-
//const int n_past = ((int32_t *) dst->op_params)[0];
1218+
const int n_past = ((int32_t *) dst->op_params)[0];
12161219
const int n_dims = ((int32_t *) dst->op_params)[1];
12171220
const int mode = ((int32_t *) dst->op_params)[2];
12181221

@@ -1221,7 +1224,12 @@ void ggml_metal_graph_compute(
12211224
memcpy(&freq_base, (int32_t *) dst->op_params + 4, sizeof(float));
12221225
memcpy(&freq_scale, (int32_t *) dst->op_params + 5, sizeof(float));
12231226

1224-
[encoder setComputePipelineState:ctx->pipeline_rope];
1227+
switch (src0->type) {
1228+
case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_rope_f32]; break;
1229+
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_rope_f16]; break;
1230+
default: GGML_ASSERT(false);
1231+
};
1232+
12251233
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
12261234
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
12271235
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
@@ -1241,7 +1249,7 @@ void ggml_metal_graph_compute(
12411249
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:16];
12421250
[encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:17];
12431251
[encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:18];
1244-
//[encoder setBytes:&n_past length:sizeof( int) atIndex:19];
1252+
[encoder setBytes:&n_past length:sizeof( int) atIndex:19];
12451253
[encoder setBytes:&n_dims length:sizeof( int) atIndex:20];
12461254
[encoder setBytes:&mode length:sizeof( int) atIndex:21];
12471255
[encoder setBytes:&freq_base length:sizeof(float) atIndex:22];
@@ -1253,7 +1261,7 @@ void ggml_metal_graph_compute(
12531261
case GGML_OP_CPY:
12541262
case GGML_OP_CONT:
12551263
{
1256-
const int nth = 32;
1264+
const int nth = MIN(1024, ne00);
12571265

12581266
switch (src0t) {
12591267
case GGML_TYPE_F32:

ggml-metal.metal

+39-6
Original file line numberDiff line numberDiff line change
@@ -853,6 +853,36 @@ kernel void kernel_alibi_f32(
853853
}
854854
}
855855

856+
typedef void (rope_t)(
857+
device const void * src0,
858+
device const int32_t * src1,
859+
device float * dst,
860+
constant int64_t & ne00,
861+
constant int64_t & ne01,
862+
constant int64_t & ne02,
863+
constant int64_t & ne03,
864+
constant uint64_t & nb00,
865+
constant uint64_t & nb01,
866+
constant uint64_t & nb02,
867+
constant uint64_t & nb03,
868+
constant int64_t & ne0,
869+
constant int64_t & ne1,
870+
constant int64_t & ne2,
871+
constant int64_t & ne3,
872+
constant uint64_t & nb0,
873+
constant uint64_t & nb1,
874+
constant uint64_t & nb2,
875+
constant uint64_t & nb3,
876+
constant int & n_past,
877+
constant int & n_dims,
878+
constant int & mode,
879+
constant float & freq_base,
880+
constant float & freq_scale,
881+
uint tiitg[[thread_index_in_threadgroup]],
882+
uint3 tptg[[threads_per_threadgroup]],
883+
uint3 tgpig[[threadgroup_position_in_grid]]);
884+
885+
template<typename T>
856886
kernel void kernel_rope(
857887
device const void * src0,
858888
device const int32_t * src1,
@@ -901,11 +931,11 @@ kernel void kernel_rope(
901931
const float cos_theta = cos(theta);
902932
const float sin_theta = sin(theta);
903933

904-
device const float * const src = (device float *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
905-
device float * dst_data = (device float *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
934+
device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
935+
device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
906936

907-
const float x0 = src[0];
908-
const float x1 = src[1];
937+
const T x0 = src[0];
938+
const T x1 = src[1];
909939

910940
dst_data[0] = x0*cos_theta - x1*sin_theta;
911941
dst_data[1] = x0*sin_theta + x1*cos_theta;
@@ -920,8 +950,8 @@ kernel void kernel_rope(
920950

921951
const int64_t i0 = ib*n_dims + ic/2;
922952

923-
device const float * const src = (device float *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
924-
device float * dst_data = (device float *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
953+
device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
954+
device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
925955

926956
const float x0 = src[0];
927957
const float x1 = src[n_dims/2];
@@ -933,6 +963,9 @@ kernel void kernel_rope(
933963
}
934964
}
935965

966+
template [[host_name("kernel_rope_f32")]] kernel rope_t kernel_rope<float>;
967+
template [[host_name("kernel_rope_f16")]] kernel rope_t kernel_rope<half>;
968+
936969
kernel void kernel_cpy_f16_f16(
937970
device const half * src0,
938971
device half * dst,

0 commit comments

Comments
 (0)