Skip to content

Commit

Permalink
[MODEL] Add mllama support (#401)
Browse files Browse the repository at this point in the history
* layer_type support List[str]

* add mllama support

* check layer not MllamaCrossAttentionDecoderLayer

* TODO need image dataset for vision quantization

* Update mllama.py

* comment on mllama repeating 4 layer group structure

---------

Co-authored-by: LRL-ModelCloud <lrl@modelcloud.ai>
Co-authored-by: Qubitium-ModelCloud <qubitium@modelcloud.ai>
  • Loading branch information
3 people authored Sep 26, 2024
1 parent 4b9506f commit 4921d68
Show file tree
Hide file tree
Showing 5 changed files with 34 additions and 3 deletions.
1 change: 1 addition & 0 deletions gptqmodel/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,4 @@
from .starcoder2 import Starcoder2GPTQ
from .xverse import XverseGPTQ
from .yi import YiGPTQ
from .mllama import MLlamaGPTQ
1 change: 1 addition & 0 deletions gptqmodel/models/_const.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def get_device_by_type(type_value: str):
"deepseek_v2",
"exaone",
"grinmoe",
"mllama",
]

EXLLAMA_DEFAULT_MAX_INPUT_LENGTH = 2048
Expand Down
2 changes: 2 additions & 0 deletions gptqmodel/models/auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
from .starcoder2 import Starcoder2GPTQ
from .xverse import XverseGPTQ
from .yi import YiGPTQ
from .mllama import MLlamaGPTQ

MODEL_MAP = {
"bloom": BloomGPTQ,
Expand Down Expand Up @@ -85,6 +86,7 @@
"deepseek_v2": DeepSeekV2GPTQ,
"exaone": ExaoneGPTQ,
"grinmoe": GrinMOEGPTQ,
"mllama": MLlamaGPTQ,
}


Expand Down
11 changes: 8 additions & 3 deletions gptqmodel/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from tqdm import tqdm
from transformers import AutoConfig, AutoModelForCausalLM, PretrainedConfig, PreTrainedModel, PreTrainedTokenizerBase
from transformers.modeling_utils import no_init_weights, shard_checkpoint
from transformers.models.mllama.modeling_mllama import MllamaCrossAttentionDecoderLayer
from transformers.utils.generic import ContextManagers

from ..nn_modules.qlinear.qlinear_qbits import QBitsQuantLinear, qbits_dtype
Expand Down Expand Up @@ -64,7 +65,7 @@ class BaseGPTQModel(nn.Module):
# node holding all the repeating layers
layers_node: str = None
# repeating layer type
layer_type: str = None
layer_type: Union[List[str], str] = None
# for each repeating layer there are multiple modules within each layer
layer_modules: List[List[str]] = None

Expand Down Expand Up @@ -440,6 +441,10 @@ def store_input_hook(_, args, kwargs):
for i in layer_pb:
layer_pb.set_description(f"Quantizing layer {i} of {layer_count - 1}")
layer = layers[i]
if isinstance(layer, MllamaCrossAttentionDecoderLayer):
# TODO FIXME: currently we not support quantizing cross attention layer (pixel_values)
continue

force_layer_back_to_cpu = False
if get_device(layer) == CPU:
move_to(layer, CUDA_0)
Expand Down Expand Up @@ -1356,14 +1361,14 @@ def skip(*args, **kwargs):
max_memory = accelerate.utils.get_balanced_memory(
model=model,
max_memory=max_memory,
no_split_module_classes=[cls.layer_type],
no_split_module_classes=[cls.layer_type] if isinstance(cls.layer_type, str) else cls.layer_type,
low_zero=(device_map == "balanced_low_0"),
)
if not isinstance(device_map, dict):
device_map = accelerate.infer_auto_device_map(
model,
max_memory=max_memory,
no_split_module_classes=[cls.layer_type],
no_split_module_classes=[cls.layer_type] if isinstance(cls.layer_type, str) else cls.layer_type,
)

load_checkpoint_in_model = False
Expand Down
22 changes: 22 additions & 0 deletions gptqmodel/models/mllama.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from .base import BaseGPTQModel

# TODO FIXME: we currently do not support quantizing cross attention layer (pixel_values)
class MLlamaGPTQ(BaseGPTQModel):
# Non-repeating layers at the root level: same level as `layers_node`
# Excluding `layers_node`.
base_modules = ["language_model.model.embed_tokens", "language_model.model.norm"]

# Below describes all the repeating layers in this transformer model
# `model.layers` is a node/module that hold all the repeating layers. The parent node for all n-layers.
layers_node = "language_model.model.layers"
# MLllama has two types of repeating layers. Repeats in groups of 4 layers: 0-2 (first 3 layers) is text layers, 3 (4th) is cross-attention layer for vision
layer_type = ["MllamaSelfAttentionDecoderLayer", "MllamaCrossAttentionDecoderLayer"]
# Inside each `LlamaDecoderLayer` layer are many internal modules
# List them in the order executed in model forward() code
# Many models have same execution order of: attention (q_k_v) projection, attention (output) projection, mlp (n) projections
layer_modules = [
["self_attn.k_proj", "self_attn.v_proj", "self_attn.q_proj"],
["self_attn.o_proj"],
["mlp.up_proj", "mlp.gate_proj"],
["mlp.down_proj"],
]

0 comments on commit 4921d68

Please sign in to comment.