Skip to content

Commit

Permalink
Merge pull request huggingface#15 from PanQiWei/support_moss
Browse files Browse the repository at this point in the history
Support MOSS model
  • Loading branch information
PanQiWei authored Apr 25, 2023
2 parents 73606a3 + 419160b commit 1c98811
Show file tree
Hide file tree
Showing 8 changed files with 29 additions and 9 deletions.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
An easy-to-use model quantization package with user-friendly apis, based on GPTQ algorithm.

## News or Update
- 2023-04-25 - (News&Update) - [MOSS](https://github.com/OpenLMLab/MOSS) is an open-source tool-augmented conversational language model from Fudan University, quantization is now supported in AutoGPTQ.
- 2023-04-23 - (Update) - Support evaluation on multiple (down-stream) tasks such as: language-modeling, text-classification, text-summarization.
- 2023-04-22 - (News) - qwopqwop200's [AutoGPTQ-triton](https://github.com/qwopqwop200/AutoGPTQ-triton) provides faster speed to integrate with quantized model, for everyone who can access to triton, try and enjoy yourself!
- 2023-04-20 - (News) - AutoGPTQ is automatically compatible with Stability-AI's newly released `gpt_neox` type model family [StableLM](https://github.com/Stability-AI/StableLM).
Expand All @@ -25,7 +26,7 @@ pip install .[llama]
```

## Supported Models
Currently, `auto_gptq` supports: `bloom`, `gpt_neox`, `gptj`, `llama` and `opt`; more CausalLMs will come soon!
Currently, `auto_gptq` supports: `bloom`, `gpt_neox`, `gptj`, `llama`, `moss` and `opt`; more CausalLMs will come soon!

## Supported Evaluation Tasks
Currently, `auto_gptq` supports: `LanguageModelingTask`, `SequenceClassificationTask` and `TextSummarizationTask`; more Tasks will come soon!
Expand Down
7 changes: 4 additions & 3 deletions auto_gptq/modeling/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,14 +303,15 @@ def skip(*args, **kwargs):
torch.nn.init.uniform_ = skip
torch.nn.init.normal_ = skip

config = AutoConfig.from_pretrained(pretrained_model_name_or_path)
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, trust_remote_code=True)
if config.model_type not in SUPPORTED_MODELS:
raise TypeError(f"{config.model_type} isn't supported yet.")

# enforce some values despite user specified
model_init_kwargs["device_map"] = None
model_init_kwargs["torch_dtype"] = torch.bfloat16 if bf16 else torch.float16
model_init_kwargs["low_cpu_mem_usage"] = False
model_init_kwargs["trust_remote_code"] = True

model = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path, **model_init_kwargs)
model_config = model.config.to_dict()
Expand All @@ -335,7 +336,7 @@ def from_quantized(
use_safetensors: bool = False
):
"""load quantized model from local disk"""
config = AutoConfig.from_pretrained(save_dir)
config = AutoConfig.from_pretrained(save_dir, trust_remote_code=True)
if config.model_type not in SUPPORTED_MODELS:
raise TypeError(f"{config.model_type} isn't supported yet.")

Expand All @@ -356,7 +357,7 @@ def skip(*args, **kwargs):

transformers.modeling_utils._init_weights = False
torch.set_default_dtype(torch.half)
model = AutoModelForCausalLM.from_config(config)
model = AutoModelForCausalLM.from_config(config, trust_remote_code=True)
torch.set_default_dtype(torch.float)
model = model.eval()
layers = find_layers(model)
Expand Down
2 changes: 1 addition & 1 deletion auto_gptq/modeling/_const.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
CPU = device("cpu")
CUDA = device("cuda:0")

SUPPORTED_MODELS = ["bloom", "gptj", "gpt_neox", "opt"]
SUPPORTED_MODELS = ["bloom", "gptj", "gpt_neox", "opt", "moss"]
if parse_version(transformers_version) >= parse_version("v4.28.0"):
SUPPORTED_MODELS.append("llama")

Expand Down
2 changes: 1 addition & 1 deletion auto_gptq/modeling/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def pack_model(model, quantizers, bits, group_size):


def check_and_get_model_type(model_dir):
config = AutoConfig.from_pretrained(model_dir)
config = AutoConfig.from_pretrained(model_dir, trust_remote_code=True)
if config.model_type not in SUPPORTED_MODELS:
raise TypeError(f"{config.model_type} isn't supported yet.")
model_type = config.model_type
Expand Down
4 changes: 3 additions & 1 deletion auto_gptq/modeling/auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from .gpt_neox import GPTNeoXGPTQForCausalLM
from .gptj import GPTJGPTQForCausalLM
from .llama import LlamaGPTQForCausalLM
from .moss import MOSSGPTQForCausalLM
from .opt import OPTGPTQForCausalLM


Expand All @@ -12,7 +13,8 @@
"gpt_neox": GPTNeoXGPTQForCausalLM,
"gptj": GPTJGPTQForCausalLM,
"llama": LlamaGPTQForCausalLM,
"opt": OPTGPTQForCausalLM
"opt": OPTGPTQForCausalLM,
"moss": MOSSGPTQForCausalLM
}


Expand Down
12 changes: 12 additions & 0 deletions auto_gptq/modeling/moss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from ._base import *


class MOSSGPTQForCausalLM(BaseGPTQForCausalLM):
layers_block_name = "transformer.h"
outside_layer_modules = ["transformer.wte", "transformer.drop", "transformer.ln_f"]
inside_layer_modules = [
["attn.qkv_proj"],
["attn.out_proj"],
["mlp.fc_in"],
["mlp.fc_out"]
]
6 changes: 5 additions & 1 deletion examples/quantization/quant_with_alpaca.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,11 @@ def main():
parser.add_argument("--fast_tokenizer", action="store_true")
args = parser.parse_args()

tokenizer = AutoTokenizer.from_pretrained(args.pretrained_model_dir, use_fast=args.fast_tokenizer)
tokenizer = AutoTokenizer.from_pretrained(
args.pretrained_model_dir,
use_fast=args.fast_tokenizer,
trust_remote_code=True
)
model = AutoGPTQForCausalLM.from_pretrained(
args.pretrained_model_dir,
quantize_config=BaseQuantizeConfig(bits=args.bits, group_size=args.group_size)
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
setup(
name="auto_gptq",
packages=find_packages(),
version="v0.0.3",
version="v0.0.4-dev",
install_requires=requirements,
extras_require=extras_require,
ext_modules=extensions,
Expand Down

0 comments on commit 1c98811

Please sign in to comment.