Skip to content

Commit

Permalink
feat: support 8bit and fp16 for llm pipeline
Browse files Browse the repository at this point in the history
  • Loading branch information
kyriediculous committed Aug 5, 2024
1 parent a11391f commit 87bfe3f
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 44 deletions.
124 changes: 80 additions & 44 deletions runner/app/pipelines/llm_generate.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import asyncio
import logging
import os
import psutil
from typing import Dict, Any, List, Optional, AsyncGenerator, Union

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer, BitsAndBytesConfig
from accelerate import init_empty_weights, load_checkpoint_and_dispatch
from app.pipelines.base import Pipeline
from app.pipelines.utils import get_model_dir, get_torch_device
Expand All @@ -13,15 +14,84 @@

logger = logging.getLogger(__name__)

def get_max_memory():
num_gpus = torch.cuda.device_count()
gpu_memory = {i: f"{torch.cuda.get_device_properties(i).total_memory // 1024**3}GiB" for i in range(num_gpus)}
cpu_memory = f"{psutil.virtual_memory().available // 1024**3}GiB"
max_memory = {**gpu_memory, "cpu": cpu_memory}

logger.info(f"Max memory configuration: {max_memory}")
return max_memory

def load_model_8bit(model_id: str, **kwargs):
max_memory = get_max_memory()

quantization_config = BitsAndBytesConfig(
load_in_8bit=True,
llm_int8_threshold=6.0,
llm_int8_has_fp16_weight=False,
)

tokenizer = AutoTokenizer.from_pretrained(model_id, **kwargs)

model = AutoModelForCausalLM.from_pretrained(
model_id,
quantization_config=quantization_config,
device_map="auto",
max_memory=max_memory,
offload_folder="offload",
low_cpu_mem_usage=True,
**kwargs
)

return tokenizer, model

def load_model_fp16(model_id: str, **kwargs):
device = get_torch_device()
max_memory = get_max_memory()

# Check for fp16 variant
local_model_path = os.path.join(get_model_dir(), file_download.repo_folder_name(repo_id=model_id, repo_type="model"))
has_fp16_variant = any(".fp16.safetensors" in fname for _, _, files in os.walk(local_model_path) for fname in files)

if device != "cpu" and has_fp16_variant:
logger.info("Loading fp16 variant for %s", model_id)
kwargs["torch_dtype"] = torch.float16
kwargs["variant"] = "fp16"
elif device != "cpu":
kwargs["torch_dtype"] = torch.bfloat16

tokenizer = AutoTokenizer.from_pretrained(model_id, **kwargs)

config = AutoModelForCausalLM.from_pretrained(model_id, **kwargs).config

with init_empty_weights():
model = AutoModelForCausalLM.from_config(config)

checkpoint_dir = snapshot_download(model_id, cache_dir=get_model_dir(), local_files_only=True)

model = load_checkpoint_and_dispatch(
model,
checkpoint_dir,
device_map="auto",
max_memory=max_memory,
no_split_module_classes=["LlamaDecoderLayer"], # Adjust based on your model architecture
dtype=kwargs.get("torch_dtype", torch.float32),
offload_folder="offload",
offload_state_dict=True,
)

return tokenizer, model

class LLMGeneratePipeline(Pipeline):
def __init__(self, model_id: str):
self.model_id = model_id
kwargs = {
"cache_dir": get_model_dir(),
"local_files_only": True
"local_files_only": True,
}
self.device = get_torch_device()

# Generate the correct folder name
folder_path = file_download.repo_folder_name(repo_id=model_id, repo_type="model")
self.local_model_path = os.path.join(get_model_dir(), folder_path)
Expand All @@ -30,47 +100,14 @@ def __init__(self, model_id: str):
logger.info(f"Local model path: {self.local_model_path}")
logger.info(f"Directory contents: {os.listdir(self.local_model_path)}")

# Check for fp16 variant
has_fp16_variant = any(
".fp16.safetensors" in fname
for _, _, files in os.walk(self.local_model_path)
for fname in files
)
if self.device != "cpu" and has_fp16_variant:
logger.info("LLMGeneratePipeline loading fp16 variant for %s", model_id)
kwargs["torch_dtype"] = torch.float16
kwargs["variant"] = "fp16"
elif self.device != "cpu":
kwargs["torch_dtype"] = torch.bfloat16

logger.info(f"Loading model {model_id}")
self.tokenizer = AutoTokenizer.from_pretrained(model_id, **kwargs)

# Load the model configuration
config = AutoModelForCausalLM.from_pretrained(model_id, **kwargs).config
use_8bit = os.getenv("USE_8BIT", "").strip().lower() == "true"

# Initialize empty weights
with init_empty_weights():
self.model = AutoModelForCausalLM.from_config(config)

# Prepare for distributed setup
num_gpus = torch.cuda.device_count()
max_memory = {i: f"{torch.cuda.get_device_properties(i).total_memory // 1024**3}GiB" for i in range(num_gpus)}
max_memory["cpu"] = "24GiB" # Adjust based on your system's RAM

logger.info(f"Max memory configuration: {max_memory}")

# Load and dispatch the model
self.model = load_checkpoint_and_dispatch(
self.model,
self.checkpoint_dir,
device_map="auto",
max_memory=max_memory,
no_split_module_classes=["LlamaDecoderLayer"], # Adjust based on your model architecture
dtype=kwargs.get("torch_dtype", torch.float32),
offload_folder="offload", # Optional: specify a folder for offloading
offload_state_dict=True, # Optional: offload state dict to CPU
)
if use_8bit:
logger.info("Using 8-bit quantization")
self.tokenizer, self.model = load_model_8bit(model_id, **kwargs)
else:
logger.info("Using fp16/bf16 precision")
self.tokenizer, self.model = load_model_fp16(model_id, **kwargs)

logger.info(f"Model loaded and distributed. Device map: {self.model.hf_device_map}")

Expand All @@ -91,7 +128,6 @@ def __init__(self, model_id: str):
)
from app.pipelines.optim.sfast import compile_model
self.model = compile_model(self.model)

async def __call__(self, prompt: str, history: Optional[List[tuple]] = None, system_msg: Optional[str] = None, **kwargs) -> AsyncGenerator[Union[str, Dict[str, Any]], None]:
conversation = []
if system_msg:
Expand Down
2 changes: 2 additions & 0 deletions runner/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,5 @@ numpy==1.26.4
av==12.1.0
sentencepiece== 0.2.0
protobuf==5.27.2
bitsandbytes==0.43.3
psutil==6.0.0

0 comments on commit 87bfe3f

Please sign in to comment.