Skip to content

Commit

Permalink
[CORE] Add vLLM Backend for FORMAT.GPTQ (ModelCloud#190)
Browse files Browse the repository at this point in the history
* add vllm load support

* add sglang

* fix vllm load model show kv_caches error

* revert sglang

* mod clean up

* Update base.py

* Update base.py

* Update base.py

* Update test_vllm.py

* Update vllm.py

* Update base.py

* Update vllm.py

* add convert_hf_params_to_vllm and clean up

* format code

* mod clean up

* mod clean up

---------

Co-authored-by: Qubitium-ModelCloud <qubitium@modelcloud.ai>
  • Loading branch information
PZS-ModelCloud and Qubitium authored Jul 10, 2024
1 parent a39fa93 commit 40308cd
Show file tree
Hide file tree
Showing 7 changed files with 139 additions and 6 deletions.
2 changes: 1 addition & 1 deletion gptqmodel/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from .gpt_neox import GPTNeoXGPTQ
from .gptj import GPTJGPTQ
from .internlm import InternLMGPTQ
from .internlm2 import InternLM2GPTQ
from .llama import LlamaGPTQ
from .longllama import LongLlamaGPTQ
from .mistral import MistralGPTQ
Expand All @@ -33,4 +34,3 @@
from .starcoder2 import Starcoder2GPTQ
from .xverse import XverseGPTQ
from .yi import YiGPTQ
from .internlm2 import InternLM2GPTQ
2 changes: 1 addition & 1 deletion gptqmodel/models/auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from .gpt_neox import GPTNeoXGPTQ
from .gptj import GPTJGPTQ
from .internlm import InternLMGPTQ
from .internlm2 import InternLM2GPTQ
from .llama import LlamaGPTQ
from .longllama import LongLlamaGPTQ
from .minicpm import MiniCPMGPTQ
Expand All @@ -37,7 +38,6 @@
from .starcoder2 import Starcoder2GPTQ
from .xverse import XverseGPTQ
from .yi import YiGPTQ
from .internlm2 import InternLM2GPTQ

MODEL_MAP = {
"bloom": BloomGPTQ,
Expand Down
29 changes: 27 additions & 2 deletions gptqmodel/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
get_module_by_name_suffix, get_moe_layer_modules, gptqmodel_post_init, make_quant,
move_to, nested_move_to, pack_model, simple_dispatch_model, verify_model_hash,
verify_sharded_model_hashes)
from ..utils.vllm import load_model_by_vllm, vllm_generate
from ..version import __version__
from ._const import CPU, CUDA_0, DEVICE, SUPPORTED_MODELS

Expand Down Expand Up @@ -590,8 +591,12 @@ def forward(self, *args, **kwargs):

def generate(self, **kwargs):
"""shortcut for model.generate"""
with torch.inference_mode(), torch.amp.autocast(device_type=self.device.type):
return self.model.generate(**kwargs)
if hasattr(self.model.config, "model_type") and self.model.config.model_type == "vllm":
with torch.inference_mode():
return vllm_generate(self.model, **kwargs)
else:
with torch.inference_mode(), torch.amp.autocast(device_type=self.device.type):
return self.model.generate(**kwargs)

def prepare_inputs_for_generation(self, *args, **kwargs):
"""shortcut for model.prepare_inputs_for_generation"""
Expand Down Expand Up @@ -944,6 +949,26 @@ def from_quantized(
if not isinstance(quantize_config, QuantizeConfig):
quantize_config = QuantizeConfig.from_quant_config(quantize_config, format)

if backend == BACKEND.VLLM:
if quantize_config.format != FORMAT.GPTQ:
raise ValueError(f"{backend} backend only supports FORMAT.GPTQ: actual = {quantize_config.format}")

model = load_model_by_vllm(
model=model_name_or_path,
trust_remote_code=trust_remote_code,
**kwargs,
)

model.config = model.llm_engine.model_config
model.config.model_type = "vllm"

return cls(
model,
quantized=True,
quantize_config=quantize_config,
qlinear_kernel=None,
)

if quantize_config.format == FORMAT.MARLIN:
# format marlin requires marlin kernel
if backend != BACKEND.MARLIN and backend != BACKEND.AUTO:
Expand Down
2 changes: 1 addition & 1 deletion gptqmodel/models/internlm2.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,4 @@ class InternLM2GPTQ(BaseGPTQModel):

["feed_forward.w1", "feed_forward.w3"],
["feed_forward.w2"],
]
]
2 changes: 1 addition & 1 deletion gptqmodel/utils/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ class BACKEND(Enum):
MARLIN = 6
BITBLAS = 7
QBITS = 8

VLLM = 9

def get_backend(backend: str):
try:
Expand Down
61 changes: 61 additions & 0 deletions gptqmodel/utils/vllm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import logging

try:
from vllm import LLM, SamplingParams
VLLM_AVAILABLE = True
except ImportError:
VLLM_AVAILABLE = False
from typing import Any, Dict

VLLM_INSTALL_HINT = "vLLM not installed. Please install via `pip install -U vllm`."

def convert_hf_params_to_vllm(hf_params: Dict[str, Any]) -> SamplingParams:

params = {
'n': hf_params.get('num_return_sequences', 1),
'repetition_penalty': hf_params.get('repetition_penalty', 1.0),
'temperature': hf_params.get('temperature', 1.0),
'top_k': hf_params.get('top_k', -1),
'top_p': hf_params.get('top_p', 1.0),
'max_tokens': hf_params.get('max_length', 16),
'min_tokens': hf_params.get('min_length', 0),
'early_stopping': hf_params.get('early_stopping', False),
'length_penalty': hf_params.get('length_penalty', 1.0),
'stop_token_ids': [hf_params.get('eos_token_id'), None],
}
return SamplingParams(**params)

def load_model_by_vllm(
model,
**kwargs,
):
if not VLLM_AVAILABLE:
raise ValueError(VLLM_INSTALL_HINT)

model = LLM(
model=model,
**kwargs,
)

return model

def vllm_generate(
model,
**kwargs,
):
if not VLLM_AVAILABLE:
raise ValueError(VLLM_INSTALL_HINT)

prompts = kwargs.pop("prompts", None)
sampling_params = kwargs.pop("sampling_params", None)

if not isinstance(sampling_params, SamplingParams):
hf_params = {key: kwargs[key] for key in [
'num_return_sequences', 'repetition_penalty', 'temperature',
'top_k', 'top_p', 'max_length', 'min_length',
'early_stopping', 'length_penalty', 'eos_token_id'
] if key in kwargs}
sampling_params = convert_hf_params_to_vllm(hf_params)

outputs = model.generate(prompts, sampling_params)
return outputs
47 changes: 47 additions & 0 deletions tests/test_vllm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import os

os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
# -- end do not touch

import unittest # noqa: E402

from gptqmodel import BACKEND, GPTQModel # noqa: E402
from vllm import SamplingParams # noqa: E402


class TestLoadVLLM(unittest.TestCase):
MODEL_ID = "LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit"
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)

def test_load_vllm(self):
model = GPTQModel.from_quantized(
self.MODEL_ID,
device="cuda:0",
backend=BACKEND.VLLM,
)
outputs = model.generate(
prompts=self.prompts,
sampling_params=self.sampling_params,
)
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")

outputs_param = model.generate(
prompts=self.prompts,
temperature=0.8,
top_p=0.95,
)
for output in outputs_param:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")

self.assertTrue(outputs is not None)

0 comments on commit 40308cd

Please sign in to comment.