Skip to content

Commit

Permalink
Merge llama_config.CONFIG into MODEL_PRESETS (mlc-ai#1188)
Browse files Browse the repository at this point in the history
  • Loading branch information
junrushao authored Nov 4, 2023
1 parent 4716704 commit 9d20575
Show file tree
Hide file tree
Showing 4 changed files with 105 additions and 93 deletions.
77 changes: 0 additions & 77 deletions python/mlc_chat/compiler/model/llama_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,80 +29,3 @@ def __post_init__(self):
self.head_dim = self.hidden_size // self.num_attention_heads
assert self.num_attention_heads % self.num_key_value_heads == 0
assert self.head_dim * self.num_attention_heads == self.hidden_size

@staticmethod
def from_predefined(name: str) -> "LlamaConfig":
"""Create a LlamaConfig from a predefined configuration."""
return LlamaConfig.from_dict(CONFIG[name])


CONFIG = {
"llama2_7b": {
"architectures": ["LlamaForCausalLM"],
"bos_token_id": 1,
"eos_token_id": 2,
"hidden_act": "silu",
"hidden_size": 4096,
"initializer_range": 0.02,
"intermediate_size": 11008,
"max_position_embeddings": 2048,
"model_type": "llama",
"num_attention_heads": 32,
"num_hidden_layers": 32,
"num_key_value_heads": 32,
"pad_token_id": 0,
"pretraining_tp": 1,
"rms_norm_eps": 1e-05,
"rope_scaling": None,
"tie_word_embeddings": False,
"torch_dtype": "float16",
"transformers_version": "4.31.0.dev0",
"use_cache": True,
"vocab_size": 32000,
},
"llama2_13b": {
"_name_or_path": "meta-llama/Llama-2-13b-hf",
"architectures": ["LlamaForCausalLM"],
"bos_token_id": 1,
"eos_token_id": 2,
"hidden_act": "silu",
"hidden_size": 5120,
"initializer_range": 0.02,
"intermediate_size": 13824,
"max_position_embeddings": 2048,
"model_type": "llama",
"num_attention_heads": 40,
"num_hidden_layers": 40,
"num_key_value_heads": 40,
"pad_token_id": 0,
"pretraining_tp": 2,
"rms_norm_eps": 1e-05,
"rope_scaling": None,
"tie_word_embeddings": False,
"torch_dtype": "float16",
"transformers_version": "4.31.0.dev0",
"use_cache": True,
"vocab_size": 32000,
},
"llama2_70b": {
"architectures": ["LlamaForCausalLM"],
"bos_token_id": 1,
"eos_token_id": 2,
"hidden_act": "silu",
"hidden_size": 8192,
"initializer_range": 0.02,
"intermediate_size": 28672,
"max_position_embeddings": 2048,
"model_type": "llama",
"num_attention_heads": 64,
"num_hidden_layers": 80,
"num_key_value_heads": 8,
"pad_token_id": 0,
"rms_norm_eps": 1e-05,
"tie_word_embeddings": False,
"torch_dtype": "float16",
"transformers_version": "4.31.0.dev0",
"use_cache": True,
"vocab_size": 32000,
},
}
77 changes: 76 additions & 1 deletion python/mlc_chat/compiler/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,4 +65,79 @@ class Model:
)
}

MODEL_PRESETS: Dict[str, Dict[str, Any]] = llama_config.CONFIG
MODEL_PRESETS: Dict[str, Any] = {
"llama2_7b": llama_config.LlamaConfig.from_dict(
{
"architectures": ["LlamaForCausalLM"],
"bos_token_id": 1,
"eos_token_id": 2,
"hidden_act": "silu",
"hidden_size": 4096,
"initializer_range": 0.02,
"intermediate_size": 11008,
"max_position_embeddings": 2048,
"model_type": "llama",
"num_attention_heads": 32,
"num_hidden_layers": 32,
"num_key_value_heads": 32,
"pad_token_id": 0,
"pretraining_tp": 1,
"rms_norm_eps": 1e-05,
"rope_scaling": None,
"tie_word_embeddings": False,
"torch_dtype": "float16",
"transformers_version": "4.31.0.dev0",
"use_cache": True,
"vocab_size": 32000,
}
),
"llama2_13b": llama_config.LlamaConfig.from_dict(
{
"_name_or_path": "meta-llama/Llama-2-13b-hf",
"architectures": ["LlamaForCausalLM"],
"bos_token_id": 1,
"eos_token_id": 2,
"hidden_act": "silu",
"hidden_size": 5120,
"initializer_range": 0.02,
"intermediate_size": 13824,
"max_position_embeddings": 2048,
"model_type": "llama",
"num_attention_heads": 40,
"num_hidden_layers": 40,
"num_key_value_heads": 40,
"pad_token_id": 0,
"pretraining_tp": 2,
"rms_norm_eps": 1e-05,
"rope_scaling": None,
"tie_word_embeddings": False,
"torch_dtype": "float16",
"transformers_version": "4.31.0.dev0",
"use_cache": True,
"vocab_size": 32000,
}
),
"llama2_70b": llama_config.LlamaConfig.from_dict(
{
"architectures": ["LlamaForCausalLM"],
"bos_token_id": 1,
"eos_token_id": 2,
"hidden_act": "silu",
"hidden_size": 8192,
"initializer_range": 0.02,
"intermediate_size": 28672,
"max_position_embeddings": 2048,
"model_type": "llama",
"num_attention_heads": 64,
"num_hidden_layers": 80,
"num_key_value_heads": 8,
"pad_token_id": 0,
"rms_norm_eps": 1e-05,
"tie_word_embeddings": False,
"torch_dtype": "float16",
"transformers_version": "4.31.0.dev0",
"use_cache": True,
"vocab_size": 32000,
}
),
}
10 changes: 5 additions & 5 deletions tests/python/model/test_llama.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
# pylint: disable=invalid-name,missing-docstring
import pytest

from mlc_chat.compiler import MODELS
from mlc_chat.compiler import MODEL_PRESETS, MODELS


@pytest.mark.parametrize("model_name", ["llama2_7b", "llama2_13b", "llama2_70b"])
def test_llama2_creation(model_name: str):
model_info = MODELS["llama"]
config = model_info.config.from_predefined(model_name)
model = model_info.model(config)
mod, named_params = model.export_tvm(spec=model.get_default_spec())
model = MODELS["llama"].model(MODEL_PRESETS[model_name])
mod, named_params = model.export_tvm(
spec=model.get_default_spec(), # type: ignore
)
mod.show(black_format=False)
for name, param in named_params:
print(name, param.shape, param.dtype)
Expand Down
34 changes: 24 additions & 10 deletions tests/python/model/test_llama_quantization.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# pylint: disable=invalid-name,missing-docstring
import pytest

from mlc_chat.compiler import MODELS, QUANTIZATION
from mlc_chat.compiler import MODEL_PRESETS, MODELS, QUANTIZATION
from mlc_chat.compiler.quantization.group_quantization import (
GroupQuantizeEmbedding,
GroupQuantizeLinear,
Expand All @@ -18,22 +18,36 @@
],
)
def test_llama2_group_quantization(model_name: str, quant_name: str):
model_info = MODELS["llama"]
config = model_info.config.from_predefined(model_name)
model, quant_map = model_info.quantize["group-quant"](config, QUANTIZATION[quant_name])
config = MODEL_PRESETS[model_name]
model, quant_map = MODELS["llama"].quantize["group-quant"](config, QUANTIZATION[quant_name])
assert "model.embed_tokens.weight" in quant_map.param_map
assert isinstance(model.model.embed_tokens, GroupQuantizeEmbedding)
assert isinstance(
model.model.embed_tokens, # type: ignore[attr-defined]
GroupQuantizeEmbedding,
)
assert "lm_head.weight" in quant_map.param_map
assert isinstance(model.lm_head, GroupQuantizeLinear)
assert isinstance(model.lm_head, GroupQuantizeLinear) # type: ignore[attr-defined]
for i in range(config.num_hidden_layers):
assert f"model.layers.{i}.self_attn.qkv_proj.weight" in quant_map.param_map
assert isinstance(model.model.layers[i].self_attn.qkv_proj, GroupQuantizeMultiLinear)
assert isinstance(
model.model.layers[i].self_attn.qkv_proj, # type: ignore[attr-defined]
GroupQuantizeMultiLinear,
)
assert f"model.layers.{i}.self_attn.o_proj.weight" in quant_map.param_map
assert isinstance(model.model.layers[i].self_attn.o_proj, GroupQuantizeLinear)
assert isinstance(
model.model.layers[i].self_attn.o_proj, # type: ignore[attr-defined]
GroupQuantizeLinear,
)
assert f"model.layers.{i}.mlp.gate_up_proj.weight" in quant_map.param_map
assert isinstance(model.model.layers[i].mlp.gate_up_proj, GroupQuantizeMultiLinear)
assert isinstance(
model.model.layers[i].mlp.gate_up_proj, # type: ignore[attr-defined]
GroupQuantizeMultiLinear,
)
assert f"model.layers.{i}.mlp.down_proj.weight" in quant_map.param_map
assert isinstance(model.model.layers[i].mlp.down_proj, GroupQuantizeLinear)
assert isinstance(
model.model.layers[i].mlp.down_proj, # type: ignore[attr-defined]
GroupQuantizeLinear,
)


if __name__ == "__main__":
Expand Down

0 comments on commit 9d20575

Please sign in to comment.