-
Notifications
You must be signed in to change notification settings - Fork 8
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Update model definition to support Flash-Decoding #177
Conversation
@@ -402,6 +403,10 @@ class BuildArgs: | |||
"action": "store_true", | |||
}, | |||
) | |||
paged_kv_cache_type: str = field( | |||
default="vllm", | |||
metadata={"help": "The type of paged KV cache, either vllm or flash-decoding"}, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This new option makes --use_vllm_attention
obsolete. Since removing it is a breaking change, I'll do that later when I integrate Flash-Decoding into mlc-serve
. @sunggg
c9d7fac
to
b9e41e1
Compare
The repetition penalty (introduced in [CTRL](https://arxiv.org/abs/1909.05858)) can help prevent the LLM from generating repetitive tokens. This PR implements the repetition penalty. Note: Previous the logits softmax is performed on GPU, this PR moves it to CPU to accommodate the repetition penalty.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, in the follow-up PR, would you share some benchmark numbers? Thank you!
This PR integrates Flash-Decoding support from apache/tvm#16474. This is a drop-in replacement for the vLLM kernel. The only difference with the vLLM-based build is the shape of KV cache blocks. In particular, the block size for vLLM is 16 while for Flash-Decoding it is 256.
In addition, it supports decoding with multiple, fixed length queries per request, which is necessary for speculative decoding.
evaluate_multi_query
from #156 can also be used for this purpose, but it supports variable-length queries per request and piggy-backs to the prefill attention, which is not efficient when the number of queries is fixed and small. The changes inrun_llama_batched_vllm.py
demonstrates that the new Relax function,decode_multi_query
, can do exactly the same thing asevaluate_multi_query
when the query length is fixed.This PR only updates the model definition and
run_llama_batched_vllm.py
example script. I'll follow up with the integration intomlc-serve
next.Need the latest https://github.com/octoml/tvm/tree/for-mlc-serve-jan12
@sunggg @yelite @vinx13