Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Fix] fix quantization arg when using marlin #3319

Merged
merged 6 commits into from
Mar 13, 2024

Conversation

DreamTeamWangbowen
Copy link
Contributor

@DreamTeamWangbowen DreamTeamWangbowen commented Mar 11, 2024

#3331 fix when using marlin model

@DreamTeamWangbowen DreamTeamWangbowen changed the title when using marlin, fix quantization argument [fix]when using marlin, fix quantization argument Mar 11, 2024
@DreamTeamWangbowen DreamTeamWangbowen changed the title [fix]when using marlin, fix quantization argument [Fix] when using marlin, fix quantization argument Mar 11, 2024
@DreamTeamWangbowen
Copy link
Contributor Author

DreamTeamWangbowen commented Mar 11, 2024

@zhuohan123 @WoosukKwon please, could you help me to merge it?

@DreamTeamWangbowen
Copy link
Contributor Author

After the fix is merged, it can run normally.

@esmeetu
Copy link
Collaborator

esmeetu commented Mar 11, 2024

Thanks for your contribution! @DreamTeamWangbowen Could you fix the code style following CONTRIBUTING.md ?

@DreamTeamWangbowen DreamTeamWangbowen changed the title [Fix] when using marlin, fix quantization argument [Fix] fix quantization arg when using marlin Mar 12, 2024
@DreamTeamWangbowen
Copy link
Contributor Author

Thanks for your contribution! @DreamTeamWangbowen Could you fix the code style following CONTRIBUTING.md ?

Okay, I submit a new issue and fix it.

@esmeetu
Copy link
Collaborator

esmeetu commented Mar 12, 2024

please use sh format.sh to format your code. And then i can merge this.

@WoosukKwon
Copy link
Collaborator

@DreamTeamWangbowen Do we need this btw? IIUC, the Marlin kernel is automatically used for GPTQ models when the condition is met (act_order=False, etc.).

@DreamTeamWangbowen
Copy link
Contributor Author

please use sh format.sh to format your code. And then i can merge this.

I've finished formatting my code.

@DreamTeamWangbowen
Copy link
Contributor Author

DreamTeamWangbowen commented Mar 12, 2024

@DreamTeamWangbowen Do we need this btw? IIUC, the Marlin kernel is automatically used for GPTQ models when the condition is met (act_order=False, etc.).

Yes, we need it, I did not find the act_order parameter in the code and model configuration file.

The model address I use is https://huggingface.co/neuralmagic/Nous-Hermes-2-Yi-34B-marlin

@WoosukKwon
Copy link
Collaborator

Yes, we need it, I did not find the act_order parameter in the code and model configuration file.
The model address I use is https://huggingface.co/neuralmagic/Nous-Hermes-2-Yi-34B-marlin

@DreamTeamWangbowen IIUC, the Marlin kernel should be automatically used (without specifying quantization=marlin):

vllm/vllm/config.py

Lines 174 to 176 in 654865e

if (hf_quant_method == "gptq"
and "is_marlin_format" in hf_quant_config
and hf_quant_config["is_marlin_format"]):

While the condition was not actually about act_order (sorry for the wrong information), the mode configuration file meets the above condition.

@DreamTeamWangbowen
Copy link
Contributor Author

Yes, we need it, I did not find the act_order parameter in the code and model configuration file.
The model address I use is https://huggingface.co/neuralmagic/Nous-Hermes-2-Yi-34B-marlin

@DreamTeamWangbowen IIUC, the Marlin kernel should be automatically used (without specifying quantization=marlin):

vllm/vllm/config.py

Lines 174 to 176 in 654865e

if (hf_quant_method == "gptq"
and "is_marlin_format" in hf_quant_config
and hf_quant_config["is_marlin_format"]):

While the condition was not actually about act_order (sorry for the wrong information), the mode configuration file meets the above condition.

If I do not specify marlin, but specify gptq, the following error will occur

File "/usr/local/lib/python3.10/dist-packages/vllm/config.py", line 130, in init
self._verify_quantization()
File "/usr/local/lib/python3.10/dist-packages/vllm/config.py", line 204, in _verify_quantization
raise ValueError(
ValueError: Quantization method specified in the model config (marlin) does not match the quantization method specified in the quantization argument (gptq).

@DreamTeamWangbowen
Copy link
Contributor Author

DreamTeamWangbowen commented Mar 12, 2024

Yes, we need it, I did not find the act_order parameter in the code and model configuration file.
The model address I use is https://huggingface.co/neuralmagic/Nous-Hermes-2-Yi-34B-marlin

@DreamTeamWangbowen IIUC, the Marlin kernel should be automatically used (without specifying quantization=marlin):

vllm/vllm/config.py

Lines 174 to 176 in 654865e

if (hf_quant_method == "gptq"
and "is_marlin_format" in hf_quant_config
and hf_quant_config["is_marlin_format"]):

While the condition was not actually about act_order (sorry for the wrong information), the mode configuration file meets the above condition.

Look at the code here 178~185, there is a judgment here

vllm/vllm/config.py

Lines 178 to 185 in 654865e

if self.quantization is None:
self.quantization = hf_quant_method
elif self.quantization != hf_quant_method:
raise ValueError(
"Quantization method specified in the model config "
f"({hf_quant_method}) does not match the quantization "
f"method specified in the `quantization` argument "
f"({self.quantization}).")

Therefore we need to add marlin to the quantization argument.

There may be another way to modify it, or add and change self.quantization="gptq" to self.quantization="marlin" on line 177

hf_quant_method = "marlin"

@esmeetu
Copy link
Collaborator

esmeetu commented Mar 12, 2024

@DreamTeamWangbowen I think you can run your server without pass -q or --quantization args since vLLM will detect quant method from config.json.

@DreamTeamWangbowen
Copy link
Contributor Author

@DreamTeamWangbowen I think you can run your server without pass -q or --quantization args since vLLM will detect quant method from config.json.

Yes, but if the quantization parameter is specified as gptq or marlin, an error will occur when the server is running.

@robertgshaw2-redhat
Copy link
Collaborator

robertgshaw2-redhat commented Mar 12, 2024

@DreamTeamWangbowen

Marlin kernels use a special serialization method that are different from exllama. So the model must be saved on disk in Marlin format to be loaded by vLLM. vLLM currently does not support converting formats on the fly (though this is something we are working on). The model must be saved to disk in Marlin format to run in vLLM.

I added the functionality to save models in Marlin format to AutoGPTQ. Here's an example here:

Also here's a model I saved in this format:

I will add a doc to vLLM about this

Note --> Marlin currently requires group_size=128 act_order=False. We are working on expanding this

@WoosukKwon FYI

Passing marlin to this argument will not work

@DreamTeamWangbowen
Copy link
Contributor Author

@DreamTeamWangbowen

Marlin kernels use a special serialization method that are different from exllama. So the model must be saved on disk in Marlin format to be loaded by vLLM. vLLM currently does not support converting formats on the fly (though this is something we are working on). The model must be saved to disk in Marlin format to run in vLLM.

I added the functionality to save models in Marlin format to AutoGPTQ. Here's an example here:

Also here's a model I saved in this format:

I will add a doc to vLLM about this

Note --> Marlin currently requires group_size=128 act_order=False. We are working on expanding this

@WoosukKwon FYI

Passing marlin to this argument will not work

yeath, the model I use is in marlin format https://huggingface.co/neuralmagic/Nous-Hermes-2-Yi-34B-marlin

@robertgshaw2-redhat
Copy link
Collaborator

@DreamTeamWangbowen
Marlin kernels use a special serialization method that are different from exllama. So the model must be saved on disk in Marlin format to be loaded by vLLM. vLLM currently does not support converting formats on the fly (though this is something we are working on). The model must be saved to disk in Marlin format to run in vLLM.
I added the functionality to save models in Marlin format to AutoGPTQ. Here's an example here:

Also here's a model I saved in this format:

I will add a doc to vLLM about this
Note --> Marlin currently requires group_size=128 act_order=False. We are working on expanding this
@WoosukKwon FYI
Passing marlin to this argument will not work

yeath, the model I use is in marlin format https://huggingface.co/neuralmagic/Nous-Hermes-2-Yi-34B-marlin

Nice - vllm will use marlin by default if you pass this model. You do not need to pass the --quantization argument explicitly

@DreamTeamWangbowen
Copy link
Contributor Author

DreamTeamWangbowen commented Mar 13, 2024

@DreamTeamWangbowen
Marlin kernels use a special serialization method that are different from exllama. So the model must be saved on disk in Marlin format to be loaded by vLLM. vLLM currently does not support converting formats on the fly (though this is something we are working on). The model must be saved to disk in Marlin format to run in vLLM.
I added the functionality to save models in Marlin format to AutoGPTQ. Here's an example here:

Also here's a model I saved in this format:

I will add a doc to vLLM about this
Note --> Marlin currently requires group_size=128 act_order=False. We are working on expanding this
@WoosukKwon FYI
Passing marlin to this argument will not work

yeath, the model I use is in marlin format https://huggingface.co/neuralmagic/Nous-Hermes-2-Yi-34B-marlin

Nice - vllm will use marlin by default if you pass this model. You do not need to pass the --quantization argument explicitly

Yes, the marlin model I am using is act_order=False and group_size=128.

But I understand that the -q argument does not tell vllm to specify the type of quantization method used, such as awq, gptq, so I added marlin, for example tell vllm to use MarlinLinearMethod

@WoosukKwon
Copy link
Collaborator

WoosukKwon commented Mar 13, 2024

@DreamTeamWangbowen IIUC, Marlin is not a quantization method. It's a fast kernel implementation for GPTQ.

I've updated the PR so that it fixes the bug that quantization="gptq" raises an error when Marlin is enabled.

@WoosukKwon WoosukKwon merged commit b167109 into vllm-project:main Mar 13, 2024
3 checks passed
@DreamTeamWangbowen
Copy link
Contributor Author

DreamTeamWangbowen commented Mar 13, 2024

@DreamTeamWangbowen IIUC, Marlin is not a quantization method. It's a fast kernel implementation for GPTQ.

I've updated the PR so that it fixes the bug that quantization="gptq" raises an error when Marlin is enabled.

Yeath, you are right, thank you very much. :)

starmpcc pushed a commit to starmpcc/vllm that referenced this pull request Mar 14, 2024
Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
Temirulan pushed a commit to Temirulan/vllm-whisper that referenced this pull request Sep 6, 2024
Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants