Skip to content

Commit

Permalink
metal : alibi for arbitrary number of heads (ggerganov#3426)
Browse files Browse the repository at this point in the history
  • Loading branch information
li-plus authored and yusiwen committed Oct 7, 2023
1 parent 6730941 commit 6bd6c2f
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 7 deletions.
9 changes: 4 additions & 5 deletions ggml-metal.m
Original file line number Diff line number Diff line change
Expand Up @@ -1213,12 +1213,9 @@ void ggml_metal_graph_compute(
float max_bias;
memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float));

if (__builtin_popcount(n_head) != 1) {
GGML_ASSERT(false && "only power-of-two n_head implemented");
}

const int n_heads_log2_floor = 1 << (int) floor(log2(n_head));
const float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor);
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_heads_log2_floor);

[encoder setComputePipelineState:ctx->pipeline_alibi_f32];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
Expand All @@ -1239,7 +1236,9 @@ void ggml_metal_graph_compute(
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
[encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
[encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
[encoder setBytes:&m0 length:sizeof( float) atIndex:18];
[encoder setBytes:&m0 length:sizeof( float) atIndex:18];
[encoder setBytes:&m1 length:sizeof( float) atIndex:19];
[encoder setBytes:&n_heads_log2_floor length:sizeof(int) atIndex:20];

[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
} break;
Expand Down
11 changes: 9 additions & 2 deletions ggml-metal.metal
Original file line number Diff line number Diff line change
Expand Up @@ -830,7 +830,9 @@ kernel void kernel_alibi_f32(
constant uint64_t & nb1,
constant uint64_t & nb2,
constant uint64_t & nb3,
constant float & m0,
constant float & m0,
constant float & m1,
constant int & n_heads_log2_floor,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]) {
Expand All @@ -846,7 +848,12 @@ kernel void kernel_alibi_f32(
const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);

device float * dst_data = (device float *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
float m_k = pow(m0, i2 + 1);
float m_k;
if (i2 < n_heads_log2_floor) {
m_k = pow(m0, i2 + 1);
} else {
m_k = pow(m1, 2 * (i2 - n_heads_log2_floor) + 1);
}
for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
dst_data[i00] = src[0] + m_k * (i00 - ne00 + 1);
Expand Down

0 comments on commit 6bd6c2f

Please sign in to comment.