diff --git a/README.md b/README.md index 8e6f52b0..134b4c43 100644 --- a/README.md +++ b/README.md @@ -10,6 +10,8 @@ This project provides [a C library rwkv.h](rwkv.h) and [a convinient Python wrap [RWKV v5](https://huggingface.co/BlinkDL/rwkv-5-world) is a major upgrade to RWKV architecture, making it competitive with Transformers in quality. RWKV v5 models are supported. +[RWKV v6](https://huggingface.co/BlinkDL/rwkv-6-world) is a further improvement to RWKV architecture, with better quality. RWKV v6 models are supported. + Loading LoRA checkpoints in [Blealtan's format](https://github.com/Blealtan/RWKV-LM-LoRA) is supported through [merge_lora_into_ggml.py script](rwkv%2Fmerge_lora_into_ggml.py). ## Quality and performance diff --git a/python/convert_pytorch_to_ggml.py b/python/convert_pytorch_to_ggml.py index 99568449..2c413dca 100644 --- a/python/convert_pytorch_to_ggml.py +++ b/python/convert_pytorch_to_ggml.py @@ -34,8 +34,11 @@ def write_state_dict(state_dict: Dict[str, torch.Tensor], dest_path: str, data_t is_v5_1_or_2: bool = 'blocks.0.att.ln_x.weight' in state_dict is_v5_2: bool = 'blocks.0.att.gate.weight' in state_dict + is_v6_0: bool = 'blocks.0.att.time_maa_x' in state_dict - if is_v5_2: + if is_v6_0: + print('Detected RWKV v6.0') + elif is_v5_2: print('Detected RWKV v5.2') elif is_v5_1_or_2: print('Detected RWKV v5.1') @@ -57,13 +60,25 @@ def write_state_dict(state_dict: Dict[str, torch.Tensor], dest_path: str, data_t 1 if is_FP16 else 0 )) + if is_v6_0: + n_head: int = state_dict['blocks.0.att.time_faaaa'].shape[0] for k in state_dict.keys(): tensor: torch.Tensor = state_dict[k].float() if '.time_' in k: tensor = tensor.squeeze() - if is_v5_1_or_2: + if is_v6_0: + if '.time_faaaa' in k: + tensor = tensor.unsqueeze(-1) + if '.time_maa_w1' in k or '.time_decay_w' in k: + tensor = tensor.transpose(0, 1) + if '.time_maa_w2' in k: + tensor = tensor.transpose(1, 2) + if '.time_decay' in k and '_w' not in k: + tensor = tensor.reshape(n_head, -1, 1) + + elif is_v5_1_or_2: if '.time_decay' in k: if is_v5_2: tensor = torch.exp(-torch.exp(tensor)).unsqueeze(-1) @@ -105,7 +120,7 @@ def write_state_dict(state_dict: Dict[str, torch.Tensor], dest_path: str, data_t out_file.write(k_encoded) - tensor.numpy().tofile(out_file) + tensor.detach().numpy().tofile(out_file) def main() -> None: args = parse_args() diff --git a/python/merge_lora_into_ggml.py b/python/merge_lora_into_ggml.py index 39886978..d2dca333 100644 --- a/python/merge_lora_into_ggml.py +++ b/python/merge_lora_into_ggml.py @@ -13,7 +13,7 @@ def parse_args(): parser = argparse.ArgumentParser(description='Merge a PyTorch LoRA checkpoint (.pth) into an rwkv.cpp model file') parser.add_argument('src_path', help='Path to source rwkv.cpp model') - parser.add_argument('rwkv_arch_version', help='Version of RWKV architecture: v4, v5.1, v5.2', type=str, choices=['v4', 'v5.1', 'v5.2']) + parser.add_argument('rwkv_arch_version', help='Version of RWKV architecture: v4, v5.1, v5.2, v6.0', type=str, choices=['v4', 'v5.1', 'v5.2', 'v6.0']) parser.add_argument('lora_path', help='Path to LoRA checkpoint in PyTorch format') parser.add_argument('lora_alpha', help='Value of lora_alpha parameter used when training this LoRA checkpoint', type=int) parser.add_argument('dest_path', help='Path to destination rwkv.cpp model, will be overwitten with the merged model') @@ -47,7 +47,7 @@ def main() -> None: arch_version: str = args.rwkv_arch_version - if not (arch_version == 'v4' or arch_version == 'v5.1' or arch_version == 'v5.2'): + if not (arch_version == 'v4' or arch_version == 'v5.1' or arch_version == 'v5.2' or arch_version == 'v6.0'): raise ValueError(f'Invalid RWKV architecture version {arch_version}') print(f'Reading {args.lora_path}') @@ -108,7 +108,17 @@ def main() -> None: if '.time_' in key: replacement = replacement.squeeze() - if arch_version == 'v5.1' or arch_version == 'v5.2': + if arch_version == 'v6.0': + if '.time_faaaa' in k: + replacement = replacement.unsqueeze(-1) + if '.time_maa_w1' in k or '.time_decay_w' in k: + replacement = replacement.transpose(0, 1) + if '.time_maa_w2' in k: + n_head: int = replacement.shape[1] + replacement = replacement.transpose(1, 2) + if '.time_decay' in k and '_w' not in k: + replacement = replacement.reshape(n_head, -1, 1) + elif arch_version == 'v5.1' or arch_version == 'v5.2': if '.time_decay' in key: if arch_version == 'v5.2': replacement = torch.exp(-torch.exp(replacement)).unsqueeze(-1) diff --git a/rwkv.cpp b/rwkv.cpp index 6fae1521..84f7c1ec 100644 --- a/rwkv.cpp +++ b/rwkv.cpp @@ -49,6 +49,8 @@ static_assert(sizeof(decltype(ftell(NULL))) >= 8, "File offsets should be 64-bit #include "rwkv_operators_wkv_v5.inc" +#include "rwkv_operators_wkv_v6.inc" + #include "rwkv_graph.inc" // API function. diff --git a/rwkv_graph.inc b/rwkv_graph.inc index 90dda815..80e79e6d 100644 --- a/rwkv_graph.inc +++ b/rwkv_graph.inc @@ -320,6 +320,216 @@ static struct ggml_tensor * rwkv_att_v5( return ggml_mul_mat(ctx, layer.att_output, x); } +static struct ggml_tensor * rwkv_att_v6( + struct ggml_context * ctx, + struct ggml_tensor * x, + struct rwkv_layer layer, + struct rwkv_layer_state & state, + const int64_t head_count, + const int64_t head_size, + const uint32_t arch_version_minor +) { + size_t n_embed = x->ne[0]; + size_t sequence_length = x->ne[1]; + + x = rwkv_layer_norm(ctx, x, layer.ln1_weight, layer.ln1_bias); + + struct ggml_tensor * x_prev; + + if (sequence_length > 1) { + x_prev = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embed, sequence_length); + x_prev = ggml_set_1d_inplace(ctx, x_prev, state.att_xx, 0); + x_prev = ggml_set_1d_inplace( + ctx, + x_prev, + ggml_view_1d(ctx, x, n_embed * (sequence_length - 1), 0), n_embed * sizeof(float) + ); + } else { + x_prev = state.att_xx; + } + + // sx = x - state.att_xx + // xxx = x + sx * x_maa + x_prev = ggml_sub_inplace(ctx, x_prev, x); + struct ggml_tensor * xxx = ggml_add_inplace( + ctx, + ggml_mul(ctx, x_prev, layer.att_time_maa_x), + x + ); + + // xxx = tanh(xxx @ tm_w1).view(5, 1, -1) + xxx = ggml_reshape_4d( + ctx, + ggml_tanh_inplace( + ctx, + ggml_mul_mat(ctx, layer.att_time_maa_w1, xxx) + ), + layer.att_time_maa_w1->ne[1] / 5, 1, 5, sequence_length + ); + + xxx = ggml_cont( + ctx, + ggml_permute(ctx, xxx, 0, 1, 3, 2) + ); + + // xxx = torch.bmm(xxx, tm_w2) + xxx = ggml_mul_mat( + ctx, + ggml_reshape_4d( + ctx, + layer.att_time_maa_w2, + layer.att_time_maa_w2->ne[0], layer.att_time_maa_w2->ne[1], 1, 5 + ), + xxx + ); + + xxx = ggml_reshape_2d(ctx, xxx, n_embed * sequence_length, 5); + + struct ggml_tensor * mw = ggml_reshape_2d( + ctx, + ggml_get_rows(ctx, xxx, ggml_new_i32(ctx, 0)), + n_embed, sequence_length + ); + struct ggml_tensor * mk = ggml_reshape_2d( + ctx, + ggml_get_rows(ctx, xxx, ggml_new_i32(ctx, 1)), + n_embed, sequence_length + ); + struct ggml_tensor * mv = ggml_reshape_2d( + ctx, + ggml_get_rows(ctx, xxx, ggml_new_i32(ctx, 2)), + n_embed, sequence_length + ); + struct ggml_tensor * mr = ggml_reshape_2d( + ctx, + ggml_get_rows(ctx, xxx, ggml_new_i32(ctx, 3)), + n_embed, sequence_length + ); + struct ggml_tensor * mg = ggml_reshape_2d( + ctx, + ggml_get_rows(ctx, xxx, ggml_new_i32(ctx, 4)), + n_embed, sequence_length + ); + + struct ggml_tensor * xw = ggml_add_inplace( + ctx, + ggml_mul_inplace( + ctx, + ggml_add_inplace(ctx, mw, layer.att_time_maa_w), + x_prev + ), + x + ); + + struct ggml_tensor * xk = ggml_add_inplace( + ctx, + ggml_mul_inplace( + ctx, + ggml_add_inplace(ctx, mk, layer.att_time_maa_k), + x_prev + ), + x + ); + + struct ggml_tensor * xv = ggml_add_inplace( + ctx, + ggml_mul_inplace( + ctx, + ggml_add_inplace(ctx, mv, layer.att_time_maa_v), + x_prev + ), + x + ); + + struct ggml_tensor * xr = ggml_add_inplace( + ctx, + ggml_mul_inplace( + ctx, + ggml_add_inplace(ctx, mr, layer.att_time_maa_r), + x_prev + ), + x + ); + + struct ggml_tensor * xg = ggml_add_inplace( + ctx, + ggml_mul_inplace( + ctx, + ggml_add_inplace(ctx, mg, layer.att_time_maa_g), + x_prev + ), + x + ); + + state.att_xx = ggml_view_1d(ctx, x, n_embed, n_embed * (sequence_length - 1) * sizeof(float)); + struct ggml_tensor * r = ggml_reshape_4d(ctx, ggml_mul_mat(ctx, layer.att_receptance, xr), head_size, 1, head_count, sequence_length); + struct ggml_tensor * k = ggml_reshape_4d(ctx, ggml_mul_mat(ctx, layer.att_key, xk), 1, head_size, head_count, sequence_length); + struct ggml_tensor * v = ggml_reshape_4d(ctx, ggml_mul_mat(ctx, layer.att_value, xv), head_size, 1, head_count, sequence_length); + struct ggml_tensor * g = ggml_silu_inplace( + ctx, + ggml_mul_mat(ctx, layer.att_gate, xg) + ); + + struct ggml_tensor * w = ggml_mul_mat( + ctx, + layer.att_time_decay_w2, + ggml_tanh_inplace( + ctx, + ggml_mul_mat(ctx, layer.att_time_decay_w1, xw) + ) + ); + w = ggml_add_inplace( + ctx, + w, + ggml_reshape_1d(ctx, layer.att_time_decay, n_embed) + ); + + w = rwkv_exp(ctx, ggml_neg_inplace(ctx, rwkv_exp(ctx, w))); + w = ggml_reshape_4d(ctx, w, 1, head_size, head_count, sequence_length); + + // dup is not strictly required; doing it just in case. + struct ggml_tensor * state_out = ggml_dup(ctx, state.att_heads); + + x = rwkv_wkv_v6( + ctx, + sequence_length, + n_embed, + head_count, + head_size, + x, + k, + v, + r, + layer.att_time_faaaa, + w, + state_out + ); + + state.att_heads = state_out; + + // rwkv/ggml ggml_group_norm uses eps=1e-5, while rwkv v6 uses eps=64e-5 + // Do 1/8 scale to x before group_norm for now. + x = ggml_scale_inplace(ctx, x, ggml_new_f32(ctx, 0.125)); + // ggml_group_norm considers groups in the third dimension. + x = ggml_reshape_4d(ctx, x, 1, 1, n_embed, sequence_length); + x = ggml_group_norm_inplace(ctx, x, head_count); + // Convert back to a regular vector. + x = ggml_reshape_2d(ctx, x, n_embed, sequence_length); + x = ggml_add_inplace( + ctx, + ggml_mul_inplace( + ctx, + x, + layer.att_ln_x_weight + ), + layer.att_ln_x_bias + ); + + x = ggml_mul_inplace(ctx, x, g); + + return ggml_mul_mat(ctx, layer.att_output, x); +} + static struct ggml_tensor * rwkv_ffn(struct ggml_context * ctx, struct ggml_tensor * x, struct rwkv_layer layer, struct rwkv_layer_state & state) { struct ggml_tensor * x_prev; rwkv_carry_x(ctx, layer.ln2_weight, layer.ln2_bias, x, x_prev, state.ffn_xx); @@ -349,6 +559,35 @@ static struct ggml_tensor * rwkv_ffn(struct ggml_context * ctx, struct ggml_tens return ggml_mul_inplace(ctx, r, ggml_mul_mat(ctx, layer.ffn_value, k)); } +static struct ggml_tensor * rwkv_ffn_v6(struct ggml_context * ctx, struct ggml_tensor * x, struct rwkv_layer layer, struct rwkv_layer_state & state) { + struct ggml_tensor * x_prev; + rwkv_carry_x(ctx, layer.ln2_weight, layer.ln2_bias, x, x_prev, state.ffn_xx); + x_prev = ggml_sub_inplace(ctx, x_prev, x); + + // xk = x + sx * time_maa_k + // xr = x + sx * time_maa_r + struct ggml_tensor * xk = ggml_add_inplace( + ctx, + ggml_mul(ctx, x_prev, layer.ffn_time_maa_k), + x + ); + + struct ggml_tensor * xr = ggml_add_inplace( + ctx, + ggml_mul(ctx, x_prev, layer.ffn_time_maa_r), + x + ); + + // r = torch.sigmoid(rw @ xr) + struct ggml_tensor * r = rwkv_sigmoid_inplace(ctx, ggml_mul_mat(ctx, layer.ffn_receptance, xr)); + + // k = torch.square(torch.relu(kw @ xk)) + struct ggml_tensor * k = ggml_sqr_inplace(ctx, ggml_relu_inplace(ctx, ggml_mul_mat(ctx, layer.ffn_key, xk))); + + // r * (vw @ k) + return ggml_mul_inplace(ctx, r, ggml_mul_mat(ctx, layer.ffn_value, k)); +} + static void rwkv_create_input_and_output_views( struct ggml_context * ctx, struct rwkv_layer_state * inputs, @@ -442,8 +681,8 @@ static bool rwkv_build_serial_graph(struct rwkv_model & model, struct rwkv_compu struct rwkv_layer_state state = inputs[i]; - x = model.arch_version_major >= 5 ? - ggml_add_inplace(ctx, x, rwkv_att_v5( + if (model.arch_version_major == 6) { + x = ggml_add_inplace(ctx, x, rwkv_att_v6( ctx, x, layer, @@ -451,10 +690,24 @@ static bool rwkv_build_serial_graph(struct rwkv_model & model, struct rwkv_compu model.head_count, model.head_size, model.arch_version_minor - )) : - ggml_add_inplace(ctx, x, rwkv_att(ctx, x, layer, state)); + )); - x = ggml_add_inplace(ctx, x, rwkv_ffn(ctx, x, layer, state)); + x = ggml_add_inplace(ctx, x, rwkv_ffn_v6(ctx, x, layer, state)); + } else { + x = model.arch_version_major >= 5 ? + ggml_add_inplace(ctx, x, rwkv_att_v5( + ctx, + x, + layer, + state, + model.head_count, + model.head_size, + model.arch_version_minor + )) : + ggml_add_inplace(ctx, x, rwkv_att(ctx, x, layer, state)); + + x = ggml_add_inplace(ctx, x, rwkv_ffn(ctx, x, layer, state)); + } struct rwkv_layer_state & output_state = outputs[i]; @@ -567,7 +820,17 @@ static bool rwkv_build_sequential_graph(struct rwkv_model & model, struct rwkv_c struct rwkv_layer_state state = inputs[i]; - if (model.arch_version_major >= 5) { + if (model.arch_version_major == 6) { + x = ggml_add_inplace(ctx, x, rwkv_att_v6( + ctx, + x, + layer, + state, + model.head_count, + model.head_size, + model.arch_version_minor + )); + } else if (model.arch_version_major >= 5) { x = ggml_add_inplace(ctx, x, rwkv_att_v5( ctx, x, @@ -598,7 +861,11 @@ static bool rwkv_build_sequential_graph(struct rwkv_model & model, struct rwkv_c } // TODO Can we skip ffn for all but the last token, the same way we skip unembedding? - x = ggml_add_inplace(ctx, x, rwkv_ffn(ctx, x, layer, state)); + if (model.arch_version_major == 6) { + x = ggml_add_inplace(ctx, x, rwkv_ffn_v6(ctx, x, layer, state)); + } else { + x = ggml_add_inplace(ctx, x, rwkv_ffn(ctx, x, layer, state)); + } struct rwkv_layer_state & output_state = outputs[i]; diff --git a/rwkv_model_loading.inc b/rwkv_model_loading.inc index fef0ea92..ba58acb5 100644 --- a/rwkv_model_loading.inc +++ b/rwkv_model_loading.inc @@ -23,12 +23,29 @@ struct rwkv_layer { struct ggml_tensor * att_time_mix_g; struct ggml_tensor * att_gate; + // Added in RWKV v6. + struct ggml_tensor * att_time_maa_x; + struct ggml_tensor * att_time_maa_w; + struct ggml_tensor * att_time_maa_k; + struct ggml_tensor * att_time_maa_v; + struct ggml_tensor * att_time_maa_r; + struct ggml_tensor * att_time_maa_g; + struct ggml_tensor * att_time_maa_w1; + struct ggml_tensor * att_time_maa_w2; + struct ggml_tensor * att_time_decay_w1; + struct ggml_tensor * att_time_decay_w2; + struct ggml_tensor * ln2_weight; struct ggml_tensor * ln2_bias; // FFN. struct ggml_tensor * ffn_time_mix_k; struct ggml_tensor * ffn_time_mix_r; + + // Added in RWKV v6. + struct ggml_tensor * ffn_time_maa_k; + struct ggml_tensor * ffn_time_maa_r; + struct ggml_tensor * ffn_key; struct ggml_tensor * ffn_value; struct ggml_tensor * ffn_receptance; @@ -101,37 +118,66 @@ static bool rwkv_set_params(struct rwkv_model & model, F callback) { RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "ln1.weight"), buffer), layer.ln1_weight)); RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "ln1.bias"), buffer), layer.ln1_bias)); - RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.time_mix_k"), buffer), layer.att_time_mix_k)); - RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.time_mix_v"), buffer), layer.att_time_mix_v)); - RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.time_mix_r"), buffer), layer.att_time_mix_r)); + if (model.arch_version_major == 6) { + RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.time_maa_x"), buffer), layer.att_time_maa_x)); + RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.time_maa_w"), buffer), layer.att_time_maa_w)); + RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.time_maa_k"), buffer), layer.att_time_maa_k)); + RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.time_maa_v"), buffer), layer.att_time_maa_v)); + RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.time_maa_r"), buffer), layer.att_time_maa_r)); + RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.time_maa_g"), buffer), layer.att_time_maa_g)); + RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.time_maa_w1"), buffer), layer.att_time_maa_w1)); + RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.time_maa_w2"), buffer), layer.att_time_maa_w2)); - if (model.arch_version_major >= 5 && model.arch_version_minor >= 2) { RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.time_faaaa"), buffer), layer.att_time_faaaa)); + RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.time_decay"), buffer), layer.att_time_decay)); + RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.time_decay_w1"), buffer), layer.att_time_decay_w1)); + RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.time_decay_w2"), buffer), layer.att_time_decay_w2)); + RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.key.weight"), buffer), layer.att_key)); + RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.value.weight"), buffer), layer.att_value)); + RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.receptance.weight"), buffer), layer.att_receptance)); + RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.gate.weight"), buffer), layer.att_gate)); + RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.output.weight"), buffer), layer.att_output)); + RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.ln_x.weight"), buffer), layer.att_ln_x_weight)); + RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.ln_x.bias"), buffer), layer.att_ln_x_bias)); } else { - RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.time_first"), buffer), layer.att_time_first)); - } + RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.time_mix_k"), buffer), layer.att_time_mix_k)); + RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.time_mix_v"), buffer), layer.att_time_mix_v)); + RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.time_mix_r"), buffer), layer.att_time_mix_r)); + + if (model.arch_version_major >= 5 && model.arch_version_minor >= 2) { + RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.time_faaaa"), buffer), layer.att_time_faaaa)); + } else { + RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.time_first"), buffer), layer.att_time_first)); + } - RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.time_decay"), buffer), layer.att_time_decay)); - RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.key.weight"), buffer), layer.att_key)); - RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.value.weight"), buffer), layer.att_value)); - RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.receptance.weight"), buffer), layer.att_receptance)); - RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.output.weight"), buffer), layer.att_output)); + RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.time_decay"), buffer), layer.att_time_decay)); + RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.key.weight"), buffer), layer.att_key)); + RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.value.weight"), buffer), layer.att_value)); + RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.receptance.weight"), buffer), layer.att_receptance)); + RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.output.weight"), buffer), layer.att_output)); - if (model.arch_version_major >= 5) { - RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.ln_x.weight"), buffer), layer.att_ln_x_weight)); - RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.ln_x.bias"), buffer), layer.att_ln_x_bias)); + if (model.arch_version_major >= 5) { + RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.ln_x.weight"), buffer), layer.att_ln_x_weight)); + RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.ln_x.bias"), buffer), layer.att_ln_x_bias)); - if (model.arch_version_minor >= 2) { - RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.time_mix_g"), buffer), layer.att_time_mix_g)); - RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.gate.weight"), buffer), layer.att_gate)); + if (model.arch_version_minor >= 2) { + RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.time_mix_g"), buffer), layer.att_time_mix_g)); + RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.gate.weight"), buffer), layer.att_gate)); + } } } RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "ln2.weight"), buffer), layer.ln2_weight)); RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "ln2.bias"), buffer), layer.ln2_bias)); - RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "ffn.time_mix_k"), buffer), layer.ffn_time_mix_k)); - RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "ffn.time_mix_r"), buffer), layer.ffn_time_mix_r)); + if (model.arch_version_major == 6) { + RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "ffn.time_maa_k"), buffer), layer.ffn_time_maa_k)); + RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "ffn.time_maa_r"), buffer), layer.ffn_time_maa_r)); + } else { + RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "ffn.time_mix_k"), buffer), layer.ffn_time_mix_k)); + RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "ffn.time_mix_r"), buffer), layer.ffn_time_mix_r)); + } + RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "ffn.key.weight"), buffer), layer.ffn_key)); RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "ffn.value.weight"), buffer), layer.ffn_value)); RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "ffn.receptance.weight"), buffer), layer.ffn_receptance)); @@ -186,6 +232,11 @@ static bool rwkv_load_model_from_file(const char * file_path, struct rwkv_model } } + if (parameters.find("blocks.0.att.time_maa_x") != parameters.end()) { + model.arch_version_major = 6; + model.arch_version_minor = 0; + } + std::unordered_map & parameters_ref = parameters; RWKV_ASSERT_NULL(RWKV_ERROR_MODEL_PARAMS | RWKV_ERROR_PARAM_MISSING, rwkv_set_params( model, diff --git a/rwkv_operators_wkv_common.inc b/rwkv_operators_wkv_common.inc new file mode 100644 index 00000000..94e36aaf --- /dev/null +++ b/rwkv_operators_wkv_common.inc @@ -0,0 +1,35 @@ +// Ported from https://github.com/harrisonvanderbyl/RNN-Factory/blob/3b696b547cc9e25de04a077602c3fe1133d8984c/src/models/modules/cuda/cpuonly.cpp#L8 +// Original code by Harrison Vanderbyl. +// TODO Fix 1. unaligned memory access on Linux with AVX2, 2. tiny-rwkv with AVX-512 +/*#ifdef __AVX512F__ + #include + #define SIMD_WIDTH 16 + #define LOAD(x) _mm512_load_ps(x) + #define STORE(x, y) _mm512_store_ps(x, y) + #define SET1(x) _mm512_set1_ps(x) + #define MULTIPLY(x, y) _mm512_mul_ps(x, y) + #define MULTADD(x, y, z) _mm512_fmadd_ps(x, y, z) +#elif __AVX2__ + #include + #define SIMD_WIDTH 8 + #define LOAD(x) _mm256_load_ps(x) + #define STORE(x, y) _mm256_store_ps(x, y) + #define SET1(x) _mm256_set1_ps(x) + #define MULTIPLY(x, y) _mm256_mul_ps(x, y) + #define MULTADD(x, y, z) _mm256_fmadd_ps(x, y, z) +#elif defined(__ARM_NEON) || defined(__ARM_NEON__) + #include + #define SIMD_WIDTH 4 + #define LOAD(x) vld1q_f32(x) + #define STORE(x, y) vst1q_f32(x, y) + #define SET1(x) vdupq_n_f32(x) + #define MULTIPLY(x, y) vmulq_f32(x, y) + #define MULTADD(x, y, z) vmlaq_f32(z, x, y) +#else*/ + #define SIMD_WIDTH 1 + #define LOAD(x) *x + #define STORE(x, y) *x = y + #define SET1(x) x + #define MULTIPLY(x, y) x * y + #define MULTADD(x, y, z) x * y + z +//#endif diff --git a/rwkv_operators_wkv_v5.inc b/rwkv_operators_wkv_v5.inc index f9502d4a..570d4854 100644 --- a/rwkv_operators_wkv_v5.inc +++ b/rwkv_operators_wkv_v5.inc @@ -1,38 +1,4 @@ -// Ported from https://github.com/harrisonvanderbyl/RNN-Factory/blob/3b696b547cc9e25de04a077602c3fe1133d8984c/src/models/modules/cuda/cpuonly.cpp#L8 -// Original code by Harrison Vanderbyl. -// TODO Fix 1. unaligned memory access on Linux with AVX2, 2. tiny-rwkv with AVX-512 -/*#ifdef __AVX512F__ - #include - #define SIMD_WIDTH 16 - #define LOAD(x) _mm512_load_ps(x) - #define STORE(x, y) _mm512_store_ps(x, y) - #define SET1(x) _mm512_set1_ps(x) - #define MULTIPLY(x, y) _mm512_mul_ps(x, y) - #define MULTADD(x, y, z) _mm512_fmadd_ps(x, y, z) -#elif __AVX2__ - #include - #define SIMD_WIDTH 8 - #define LOAD(x) _mm256_load_ps(x) - #define STORE(x, y) _mm256_store_ps(x, y) - #define SET1(x) _mm256_set1_ps(x) - #define MULTIPLY(x, y) _mm256_mul_ps(x, y) - #define MULTADD(x, y, z) _mm256_fmadd_ps(x, y, z) -#elif defined(__ARM_NEON) || defined(__ARM_NEON__) - #include - #define SIMD_WIDTH 4 - #define LOAD(x) vld1q_f32(x) - #define STORE(x, y) vst1q_f32(x, y) - #define SET1(x) vdupq_n_f32(x) - #define MULTIPLY(x, y) vmulq_f32(x, y) - #define MULTADD(x, y, z) vmlaq_f32(z, x, y) -#else*/ - #define SIMD_WIDTH 1 - #define LOAD(x) *x - #define STORE(x, y) *x = y - #define SET1(x) x - #define MULTIPLY(x, y) x * y - #define MULTADD(x, y, z) x * y + z -//#endif +#include "rwkv_operators_wkv_common.inc" // Ported from https://github.com/harrisonvanderbyl/RNN-Factory/blob/3b696b547cc9e25de04a077602c3fe1133d8984c/src/models/modules/cuda/cpuonly.cpp#L57 // Original code by Harrison Vanderbyl. diff --git a/rwkv_operators_wkv_v6.inc b/rwkv_operators_wkv_v6.inc new file mode 100644 index 00000000..0cc3e131 --- /dev/null +++ b/rwkv_operators_wkv_v6.inc @@ -0,0 +1,146 @@ +#include "rwkv_operators_wkv_common.inc" + +// Ported from https://github.com/harrisonvanderbyl/RNN-Factory/blob/3b696b547cc9e25de04a077602c3fe1133d8984c/src/models/modules/cuda/cpuonly.cpp#L57 +// Original code by Harrison Vanderbyl. +static void rwkv_wkv_v6_impl(struct ggml_tensor * result, const struct ggml_tensor * src, int ith, int nth, void * userdata) { + const size_t T = result->ne[1]; + const size_t C = result->ne[0]; + const size_t H = result->src[1]->ne[2]; + + float * result_data = (float *) result->data; + + memset(result_data, 0, T * C * sizeof(float)); + + float * k = (float *) result->src[1]->data; + float * v = (float *) result->src[2]->data; + float * r = (float *) result->src[3]->data; + float * time_faaaa = (float *) result->src[4]->data; + float * time_decay = (float *) result->src[5]->data; + float * state = (float *) result->src[6]->data; + + size_t t_stride = H * (C / H); + + size_t h_stride = C / H; + size_t h_stride_2d = (C / H) * (C / H); + + for (size_t t = 0; t < T; t++) { + size_t t_offset = t * t_stride; + + for (size_t h = 0; h < H; h++) { + size_t h_offset = h * h_stride; + size_t t_h_offset = t_offset + h_offset; + size_t h_2d_offset = h * h_stride_2d; + + for (size_t i = 0; i < C / H; i++) { + size_t t_h_i_offset = t_h_offset + i; + size_t h_i_offset = h_offset + i; + size_t h_2d_i_offset = h_2d_offset + i * h_stride; + + auto k_val = SET1(k[t_h_i_offset]); + auto r_val = SET1(r[t_h_i_offset]); + auto time_faaaa_val = SET1(time_faaaa[h_i_offset]); + // RWKV v6: different time_decay for each token. + auto time_decay_val = SET1(time_decay[t_h_i_offset]); + + for (size_t j = 0; j < C / H; j += SIMD_WIDTH) { + size_t t_h_j_offset = t_h_offset + j; + size_t h_2d_i_j_offset = h_2d_i_offset + j; + + auto v_val = LOAD(&v[t_h_j_offset]); + + auto kv_val = MULTIPLY(v_val, k_val); + + auto prev_state_val = LOAD(&state[h_2d_i_j_offset]); + + auto temp_val = MULTADD(kv_val, time_faaaa_val, prev_state_val); + + auto prev_result_data = LOAD(&result_data[t_h_j_offset]); + + STORE(&result_data[t_h_j_offset], MULTADD(temp_val, r_val, prev_result_data)); + + STORE(&state[h_2d_i_j_offset], MULTADD(prev_state_val, time_decay_val, kv_val)); + } + } + } + } + + // Suppress "unused parameter" warnings. + (void) src; + (void) ith; + (void) nth; + (void) userdata; +} + +// Parameters: +// - T: sequence length +// - C: channel count, same as n_embed +// - H: head count +// - S: head size +// Shapes (in ggml order): +// - x: [C, T, 1, 1] +// - k: [1, S, H, T] +// - v: [S, 1, H, T] +// - r: [S, 1, H, T] +// - time_faaaa: [1, S, H, 1] +// - w: [1, S, H, T] +// - state: [S * S * H, 1, 1, 1] +// - result: same as x +// state will be written to. +static struct ggml_tensor * rwkv_wkv_v6( + struct ggml_context * ctx, + const size_t T, + const size_t C, + const size_t H, + const size_t S, + struct ggml_tensor * x, + struct ggml_tensor * k, + struct ggml_tensor * v, + struct ggml_tensor * r, + struct ggml_tensor * time_faaaa, + struct ggml_tensor * w, + struct ggml_tensor * state +) { + GGML_ASSERT(x->type == GGML_TYPE_F32); + GGML_ASSERT(k->type == GGML_TYPE_F32); + GGML_ASSERT(v->type == GGML_TYPE_F32); + GGML_ASSERT(r->type == GGML_TYPE_F32); + GGML_ASSERT(time_faaaa->type == GGML_TYPE_F32); + GGML_ASSERT(w->type == GGML_TYPE_F32); + GGML_ASSERT(state->type == GGML_TYPE_F32); + + GGML_ASSERT(ggml_is_contiguous(x)); + GGML_ASSERT(ggml_is_contiguous(k)); + GGML_ASSERT(ggml_is_contiguous(v)); + GGML_ASSERT(ggml_is_contiguous(r)); + GGML_ASSERT(ggml_is_contiguous(time_faaaa)); + GGML_ASSERT(ggml_is_contiguous(w)); + GGML_ASSERT(ggml_is_contiguous(state)); + + GGML_ASSERT(x->ne[0] == C && x->ne[1] == T && x->ne[2] == 1 && x->ne[3] == 1); + GGML_ASSERT(k->ne[0] == 1 && k->ne[1] == S && k->ne[2] == H && k->ne[3] == T); + GGML_ASSERT(v->ne[0] == S && v->ne[1] == 1 && v->ne[2] == H && v->ne[3] == T); + GGML_ASSERT(r->ne[0] == S && r->ne[1] == 1 && r->ne[2] == H && r->ne[3] == T); + GGML_ASSERT(w->ne[0] == 1 && w->ne[1] == S && w->ne[2] == H && w->ne[3] == T); + GGML_ASSERT(ggml_nelements(state) == S * S * H); + + k = ggml_cont_inplace(ctx, ggml_transpose(ctx, k)); + v = ggml_cont_inplace(ctx, ggml_transpose(ctx, v)); + r = ggml_cont_inplace(ctx, ggml_transpose(ctx, r)); + + struct ggml_tensor * result = ggml_map_custom1( + ctx, + x, + rwkv_wkv_v6_impl, + 1, + NULL + ); + result->src[1] = k; + result->src[2] = v; + result->src[3] = r; + result->src[4] = time_faaaa; + result->src[5] = w; + // GGML_MAX_SRC must be increased from 6 to 8 for this. + result->src[6] = state; + + return result; +} diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 60c783a7..506591ff 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -39,6 +39,12 @@ file(COPY tiny-rwkv-5v2-730K-Q5_0.bin DESTINATION ${CMAKE_CURRENT_BINARY_DIR}) file(COPY tiny-rwkv-5v2-730K-Q5_1.bin DESTINATION ${CMAKE_CURRENT_BINARY_DIR}) file(COPY expected-logits-5v2-730K.bin DESTINATION ${CMAKE_CURRENT_BINARY_DIR}) +file(COPY tiny-rwkv-6v0-3m-FP32.bin DESTINATION ${CMAKE_CURRENT_BINARY_DIR}) +file(COPY tiny-rwkv-6v0-3m-FP16.bin DESTINATION ${CMAKE_CURRENT_BINARY_DIR}) +file(COPY tiny-rwkv-6v0-3m-Q5_0.bin DESTINATION ${CMAKE_CURRENT_BINARY_DIR}) +file(COPY tiny-rwkv-6v0-3m-Q5_1.bin DESTINATION ${CMAKE_CURRENT_BINARY_DIR}) +file(COPY expected-logits-6v0-3m.bin DESTINATION ${CMAKE_CURRENT_BINARY_DIR}) + rwkv_add_test(test_ggml_basics.c) rwkv_add_test(test_quantized_matmul_on_gpu.c) rwkv_add_test(test_tiny_rwkv.c) diff --git a/tests/expected-logits-6v0-3m.bin b/tests/expected-logits-6v0-3m.bin new file mode 100644 index 00000000..9504bb93 Binary files /dev/null and b/tests/expected-logits-6v0-3m.bin differ diff --git a/tests/test_quantization_format_compatibility.c b/tests/test_quantization_format_compatibility.c index 4edb3706..6ad41136 100644 --- a/tests/test_quantization_format_compatibility.c +++ b/tests/test_quantization_format_compatibility.c @@ -6,7 +6,7 @@ #include "logit_difference_validator.inc" -#define VERSION_COUNT 3 +#define VERSION_COUNT 4 int main(void) { fprintf(stderr, "System info: %s\n", rwkv_get_system_info_string()); @@ -14,7 +14,8 @@ int main(void) { const char * versions[VERSION_COUNT] = { "4v0-660K", "5v1-730K", - "5v2-730K" + "5v2-730K", + "6v0-3m" }; // See the explanation of huge expected differences for v5 models in test_tiny_rwkv.c @@ -27,7 +28,10 @@ int main(void) { -018.017435F, // 5v2 +025.273308F, - +048.068733F + +048.068733F, + // 6v0 + -019.400530F, + +003.576909F }; for (int i = 0; i < VERSION_COUNT; i++) { diff --git a/tests/test_tiny_rwkv.c b/tests/test_tiny_rwkv.c index 291142a6..5b0b7769 100644 --- a/tests/test_tiny_rwkv.c +++ b/tests/test_tiny_rwkv.c @@ -6,7 +6,7 @@ #include "logit_difference_validator.inc" -#define VERSION_COUNT 3 +#define VERSION_COUNT 4 #define FORMAT_COUNT 7 int main(void) { @@ -20,7 +20,8 @@ int main(void) { const char * versions[VERSION_COUNT] = { "4v0-660K", "5v1-730K", - "5v2-730K" + "5v2-730K", + "6v0-3m" }; const char * formats[FORMAT_COUNT] = { @@ -36,13 +37,16 @@ int main(void) { const float expected_difference_sum_full[VERSION_COUNT * 2] = { // 4v0 +0.001000F, // FP32 - -0.005320F, // FP16 + -0.013652F, // FP16 // 5v1 +0.001000F, // FP32 -0.289921F, // FP16 // 5v2 +0.001000F, // FP32 - +0.206919F // FP16 + +0.455912F, // FP16 + // 6v0 + +0.001566F, // FP32 + -0.416620F // FP16 }; // *** Why the hell the expected logit difference sum for v4 models is < 1, and for v5 models it can be as high as 160? *** @@ -77,16 +81,22 @@ int main(void) { +061.719509F, // Q4_1 +025.273308F, // Q5_0 +048.068733F, // Q5_1 - -009.441034F // Q8_0 + -009.441034F, // Q8_0 + // 6v0 + -003.824263F, // Q4_0 + +021.939022F, // Q4_1 + -021.191444F, // Q5_0 + +003.576909F, // Q5_1 + -009.539596F // Q8_0 }; const float expected_difference_sum_quantized_FP16[VERSION_COUNT * (FORMAT_COUNT - 2)] = { // 4v0 +000.154614F, // Q4_0 -000.539827F, // Q4_1 - -000.170043F, // Q5_0 + -000.180142F, // Q5_0 +000.294953F, // Q5_1 - +000.070944F, // Q8_0 + +000.077226F, // Q8_0 // 5v1 +119.471931F, // Q4_0 -028.245888F, // Q4_1 @@ -98,7 +108,13 @@ int main(void) { +059.066830F, // Q4_1 +021.588751F, // Q5_0 +029.726818F, // Q5_1 - -007.242277F // Q8_0 + -007.242277F, // Q8_0 + // 6v0 + -003.487368F, // Q4_0 + +021.797060F, // Q4_1 + -021.271053F, // Q5_0 + +003.405264F, // Q5_1 + -009.734720F // Q8_0 }; for (int i_version = 0; i_version < VERSION_COUNT; i_version++) { diff --git a/tests/tiny-rwkv-6v0-3m-FP16.bin b/tests/tiny-rwkv-6v0-3m-FP16.bin new file mode 100644 index 00000000..ee87d077 Binary files /dev/null and b/tests/tiny-rwkv-6v0-3m-FP16.bin differ diff --git a/tests/tiny-rwkv-6v0-3m-FP32.bin b/tests/tiny-rwkv-6v0-3m-FP32.bin new file mode 100644 index 00000000..e091b0e4 Binary files /dev/null and b/tests/tiny-rwkv-6v0-3m-FP32.bin differ diff --git a/tests/tiny-rwkv-6v0-3m-Q5_0.bin b/tests/tiny-rwkv-6v0-3m-Q5_0.bin new file mode 100644 index 00000000..5c5a4d4a Binary files /dev/null and b/tests/tiny-rwkv-6v0-3m-Q5_0.bin differ diff --git a/tests/tiny-rwkv-6v0-3m-Q5_1.bin b/tests/tiny-rwkv-6v0-3m-Q5_1.bin new file mode 100644 index 00000000..b8d040e7 Binary files /dev/null and b/tests/tiny-rwkv-6v0-3m-Q5_1.bin differ