Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

llama : support RWKV v6 models #8980

Merged
merged 53 commits into from
Sep 1, 2024
Merged

Conversation

MollySophia
Copy link
Contributor

@MollySophia MollySophia commented Aug 11, 2024

This should fix #846.

Added:

ggml:

llama.cpp:

  • rwkv_world tokenizer support (by @LaylBongers)
  • convert_hf_to_gguf.py support for converting RWKV v6 HF models
  • RWKV v6 graph building

TODO:

@github-actions github-actions bot added python python script changes ggml changes relating to the ggml tensor library for machine learning labels Aug 11, 2024
@compilade compilade self-requested a review August 11, 2024 02:30
Copy link
Collaborator

@compilade compilade left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A few things I've noticed. I'll review this more deeply in the next days.

src/llama.cpp Outdated Show resolved Hide resolved
src/llama.cpp Outdated Show resolved Hide resolved
src/llama.cpp Outdated Show resolved Hide resolved
convert_hf_to_gguf.py Outdated Show resolved Hide resolved
convert_hf_to_gguf.py Outdated Show resolved Hide resolved
@MollySophia MollySophia force-pushed the for-upstream branch 2 times, most recently from 487fb6d to 9bf958f Compare August 11, 2024 04:11
convert_hf_to_gguf.py Outdated Show resolved Hide resolved
convert_hf_to_gguf.py Outdated Show resolved Hide resolved
convert_hf_to_gguf.py Outdated Show resolved Hide resolved
convert_hf_to_gguf.py Outdated Show resolved Hide resolved
src/llama.cpp Outdated Show resolved Hide resolved
@MollySophia MollySophia force-pushed the for-upstream branch 3 times, most recently from ecf84ca to e7d35a3 Compare August 13, 2024 09:20
src/llama.cpp Outdated Show resolved Hide resolved
@MollySophia
Copy link
Contributor Author

MollySophia commented Aug 23, 2024

Synchronized the changes and made it working again after #8526 being merged.
This PR should be ready for review again now :D
@compilade Could you take a look when convenient?

Copy link
Collaborator

@compilade compilade left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm impressed that ggml_rwkv_wkv only takes around 2% of the CPU time during inference of the 1.6B RWKV-v6 model (when measured with perf record --call-graph=lbr).

I have some styling comments, some suggestions, and I also found some problems.

src/llama.cpp Outdated Show resolved Hide resolved
src/llama.cpp Outdated Show resolved Hide resolved
src/llama.cpp Outdated Show resolved Hide resolved
src/llama.cpp Show resolved Hide resolved
src/llama.cpp Outdated Show resolved Hide resolved
src/llama.cpp Show resolved Hide resolved
convert_hf_to_gguf.py Show resolved Hide resolved
convert_hf_to_gguf.py Outdated Show resolved Hide resolved
src/llama.cpp Outdated Show resolved Hide resolved
src/llama.cpp Show resolved Hide resolved
@MollySophia
Copy link
Contributor Author

I'm impressed that ggml_rwkv_wkv only takes around 2% of the CPU time during inference of the 1.6B RWKV-v6 model (when measured with perf record --call-graph=lbr).

I have some styling comments, some suggestions, and I also found some problems.

Indeed. I did consider writing a metal kernel for wkv, but it turned out that wkv kernels didn't eat much cpu time.
I've also tried modfying current rwkv_wkv impl with GGML_SIMD macros, but the speed was almost the same. (Clang already did optimizations like vectorization, so writing manually may not be that necessary)

@MollySophia MollySophia force-pushed the for-upstream branch 2 times, most recently from 8e2e9aa to a8db247 Compare August 25, 2024 09:36
src/llama.cpp Outdated Show resolved Hide resolved
src/llama.cpp Outdated Show resolved Hide resolved
src/llama.cpp Outdated Show resolved Hide resolved
MollySophia and others added 2 commits August 28, 2024 10:20
Signed-off-by: Molly Sophia <mollysophia379@gmail.com>
MollySophia and others added 11 commits August 28, 2024 10:22
Signed-off-by: Molly Sophia <mollysophia379@gmail.com>
Co-authored-by: compilade <git@compilade.net>
Co-authored-by: compilade <git@compilade.net>
Signed-off-by: Molly Sophia <mollysophia379@gmail.com>
Signed-off-by: Molly Sophia <mollysophia379@gmail.com>
Signed-off-by: Molly Sophia <mollysophia379@gmail.com>
Signed-off-by: Molly Sophia <mollysophia379@gmail.com>
…t tensors

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>
Signed-off-by: Molly Sophia <mollysophia379@gmail.com>
Signed-off-by: Molly Sophia <mollysophia379@gmail.com>
Signed-off-by: Molly Sophia <mollysophia379@gmail.com>
Currently att.key/receptance/value/gate/output, ffn.receptance/key/value, as well as head.weight

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>
@ggerganov
Copy link
Owner

Lets look to merge soon. @MollySophia Which HF model do you recommend to run a few tests with this branch?

@MollySophia
Copy link
Contributor Author

Lets look to merge soon. @MollySophia Which HF model do you recommend to run a few tests with this branch?

https://huggingface.co/RWKV/v6-Finch-1B6-HF should be enough for testing the functionalities.
https://huggingface.co/RWKV/v6-Finch-7B-HF/tree/main or the 3B one should be working too

@ggerganov
Copy link
Owner

I've updated the tokenizer to use a true for string search (7004323). With this change the time for tokenizing wiki.test dropped from 27s to 40ms on my Mac.

@ggerganov ggerganov requested a review from compilade August 30, 2024 10:31
Copy link
Collaborator

@compilade compilade left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BTW What's next for this PR?

@MollySophia It looks ready for me, at least. Nice work!

There's some potential division by zero with hparams.rescale_every_n_layers which I think should be fixed before merging.

Improvements to ggml_rwkv_wkv (if relevant) can be done later in a follow-up PR, so I think this will be ready to merge.

ggml/src/ggml.c Show resolved Hide resolved
ggml/src/ggml.c Outdated Show resolved Hide resolved
src/llama.cpp Outdated Show resolved Hide resolved
MollySophia and others added 2 commits August 31, 2024 11:59
Co-authored-by: compilade <git@compilade.net>
Signed-off-by: Molly Sophia <mollysophia379@gmail.com>
@ggerganov ggerganov merged commit 8f1d81a into ggerganov:master Sep 1, 2024
54 checks passed
dsx1986 pushed a commit to dsx1986/llama.cpp that referenced this pull request Oct 29, 2024
* convert_hf_to_gguf: Add support for RWKV v6

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>

* Add RWKV tokenization

* Fix build

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>

* Do not use special tokens when matching in RWKV tokenizer

* Fix model loading

* Add (broken) placeholder graph builder for RWKV

* Add workaround for kv cache

* Add logits conversion to rwkv5

* Add rwkv5 layer norms

* Add time mix KVRG & correct merge mistake

* Add remaining time mix parameters

* Add time mix output loading

* Add placeholder llm_build_time_mix

* Fix build

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>

* Load more tensors for rwkv v6

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>

* Fix rwkv tokenizer

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>

* ggml: Add unary operator Exp

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>

* RWKV v6 graph building

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>

* Add ``rescale_every_n_layers`` parameter

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>

* Add ``wkv.head_size`` key for RWKV

so it doesn't reuse Mamba ssm parameters

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>

* Fix offloading layers to CUDA

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>

* Fix parallel inferencing for RWKV

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>

* Remove trailing whitespaces

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>

* build_rwkv: Avoid using inplace operations

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>

* convert_hf_to_gguf: rwkv: Avoid using ``eval``

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>

* convert_hf_to_gguf: rwkv tokenizer: Don't escape sequences manually

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>

* Update convert_hf_to_gguf.py

Co-authored-by: compilade <git@compilade.net>

* ggml: Add backward computation for unary op ``exp``

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>

* Update convert_hf_to_gguf.py

Co-authored-by: compilade <git@compilade.net>

* Update convert_hf_to_gguf.py

Co-authored-by: compilade <git@compilade.net>

* Use MODEL_ARCH.RWKV6 instead of MODEL_ARCH.RWKV

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>

* build_rwkv6: Simplify graph

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>

* llama: rwkv6: Detect model.type

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>

* llama: rwkv6: Fix tensor loading for 7B/14B models

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>

* llama: rwkv6: Fix group_norm assertion failure with Metal

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>

* llama: rwkv6: Clean up

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>

* llama: rwkv6: Add quantization tensor exclusion

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>

* llama: rwkv6: Use the new advanced batch splits

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>

* Update src/llama.cpp

Co-authored-by: compilade <git@compilade.net>

* llama: rwkv6: Use ``ggml_norm`` instead of ``ggml_group_norm``

Co-authored-by: compilade <git@compilade.net>

* llama: rwkv6: Apply code style and misc changes

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>

* converter: Use class name ``Rwkv6Model``

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>

* llama: rwkv6: Make use of key ``feed_forward_length``

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>

* llama: rwkv6: Add kv ``time_mix_extra_dim`` and ``time_decay_extra_dim``

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>

* converter: Match ``new_name`` instead of ``name`` for float32 explicit tensors

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>

* llama: rwkv6: Keep ``time_mix_w1/w2`` as F32

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>

* llama: rwkv6: Remove unused nodes

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>

* llama: rwkv6: Apply code format changes

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>

* llama: rwkv6: Add lora for some supported tensors

Currently att.key/receptance/value/gate/output, ffn.receptance/key/value, as well as head.weight

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>

* rwkv : speed-up tokenization using trie

* minor : style + indentation

* llama: rwkv6: Avoid division by zero

Co-authored-by: compilade <git@compilade.net>

* ggml: rwkv_wkv: Avoid copying the state

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>

---------

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>
Co-authored-by: Layl Bongers <3094382+LaylBongers@users.noreply.github.com>
Co-authored-by: compilade <git@compilade.net>
Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
arthw pushed a commit to arthw/llama.cpp that referenced this pull request Nov 15, 2024
* convert_hf_to_gguf: Add support for RWKV v6

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>

* Add RWKV tokenization

* Fix build

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>

* Do not use special tokens when matching in RWKV tokenizer

* Fix model loading

* Add (broken) placeholder graph builder for RWKV

* Add workaround for kv cache

* Add logits conversion to rwkv5

* Add rwkv5 layer norms

* Add time mix KVRG & correct merge mistake

* Add remaining time mix parameters

* Add time mix output loading

* Add placeholder llm_build_time_mix

* Fix build

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>

* Load more tensors for rwkv v6

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>

* Fix rwkv tokenizer

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>

* ggml: Add unary operator Exp

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>

* RWKV v6 graph building

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>

* Add ``rescale_every_n_layers`` parameter

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>

* Add ``wkv.head_size`` key for RWKV

so it doesn't reuse Mamba ssm parameters

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>

* Fix offloading layers to CUDA

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>

* Fix parallel inferencing for RWKV

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>

* Remove trailing whitespaces

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>

* build_rwkv: Avoid using inplace operations

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>

* convert_hf_to_gguf: rwkv: Avoid using ``eval``

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>

* convert_hf_to_gguf: rwkv tokenizer: Don't escape sequences manually

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>

* Update convert_hf_to_gguf.py

Co-authored-by: compilade <git@compilade.net>

* ggml: Add backward computation for unary op ``exp``

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>

* Update convert_hf_to_gguf.py

Co-authored-by: compilade <git@compilade.net>

* Update convert_hf_to_gguf.py

Co-authored-by: compilade <git@compilade.net>

* Use MODEL_ARCH.RWKV6 instead of MODEL_ARCH.RWKV

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>

* build_rwkv6: Simplify graph

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>

* llama: rwkv6: Detect model.type

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>

* llama: rwkv6: Fix tensor loading for 7B/14B models

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>

* llama: rwkv6: Fix group_norm assertion failure with Metal

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>

* llama: rwkv6: Clean up

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>

* llama: rwkv6: Add quantization tensor exclusion

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>

* llama: rwkv6: Use the new advanced batch splits

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>

* Update src/llama.cpp

Co-authored-by: compilade <git@compilade.net>

* llama: rwkv6: Use ``ggml_norm`` instead of ``ggml_group_norm``

Co-authored-by: compilade <git@compilade.net>

* llama: rwkv6: Apply code style and misc changes

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>

* converter: Use class name ``Rwkv6Model``

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>

* llama: rwkv6: Make use of key ``feed_forward_length``

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>

* llama: rwkv6: Add kv ``time_mix_extra_dim`` and ``time_decay_extra_dim``

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>

* converter: Match ``new_name`` instead of ``name`` for float32 explicit tensors

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>

* llama: rwkv6: Keep ``time_mix_w1/w2`` as F32

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>

* llama: rwkv6: Remove unused nodes

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>

* llama: rwkv6: Apply code format changes

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>

* llama: rwkv6: Add lora for some supported tensors

Currently att.key/receptance/value/gate/output, ffn.receptance/key/value, as well as head.weight

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>

* rwkv : speed-up tokenization using trie

* minor : style + indentation

* llama: rwkv6: Avoid division by zero

Co-authored-by: compilade <git@compilade.net>

* ggml: rwkv_wkv: Avoid copying the state

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>

---------

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>
Co-authored-by: Layl Bongers <3094382+LaylBongers@users.noreply.github.com>
Co-authored-by: compilade <git@compilade.net>
Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
arthw pushed a commit to arthw/llama.cpp that referenced this pull request Nov 18, 2024
* convert_hf_to_gguf: Add support for RWKV v6

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>

* Add RWKV tokenization

* Fix build

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>

* Do not use special tokens when matching in RWKV tokenizer

* Fix model loading

* Add (broken) placeholder graph builder for RWKV

* Add workaround for kv cache

* Add logits conversion to rwkv5

* Add rwkv5 layer norms

* Add time mix KVRG & correct merge mistake

* Add remaining time mix parameters

* Add time mix output loading

* Add placeholder llm_build_time_mix

* Fix build

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>

* Load more tensors for rwkv v6

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>

* Fix rwkv tokenizer

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>

* ggml: Add unary operator Exp

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>

* RWKV v6 graph building

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>

* Add ``rescale_every_n_layers`` parameter

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>

* Add ``wkv.head_size`` key for RWKV

so it doesn't reuse Mamba ssm parameters

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>

* Fix offloading layers to CUDA

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>

* Fix parallel inferencing for RWKV

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>

* Remove trailing whitespaces

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>

* build_rwkv: Avoid using inplace operations

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>

* convert_hf_to_gguf: rwkv: Avoid using ``eval``

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>

* convert_hf_to_gguf: rwkv tokenizer: Don't escape sequences manually

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>

* Update convert_hf_to_gguf.py

Co-authored-by: compilade <git@compilade.net>

* ggml: Add backward computation for unary op ``exp``

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>

* Update convert_hf_to_gguf.py

Co-authored-by: compilade <git@compilade.net>

* Update convert_hf_to_gguf.py

Co-authored-by: compilade <git@compilade.net>

* Use MODEL_ARCH.RWKV6 instead of MODEL_ARCH.RWKV

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>

* build_rwkv6: Simplify graph

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>

* llama: rwkv6: Detect model.type

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>

* llama: rwkv6: Fix tensor loading for 7B/14B models

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>

* llama: rwkv6: Fix group_norm assertion failure with Metal

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>

* llama: rwkv6: Clean up

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>

* llama: rwkv6: Add quantization tensor exclusion

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>

* llama: rwkv6: Use the new advanced batch splits

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>

* Update src/llama.cpp

Co-authored-by: compilade <git@compilade.net>

* llama: rwkv6: Use ``ggml_norm`` instead of ``ggml_group_norm``

Co-authored-by: compilade <git@compilade.net>

* llama: rwkv6: Apply code style and misc changes

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>

* converter: Use class name ``Rwkv6Model``

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>

* llama: rwkv6: Make use of key ``feed_forward_length``

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>

* llama: rwkv6: Add kv ``time_mix_extra_dim`` and ``time_decay_extra_dim``

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>

* converter: Match ``new_name`` instead of ``name`` for float32 explicit tensors

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>

* llama: rwkv6: Keep ``time_mix_w1/w2`` as F32

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>

* llama: rwkv6: Remove unused nodes

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>

* llama: rwkv6: Apply code format changes

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>

* llama: rwkv6: Add lora for some supported tensors

Currently att.key/receptance/value/gate/output, ffn.receptance/key/value, as well as head.weight

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>

* rwkv : speed-up tokenization using trie

* minor : style + indentation

* llama: rwkv6: Avoid division by zero

Co-authored-by: compilade <git@compilade.net>

* ggml: rwkv_wkv: Avoid copying the state

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>

---------

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>
Co-authored-by: Layl Bongers <3094382+LaylBongers@users.noreply.github.com>
Co-authored-by: compilade <git@compilade.net>
Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ggml changes relating to the ggml tensor library for machine learning python python script changes
Projects
None yet
Development

Successfully merging this pull request may close these issues.

llama : add RWKV models support
5 participants