-
Notifications
You must be signed in to change notification settings - Fork 1.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add rwkv5 model and fix rwkv4 init prompt bug #1275
Conversation
support rwkv5
CC: @Hzfengsy |
There is a confirmed bug on Cuda devices, due to an all-reduce compilation error. However, the model works well on other platforms. Should we get it in right now, or wait for the cuda fix? |
would be good to confirm the cuda-error and suggest a quick solution, to ensure things work across the board |
I confirmed two bugs on cuda devices. Unfortunately, I have no idea about either of them. cuBLAS BYOC
To reproduce the issue, please follow the instructions. Please make sure the cublas is enabled. git clone git@github.com:GiantPandaCV/mlc-llm.git mlc-llm-rwkv && cd mlc-llm-rwkv
python -m mlc_llm.build --hf-path RWKV/rwkv-5-world-1b5 --target cuda --quantization q0f16 --use-cache=0 --build-model-only Cross Thread Reduction CodegenThe current codegen failed on |
cpp/llm_chat.cc
Outdated
@@ -615,7 +615,7 @@ class LLMChat { | |||
std::vector<int32_t> encoded = this->tokenizer_->Encode(all_prompt); | |||
tokens.insert(tokens.end(), encoded.begin(), encoded.end()); | |||
if (this->sliding_window_ != -1 || // There is no max window size if we use sliding window | |||
this->total_seq_len_ + tokens.size() + gen_mean_gen_len < this->max_window_size_) { | |||
this->total_seq_len_ + (int)tokens.size() + gen_mean_gen_len < this->max_window_size_) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
use static_cast<int64_t>(tokens.size())
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok. I think this is a quite serious bug that has troubled me for more than two weeks. The quantized int4 version of rwkv5 seems to give very unintelligent responses, and it was only today that I thought to print out the prompt. Then I discovered that all the code after line 618 was ineffective, and finally pinpointed this issue. Now the quantized int4 version of rwkv5 can also generate text normally. The performance has also improved in other modes for rwkv.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you elaboratae a bit why static cast int is needed here? do we involve some negative numbers in computing this ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This error might occur whenever the input prompt is relatively long. Before using static_cast<int64_t>
to convert tokens.size()
, the expression (this->total_seq_len_ + tokens.size() + gen_mean_gen_len)
might have experienced integer overflow at some stage, causing the result to be incorrectly interpreted as a negative number, which in turn erroneously returned true
for the comparison operation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK, I feel this is a strange way to think about it, given max_window_size_ == -1
, we should specially check it, and that means there is no out of bound and we do not need to re-encode (aka running the code after), would be good for @Hzfengsy to take a loo as well
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay, there seems to be a bug in the handling of the rwkv system prompts. I expect that each interaction with the rwkv model should include the system prompt along with the current text. This is because its series of models(rwkv4,5,6) have higher requirements for prompts. Currently, only the first round of dialogue includes the system's prompt, and the system prompt is forgotten in subsequent dialogues
I found some var to var bindings in the model, for example, in decode function,
This issue is fixed by apache/tvm#16175, it can compile successfully however such bindings can still break fusion and pattern matching, for better performance, I'd recommend update the model definition to eliminate such bindings. |
q8fp16_1 display: