Skip to content

Commit

Permalink
RWKV v6: Make outputs correct and update test values
Browse files Browse the repository at this point in the history
Signed-off-by: Molly Sophia <mollysophia379@gmail.com>
  • Loading branch information
MollySophia committed Jun 24, 2024
1 parent 2c14946 commit edea0c2
Show file tree
Hide file tree
Showing 9 changed files with 21 additions and 14 deletions.
3 changes: 2 additions & 1 deletion python/convert_pytorch_to_ggml.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ def write_state_dict(state_dict: Dict[str, torch.Tensor], dest_path: str, data_t
1 if is_FP16 else 0
))

if is_v6_0:
n_head: int = state_dict['blocks.0.att.time_faaaa'].shape[0]
for k in state_dict.keys():
tensor: torch.Tensor = state_dict[k].float()

Expand All @@ -72,7 +74,6 @@ def write_state_dict(state_dict: Dict[str, torch.Tensor], dest_path: str, data_t
if '.time_maa_w1' in k or '.time_decay_w' in k:
tensor = tensor.transpose(0, 1)
if '.time_maa_w2' in k:
n_head: int = tensor.shape[1]
tensor = tensor.transpose(1, 2)
if '.time_decay' in k and '_w' not in k:
tensor = tensor.reshape(n_head, -1, 1)
Expand Down
4 changes: 2 additions & 2 deletions rwkv_graph.inc
Original file line number Diff line number Diff line change
Expand Up @@ -364,7 +364,7 @@ static struct ggml_tensor * rwkv_att_v6(
ctx,
ggml_mul_mat(ctx, layer.att_time_maa_w1, xxx)
),
head_count, 1, 5, sequence_length
layer.att_time_maa_w1->ne[1] / 5, 1, 5, sequence_length
);

xxx = ggml_cont(
Expand All @@ -378,7 +378,7 @@ static struct ggml_tensor * rwkv_att_v6(
ggml_reshape_4d(
ctx,
layer.att_time_maa_w2,
head_count, n_embed, 1, 5
layer.att_time_maa_w2->ne[0], layer.att_time_maa_w2->ne[1], 1, 5
),
xxx
);
Expand Down
6 changes: 6 additions & 0 deletions tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,12 @@ file(COPY tiny-rwkv-5v2-730K-Q5_0.bin DESTINATION ${CMAKE_CURRENT_BINARY_DIR})
file(COPY tiny-rwkv-5v2-730K-Q5_1.bin DESTINATION ${CMAKE_CURRENT_BINARY_DIR})
file(COPY expected-logits-5v2-730K.bin DESTINATION ${CMAKE_CURRENT_BINARY_DIR})

file(COPY tiny-rwkv-6v0-1m-FP32.bin DESTINATION ${CMAKE_CURRENT_BINARY_DIR})
file(COPY tiny-rwkv-6v0-1m-FP16.bin DESTINATION ${CMAKE_CURRENT_BINARY_DIR})
file(COPY tiny-rwkv-6v0-1m-Q5_0.bin DESTINATION ${CMAKE_CURRENT_BINARY_DIR})
file(COPY tiny-rwkv-6v0-1m-Q5_1.bin DESTINATION ${CMAKE_CURRENT_BINARY_DIR})
file(COPY expected-logits-6v0-1m.bin DESTINATION ${CMAKE_CURRENT_BINARY_DIR})

rwkv_add_test(test_ggml_basics.c)
rwkv_add_test(test_quantized_matmul_on_gpu.c)
rwkv_add_test(test_tiny_rwkv.c)
Expand Down
Binary file modified tests/expected-logits-6v0-1m.bin
Binary file not shown.
22 changes: 11 additions & 11 deletions tests/test_tiny_rwkv.c
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ int main(void) {
+0.206919F, // FP16
// 6v0
+0.001000F, // FP32
+0.206919F // FP16
-0.184410F // FP16
};

// *** Why the hell the expected logit difference sum for v4 models is < 1, and for v5 models it can be as high as 160? ***
Expand Down Expand Up @@ -83,11 +83,11 @@ int main(void) {
+048.068733F, // Q5_1
-009.441034F, // Q8_0
// 6v0
+035.271305F, // Q4_0
+061.719509F, // Q4_1
+025.273308F, // Q5_0
+048.068733F, // Q5_1
-009.441034F // Q8_0
+039.715752F, // Q4_0
+049.779972F, // Q4_1
-005.838017F, // Q5_0
-017.046452F, // Q5_1
-000.220227F // Q8_0
};

const float expected_difference_sum_quantized_FP16[VERSION_COUNT * (FORMAT_COUNT - 2)] = {
Expand All @@ -110,11 +110,11 @@ int main(void) {
+029.726818F, // Q5_1
-007.242277F, // Q8_0
// 6v0
+034.135971F, // Q4_0
+059.066830F, // Q4_1
+021.588751F, // Q5_0
+029.726818F, // Q5_1
-007.242277F // Q8_0
+039.676075F, // Q4_0
+049.956646F, // Q4_1
-006.077929F, // Q5_0
-016.773785F, // Q5_1
-000.038582F // Q8_0
};

for (int i_version = 0; i_version < VERSION_COUNT; i_version++) {
Expand Down
Binary file modified tests/tiny-rwkv-6v0-1m-FP16.bin
Binary file not shown.
Binary file modified tests/tiny-rwkv-6v0-1m-FP32.bin
Binary file not shown.
Binary file modified tests/tiny-rwkv-6v0-1m-Q5_0.bin
Binary file not shown.
Binary file modified tests/tiny-rwkv-6v0-1m-Q5_1.bin
Binary file not shown.

0 comments on commit edea0c2

Please sign in to comment.