100
100
GGML_METAL_DECL_KERNEL (mul_mm_q4_K_f32);
101
101
GGML_METAL_DECL_KERNEL (mul_mm_q5_K_f32);
102
102
GGML_METAL_DECL_KERNEL (mul_mm_q6_K_f32);
103
- GGML_METAL_DECL_KERNEL (rope);
103
+ GGML_METAL_DECL_KERNEL (rope_f32);
104
+ GGML_METAL_DECL_KERNEL (rope_f16);
104
105
GGML_METAL_DECL_KERNEL (alibi_f32);
105
106
GGML_METAL_DECL_KERNEL (cpy_f32_f16);
106
107
GGML_METAL_DECL_KERNEL (cpy_f32_f32);
@@ -261,7 +262,8 @@ @implementation GGMLMetalClass
261
262
GGML_METAL_ADD_KERNEL (mul_mm_q4_K_f32);
262
263
GGML_METAL_ADD_KERNEL (mul_mm_q5_K_f32);
263
264
GGML_METAL_ADD_KERNEL (mul_mm_q6_K_f32);
264
- GGML_METAL_ADD_KERNEL (rope);
265
+ GGML_METAL_ADD_KERNEL (rope_f32);
266
+ GGML_METAL_ADD_KERNEL (rope_f16);
265
267
GGML_METAL_ADD_KERNEL (alibi_f32);
266
268
GGML_METAL_ADD_KERNEL (cpy_f32_f16);
267
269
GGML_METAL_ADD_KERNEL (cpy_f32_f32);
@@ -335,7 +337,8 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
335
337
GGML_METAL_DEL_KERNEL (mul_mm_q4_K_f32);
336
338
GGML_METAL_DEL_KERNEL (mul_mm_q5_K_f32);
337
339
GGML_METAL_DEL_KERNEL (mul_mm_q6_K_f32);
338
- GGML_METAL_DEL_KERNEL (rope);
340
+ GGML_METAL_DEL_KERNEL (rope_f32);
341
+ GGML_METAL_DEL_KERNEL (rope_f16);
339
342
GGML_METAL_DEL_KERNEL (alibi_f32);
340
343
GGML_METAL_DEL_KERNEL (cpy_f32_f16);
341
344
GGML_METAL_DEL_KERNEL (cpy_f32_f32);
@@ -870,7 +873,7 @@ void ggml_metal_graph_compute(
870
873
} break ;
871
874
case GGML_OP_SOFT_MAX:
872
875
{
873
- const int nth = 32 ;
876
+ const int nth = MIN ( 32 , ne00) ;
874
877
875
878
if (ne00%4 == 0 ) {
876
879
[encoder setComputePipelineState: ctx->pipeline_soft_max_4];
@@ -1134,7 +1137,7 @@ void ggml_metal_graph_compute(
1134
1137
float eps;
1135
1138
memcpy (&eps, dst->op_params , sizeof (float ));
1136
1139
1137
- const int nth = 512 ;
1140
+ const int nth = MIN ( 512 , ne00) ;
1138
1141
1139
1142
[encoder setComputePipelineState: ctx->pipeline_rms_norm];
1140
1143
[encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
@@ -1153,7 +1156,7 @@ void ggml_metal_graph_compute(
1153
1156
float eps;
1154
1157
memcpy (&eps, dst->op_params , sizeof (float ));
1155
1158
1156
- const int nth = 256 ;
1159
+ const int nth = MIN ( 256 , ne00) ;
1157
1160
1158
1161
[encoder setComputePipelineState: ctx->pipeline_norm];
1159
1162
[encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
@@ -1212,7 +1215,7 @@ void ggml_metal_graph_compute(
1212
1215
{
1213
1216
GGML_ASSERT (ne10 == ne02);
1214
1217
1215
- // const int n_past = ((int32_t *) dst->op_params)[0];
1218
+ const int n_past = ((int32_t *) dst->op_params )[0 ];
1216
1219
const int n_dims = ((int32_t *) dst->op_params )[1 ];
1217
1220
const int mode = ((int32_t *) dst->op_params )[2 ];
1218
1221
@@ -1221,7 +1224,12 @@ void ggml_metal_graph_compute(
1221
1224
memcpy (&freq_base, (int32_t *) dst->op_params + 4 , sizeof (float ));
1222
1225
memcpy (&freq_scale, (int32_t *) dst->op_params + 5 , sizeof (float ));
1223
1226
1224
- [encoder setComputePipelineState: ctx->pipeline_rope];
1227
+ switch (src0->type ) {
1228
+ case GGML_TYPE_F32: [encoder setComputePipelineState: ctx->pipeline_rope_f32]; break ;
1229
+ case GGML_TYPE_F16: [encoder setComputePipelineState: ctx->pipeline_rope_f16]; break ;
1230
+ default : GGML_ASSERT (false );
1231
+ };
1232
+
1225
1233
[encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
1226
1234
[encoder setBuffer: id_src1 offset: offs_src1 atIndex: 1 ];
1227
1235
[encoder setBuffer: id_dst offset: offs_dst atIndex: 2 ];
@@ -1241,7 +1249,7 @@ void ggml_metal_graph_compute(
1241
1249
[encoder setBytes: &nb1 length: sizeof (uint64_t ) atIndex: 16 ];
1242
1250
[encoder setBytes: &nb2 length: sizeof (uint64_t ) atIndex: 17 ];
1243
1251
[encoder setBytes: &nb3 length: sizeof (uint64_t ) atIndex: 18 ];
1244
- // [encoder setBytes:&n_past length:sizeof( int) atIndex:19];
1252
+ [encoder setBytes: &n_past length: sizeof ( int ) atIndex: 19 ];
1245
1253
[encoder setBytes: &n_dims length: sizeof ( int ) atIndex: 20 ];
1246
1254
[encoder setBytes: &mode length: sizeof ( int ) atIndex: 21 ];
1247
1255
[encoder setBytes: &freq_base length: sizeof (float ) atIndex: 22 ];
@@ -1253,7 +1261,7 @@ void ggml_metal_graph_compute(
1253
1261
case GGML_OP_CPY:
1254
1262
case GGML_OP_CONT:
1255
1263
{
1256
- const int nth = 32 ;
1264
+ const int nth = MIN ( 1024 , ne00) ;
1257
1265
1258
1266
switch (src0t) {
1259
1267
case GGML_TYPE_F32:
0 commit comments