Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add initial support for RWKV v6 #174

Merged
merged 8 commits into from
Jul 2, 2024
Merged

Conversation

MollySophia
Copy link
Contributor

@MollySophia MollySophia commented Jun 22, 2024

Add initial support for RWKV v6.
Tested sequence prefill and normal inference with RWKV v6 1.6B and it gets exactly the same texts with rwkv-pip-package when generating with FP32, top_k=0.
Precise testing of differences of logits is not done yet.

Regarding the tests, I wonder if the training of a tiny-rwkv-v6 is needed (since rwkv-x060-173m-pile is still a bit large for testing use) ?

Edit: I'm new to ggml :P Feel free to point out anything optimizable

sequence inferencing doesn't work correctly yet

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>
Thanks to @cryscan

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>
Signed-off-by: Molly Sophia <mollysophia379@gmail.com>
@saharNooby
Copy link
Collaborator

HI!

Regarding the tests, I wonder if the training of a tiny-rwkv-v6 is needed

Testing tiny-rwkv in CI process serves 2 goals:

  • comprehensive testing before merge: you've verified that the model produces identical texts for FP32 on your OS & architecture; but tests of rwkv-tiny test all combinations of all formats (including quantized formats) and OS & architectures (that are available in GitHub Actions). They also compare whole logit vectors instead of texts (which look only at the top token). They also enable sanitizer, which detects memory use mistakes that may not be visible during manual testing.
  • comprehensive testing of each subsequent change after merge: these tests ensure that any change, no matter how small or unrelated, will not break all supported RWKV versions in all formats on most OS & architectures.

No one can expect from an engineer to test every single combination after each commit. I definitely did not expect if from myself, so that's why these tests were added.

I would strongly recommend training tiny-rwkv v6 and integrating it into existing tests. It would take some significant time, but I believe the resulting quality is worth it :)

Unfortunately, the code that I've used to train tiny-rwkv is not open-source. Nevertheless, here is some information, including hyerparameter values, that may help get you started.

@MollySophia
Copy link
Contributor Author

HI!

Regarding the tests, I wonder if the training of a tiny-rwkv-v6 is needed

Testing tiny-rwkv in CI process serves 2 goals:

  • comprehensive testing before merge: you've verified that the model produces identical texts for FP32 on your OS & architecture; but tests of rwkv-tiny test all combinations of all formats (including quantized formats) and OS & architectures (that are available in GitHub Actions). They also compare whole logit vectors instead of texts (which look only at the top token). They also enable sanitizer, which detects memory use mistakes that may not be visible during manual testing.
  • comprehensive testing of each subsequent change after merge: these tests ensure that any change, no matter how small or unrelated, will not break all supported RWKV versions in all formats on most OS & architectures.

No one can expect from an engineer to test every single combination after each commit. I definitely did not expect if from myself, so that's why these tests were added.

I would strongly recommend training tiny-rwkv v6 and integrating it into existing tests. It would take some significant time, but I believe the resulting quality is worth it :)

Unfortunately, the code that I've used to train tiny-rwkv is not open-source. Nevertheless, here is some information, including hyerparameter values, that may help get you started.

Hi,
I've trained a tiny v6 model (tiny-rwkv-6v0-1m.pth.zip) with alpaca_data_cleaned.json using the official RWKV-LM trainer for 1 epoch, with no tokenizer.
The parameters are basicly the default ones in RWKV-LM repo's script, with the modifications below:

N_LAYER="12"
N_EMBD="64"
CTX_LEN="512"
vocab_size=256
head_size=8

The tiny model can output correct words, so I guess this is enough?

@MollySophia
Copy link
Contributor Author

MollySophia commented Jun 24, 2024

Update: The tiny-rwkv testing indeed revealed that there's still something wrong in my code, which I'm currently struggling to debug...
Both tiny-rwkv and rwkv v6 1.6b gets about -84 ~ -90 difference sum on logits comparing with the results of rwkv pip package. I guess this isn't normal since both are using FP32 computation?
@saharNooby Do you have any ideas about how to print out values of specific tensor on specific layer with rwkv.cpp?

Update 1: I've corrected some buggy code and now the diff sums of sanity checks don't look insane now :D

The expected_difference_sum values needs to be determined,
after making sure FP32 logits are nearly the same.

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>
Signed-off-by: Molly Sophia <mollysophia379@gmail.com>
Copy link
Collaborator

@saharNooby saharNooby left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suggest some minor code style improvements,

python/convert_pytorch_to_ggml.py Outdated Show resolved Hide resolved
rwkv_graph.inc Outdated Show resolved Hide resolved
rwkv_model_loading.inc Outdated Show resolved Hide resolved
rwkv_operators_wkv_v6.inc Outdated Show resolved Hide resolved
rwkv_operators_wkv_v6.inc Outdated Show resolved Hide resolved
rwkv_operators_wkv_v6.inc Outdated Show resolved Hide resolved
@saharNooby saharNooby linked an issue Jun 24, 2024 that may be closed by this pull request
@saharNooby
Copy link
Collaborator

saharNooby commented Jun 24, 2024

Great work!

There are couple of things remaining:

  • new tiny-rwkv models need to be added to test_quantization_format_compatibility.c -- you've already added Q5_1 and Q5_0 models, need to just register them
  • README.md should mention that we now support v6 (probably we should just additionally mention v6 wherever v5 is mentioned)
  • macOS build is failing with some ggml assertion

I will be able to approve the PR, but I'm not sure I can merge it. 4 months ago I've given the repository to the RWKV Foundation and stepped down as a maintainer. Last time I checked, @LaylBongers was the new maintainer. In any case, I think the PR should be reviewed and merged by a new maintainer, whoever they are.

@MollySophia
Copy link
Contributor Author

Great work!

There are couple of things remaining:

  • new tiny-rwkv models need to be added to test_quantization_format_compatibility.c -- you've already added Q5_1 and Q5_0 models, need to just register them
  • README.md should mention that we now support v6 (probably we should just additionally mention v6 wherever v5 is mentioned)
  • macOS build is failing with some ggml assertion

I will be able to approve the PR, but I'm not sure I can merge it. 4 months ago I've given the repository to the RWKV Foundation and stepped down as a maintainer. Last time I checked, @LaylBongers was the new maintainer. In any case, I think the PR should be reviewed and merged by a new maintainer, whoever they are.

Thanks! I’ll do the changes later today.

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>
@MollySophia
Copy link
Contributor Author

@saharNooby Hi! I've applied some changes mentioned above.

  • macOS build is failing with some ggml assertion

Regarding this, the assertion happens at:

static void ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
    const int qk = QK8_0;
    const int nb = n / qk;

    assert(n % qk == 0);
    assert(qk == QK5_0);

    const block_q5_0 * restrict x = vx;
    const block_q8_0 * restrict y = vy;

#if defined(__ARM_NEON)
    float32x4_t sumv0 = vdupq_n_f32(0.0f);
    float32x4_t sumv1 = vdupq_n_f32(0.0f);

    uint32_t qh0;
    uint32_t qh1;

    uint64_t tmp0[4];
    uint64_t tmp1[4];

    GGML_ASSERT(nb % 2 == 0); // TODO: handle odd nb      <----
    for (int i = 0; i < nb; i += 2) {

Using lldb, I can see that it happens in FFN: mul_mat(vw, k), where vw has the shape [n_embed, dim_ffn=int((n_embd * 3.5) // 32 * 32)].
Here for the tiny-rwkv v6 I've trained, n_embed = 64, while dim_ffn=224.
const int nb = n / qk = 224 / 32 = 7, thus violating the assertion.

I wonder what's the best solution here? Re-train a tiny-rwkv v6 with n_embed = 128, or fix ggml?

@saharNooby
Copy link
Collaborator

I wonder what's the best solution here?

There is a concern that this assertion also fails for larger RWKV v6 models. If so, it would mean macOS inference is broken. Is it possible for you to verify it? (probably just need to know relevant dims for the models)

If larger models are OK, then I think retraining tiny-rwkv would be quicker. The ggml used in rwkv.cpp is very outdated, and it would be complicated to update it (regardless of whether it is already fixed or you merge a fix into upstream).

@saharNooby
Copy link
Collaborator

Although if this // TODO: handle odd nb is already resolved in upstream ggml, you can try copy-pasting the new ggml_vec_dot_q5_0_q8_0 into our fork of ggml.

@MollySophia
Copy link
Contributor Author

MollySophia commented Jun 24, 2024

There is a concern that this assertion also fails for larger RWKV v6 models. If so, it would mean macOS inference is broken. Is it possible for you to verify it? (probably just need to know relevant dims for the models)

It should be okay if the dims of these operands of matmuls are all even multiples of 32. For ChannelMixing weights, there won't be any problem as long as n_embed >= 128. dim_att is equal to n_embed by default, which seems to be okay for all these existing models for now (n_embed = [512, 1024, 2048, 4096]).
There could potentially have problem though. Fixing our fork of ggml is not really hard either I guess? Just use arm neon for the even multiples of 32 parts, then do the rest using normal operations. This still needs some time though.

Although if this // TODO: handle odd nb is already resolved in upstream ggml, you can try copy-pasting the new ggml_vec_dot_q5_0_q8_0 into our fork of ggml.

Unfortunately, it isn't fixed upstream either :P
https://github.com/ggerganov/ggml/blob/5a4ecabfa5f503937a63acb11c6d008096ce5a1f/src/ggml-quants.c#L4636

@MollySophia
Copy link
Contributor Author

I just tested RWKV v6 1.6B/3B/7B with Q5_0/Q5_1 on my M1 MacBook, and they all generate completions normally without failing the assertion. So I guess I can just re-train a tiny-rwkv with n_embed=128 tomorrow :P Maybe fix ggml someday when an irregular rwkv6 is released.

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>
Signed-off-by: Molly Sophia <mollysophia379@gmail.com>
@LaylBongers
Copy link
Collaborator

I've looked over the code and ran a smoke test. Everything looks good to me. I see you've already included a tiny-rwkv by now for the tests. If there's nothing else left I'll have it merged in.

@MollySophia
Copy link
Contributor Author

I've looked over the code and ran a smoke test. Everything looks good to me. I see you've already included a tiny-rwkv by now for the tests. If there's nothing else left I'll have it merged in.

I guess there's nothing else left for now?

@PicoCreator PicoCreator merged commit 970a813 into RWKV:master Jul 2, 2024
1 check passed
@saharNooby saharNooby mentioned this pull request Jul 5, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Support RWKV v6
4 participants