Skip to content

Commit

Permalink
build_rwkv: Avoid using inplace operations
Browse files Browse the repository at this point in the history
Signed-off-by: Molly Sophia <mollysophia379@gmail.com>
  • Loading branch information
MollySophia committed Aug 11, 2024
1 parent b1b6c7e commit 487fb6d
Showing 1 changed file with 61 additions and 85 deletions.
146 changes: 61 additions & 85 deletions src/llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8563,36 +8563,29 @@ static struct ggml_tensor * llm_build_kv(
static struct ggml_tensor * llm_build_time_mix(
struct ggml_context * ctx,
const struct llama_layer * layer,
struct ggml_tensor * current,
struct ggml_tensor * cur,
struct ggml_tensor * x_prev,
struct ggml_tensor ** wkv_state,
struct ggml_tensor * state_seq) {
size_t n_embed = current->ne[0];
size_t n_tokens = current->ne[1];
size_t n_embed = cur->ne[0];
size_t n_tokens = cur->ne[1];
size_t head_size = layer->time_mix_first->ne[0];
size_t head_count = layer->time_mix_first->ne[1];
size_t n_kv = state_seq->ne[0];

struct ggml_tensor * sx = ggml_sub(ctx, x_prev, current);
struct ggml_tensor * xxx = ggml_add_inplace(
ctx,
ggml_mul(ctx, sx, layer->time_mix_lerp_x),
current
);
struct ggml_tensor * sx = ggml_sub(ctx, x_prev, cur);
struct ggml_tensor * xxx = ggml_add(ctx, ggml_mul(ctx, sx, layer->time_mix_lerp_x), cur);

xxx = ggml_reshape_4d(
ctx,
ggml_tanh_inplace(
ggml_tanh(
ctx,
ggml_mul_mat(ctx, layer->time_mix_w1, xxx)
),
layer->time_mix_w1->ne[1] / 5, 1, 5, n_tokens
);

xxx = ggml_cont(
ctx,
ggml_permute(ctx, xxx, 0, 1, 3, 2)
);
xxx = ggml_cont(ctx, ggml_permute(ctx, xxx, 0, 1, 3, 2));

xxx = ggml_mul_mat(
ctx,
Expand All @@ -8614,151 +8607,138 @@ static struct ggml_tensor * llm_build_time_mix(
struct ggml_tensor *mk = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embed * n_tokens);
mk = ggml_reshape_2d(
ctx,
ggml_set_1d_inplace(ctx, mk, ggml_view_1d(ctx, xxx, n_embed * n_tokens, n_embed * n_tokens * sizeof(float)), 0),
ggml_set_1d(ctx, mk, ggml_view_1d(ctx, xxx, n_embed * n_tokens, n_embed * n_tokens * sizeof(float)), 0),
n_embed, n_tokens
);

struct ggml_tensor *mv = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embed * n_tokens);
mv = ggml_reshape_2d(
ctx,
ggml_set_1d_inplace(ctx, mv, ggml_view_1d(ctx, xxx, n_embed * n_tokens, n_embed * n_tokens * 2 * sizeof(float)), 0),
ggml_set_1d(ctx, mv, ggml_view_1d(ctx, xxx, n_embed * n_tokens, n_embed * n_tokens * 2 * sizeof(float)), 0),
n_embed, n_tokens
);

struct ggml_tensor *mr = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embed * n_tokens);
mr = ggml_reshape_2d(
ctx,
ggml_set_1d_inplace(ctx, mr, ggml_view_1d(ctx, xxx, n_embed * n_tokens, n_embed * n_tokens * 3 * sizeof(float)), 0),
ggml_set_1d(ctx, mr, ggml_view_1d(ctx, xxx, n_embed * n_tokens, n_embed * n_tokens * 3 * sizeof(float)), 0),
n_embed, n_tokens
);

struct ggml_tensor *mg = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embed * n_tokens);
mg = ggml_reshape_2d(
ctx,
ggml_set_1d_inplace(ctx, mg, ggml_view_1d(ctx, xxx, n_embed * n_tokens, n_embed * n_tokens * 4 * sizeof(float)), 0),
ggml_set_1d(ctx, mg, ggml_view_1d(ctx, xxx, n_embed * n_tokens, n_embed * n_tokens * 4 * sizeof(float)), 0),
n_embed, n_tokens
);

struct ggml_tensor * xw = ggml_add_inplace(
struct ggml_tensor * xw = ggml_add(
ctx,
ggml_mul_inplace(
ggml_mul(
ctx,
ggml_add(ctx, mw, layer->time_mix_lerp_w),
sx
),
current
cur
);

struct ggml_tensor * xk = ggml_add_inplace(
struct ggml_tensor * xk = ggml_add(
ctx,
ggml_mul_inplace(
ggml_mul(
ctx,
ggml_add(ctx, mk, layer->time_mix_lerp_k),
sx
),
current
cur
);

struct ggml_tensor * xv = ggml_add_inplace(
struct ggml_tensor * xv = ggml_add(
ctx,
ggml_mul_inplace(
ggml_mul(
ctx,
ggml_add(ctx, mv, layer->time_mix_lerp_v),
sx
),
current
cur
);

struct ggml_tensor * xr = ggml_add_inplace(
struct ggml_tensor * xr = ggml_add(
ctx,
ggml_mul_inplace(
ggml_mul(
ctx,
ggml_add(ctx, mr, layer->time_mix_lerp_r),
sx
),
current
cur
);

struct ggml_tensor * xg = ggml_add_inplace(
struct ggml_tensor * xg = ggml_add(
ctx,
ggml_mul_inplace(
ggml_mul(
ctx,
ggml_add(ctx, mg, layer->time_mix_lerp_g),
sx
),
current
cur
);

struct ggml_tensor * r = ggml_reshape_4d(ctx, ggml_mul_mat(ctx, layer->time_mix_receptance, xr), head_size, 1, head_count, n_tokens);
struct ggml_tensor * k = ggml_reshape_4d(ctx, ggml_mul_mat(ctx, layer->time_mix_key, xk), 1, head_size, head_count, n_tokens);
struct ggml_tensor * v = ggml_reshape_4d(ctx, ggml_mul_mat(ctx, layer->time_mix_value, xv), head_size, 1, head_count, n_tokens);
struct ggml_tensor * g = ggml_silu_inplace(
struct ggml_tensor * g = ggml_silu(
ctx,
ggml_mul_mat(ctx, layer->time_mix_gate, xg)
);

struct ggml_tensor * w = ggml_mul_mat(
ctx,
layer->time_mix_decay_w2,
ggml_tanh_inplace(
ggml_tanh(
ctx,
ggml_mul_mat(ctx, layer->time_mix_decay_w1, xw)
)
);
w = ggml_add_inplace(
ctx,
w,
ggml_reshape_1d(ctx, layer->time_mix_decay, n_embed)
);
w = ggml_add(ctx, w, ggml_reshape_1d(ctx, layer->time_mix_decay, n_embed));
w = ggml_exp(ctx, ggml_neg(ctx, ggml_exp(ctx, w)));
w = ggml_reshape_4d(ctx, w, 1, head_size, head_count, n_tokens);

k = ggml_transpose(ctx, k);
v = ggml_transpose(ctx, v);
r = ggml_transpose(ctx, r);
struct ggml_tensor * wkv_output = ggml_rwkv_wkv(ctx, k, v, r, layer->time_mix_first, w, *wkv_state, state_seq);
current = ggml_view_1d(ctx, wkv_output, n_embed * n_tokens, 0);
cur = ggml_view_1d(ctx, wkv_output, n_embed * n_tokens, 0);
*wkv_state = ggml_view_1d(ctx, wkv_output, n_embed * head_size * n_kv, n_embed * n_tokens * sizeof(float));

// ggml_group_norm considers groups in the third dimension.
current = ggml_reshape_4d(ctx, current, 1, 1, n_embed, n_tokens);
current = ggml_group_norm(ctx, current, head_count, 64e-5f);
cur = ggml_reshape_4d(ctx, cur, 1, 1, n_embed, n_tokens);
cur = ggml_group_norm(ctx, cur, head_count, 64e-5f);
// Convert back to a regular vector.
current = ggml_reshape_2d(ctx, current, n_embed, n_tokens);
current = ggml_add_inplace(
ctx,
ggml_mul_inplace(
ctx,
current,
layer->time_mix_ln
),
layer->time_mix_ln_b
);
cur = ggml_reshape_2d(ctx, cur, n_embed, n_tokens);
cur = ggml_add(ctx, ggml_mul(ctx, cur, layer->time_mix_ln), layer->time_mix_ln_b);

current = ggml_mul(ctx, current, g);
cur = ggml_mul(ctx, cur, g);

return ggml_mul_mat(ctx, layer->time_mix_output, current);
return ggml_mul_mat(ctx, layer->time_mix_output, cur);
}

static struct ggml_tensor * llm_build_channel_mix(
struct ggml_context * ctx,
const struct llama_layer * layer,
struct ggml_tensor * current,
struct ggml_tensor * cur,
struct ggml_tensor * x_prev) {
struct ggml_tensor * sx = ggml_sub(ctx, x_prev, current);
struct ggml_tensor * xk = ggml_add_inplace(
ctx,
ggml_mul(ctx, sx, layer->channel_mix_lerp_k),
current
);
struct ggml_tensor * xr = ggml_add_inplace(
struct ggml_tensor * sx = ggml_sub(ctx, x_prev, cur);
struct ggml_tensor * xk = ggml_add(ctx, ggml_mul(ctx, sx, layer->channel_mix_lerp_k), cur);
struct ggml_tensor * xr = ggml_add(ctx, ggml_mul(ctx, sx, layer->channel_mix_lerp_r), cur);

struct ggml_tensor * r = ggml_sigmoid(ctx, ggml_mul_mat(ctx, layer->channel_mix_receptance, xr));
struct ggml_tensor * k = ggml_sqr(
ctx,
ggml_mul(ctx, sx, layer->channel_mix_lerp_r),
current
ggml_relu(
ctx,
ggml_mul_mat(ctx, layer->channel_mix_key, xk)
)
);
struct ggml_tensor * r = ggml_sigmoid_inplace(ctx, ggml_mul_mat(ctx, layer->channel_mix_receptance, xr));
struct ggml_tensor * k = ggml_sqr_inplace(ctx, ggml_relu_inplace(ctx, ggml_mul_mat(ctx, layer->channel_mix_key, xk)));
return ggml_mul_inplace(ctx, r, ggml_mul_mat(ctx, layer->channel_mix_value, k));
return ggml_mul(ctx, r, ggml_mul_mat(ctx, layer->channel_mix_value, k));
}

struct llm_build_context {
Expand Down Expand Up @@ -14165,13 +14145,12 @@ struct llm_build_context {
// Token shift state dimensions should be 2 * n_emb
GGML_ASSERT(n_embd == hparams.n_embd_k_s() / 2);

// Input embeddings, start of the model after tokenizing ({n_embd, n_tokens})
ggml_tensor * input_embeddings = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);

struct ggml_tensor * state_mask = build_inp_s_mask();
struct ggml_tensor * state_seq = build_inp_s_seq();

ggml_tensor * x = llm_build_norm(ctx0, input_embeddings, hparams, model.tok_norm, model.tok_norm_b, LLM_NORM, cb, -1);
ggml_tensor * cur = llm_build_norm(ctx0, input_embeddings, hparams, model.tok_norm, model.tok_norm_b, LLM_NORM, cb, -1);

for (int layer_i = 0; layer_i < n_layer; ++layer_i) {
const llama_layer * layer = &model.layers[layer_i];
Expand Down Expand Up @@ -14200,16 +14179,16 @@ struct llm_build_context {
struct ggml_tensor * att_shift = ggml_view_1d(ctx0, token_shift, n_embd * n_kv, 0);
struct ggml_tensor * ffn_shift = ggml_view_1d(ctx0, token_shift, n_embd * n_kv, n_embd * n_kv * ggml_element_size(kv_self.k_l[layer_i]));

struct ggml_tensor * x_norm = llm_build_norm(ctx0, x, hparams, layer->attn_norm, layer->attn_norm_b, LLM_NORM, cb, layer_i);
struct ggml_tensor * x_norm = llm_build_norm(ctx0, cur, hparams, layer->attn_norm, layer->attn_norm_b, LLM_NORM, cb, layer_i);
struct ggml_tensor * tmp = ggml_rwkv_token_shift(ctx0, att_shift, x_norm, state_seq);
struct ggml_tensor * x_prev = ggml_reshape_2d(
ctx0,
ggml_view_1d(ctx0, tmp, n_embd * n_tokens, 0),
n_embd, n_tokens
);

x = ggml_add(ctx0, x, llm_build_time_mix(ctx0, layer, x_norm, x_prev, &wkv_states, state_seq));
ggml_build_forward_expand(gf, x);
cur = ggml_add(ctx0, cur, llm_build_time_mix(ctx0, layer, x_norm, x_prev, &wkv_states, state_seq));
ggml_build_forward_expand(gf, cur);
ggml_build_forward_expand(
gf,
ggml_cpy(
Expand Down Expand Up @@ -14237,15 +14216,15 @@ struct llm_build_context {
)
);

x_norm = llm_build_norm(ctx0, x, hparams, layer->attn_norm_2, layer->attn_norm_2_b, LLM_NORM, cb, layer_i);
x_norm = llm_build_norm(ctx0, cur, hparams, layer->attn_norm_2, layer->attn_norm_2_b, LLM_NORM, cb, layer_i);
tmp = ggml_rwkv_token_shift(ctx0, ffn_shift, x_norm, state_seq);
x_prev = ggml_reshape_2d(
ctx0,
ggml_view_1d(ctx0, tmp, n_embd * n_tokens, 0),
n_embd, n_tokens
);
x = ggml_add(ctx0, x, llm_build_channel_mix(ctx0, layer, x_norm, x_prev));
ggml_build_forward_expand(gf, x);
cur = ggml_add(ctx0, cur, llm_build_channel_mix(ctx0, layer, x_norm, x_prev));
ggml_build_forward_expand(gf, cur);
ggml_build_forward_expand(
gf,
ggml_cpy(
Expand Down Expand Up @@ -14279,21 +14258,18 @@ struct llm_build_context {
);

if ((layer_i + 1) % hparams.rescale_every_n_layers == 0) {
x = ggml_scale(ctx0, x, 0.5F);
cur = ggml_scale(ctx0, cur, 0.5F);
}
}

// Something related to skipping tokens, specifics unclear
ggml_tensor * inp_out_ids = build_inp_out_ids();
x = ggml_get_rows(ctx0, x, inp_out_ids);
cur = ggml_get_rows(ctx0, cur, inp_out_ids);

// Output head, convert result vector to logits
x = llm_build_norm(ctx0, x, hparams, model.output_norm, model.output_norm_b, LLM_NORM, cb, -1);
x = ggml_mul_mat(ctx0, model.output, x);
cur = llm_build_norm(ctx0, cur, hparams, model.output_norm, model.output_norm_b, LLM_NORM, cb, -1);
cur = ggml_mul_mat(ctx0, model.output, cur);

// Mark the output as being the result
cb(x, "result_output", -1);
ggml_build_forward_expand(gf, x);
cb(cur, "result_output", -1);
ggml_build_forward_expand(gf, cur);

return gf;
}
Expand Down

0 comments on commit 487fb6d

Please sign in to comment.