Skip to content

Commit

Permalink
ggml: rwkv_wkv: Avoid copying the state
Browse files Browse the repository at this point in the history
Signed-off-by: Molly Sophia <mollysophia379@gmail.com>
  • Loading branch information
MollySophia committed Aug 31, 2024
1 parent 5175375 commit 846358d
Showing 1 changed file with 4 additions and 3 deletions.
7 changes: 4 additions & 3 deletions ggml/src/ggml.c
Original file line number Diff line number Diff line change
Expand Up @@ -16874,7 +16874,6 @@ static void ggml_compute_forward_rwkv_wkv_f32(
float * r = (float *) dst->src[2]->data;
float * time_faaaa = (float *) dst->src[3]->data;
float * time_decay = (float *) dst->src[4]->data;
memcpy(state, dst->src[5]->data, (C / H) * C * n_seqs * sizeof(float));

size_t t_stride = H * (C / H);

Expand All @@ -16887,7 +16886,9 @@ static void ggml_compute_forward_rwkv_wkv_f32(
// recursive through each token
for (size_t t = 0; t < T; t++) {
size_t t_offset = t * t_stride;
float * state_cur = state + (C / H) * C * (t / (T / n_seqs));
size_t state_offset = (C / H) * C * (t / (T / n_seqs));
float * state_cur = state + state_offset;
float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[5]->data + state_offset;

for (size_t h = 0; h < H; h++) {
size_t h_offset = h * h_stride;
Expand All @@ -16911,7 +16912,7 @@ static void ggml_compute_forward_rwkv_wkv_f32(

float v_val = v[t_h_j_offset];
float kv_val = v_val * k_val;
float prev_state_val = state_cur[h_2d_i_j_offset];
float prev_state_val = state_prev[h_2d_i_j_offset];
float temp_val = kv_val * time_faaaa_val + prev_state_val;
dst_data[t_h_j_offset] += temp_val * r_val;
state_cur[h_2d_i_j_offset] = prev_state_val * time_decay_val + kv_val;
Expand Down

0 comments on commit 846358d

Please sign in to comment.