Skip to content

Commit

Permalink
Added support for bitsandbytes quantization of models (#33)
Browse files Browse the repository at this point in the history
* added support for bitsandbytes quantization of models
  • Loading branch information
AyushSawant18588 authored Jan 5, 2024
1 parent 15debb4 commit 1996c11
Show file tree
Hide file tree
Showing 7 changed files with 67 additions and 6 deletions.
22 changes: 22 additions & 0 deletions llm/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,30 @@ def initialize(self, context):
self.tokenizer.padding_side = "left"
logger.info("Tokenizer loaded successfully")

quantize_bits = None
if os.environ.get("NAI_QUANTIZATION"):
quantize_bits = int(self.get_env_value("NAI_QUANTIZATION"))

if quantize_bits == 4:
bnb_config = transformers.BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=False,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
)
quantization_config = bnb_config
logger.info("Loading Model with %s bit Quantization", quantize_bits)
elif quantize_bits == 8:
bnb_config = transformers.BitsAndBytesConfig(load_in_8bit=True)
quantization_config = bnb_config
logger.info("Loading Model with %s bit Quantization", quantize_bits)
else:
quantization_config = None
logger.info("Loading Model with bfloat16 data type")

self.model = transformers.AutoModelForCausalLM.from_pretrained(
model_dir,
quantization_config=quantization_config,
torch_dtype=torch.bfloat16, # Load model weights in bfloat16
device_map=self.device_map,
local_files_only=True,
Expand Down
2 changes: 1 addition & 1 deletion llm/model_config.json
Original file line number Diff line number Diff line change
Expand Up @@ -77,4 +77,4 @@
"response_timeout": 2000
}
}
}
}
3 changes: 2 additions & 1 deletion llm/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,5 @@ accelerate==0.22.0
nvgpu==0.10.0
torchserve==0.8.2
torch-model-archiver==0.8.1
einops==0.6.1
einops==0.6.1
bitsandbytes==0.41.1
8 changes: 7 additions & 1 deletion llm/run.sh
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,18 @@ helpFunction()
echo -e "\t-v HuggingFace repository version (optional)"
echo -e "\t-d Absolute path of input data folder (optional)"
echo -e "\t-a Absolute path to the Model Store directory"
echo -e "\t-q BitsAndBytes Quantization Precision (4 or 8) (optional)"
exit 1 # Exit script after printing help
}

while getopts ":n:v:d:a:o:r" opt;
while getopts ":n:v:d:q:a:" opt;
do
case "$opt" in
n ) model_name="$OPTARG" ;;
v ) repo_version="$OPTARG" ;;
d ) data="$OPTARG" ;;
a ) model_store="$OPTARG" ;;
q ) quantize_bits="$OPTARG" ;;
? ) helpFunction ;; # Print helpFunction in case parameter is non-existent
esac
done
Expand Down Expand Up @@ -53,6 +55,10 @@ function create_execution_cmd()
if [ ! -z "$data" ] ; then
cmd+=" --data $data"
fi

if [ ! -z "$quantize_bits" ] ; then
cmd+=" --quantize_bits $quantize_bits"
fi
}

function inference_exec_vm(){
Expand Down
7 changes: 7 additions & 0 deletions llm/torchserve_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ def run_inference(params: argparse.Namespace) -> None:
if params.data:
check_if_path_exists(params.data, "Input data folder", is_dir=True)

ts.set_model_precision(params.quantize_bits)
create_folder_if_not_exists(
os.path.join(os.path.dirname(__file__), "utils", params.gen_folder_name)
)
Expand Down Expand Up @@ -219,5 +220,11 @@ def cleanup(gen_folder: str, ts_stop: bool = True, ts_cleanup: bool = True) -> N
metavar="model_store",
help="absolute path to the model store directory",
)
parser.add_argument(
"--quantize_bits",
type=str,
default="",
help="BitsAndBytes Quantization Precision (4 or 8)",
)
args = parser.parse_args()
torchserve_run(args)
8 changes: 5 additions & 3 deletions llm/utils/generate_data_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,9 +122,11 @@ def validate_hf_token(self) -> None:
and self.repo_info.hf_token is None
):
print(
"## Error: HuggingFace Hub token is required for llama download."
" Please specify it using --hf_token=<your token>. "
"Refer https://huggingface.co/docs/hub/security-tokens"
(
"HuggingFace Hub token is required for llama download. "
"Please specify it using --hf_token=<your token> argument "
". Refer https://huggingface.co/docs/hub/security-tokens"
)
)
sys.exit(1)

Expand Down
23 changes: 23 additions & 0 deletions llm/utils/tsutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,12 @@
different operating systems.
"""
import os
import sys
import platform
import time
import json
from typing import Tuple, Dict
import torch
import requests
from utils.inference_data_model import InferenceDataModel, TorchserveStartData
from utils.system_utils import check_if_path_exists
Expand Down Expand Up @@ -198,6 +200,27 @@ def set_model_params(model_name: str) -> None:
del os.environ[param_name]


def set_model_precision(quantize_bits: int) -> None:
"""
This function reads the precision to which the model weights are to be
quantized and sets it as environment variable for the handler
to read.
Args:
quantize_bits (int): BitsAndBytes Quantization Precision.
"""
if quantize_bits and int(quantize_bits) not in [4, 8]:
print(
"## Quantization precision bits should be either 4 or 8. Default precision used is 16"
)
sys.exit(1)
elif quantize_bits and not torch.cuda.is_available():
print("## BitsAndBytes Quantization requires GPUs")
sys.exit(1)
else:
os.environ["NAI_QUANTIZATION"] = quantize_bits


def get_params_for_registration(model_name: str) -> Tuple[str, str, str, str]:
"""
This function reads registration parameters from model_config.json returns them.
Expand Down

0 comments on commit 1996c11

Please sign in to comment.