Skip to content

Commit f018669

Browse files
committedNov 17, 2024
metal : GGML_OP_NORM
1 parent b438ff7 commit f018669

File tree

3 files changed

+79
-44
lines changed

3 files changed

+79
-44
lines changed
 

‎ggml/src/ggml-common.h

+7
Original file line numberDiff line numberDiff line change
@@ -643,6 +643,13 @@ typedef struct {
643643
uint64_t nb1;
644644
} ggml_metal_kargs_mul_mv_id;
645645

646+
typedef struct {
647+
int32_t ne00;
648+
int32_t ne00_4;
649+
uint64_t nb01;
650+
float eps;
651+
} ggml_metal_kargs_norm;
652+
646653
typedef struct {
647654
int32_t ne00;
648655
int32_t ne00_4;

‎ggml/src/ggml-metal/ggml-metal.m

+21-8
Original file line numberDiff line numberDiff line change
@@ -2685,22 +2685,35 @@ static void ggml_metal_encode_node(
26852685
} break;
26862686
case GGML_OP_NORM:
26872687
{
2688+
GGML_ASSERT(ne00 % 4 == 0);
26882689
GGML_ASSERT(ggml_is_contiguous_1(src0));
26892690

26902691
float eps;
26912692
memcpy(&eps, dst->op_params, sizeof(float));
26922693

2693-
const int nth = MIN(256, ne00);
2694-
26952694
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_NORM].pipeline;
26962695

2696+
int nth = 32; // SIMD width
2697+
2698+
while (nth < ne00/4 && nth < (int) pipeline.maxTotalThreadsPerThreadgroup) {
2699+
nth *= 2;
2700+
}
2701+
2702+
nth = MIN(nth, ne00/4);
2703+
2704+
ggml_metal_kargs_norm args = {
2705+
/*.ne00 =*/ ne00,
2706+
/*.ne00_4 =*/ ne00/4,
2707+
/*.nb01 =*/ nb01,
2708+
/*.eps =*/ eps,
2709+
};
2710+
26972711
[encoder setComputePipelineState:pipeline];
2698-
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2699-
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
2700-
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
2701-
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
2702-
[encoder setBytes:&eps length:sizeof( float) atIndex:4];
2703-
[encoder setThreadgroupMemoryLength:GGML_PAD(nth*sizeof(float), 16) atIndex:0];
2712+
[encoder setBytes:&args length:sizeof(args) atIndex:0];
2713+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
2714+
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
2715+
2716+
[encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
27042717

27052718
const int64_t nrows = ggml_nrows(src0);
27062719

‎ggml/src/ggml-metal/ggml-metal.metal

+51-36
Original file line numberDiff line numberDiff line change
@@ -1241,53 +1241,68 @@ kernel void kernel_ssm_scan_f32(
12411241
}
12421242

12431243
kernel void kernel_norm(
1244-
device const void * src0,
1245-
device float * dst,
1246-
constant int64_t & ne00,
1247-
constant uint64_t & nb01,
1248-
constant float & eps,
1249-
threadgroup float * sum [[threadgroup(0)]],
1250-
uint tgpig[[threadgroup_position_in_grid]],
1251-
uint tpitg[[thread_position_in_threadgroup]],
1252-
uint ntg[[threads_per_threadgroup]]) {
1253-
device const float * x = (device const float *) ((device const char *) src0 + tgpig*nb01);
1254-
// MEAN
1255-
// parallel sum
1256-
sum[tpitg] = 0.0f;
1257-
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
1258-
sum[tpitg] += x[i00];
1244+
constant ggml_metal_kargs_norm & args,
1245+
device const char * src0,
1246+
device char * dst,
1247+
threadgroup float * shmem_f32 [[threadgroup(0)]],
1248+
uint tgpig[[threadgroup_position_in_grid]],
1249+
ushort tpitg[[thread_position_in_threadgroup]],
1250+
ushort sgitg[[simdgroup_index_in_threadgroup]],
1251+
ushort tiisg[[thread_index_in_simdgroup]],
1252+
ushort ntg[[threads_per_threadgroup]]) {
1253+
if (sgitg == 0) {
1254+
shmem_f32[tiisg] = 0.0f;
12591255
}
1260-
// reduce
1256+
1257+
device const float4 * x = (device const float4 *) (src0 + tgpig*args.nb01);
1258+
1259+
float4 sumf4(0.0f);
1260+
1261+
float sumf = 0.0f;
1262+
1263+
for (int i00 = tpitg; i00 < args.ne00_4; i00 += ntg) {
1264+
sumf4 += x[i00];
1265+
}
1266+
sumf = sumf4[0] + sumf4[1] + sumf4[2] + sumf4[3];
1267+
sumf = simd_sum(sumf);
1268+
12611269
threadgroup_barrier(mem_flags::mem_threadgroup);
1262-
for (uint i = ntg/2; i > 0; i /= 2) {
1263-
if (tpitg < i) {
1264-
sum[tpitg] += sum[tpitg + i];
1265-
}
1266-
threadgroup_barrier(mem_flags::mem_threadgroup);
1270+
1271+
if (tiisg == 0) {
1272+
shmem_f32[sgitg] = sumf;
12671273
}
1268-
const float mean = sum[0] / ne00;
12691274

1270-
// recenter and VARIANCE
12711275
threadgroup_barrier(mem_flags::mem_threadgroup);
1272-
device float * y = dst + tgpig*ne00;
1273-
sum[tpitg] = 0.0f;
1274-
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
1276+
1277+
sumf = shmem_f32[tiisg];
1278+
sumf = simd_sum(sumf);
1279+
1280+
const float mean = sumf/args.ne00;
1281+
1282+
device float4 * y = (device float4 *) dst + tgpig*args.ne00_4;
1283+
1284+
sumf = 0.0f;
1285+
for (int i00 = tpitg; i00 < args.ne00_4; i00 += ntg) {
12751286
y[i00] = x[i00] - mean;
1276-
sum[tpitg] += y[i00] * y[i00];
1287+
sumf += dot(y[i00], y[i00]);
12771288
}
1289+
sumf = simd_sum(sumf);
12781290

1279-
// reduce
12801291
threadgroup_barrier(mem_flags::mem_threadgroup);
1281-
for (uint i = ntg/2; i > 0; i /= 2) {
1282-
if (tpitg < i) {
1283-
sum[tpitg] += sum[tpitg + i];
1284-
}
1285-
threadgroup_barrier(mem_flags::mem_threadgroup);
1292+
1293+
if (tiisg == 0) {
1294+
shmem_f32[sgitg] = sumf;
12861295
}
1287-
const float variance = sum[0] / ne00;
12881296

1289-
const float scale = 1.0f/sqrt(variance + eps);
1290-
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
1297+
threadgroup_barrier(mem_flags::mem_threadgroup);
1298+
1299+
sumf = shmem_f32[tiisg];
1300+
sumf = simd_sum(sumf);
1301+
1302+
const float variance = sumf/args.ne00;
1303+
1304+
const float scale = 1.0f/sqrt(variance + args.eps);
1305+
for (int i00 = tpitg; i00 < args.ne00_4; i00 += ntg) {
12911306
y[i00] = y[i00] * scale;
12921307
}
12931308
}

0 commit comments

Comments
 (0)