Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merge llama_config.py into llama_model.py #1189

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 0 additions & 31 deletions python/mlc_chat/compiler/model/llama_config.py

This file was deleted.

3 changes: 1 addition & 2 deletions python/mlc_chat/compiler/model/llama_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,7 @@

from ..loader import ExternMapping
from ..quantization import Quantization
from .llama_config import LlamaConfig
from .llama_model import LlamaForCasualLM
from .llama_model import LlamaConfig, LlamaForCasualLM


def huggingface(model_config: LlamaConfig, quantization: Quantization) -> ExternMapping:
Expand Down
32 changes: 30 additions & 2 deletions python/mlc_chat/compiler/model/llama_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,42 @@
Implementation for Llama2 architecture.
TODO: add docstring
"""
import dataclasses
import math
from typing import Optional
from typing import Any, Dict, Optional

from tvm import te, tir
from tvm.relax.frontend import nn
from tvm.relax.frontend.nn import Tensor, op

from .llama_config import LlamaConfig
from ...support.config import ConfigBase


@dataclasses.dataclass
class LlamaConfig(ConfigBase): # pylint: disable=too-many-instance-attributes
"""Configuration of the Llama model."""

hidden_act: str
hidden_size: int
intermediate_size: int
num_attention_heads: int
num_hidden_layers: int
rms_norm_eps: float
vocab_size: int
max_sequence_length: int = 2048
position_embedding_base: int = 10000
num_key_value_heads: int = 0
head_dim: int = 0
kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict)

def __post_init__(self):
if self.num_key_value_heads == 0:
self.num_key_value_heads = self.num_attention_heads
if self.head_dim == 0:
self.head_dim = self.hidden_size // self.num_attention_heads
assert self.num_attention_heads % self.num_key_value_heads == 0
assert self.head_dim * self.num_attention_heads == self.hidden_size


# pylint: disable=invalid-name,missing-docstring

Expand Down
3 changes: 1 addition & 2 deletions python/mlc_chat/compiler/model/llama_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@

from ..loader import QuantizeMapping
from ..quantization import GroupQuantize
from .llama_config import LlamaConfig
from .llama_model import LlamaForCasualLM
from .llama_model import LlamaConfig, LlamaForCasualLM


def group_quant(
Expand Down
146 changes: 70 additions & 76 deletions python/mlc_chat/compiler/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from ..loader import ExternMapping, QuantizeMapping
from ..quantization.quantization import Quantization
from . import llama_config, llama_loader, llama_model, llama_quantization
from . import llama_loader, llama_model, llama_quantization

ModelConfig = Any
"""A ModelConfig is an object that represents a model architecture. It is required to have
Expand Down Expand Up @@ -54,7 +54,7 @@ class Model:
"llama": Model(
name="llama",
model=llama_model.LlamaForCasualLM,
config=llama_config.LlamaConfig,
config=llama_model.LlamaConfig,
source={
"huggingface-torch": llama_loader.huggingface,
"huggingface-safetensor": llama_loader.huggingface,
Expand All @@ -66,78 +66,72 @@ class Model:
}

MODEL_PRESETS: Dict[str, Any] = {
"llama2_7b": llama_config.LlamaConfig.from_dict(
{
"architectures": ["LlamaForCausalLM"],
"bos_token_id": 1,
"eos_token_id": 2,
"hidden_act": "silu",
"hidden_size": 4096,
"initializer_range": 0.02,
"intermediate_size": 11008,
"max_position_embeddings": 2048,
"model_type": "llama",
"num_attention_heads": 32,
"num_hidden_layers": 32,
"num_key_value_heads": 32,
"pad_token_id": 0,
"pretraining_tp": 1,
"rms_norm_eps": 1e-05,
"rope_scaling": None,
"tie_word_embeddings": False,
"torch_dtype": "float16",
"transformers_version": "4.31.0.dev0",
"use_cache": True,
"vocab_size": 32000,
}
),
"llama2_13b": llama_config.LlamaConfig.from_dict(
{
"_name_or_path": "meta-llama/Llama-2-13b-hf",
"architectures": ["LlamaForCausalLM"],
"bos_token_id": 1,
"eos_token_id": 2,
"hidden_act": "silu",
"hidden_size": 5120,
"initializer_range": 0.02,
"intermediate_size": 13824,
"max_position_embeddings": 2048,
"model_type": "llama",
"num_attention_heads": 40,
"num_hidden_layers": 40,
"num_key_value_heads": 40,
"pad_token_id": 0,
"pretraining_tp": 2,
"rms_norm_eps": 1e-05,
"rope_scaling": None,
"tie_word_embeddings": False,
"torch_dtype": "float16",
"transformers_version": "4.31.0.dev0",
"use_cache": True,
"vocab_size": 32000,
}
),
"llama2_70b": llama_config.LlamaConfig.from_dict(
{
"architectures": ["LlamaForCausalLM"],
"bos_token_id": 1,
"eos_token_id": 2,
"hidden_act": "silu",
"hidden_size": 8192,
"initializer_range": 0.02,
"intermediate_size": 28672,
"max_position_embeddings": 2048,
"model_type": "llama",
"num_attention_heads": 64,
"num_hidden_layers": 80,
"num_key_value_heads": 8,
"pad_token_id": 0,
"rms_norm_eps": 1e-05,
"tie_word_embeddings": False,
"torch_dtype": "float16",
"transformers_version": "4.31.0.dev0",
"use_cache": True,
"vocab_size": 32000,
}
),
"llama2_7b": {
"architectures": ["LlamaForCausalLM"],
"bos_token_id": 1,
"eos_token_id": 2,
"hidden_act": "silu",
"hidden_size": 4096,
"initializer_range": 0.02,
"intermediate_size": 11008,
"max_position_embeddings": 2048,
"model_type": "llama",
"num_attention_heads": 32,
"num_hidden_layers": 32,
"num_key_value_heads": 32,
"pad_token_id": 0,
"pretraining_tp": 1,
"rms_norm_eps": 1e-05,
"rope_scaling": None,
"tie_word_embeddings": False,
"torch_dtype": "float16",
"transformers_version": "4.31.0.dev0",
"use_cache": True,
"vocab_size": 32000,
},
"llama2_13b": {
"_name_or_path": "meta-llama/Llama-2-13b-hf",
"architectures": ["LlamaForCausalLM"],
"bos_token_id": 1,
"eos_token_id": 2,
"hidden_act": "silu",
"hidden_size": 5120,
"initializer_range": 0.02,
"intermediate_size": 13824,
"max_position_embeddings": 2048,
"model_type": "llama",
"num_attention_heads": 40,
"num_hidden_layers": 40,
"num_key_value_heads": 40,
"pad_token_id": 0,
"pretraining_tp": 2,
"rms_norm_eps": 1e-05,
"rope_scaling": None,
"tie_word_embeddings": False,
"torch_dtype": "float16",
"transformers_version": "4.31.0.dev0",
"use_cache": True,
"vocab_size": 32000,
},
"llama2_70b": {
"architectures": ["LlamaForCausalLM"],
"bos_token_id": 1,
"eos_token_id": 2,
"hidden_act": "silu",
"hidden_size": 8192,
"initializer_range": 0.02,
"intermediate_size": 28672,
"max_position_embeddings": 2048,
"model_type": "llama",
"num_attention_heads": 64,
"num_hidden_layers": 80,
"num_key_value_heads": 8,
"pad_token_id": 0,
"rms_norm_eps": 1e-05,
"tie_word_embeddings": False,
"torch_dtype": "float16",
"transformers_version": "4.31.0.dev0",
"use_cache": True,
"vocab_size": 32000,
},
}
4 changes: 3 additions & 1 deletion tests/python/model/test_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@

@pytest.mark.parametrize("model_name", ["llama2_7b", "llama2_13b", "llama2_70b"])
def test_llama2_creation(model_name: str):
model = MODELS["llama"].model(MODEL_PRESETS[model_name])
model_info = MODELS["llama"]
config = model_info.config.from_dict(MODEL_PRESETS[model_name])
model = model_info.model(config)
mod, named_params = model.export_tvm(
spec=model.get_default_spec(), # type: ignore
)
Expand Down
5 changes: 3 additions & 2 deletions tests/python/model/test_llama_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,9 @@
],
)
def test_llama2_group_quantization(model_name: str, quant_name: str):
config = MODEL_PRESETS[model_name]
model, quant_map = MODELS["llama"].quantize["group-quant"](config, QUANTIZATION[quant_name])
model_info = MODELS["llama"]
config = model_info.config.from_dict(MODEL_PRESETS[model_name])
model, quant_map = model_info.quantize["group-quant"](config, QUANTIZATION[quant_name])
assert "model.embed_tokens.weight" in quant_map.param_map
assert isinstance(
model.model.embed_tokens, # type: ignore[attr-defined]
Expand Down