Skip to content

Commit

Permalink
[Feature][Kernel] Support bitsandbytes quantization and QLoRA (vllm-p…
Browse files Browse the repository at this point in the history
  • Loading branch information
chenqianfzh authored Jun 1, 2024
1 parent 37464a0 commit b9c0605
Show file tree
Hide file tree
Showing 11 changed files with 752 additions and 8 deletions.
140 changes: 140 additions & 0 deletions examples/lora_with_quantization_inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
"""
This example shows how to use LoRA with different quantization techniques
for offline inference.
Requires HuggingFace credentials for access.
"""

import gc
from typing import List, Optional, Tuple

import torch
from huggingface_hub import snapshot_download

from vllm import EngineArgs, LLMEngine, RequestOutput, SamplingParams
from vllm.lora.request import LoRARequest


def create_test_prompts(
lora_path: str
) -> List[Tuple[str, SamplingParams, Optional[LoRARequest]]]:
return [
# this is an example of using quantization without LoRA
("My name is",
SamplingParams(temperature=0.0,
logprobs=1,
prompt_logprobs=1,
max_tokens=128), None),
# the next three examples use quantization with LoRA
("my name is",
SamplingParams(temperature=0.0,
logprobs=1,
prompt_logprobs=1,
max_tokens=128),
LoRARequest("lora-test-1", 1, lora_path)),
("The capital of USA is",
SamplingParams(temperature=0.0,
logprobs=1,
prompt_logprobs=1,
max_tokens=128),
LoRARequest("lora-test-2", 1, lora_path)),
("The capital of France is",
SamplingParams(temperature=0.0,
logprobs=1,
prompt_logprobs=1,
max_tokens=128),
LoRARequest("lora-test-3", 1, lora_path)),
]


def process_requests(engine: LLMEngine,
test_prompts: List[Tuple[str, SamplingParams,
Optional[LoRARequest]]]):
"""Continuously process a list of prompts and handle the outputs."""
request_id = 0

while test_prompts or engine.has_unfinished_requests():
if test_prompts:
prompt, sampling_params, lora_request = test_prompts.pop(0)
engine.add_request(str(request_id),
prompt,
sampling_params,
lora_request=lora_request)
request_id += 1

request_outputs: List[RequestOutput] = engine.step()
for request_output in request_outputs:
if request_output.finished:
print("----------------------------------------------------")
print(f"Prompt: {request_output.prompt}")
print(f"Output: {request_output.outputs[0].text}")


def initialize_engine(model: str, quantization: str,
lora_repo: Optional[str]) -> LLMEngine:
"""Initialize the LLMEngine."""

if quantization == "bitsandbytes":
# QLoRA (https://arxiv.org/abs/2305.14314) is a quantization technique.
# It quantizes the model when loading, with some config info from the
# LoRA adapter repo. So need to set the parameter of load_format and
# qlora_adapter_name_or_path as below.
engine_args = EngineArgs(
model=model,
quantization=quantization,
qlora_adapter_name_or_path=lora_repo,
load_format="bitsandbytes",
enable_lora=True,
max_lora_rank=64,
# set it only in GPUs of limited memory
enforce_eager=True)
else:
engine_args = EngineArgs(
model=model,
quantization=quantization,
enable_lora=True,
max_loras=4,
# set it only in GPUs of limited memory
enforce_eager=True)
return LLMEngine.from_engine_args(engine_args)


def main():
"""Main function that sets up and runs the prompt processing."""

test_configs = [{
"name": "qlora_inference_example",
'model': "huggyllama/llama-7b",
'quantization': "bitsandbytes",
'lora_repo': 'timdettmers/qlora-flan-7b'
}, {
"name": "AWQ_inference_with_lora_example",
'model': 'TheBloke/TinyLlama-1.1B-Chat-v0.3-AWQ',
'quantization': "awq",
'lora_repo': 'jashing/tinyllama-colorist-lora'
}, {
"name": "GPTQ_inference_with_lora_example",
'model': 'TheBloke/TinyLlama-1.1B-Chat-v0.3-GPTQ',
'quantization': "gptq",
'lora_repo': 'jashing/tinyllama-colorist-lora'
}]

for test_config in test_configs:
print(
f"~~~~~~~~~~~~~~~~ Running: {test_config['name']} ~~~~~~~~~~~~~~~~"
)
engine = initialize_engine(test_config['model'],
test_config['quantization'],
test_config['lora_repo'])
lora_path = snapshot_download(repo_id=test_config['lora_repo'])
test_prompts = create_test_prompts(lora_path)
process_requests(engine, test_prompts)

# Clean up the GPU memory for the next test
del engine
gc.collect()
torch.cuda.empty_cache()


if __name__ == '__main__':
main()
3 changes: 3 additions & 0 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -35,3 +35,6 @@ aiohttp

# Multimodal
pillow

# quantization
bitsandbytes==0.42.0
80 changes: 80 additions & 0 deletions tests/quantization/test_bitsandbytes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
'''Tests whether bitsandbytes computation is enabled correctly.
Run `pytest tests/quantization/test_bitsandbytes.py`.
'''
import pytest
import torch

from vllm import SamplingParams
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS

capability = torch.cuda.get_device_capability()
capability = capability[0] * 10 + capability[1]


@pytest.mark.skipif(
capability < QUANTIZATION_METHODS['bitsandbytes'].get_min_capability(),
reason='bitsandbytes is not supported on this GPU type.')
def test_load_bnb_model(vllm_runner) -> None:
llm = vllm_runner('huggyllama/llama-7b',
quantization='bitsandbytes',
load_format='bitsandbytes',
enforce_eager=True)

model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model

# check the weights in MLP & SelfAttention are quantized to torch.uint8
qweight = model.model.layers[0].mlp.gate_up_proj.qweight
assert qweight.dtype == torch.uint8, (
f'Expected gate_up_proj dtype torch.uint8 but got {qweight.dtype}')

qweight = model.model.layers[0].mlp.down_proj.qweight
assert qweight.dtype == torch.uint8, (
f'Expected down_proj dtype torch.uint8 but got {qweight.dtype}')

qweight = model.model.layers[0].self_attn.o_proj.qweight
assert qweight.dtype == torch.uint8, (
f'Expected o_proj dtype torch.uint8 but got {qweight.dtype}')

qweight = model.model.layers[0].self_attn.qkv_proj.qweight
assert qweight.dtype == torch.uint8, (
f'Expected qkv_proj dtype torch.uint8 but got {qweight.dtype}')

# some weights should not be quantized
weight = model.lm_head.weight
assert weight.dtype != torch.uint8, (
'lm_head weight dtype should not be torch.uint8')

weight = model.model.embed_tokens.weight
assert weight.dtype != torch.uint8, (
'embed_tokens weight dtype should not be torch.uint8')

weight = model.model.layers[0].input_layernorm.weight
assert weight.dtype != torch.uint8, (
'input_layernorm weight dtype should not be torch.uint8')

weight = model.model.layers[0].post_attention_layernorm.weight
assert weight.dtype != torch.uint8, (
'input_layernorm weight dtype should not be torch.uint8')

# check the output of the model is expected
sampling_params = SamplingParams(temperature=0.0,
logprobs=1,
prompt_logprobs=1,
max_tokens=8)

prompts = ['That which does not kill us', 'To be or not to be,']
expected_outputs = [
'That which does not kill us makes us stronger.',
'To be or not to be, that is the question.'
]
outputs = llm.generate(prompts, sampling_params=sampling_params)

assert len(outputs) == len(prompts)

for index in range(len(outputs)):
# compare the first line of the output
actual_output = outputs[index][1][0].split('\n', 1)[0]
expected_output = expected_outputs[index].split('\n', 1)[0]
assert actual_output == expected_output, (
f'Expected: {expected_output}, but got: {actual_output}')
9 changes: 8 additions & 1 deletion vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,12 @@ def verify_with_parallel_config(
"must be divisible by pipeline parallel size "
f"({pipeline_parallel_size}).")

if self.quantization == "bitsandbytes" and (
parallel_config.tensor_parallel_size > 1
or parallel_config.pipeline_parallel_size > 1):
raise ValueError(
"BitAndBytes quantization with TP or PP is not supported yet.")

def get_hf_config_sliding_window(self) -> Optional[int]:
"""Get the sliding window size, or None if disabled.
"""
Expand Down Expand Up @@ -327,7 +333,7 @@ def get_num_kv_heads(self, parallel_config: "ParallelConfig") -> int:
def get_num_attention_heads(self,
parallel_config: "ParallelConfig") -> int:
return self.hf_text_config.num_attention_heads // \
parallel_config.tensor_parallel_size
parallel_config.tensor_parallel_size

def get_num_layers(self, parallel_config: "ParallelConfig") -> int:
total_num_hidden_layers = self.hf_text_config.num_hidden_layers
Expand Down Expand Up @@ -487,6 +493,7 @@ class LoadFormat(str, enum.Enum):
DUMMY = "dummy"
TENSORIZER = "tensorizer"
SHARDED_STATE = "sharded_state"
BITSANDBYTES = "bitsandbytes"


@dataclass
Expand Down
38 changes: 35 additions & 3 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,8 @@ class EngineArgs:
ngram_prompt_lookup_max: Optional[int] = None
ngram_prompt_lookup_min: Optional[int] = None

qlora_adapter_name_or_path: Optional[str] = None

def __post_init__(self):
if self.tokenizer is None:
self.tokenizer = self.model
Expand Down Expand Up @@ -159,7 +161,8 @@ def add_cli_args(
type=str,
default=EngineArgs.load_format,
choices=[
'auto', 'pt', 'safetensors', 'npcache', 'dummy', 'tensorizer'
'auto', 'pt', 'safetensors', 'npcache', 'dummy', 'tensorizer',
'bitsandbytes'
],
help='The format of the model weights to load.\n\n'
'* "auto" will try to load the weights in the safetensors format '
Expand All @@ -173,7 +176,9 @@ def add_cli_args(
'which is mainly for profiling.\n'
'* "tensorizer" will load the weights using tensorizer from '
'CoreWeave. See the Tensorize vLLM Model script in the Examples'
'section for more information.\n')
'section for more information.\n'
'* "bitsandbytes" will load the weights using bitsandbytes '
'quantization.\n')
parser.add_argument(
'--dtype',
type=str,
Expand Down Expand Up @@ -543,7 +548,10 @@ def add_cli_args(
"will also be used in `model_name` tag content of "
"prometheus metrics, if multiple names provided, metrics"
"tag will take the first one.")

parser.add_argument('--qlora-adapter-name-or-path',
type=str,
default=None,
help='Name or path of the QLoRA adapter.')
return parser

@classmethod
Expand All @@ -555,6 +563,23 @@ def from_cli_args(cls, args: argparse.Namespace):
return engine_args

def create_engine_config(self, ) -> EngineConfig:

# bitsandbytes quantization needs a specific model loader
# so we make sure the quant method and the load format are consistent
if (self.quantization == "bitsandbytes" or
self.qlora_adapter_name_or_path is not None) and \
self.load_format != "bitsandbytes":
raise ValueError(
"BitsAndBytes quantization and QLoRA adapter only support "
f"'bitsandbytes' load format, but got {self.load_format}")

if (self.load_format == "bitsandbytes" or
self.qlora_adapter_name_or_path is not None) and \
self.quantization != "bitsandbytes":
raise ValueError(
"BitsAndBytes load format and QLoRA adapter only support "
f"'bitsandbytes' quantization, but got {self.quantization}")

device_config = DeviceConfig(self.device)
model_config = ModelConfig(
self.model, self.tokenizer, self.tokenizer_mode,
Expand Down Expand Up @@ -622,6 +647,13 @@ def create_engine_config(self, ) -> EngineConfig:
max_cpu_loras=self.max_cpu_loras if self.max_cpu_loras
and self.max_cpu_loras > 0 else None) if self.enable_lora else None

if self.qlora_adapter_name_or_path is not None and \
self.qlora_adapter_name_or_path != "":
if self.model_loader_extra_config is None:
self.model_loader_extra_config = {}
self.model_loader_extra_config[
"qlora_adapter_name_or_path"] = self.qlora_adapter_name_or_path

load_config = LoadConfig(
load_format=self.load_format,
download_dir=self.download_dir,
Expand Down
Loading

0 comments on commit b9c0605

Please sign in to comment.