Skip to content

Commit

Permalink
bump rwkv.cpp (rwkv6 support)
Browse files Browse the repository at this point in the history
  • Loading branch information
josStorer committed Jul 25, 2024
1 parent 97a3cd3 commit baa0811
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 3 deletions.
21 changes: 18 additions & 3 deletions backend-python/convert_pytorch_to_ggml.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,11 @@ def write_state_dict(

is_v5_1_or_2: bool = "blocks.0.att.ln_x.weight" in state_dict
is_v5_2: bool = "blocks.0.att.gate.weight" in state_dict
is_v6_0: bool = 'blocks.0.att.time_maa_x' in state_dict

if is_v5_2:
if is_v6_0:
print('Detected RWKV v6.0')
elif is_v5_2:
print("Detected RWKV v5.2")
elif is_v5_1_or_2:
print("Detected RWKV v5.1")
Expand All @@ -81,13 +84,25 @@ def write_state_dict(
)
)

if is_v6_0:
n_head: int = state_dict['blocks.0.att.time_faaaa'].shape[0]
for k in state_dict.keys():
tensor: torch.Tensor = state_dict[k].float()

if ".time_" in k:
tensor = tensor.squeeze()

if is_v5_1_or_2:
if is_v6_0:
if '.time_faaaa' in k:
tensor = tensor.unsqueeze(-1)
if '.time_maa_w1' in k or '.time_decay_w' in k:
tensor = tensor.transpose(0, 1)
if '.time_maa_w2' in k:
tensor = tensor.transpose(1, 2)
if '.time_decay' in k and '_w' not in k:
tensor = tensor.reshape(n_head, -1, 1)

elif is_v5_1_or_2:
if ".time_decay" in k:
if is_v5_2:
tensor = torch.exp(-torch.exp(tensor)).unsqueeze(-1)
Expand Down Expand Up @@ -131,7 +146,7 @@ def write_state_dict(

out_file.write(k_encoded)

tensor.numpy().tofile(out_file)
tensor.detach().numpy().tofile(out_file)


def main() -> None:
Expand Down
Binary file modified backend-python/rwkv_pip/cpp/librwkv.dylib
Binary file not shown.
Binary file modified backend-python/rwkv_pip/cpp/librwkv.so
Binary file not shown.
Binary file modified backend-python/rwkv_pip/cpp/rwkv.dll
Binary file not shown.

0 comments on commit baa0811

Please sign in to comment.