@@ -1241,53 +1241,68 @@ kernel void kernel_ssm_scan_f32(
1241
1241
}
1242
1242
1243
1243
kernel void kernel_norm (
1244
- device const void * src0,
1245
- device float * dst,
1246
- constant int64_t & ne00,
1247
- constant uint64_t & nb01,
1248
- constant float & eps,
1249
- threadgroup float * sum [[threadgroup(0 )]],
1250
- uint tgpig[[threadgroup_position_in_grid]],
1251
- uint tpitg[[thread_position_in_threadgroup]],
1252
- uint ntg[[threads_per_threadgroup]]) {
1253
- device const float * x = (device const float *) ((device const char *) src0 + tgpig*nb01);
1254
- // MEAN
1255
- // parallel sum
1256
- sum[tpitg] = 0 .0f ;
1257
- for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
1258
- sum[tpitg] += x[i00];
1244
+ constant ggml_metal_kargs_norm & args,
1245
+ device const char * src0,
1246
+ device char * dst,
1247
+ threadgroup float * shmem_f32 [[threadgroup(0 )]],
1248
+ uint tgpig[[threadgroup_position_in_grid]],
1249
+ ushort tpitg[[thread_position_in_threadgroup]],
1250
+ ushort sgitg[[simdgroup_index_in_threadgroup]],
1251
+ ushort tiisg[[thread_index_in_simdgroup]],
1252
+ ushort ntg[[threads_per_threadgroup]]) {
1253
+ if (sgitg == 0 ) {
1254
+ shmem_f32[tiisg] = 0 .0f ;
1259
1255
}
1260
- // reduce
1256
+
1257
+ device const float4 * x = (device const float4 *) (src0 + tgpig*args.nb01 );
1258
+
1259
+ float4 sumf4 (0 .0f );
1260
+
1261
+ float sumf = 0 .0f ;
1262
+
1263
+ for (int i00 = tpitg; i00 < args.ne00_4 ; i00 += ntg) {
1264
+ sumf4 += x[i00];
1265
+ }
1266
+ sumf = sumf4[0 ] + sumf4[1 ] + sumf4[2 ] + sumf4[3 ];
1267
+ sumf = simd_sum (sumf);
1268
+
1261
1269
threadgroup_barrier (mem_flags::mem_threadgroup);
1262
- for (uint i = ntg/2 ; i > 0 ; i /= 2 ) {
1263
- if (tpitg < i) {
1264
- sum[tpitg] += sum[tpitg + i];
1265
- }
1266
- threadgroup_barrier (mem_flags::mem_threadgroup);
1270
+
1271
+ if (tiisg == 0 ) {
1272
+ shmem_f32[sgitg] = sumf;
1267
1273
}
1268
- const float mean = sum[0 ] / ne00;
1269
1274
1270
- // recenter and VARIANCE
1271
1275
threadgroup_barrier (mem_flags::mem_threadgroup);
1272
- device float * y = dst + tgpig*ne00;
1273
- sum[tpitg] = 0 .0f ;
1274
- for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
1276
+
1277
+ sumf = shmem_f32[tiisg];
1278
+ sumf = simd_sum (sumf);
1279
+
1280
+ const float mean = sumf/args.ne00 ;
1281
+
1282
+ device float4 * y = (device float4 *) dst + tgpig*args.ne00_4 ;
1283
+
1284
+ sumf = 0 .0f ;
1285
+ for (int i00 = tpitg; i00 < args.ne00_4 ; i00 += ntg) {
1275
1286
y[i00] = x[i00] - mean;
1276
- sum[tpitg] += y[i00] * y[i00];
1287
+ sumf += dot ( y[i00], y[i00]) ;
1277
1288
}
1289
+ sumf = simd_sum (sumf);
1278
1290
1279
- // reduce
1280
1291
threadgroup_barrier (mem_flags::mem_threadgroup);
1281
- for (uint i = ntg/2 ; i > 0 ; i /= 2 ) {
1282
- if (tpitg < i) {
1283
- sum[tpitg] += sum[tpitg + i];
1284
- }
1285
- threadgroup_barrier (mem_flags::mem_threadgroup);
1292
+
1293
+ if (tiisg == 0 ) {
1294
+ shmem_f32[sgitg] = sumf;
1286
1295
}
1287
- const float variance = sum[0 ] / ne00;
1288
1296
1289
- const float scale = 1 .0f /sqrt (variance + eps);
1290
- for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
1297
+ threadgroup_barrier (mem_flags::mem_threadgroup);
1298
+
1299
+ sumf = shmem_f32[tiisg];
1300
+ sumf = simd_sum (sumf);
1301
+
1302
+ const float variance = sumf/args.ne00 ;
1303
+
1304
+ const float scale = 1 .0f /sqrt (variance + args.eps );
1305
+ for (int i00 = tpitg; i00 < args.ne00_4 ; i00 += ntg) {
1291
1306
y[i00] = y[i00] * scale;
1292
1307
}
1293
1308
}
0 commit comments