Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ExllamaV2 Inference Framework Support. #2455

Merged
merged 1 commit into from
Oct 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 61 additions & 0 deletions docs/exllamaV2.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
# ExllamaV2 GPTQ Inference Franework

Integrated [ExllamaV2] (https://github.com/turboderp/exllamav2) customized kernel into Fastchat to provide **Faster** GPTQ inference speed.

**Note: Exllama not yet support embedding REST API.**

## Install ExllamaV2

Setup environment (please refer to [this link](https://github.com/turboderp/exllamav2#how-to) for more details):

```bash
git clone https://github.com/turboderp/exllamav2
cd exllamav2
pip install -e .
```

Chat with the CLI:
```bash
python3 -m fastchat.serve.cli \
--model-path models/vicuna-7B-1.1-GPTQ-4bit-128g \
--enable-exllama
```

Start model worker:
```bash
# Download quantized model from huggingface
# Make sure you have git-lfs installed (https://git-lfs.com)
git lfs install
git clone https://huggingface.co/TheBloke/vicuna-7B-1.1-GPTQ-4bit-128g models/vicuna-7B-1.1-GPTQ-4bit-128g

# Load model with default configuration (max sequence length 4096, no GPU split setting).
python3 -m fastchat.serve.model_worker \
--model-path models/vicuna-7B-1.1-GPTQ-4bit-128g \
--enable-exllama

#Load model with max sequence length 2048, allocate 18 GB to CUDA:0 and 24 GB to CUDA:1.
python3 -m fastchat.serve.model_worker \
--model-path models/vicuna-7B-1.1-GPTQ-4bit-128g \
--enable-exllama \
--exllama-max-seq-len 2048 \
--exllama-gpu-split 18,24
```

## Performance

Reference: https://github.com/turboderp/exllamav2#performance


| Model | Mode | Size | grpsz | act | V1: 3090Ti | V1: 4090 | V2: 3090Ti | V2: 4090 |
|------------|--------------|-------|-------|-----|------------|----------|------------|-------------|
| Llama | GPTQ | 7B | 128 | no | 143 t/s | 173 t/s | 175 t/s | **195** t/s |
| Llama | GPTQ | 13B | 128 | no | 84 t/s | 102 t/s | 105 t/s | **110** t/s |
| Llama | GPTQ | 33B | 128 | yes | 37 t/s | 45 t/s | 45 t/s | **48** t/s |
| OpenLlama | GPTQ | 3B | 128 | yes | 194 t/s | 226 t/s | 295 t/s | **321** t/s |
| CodeLlama | EXL2 4.0 bpw | 34B | - | - | - | - | 42 t/s | **48** t/s |
| Llama2 | EXL2 3.0 bpw | 7B | - | - | - | - | 195 t/s | **224** t/s |
| Llama2 | EXL2 4.0 bpw | 7B | - | - | - | - | 164 t/s | **197** t/s |
| Llama2 | EXL2 5.0 bpw | 7B | - | - | - | - | 144 t/s | **160** t/s |
| Llama2 | EXL2 2.5 bpw | 70B | - | - | - | - | 30 t/s | **35** t/s |
| TinyLlama | EXL2 3.0 bpw | 1.1B | - | - | - | - | 536 t/s | **635** t/s |
| TinyLlama | EXL2 4.0 bpw | 1.1B | - | - | - | - | 509 t/s | **590** t/s |
27 changes: 27 additions & 0 deletions fastchat/model/model_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,14 @@
from fastchat.constants import CPU_ISA
from fastchat.modules.gptq import GptqConfig, load_gptq_quantized
from fastchat.modules.awq import AWQConfig, load_awq_quantized
from fastchat.modules.exllama import ExllamaConfig, load_exllama_model
from fastchat.conversation import Conversation, get_conv_template
from fastchat.model.compression import load_compress_model
from fastchat.model.llama_condense_monkey_patch import replace_llama_with_condense
from fastchat.model.model_chatglm import generate_stream_chatglm
from fastchat.model.model_codet5p import generate_stream_codet5p
from fastchat.model.model_falcon import generate_stream_falcon
from fastchat.model.model_exllama import generate_stream_exllama
from fastchat.model.monkey_patch_non_inplace import (
replace_llama_attn_with_non_inplace_operations,
)
Expand Down Expand Up @@ -155,6 +157,7 @@ def load_model(
cpu_offloading: bool = False,
gptq_config: Optional[GptqConfig] = None,
awq_config: Optional[AWQConfig] = None,
exllama_config: Optional[ExllamaConfig] = None,
revision: str = "main",
debug: bool = False,
):
Expand Down Expand Up @@ -279,6 +282,9 @@ def load_model(
else:
model.to(device)
return model, tokenizer
elif exllama_config:
model, tokenizer = load_exllama_model(model_path, exllama_config)
return model, tokenizer
kwargs["revision"] = revision

if dtype is not None: # Overwrite dtype if it is provided in the arguments.
Expand Down Expand Up @@ -325,13 +331,17 @@ def get_generate_stream_function(model: torch.nn.Module, model_path: str):
is_falcon = "rwforcausallm" in model_type
is_codet5p = "codet5p" in model_type
is_peft = "peft" in model_type
is_exllama = "exllama" in model_type

if is_chatglm:
return generate_stream_chatglm
elif is_falcon:
return generate_stream_falcon
elif is_codet5p:
return generate_stream_codet5p
elif is_exllama:
return generate_stream_exllama

elif peft_share_base_weights and is_peft:
# Return a curried stream function that loads the right adapter
# according to the model_name available in this context. This ensures
Expand Down Expand Up @@ -453,6 +463,23 @@ def add_model_args(parser):
default=-1,
help="Used for AWQ. Groupsize to use for AWQ quantization; default uses full row.",
)
parser.add_argument(
"--enable-exllama",
action="store_true",
help="Used for exllamabv2. Enable exllamaV2 inference framework.",
)
parser.add_argument(
"--exllama-max-seq-len",
type=int,
default=4096,
help="Used for exllamabv2. Max sequence length to use for exllamav2 framework; default 4096 sequence length.",
)
parser.add_argument(
"--exllama-gpu-split",
type=str,
default=None,
help="Used for exllamabv2. Comma-separated list of VRAM (in GB) to use per GPU. Example: 20,7,7",
)


def remove_parent_directory_name(model_path):
Expand Down
76 changes: 76 additions & 0 deletions fastchat/model/model_exllama.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
import sys
import torch
import gc
from typing import Dict


def generate_stream_exllama(
model,
tokenizer,
params: Dict,
device: str,
context_len: int,
stream_interval: int = 2,
judge_sent_end: bool = False,
):
try:
from exllamav2.generator import ExLlamaV2StreamingGenerator, ExLlamaV2Sampler
except ImportError as e:
print(f"Error: Failed to load Exllamav2. {e}")
sys.exit(-1)

prompt = params["prompt"]

generator = ExLlamaV2StreamingGenerator(model.model, model.cache, tokenizer)
settings = ExLlamaV2Sampler.Settings()

settings.temperature = float(params.get("temperature", 0.85))
settings.top_k = int(params.get("top_k", 50))
settings.top_p = float(params.get("top_p", 0.8))
settings.token_repetition_penalty = float(params.get("repetition_penalty", 1.15))
settings.disallow_tokens(generator.tokenizer, [generator.tokenizer.eos_token_id])

max_new_tokens = int(params.get("max_new_tokens", 256))

generator.set_stop_conditions(params.get("stop_token_ids", None) or [])
echo = bool(params.get("echo", True))

input_ids = generator.tokenizer.encode(prompt)
prompt_tokens = input_ids.shape[-1]
generator.begin_stream(input_ids, settings)

generated_tokens = 0
if echo:
output = prompt
else:
output = ""
while True:
chunk, eos, _ = generator.stream()
output += chunk
generated_tokens += 1
if generated_tokens == max_new_tokens:
finish_reason = "length"
break
elif eos:
finish_reason = "length"
break
yield {
"text": output,
"usage": {
"prompt_tokens": prompt_tokens,
"completion_tokens": generated_tokens,
"total_tokens": prompt_tokens + generated_tokens,
},
"finish_reason": None,
}

yield {
"text": output,
"usage": {
"prompt_tokens": prompt_tokens,
"completion_tokens": generated_tokens,
"total_tokens": prompt_tokens + generated_tokens,
},
"finish_reason": finish_reason,
}
gc.collect()
46 changes: 46 additions & 0 deletions fastchat/modules/exllama.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
from dataclasses import dataclass, field
import sys


@dataclass
class ExllamaConfig:
max_seq_len: int
gpu_split: str = None


class ExllamaModel:
def __init__(self, exllama_model, exllama_cache):
self.model = exllama_model
self.cache = exllama_cache
self.config = self.model.config


def load_exllama_model(model_path, exllama_config: ExllamaConfig):
try:
from exllamav2 import (
ExLlamaV2Config,
ExLlamaV2Tokenizer,
ExLlamaV2,
ExLlamaV2Cache,
)
except ImportError as e:
print(f"Error: Failed to load Exllamav2. {e}")
sys.exit(-1)

exllamav2_config = ExLlamaV2Config()
exllamav2_config.model_dir = model_path
exllamav2_config.prepare()
exllamav2_config.max_seq_len = exllama_config.max_seq_len

exllama_model = ExLlamaV2(exllamav2_config)
tokenizer = ExLlamaV2Tokenizer(exllamav2_config)

split = None
if exllama_config.gpu_split:
split = [float(alloc) for alloc in exllama_config.gpu_split.split(",")]
exllama_model.load(split)

exllama_cache = ExLlamaV2Cache(exllama_model)
model = ExllamaModel(exllama_model=exllama_model, exllama_cache=exllama_cache)

return model, tokenizer
10 changes: 9 additions & 1 deletion fastchat/serve/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from fastchat.model.model_adapter import add_model_args
from fastchat.modules.gptq import GptqConfig
from fastchat.modules.awq import AWQConfig
from fastchat.modules.exllama import ExllamaConfig
from fastchat.serve.inference import ChatIO, chat_loop
from fastchat.utils import str_to_torch_dtype

Expand Down Expand Up @@ -195,7 +196,13 @@ def main(args):
)
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus
os.environ["XPU_VISIBLE_DEVICES"] = args.gpus

if args.enable_exllama:
exllama_config = ExllamaConfig(
max_seq_len=args.exllama_max_seq_len,
gpu_split=args.exllama_gpu_split,
)
else:
exllama_config = None
if args.style == "simple":
chatio = SimpleChatIO(args.multiline)
elif args.style == "rich":
Expand Down Expand Up @@ -230,6 +237,7 @@ def main(args):
wbits=args.awq_wbits,
groupsize=args.awq_groupsize,
),
exllama_config=exllama_config,
revision=args.revision,
judge_sent_end=args.judge_sent_end,
debug=args.debug,
Expand Down
3 changes: 3 additions & 0 deletions fastchat/serve/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
)
from fastchat.modules.gptq import GptqConfig
from fastchat.modules.awq import AWQConfig
from fastchat.modules.exllama import ExllamaConfig
from fastchat.utils import is_partial_stop, is_sentence_complete, get_context_length


Expand Down Expand Up @@ -302,6 +303,7 @@ def chat_loop(
chatio: ChatIO,
gptq_config: Optional[GptqConfig] = None,
awq_config: Optional[AWQConfig] = None,
exllama_config: Optional[ExllamaConfig] = None,
revision: str = "main",
judge_sent_end: bool = True,
debug: bool = True,
Expand All @@ -318,6 +320,7 @@ def chat_loop(
cpu_offloading=cpu_offloading,
gptq_config=gptq_config,
awq_config=awq_config,
exllama_config=exllama_config,
revision=revision,
debug=debug,
)
Expand Down
20 changes: 18 additions & 2 deletions fastchat/serve/model_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@
get_context_length,
str_to_torch_dtype,
)
from fastchat.modules.exllama import ExllamaConfig
from fastchat.utils import build_logger, pretty_print_semaphore, get_context_length


worker_id = str(uuid.uuid4())[:8]
Expand Down Expand Up @@ -170,8 +172,12 @@ def get_status(self):

def count_token(self, params):
prompt = params["prompt"]
input_ids = self.tokenizer(prompt).input_ids
input_echo_len = len(input_ids)

try:
input_ids = self.tokenizer(prompt).input_ids
input_echo_len = len(input_ids)
except TypeError:
input_echo_len = self.tokenizer.num_tokens(prompt)

ret = {
"count": input_echo_len,
Expand Down Expand Up @@ -201,6 +207,7 @@ def __init__(
cpu_offloading: bool = False,
gptq_config: Optional[GptqConfig] = None,
awq_config: Optional[AWQConfig] = None,
exllama_config: Optional[ExllamaConfig] = None,
stream_interval: int = 2,
conv_template: Optional[str] = None,
embed_in_truncate: bool = False,
Expand Down Expand Up @@ -228,6 +235,7 @@ def __init__(
cpu_offloading=cpu_offloading,
gptq_config=gptq_config,
awq_config=awq_config,
exllama_config=exllama_config,
)
self.device = device
if self.tokenizer.pad_token == None:
Expand Down Expand Up @@ -514,6 +522,13 @@ def create_model_worker():
wbits=args.awq_wbits,
groupsize=args.awq_groupsize,
)
if args.enable_exllama:
exllama_config = ExllamaConfig(
max_seq_len=args.exllama_max_seq_len,
gpu_split=args.exllama_gpu_split,
)
else:
exllama_config = None

worker = ModelWorker(
args.controller_address,
Expand All @@ -531,6 +546,7 @@ def create_model_worker():
cpu_offloading=args.cpu_offloading,
gptq_config=gptq_config,
awq_config=awq_config,
exllama_config=exllama_config,
stream_interval=args.stream_interval,
conv_template=args.conv_template,
embed_in_truncate=args.embed_in_truncate,
Expand Down
Loading