Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add rwkv5 model and fix rwkv4 init prompt bug #1275

Closed
wants to merge 30 commits into from

Conversation

BBuf
Copy link
Contributor

@BBuf BBuf commented Nov 16, 2023

q8fp16_1 display:

图片

图片

@tqchen tqchen requested a review from Hzfengsy November 16, 2023 13:46
@junrushao
Copy link
Member

CC: @Hzfengsy

@Hzfengsy
Copy link
Member

There is a confirmed bug on Cuda devices, due to an all-reduce compilation error. However, the model works well on other platforms. Should we get it in right now, or wait for the cuda fix?

@tqchen
Copy link
Contributor

tqchen commented Nov 18, 2023

would be good to confirm the cuda-error and suggest a quick solution, to ensure things work across the board

@Hzfengsy
Copy link
Member

I confirmed two bugs on cuda devices. Unfortunately, I have no idea about either of them.

cuBLAS BYOC

  File "tvm/src/relax/transform/fuse_ops.cc", line 882
InternalError: Check failed: (depgroup != cur_group) is false: A cyclic dependency detected between the groups lv2261 and lv2260 are in.

To reproduce the issue, please follow the instructions. Please make sure the cublas is enabled.

git clone git@github.com:GiantPandaCV/mlc-llm.git mlc-llm-rwkv && cd mlc-llm-rwkv
python -m mlc_llm.build --hf-path RWKV/rwkv-5-world-1b5  --target cuda --quantization q0f16 --use-cache=0 --build-model-only

Cross Thread Reduction Codegen

The current codegen failed on group_norm, and the minimal reproduce script is https://gist.github.com/Hzfengsy/74366154dfe51ea640de8b5bbe41ea4a

cpp/llm_chat.cc Outdated
@@ -615,7 +615,7 @@ class LLMChat {
std::vector<int32_t> encoded = this->tokenizer_->Encode(all_prompt);
tokens.insert(tokens.end(), encoded.begin(), encoded.end());
if (this->sliding_window_ != -1 || // There is no max window size if we use sliding window
this->total_seq_len_ + tokens.size() + gen_mean_gen_len < this->max_window_size_) {
this->total_seq_len_ + (int)tokens.size() + gen_mean_gen_len < this->max_window_size_) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use static_cast<int64_t>(tokens.size())

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok. I think this is a quite serious bug that has troubled me for more than two weeks. The quantized int4 version of rwkv5 seems to give very unintelligent responses, and it was only today that I thought to print out the prompt. Then I discovered that all the code after line 618 was ineffective, and finally pinpointed this issue. Now the quantized int4 version of rwkv5 can also generate text normally. The performance has also improved in other modes for rwkv.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you elaboratae a bit why static cast int is needed here? do we involve some negative numbers in computing this ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

图片

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This error might occur whenever the input prompt is relatively long. Before using static_cast<int64_t> to convert tokens.size(), the expression (this->total_seq_len_ + tokens.size() + gen_mean_gen_len) might have experienced integer overflow at some stage, causing the result to be incorrectly interpreted as a negative number, which in turn erroneously returned true for the comparison operation.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, I feel this is a strange way to think about it, given max_window_size_ == -1, we should specially check it, and that means there is no out of bound and we do not need to re-encode (aka running the code after), would be good for @Hzfengsy to take a loo as well

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, there seems to be a bug in the handling of the rwkv system prompts. I expect that each interaction with the rwkv model should include the system prompt along with the current text. This is because its series of models(rwkv4,5,6) have higher requirements for prompts. Currently, only the first round of dialogue includes the system's prompt, and the system prompt is forgotten in subsequent dialogues

@BBuf
Copy link
Contributor Author

BBuf commented Nov 27, 2023

71eefb6 fix before

图片

fix after

图片

@Hzfengsy @tqchen I think it's time to announce proper support for the rwkv model.

@vinx13
Copy link
Member

vinx13 commented Nov 29, 2023

I found some var to var bindings in the model, for example, in decode function,

            lv2259: R.Tensor((1, 2048), dtype="float16") = R.matmul(lv2257, lv2258, out_dtype="void")
            lv2260: R.Tensor((1, 2048), dtype="float16") = lv2259

This issue is fixed by apache/tvm#16175, it can compile successfully however such bindings can still break fusion and pattern matching, for better performance, I'd recommend update the model definition to eliminate such bindings.

@tqchen tqchen closed this Feb 27, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants