-
-
Notifications
You must be signed in to change notification settings - Fork 5.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Kernel][Core] Add AWQ support to the Marlin kernel #6612
[Kernel][Core] Add AWQ support to the Marlin kernel #6612
Conversation
👋 Hi! Thank you for contributing to the vLLM project. Once the PR is approved and ready to go, please make sure to run full CI as it is required to merge (or just use auto-merge). To run full CI, you can do one of these:
🚀 |
/ready |
CALL_IF(4) | ||
CALL_IF(4) | ||
CALL_IF(8) | ||
CALL_IF(8) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why are these doubled?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch, removed duplicates
namespace marlin { | ||
|
||
template <int const num_threads, int const num_bits> | ||
__global__ void awq_marlin_repack_kernel( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How is this different from GPTQ? It looks similar to me at a glance
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The internal unpacking is different: AWQ packs over columns, while GPTQ over rows, and also AWQ performs the interleaving of groups of 8 (for 4-bit) or groups of 4 (for 8-bit) to be compatible to the de-quantization PTX assembly.
@mgoin @robertgshaw2-neuralmagic added bfloat16 support |
As I told @robertgshaw2-neuralmagic on Discord, this kind of speedup warrants a backport to AutoAWQ. Does it make sense to natively pack the weights in AutoAWQ for the Marlin format and if so, do you have any reference code for this now that zero points are supported? |
Hi @alexm-neuralmagic I conducted a benchmark 5k ShareGPT on LMDeploy and SGLang, and their AWQ performance was surprisingly close, which is incredible!
|
@zhyncs thanks for doing these benchmarks! I would also expect AWQ to be inherently faster than GPTQ because AWQ has no activation order, especially for multi-gpu runs. |
Hi @alexm-neuralmagic Thank you for your reply. I conducted an eval of gsm8k using lm_eval on Llama 3 8B Instruct and the AWQ model, and found that there was a significant decrease in accuracy. Is this expected? My replication steps and results are as follows: python -m vllm.entrypoints.openai.api_server --model meta-llama/Meta-Llama-3-8B-Instruct
python -m vllm.entrypoints.openai.api_server --model casperhansen/llama-3-8b-instruct-awq
lm_eval --model local-completions --tasks gsm8k --num_fewshot 8 --model_args model=meta-llama/Meta-Llama-3-8B-Instruct,base_url=http://127.0.0.1:8000/v1/completions,num_concurrent=128,max_retries=3,tokenized_requests=True
lm_eval --model local-completions --tasks gsm8k --num_fewshot 8 --model_args model=casperhansen/llama-3-8b-instruct-awq,base_url=http://127.0.0.1:8000/v1/completions,num_concurrent=128,max_retries=3,tokenized_requests=True
|
Recently we also merged PR to increase the precision of marlin (uses FP32 full precision global reductions): #6795 I would compare awq vs awq_marlin to see apples-to-apples comparison (and not directly to fp16 to avoid any quantization related errors) |
Hi @alexm-neuralmagic OK. I saw that the PR you mentioned was merged after the latest release. I will try again after the new version is released. Thank you. |
Hi @alexm-neuralmagic I tested vLLM 0.5.4 and noticed the accuracy has worsened. Is this expected? Thanks.
|
@zhyncs thanks for checking. Could you please provide reproduction instructions |
ref #6612 (comment) |
@zhyncs - so the drop is from |
Yes, the previous version already had accuracy issues, which this version was supposed to fix, but it ended up being even worse. |
@zhyncs did you try --quantization awq (to force the original awq kernel) |
This has nothing to do with that, let's just compare it directly to fp16. The current drop in accuracy is unacceptable, it can't be used in online business at all. By the way, LMDeploy is much better than this, you can test it yourself. Thanks. |
The AWQ model is 4 bit quantized, so you should not expect to see the same scores between fp16 and int4 What score does LMDeploy achieve for the AWQ model with GSM? |
Your understanding has a big problem. From beginning to end, there was never any mention or expectation of the same score. I don't know why you have this strange misunderstanding. What is being said here is that the current implementation of AWQ Marlin drops points too severely. It's not that it can't drop points, but at least you must ensure that the accuracy after dropping points is usable. Right now, this situation belongs to an unusable state. |
I get the same accuracy scores when using Client launch command: lm_eval --model local-completions --tasks gsm8k --num_fewshot 8 --model_args model=casperhansen/llama-3-8b-instruct-awq,base_url=http://127.0.0.1:8000/v1/completions,num_concurrent=128,max_retries=3,tokenized_requests=True
|
I fully understand your comparison method and your results. Perhaps we should focus on how to improve the current AWQ's accuracy, what do you think? @robertgshaw2-neuralmagic @alexm-neuralmagic |
Yes - the source of the drop in accuracy here seems to be the quality of the model, not the correctness of the marlin kernel. Neural Magic did not create this model and I do not know anything about how it was created. You can feel free to try to improve the accuracy as you see fit. Neural Magic provides quantized checkpoints on our Hugging Face profile that are compatible with vLLM with replication instructions for creation and evaluation of the models The scope of this work is simply to run any AWQ model as fast as possible with accuracy scores the match the baseline implementations within numerical precision errors, which this PR accomplishes. |
@zhyncs I think you are confusing "AWQ, the quantization algorithm" versus "AWQ, the inference kernel". If you want to reduce the impact of "AWQ, the quantization algorithm" you can produce your own quantized checkpoint using AutoAWQ with more conservative parameters. Looking at the quantization config for that checkpoint, you can see it uses |
This drop in accuracy after quantization looks normal to me. The standard calibration dataset used was not math-related either. |
Also, running the model through huggingface, I get the same scores as we get in vllm:
lm_eval --model hf --model_args pretrained=casperhansen/llama-3-8b-instruct-awq --tasks gsm8k --num_fewshot 8 --batch_size 16
|Tasks|Version| Filter |n-shot| Metric | |Value | |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k| 3|flexible-extract| 8|exact_match|↑ |0.6967|± |0.0127|
| | |strict-match | 8|exact_match|↑ |0.6975|± |0.0127| |
@mgoin You misunderstood, I was saying here that the precision of AWQ in LMDeploy is better, and did not say it ran the casperhansen/llama-3-8b-instruct-awq model. The quantization in LMDeploy is generated through |
@robertgshaw2-neuralmagic Thank you for providing this reference. |
Okay @zhyncs thank you for clarifying that you are talking about a separate checkpoint than what you measured here in vLLM and that LMDeploy quantizes models itself. |
make sense |
Hi @robertgshaw2-neuralmagic After thinking about it, the eval results with hf for the AWQ checkpoint don't clarify anything since the auto_awq kernel is used during lm_eval as well. Therefore, in theory, there should be no difference from AWQ's original implementation in vLLM. |
Signed-off-by: Alvant <alvasian@yandex.ru>
This PR adds end-to-end support for AWQ quantization inside the Marlin kernel.
Here are initial performance results of awq_marlin (this PR) vs awq (on the vllm main) for Llama3-70b AWQ model on 2xA100 GPUs and Llama3-8B AWQ on 1xA100 GPU.
Llama3-70B AWQ on 2xA100 GPUs with prompt = 512 and decode = 256
Llama3-7B AWQ on 1xA100 GPUs with prompt = 1024 and decode = 512
TODOs (may be done after this PR lands):