Skip to content

Commit

Permalink
Merge pull request #543 from Anhforth/master
Browse files Browse the repository at this point in the history
added quantize
  • Loading branch information
ftgreat authored Oct 10, 2023
2 parents 4011bd0 + dade181 commit 44cde68
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 14 deletions.
24 changes: 14 additions & 10 deletions flagai/auto_model/auto_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,10 @@ def __init__(self,
if task_name == "aquila2":
from flagai.model.aquila2.modeling_aquila import AquilaForCausalLM
download_path = os.path.join(model_dir, model_name)


if not torch_dtype and '34b' in model_name.lower():
torch_dtype = torch.bfloat16

if not os.path.exists(download_path):
# Try to download from ModelHub
try:
Expand Down Expand Up @@ -255,24 +258,25 @@ def __init__(self,
for file_to_load in model_files:
if "pytorch_model-0" in file_to_load:
_get_checkpoint_path(download_path, file_to_load,
model_id)

if qlora_dir:
model_id)
if 'quantization_config' in kwargs:
quantization_config = kwargs['quantization_config']
elif qlora_dir:
from transformers import BitsAndBytesConfig
quantization_config=BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch_dtype,
)
else:
quantization_config = None
if inference_mode:
if qlora_dir:
model = AquilaForCausalLM.from_pretrained(download_path,low_cpu_mem_usage=low_cpu_mem_usage, torch_dtype=torch_dtype,
quantization_config=quantization_config)
else:
model = AquilaForCausalLM.from_pretrained(download_path,low_cpu_mem_usage=low_cpu_mem_usage, torch_dtype=torch_dtype,)

model = AquilaForCausalLM.from_pretrained(download_path,low_cpu_mem_usage=low_cpu_mem_usage, torch_dtype=torch_dtype,
quantization_config=quantization_config)
model.eval()
if not qlora_dir:
if not quantization_config:
model.to(device)
if lora_dir:
from flagai.model.tools.peft import PeftModel
Expand Down
8 changes: 4 additions & 4 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

setup(
name="flagai",
version="v1.8.0",
version="v1.8.1",
description="FlagAI aims to help researchers and developers to freely train and test large-scale models for NLP/CV/VL tasks.",
long_description=open("README.md", encoding="utf-8").read(),
long_description_content_type="text/markdown",
Expand All @@ -19,19 +19,19 @@
install_requires=[
'nltk>=3.6.7',
'sentencepiece>=0.1.96',
'boto3==1.17.32',
'boto3>=1.17.32',
'pandas>=1.3.5',
'jieba>=0.42.1',
'scikit-learn>=1.0.2',
'tensorboard>=2.9.0',
'transformers>=4.31.0',
'datasets>=2.0.0',
'setuptools==66.0.0',
'setuptools>=66.0.0',
'protobuf==3.19.6',
'ftfy',
'Pillow>=9.3.0',
'einops>=0.3.0',
'diffusers==0.7.2',
'diffusers>=0.7.2',
'pytorch-lightning>=1.6.5',
'taming-transformers-rom1504==0.0.6',
'rouge-score',
Expand Down

0 comments on commit 44cde68

Please sign in to comment.