Skip to content

Commit b0c5fe6

Browse files
ggerganovolexiyb
authored andcommitted
metal : fix build errors and kernel sig after ggml-org#2268 (ggml-org#3898)
1 parent cdd4a93 commit b0c5fe6

File tree

2 files changed

+40
-33
lines changed

2 files changed

+40
-33
lines changed

ggml-metal.m

+29-28
Original file line numberDiff line numberDiff line change
@@ -1419,34 +1419,35 @@ void ggml_metal_graph_compute(
14191419
default: GGML_ASSERT(false);
14201420
};
14211421

1422-
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1423-
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1424-
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1425-
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:3];
1426-
[encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:4];
1427-
[encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:5];
1428-
[encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:6];
1429-
[encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:7];
1430-
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:8];
1431-
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:9];
1432-
[encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:10];
1433-
[encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:11];
1434-
[encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:12];
1435-
[encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:13];
1436-
[encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:14];
1437-
[encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:15];
1438-
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:16];
1439-
[encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:17];
1440-
[encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:18];
1441-
[encoder setBytes:&n_past length:sizeof( int) atIndex:19];
1442-
[encoder setBytes:&n_dims length:sizeof( int) atIndex:20];
1443-
[encoder setBytes:&mode length:sizeof( int) atIndex:21];
1444-
[encoder setBytes:&freq_base length:sizeof(float) atIndex:22];
1445-
[encoder setBytes:&freq_scale length:sizeof(float) atIndex:23];
1446-
[encoder setBytes:&ext_factor length:sizeof(float) atIndex:24];
1447-
[encoder setBytes:&attn_factor length:sizeof(float) atIndex:25];
1448-
[encoder setBytes:&beta_fast length:sizeof(float) atIndex:26];
1449-
[encoder setBytes:&beta_slow length:sizeof(float) atIndex:27];
1422+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1423+
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1424+
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1425+
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:3];
1426+
[encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:4];
1427+
[encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:5];
1428+
[encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:6];
1429+
[encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:7];
1430+
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:8];
1431+
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:9];
1432+
[encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:10];
1433+
[encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:11];
1434+
[encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:12];
1435+
[encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:13];
1436+
[encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:14];
1437+
[encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:15];
1438+
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:16];
1439+
[encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:17];
1440+
[encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:18];
1441+
[encoder setBytes:&n_past length:sizeof( int) atIndex:19];
1442+
[encoder setBytes:&n_dims length:sizeof( int) atIndex:20];
1443+
[encoder setBytes:&mode length:sizeof( int) atIndex:21];
1444+
[encoder setBytes:&n_orig_ctx length:sizeof( int) atIndex:22];
1445+
[encoder setBytes:&freq_base length:sizeof( float) atIndex:23];
1446+
[encoder setBytes:&freq_scale length:sizeof( float) atIndex:24];
1447+
[encoder setBytes:&ext_factor length:sizeof( float) atIndex:25];
1448+
[encoder setBytes:&attn_factor length:sizeof( float) atIndex:26];
1449+
[encoder setBytes:&beta_fast length:sizeof( float) atIndex:27];
1450+
[encoder setBytes:&beta_slow length:sizeof( float) atIndex:28];
14501451

14511452
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
14521453
} break;

ggml-metal.metal

+11-5
Original file line numberDiff line numberDiff line change
@@ -1070,20 +1070,20 @@ static float rope_yarn_ramp(const float low, const float high, const int i0) {
10701070
// MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
10711071
static void rope_yarn(
10721072
float theta_extrap, float freq_scale, float corr_dims[2], int64_t i0, float ext_factor, float mscale,
1073-
float * cos_theta, float * sin_theta
1073+
thread float * cos_theta, thread float * sin_theta
10741074
) {
10751075
// Get n-d rotational scaling corrected for extrapolation
10761076
float theta_interp = freq_scale * theta_extrap;
10771077
float theta = theta_interp;
10781078
if (ext_factor != 0.0f) {
1079-
ramp_mix = rope_yarn_ramp(corr_dims[0], corr_dims[1], i0) * ext_factor;
1079+
float ramp_mix = rope_yarn_ramp(corr_dims[0], corr_dims[1], i0) * ext_factor;
10801080
theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
10811081

10821082
// Get n-d magnitude scaling corrected for interpolation
1083-
mscale *= 1.0f + 0.1f * logf(1.0f / freq_scale);
1083+
mscale *= 1.0f + 0.1f * log(1.0f / freq_scale);
10841084
}
1085-
*cos_theta = cosf(theta) * mscale;
1086-
*sin_theta = sinf(theta) * mscale;
1085+
*cos_theta = cos(theta) * mscale;
1086+
*sin_theta = sin(theta) * mscale;
10871087
}
10881088

10891089
// Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get
@@ -1123,8 +1123,13 @@ typedef void (rope_t)(
11231123
constant int & n_past,
11241124
constant int & n_dims,
11251125
constant int & mode,
1126+
constant int & n_orig_ctx,
11261127
constant float & freq_base,
11271128
constant float & freq_scale,
1129+
constant float & ext_factor,
1130+
constant float & attn_factor,
1131+
constant float & beta_fast,
1132+
constant float & beta_slow,
11281133
uint tiitg[[thread_index_in_threadgroup]],
11291134
uint3 tptg[[threads_per_threadgroup]],
11301135
uint3 tgpig[[threadgroup_position_in_grid]]);
@@ -1153,6 +1158,7 @@ kernel void kernel_rope(
11531158
constant int & n_past,
11541159
constant int & n_dims,
11551160
constant int & mode,
1161+
constant int & n_orig_ctx,
11561162
constant float & freq_base,
11571163
constant float & freq_scale,
11581164
constant float & ext_factor,

0 commit comments

Comments
 (0)