Skip to content

Commit

Permalink
[Gemma] Gemma model support
Browse files Browse the repository at this point in the history
This PR brings the Gemma model support.
Right now it supports `q0f16`, `q0f32`, `q4f16_1` quantization
modes for both 7B and 2B variants in MLC Chat.

We are testing unquantized Gemma for MLC Serve.
Changes will be submitted if there is any.

---

Co-authored-by: Rick Zhou <rickzhoucmu@gmail.com>
Co-authored-by: Charlie Ruan <53290280+CharlieFRuan@users.noreply.github.com>
  • Loading branch information
3 people committed Feb 21, 2024
1 parent c81ef04 commit 5cc8e8a
Show file tree
Hide file tree
Showing 11 changed files with 626 additions and 6 deletions.
23 changes: 22 additions & 1 deletion cpp/conv_templates.cc
Original file line number Diff line number Diff line change
Expand Up @@ -700,6 +700,25 @@ Conversation StableLM2() {
return conv;
}

Conversation GemmaInstruction() {
Conversation conv;
conv.name = "gemma_instruction";
conv.system = "";
conv.roles = {"<start_of_turn>user", "<start_of_turn>model"};
conv.messages = {};
conv.offset = 0;
conv.separator_style = SeparatorStyle::kSepRoleMsg;
conv.seps = {"<end_of_turn>\n"};
conv.role_msg_sep = "\n";
conv.role_empty_sep = "\n";
// TODO(mlc-team): add eos to mlc-chat-config
// and remove eos from stop token setting.
conv.stop_tokens = {1, 107}; // <eos> and <end_of_turn>
conv.stop_str = "<end_of_turn>";
conv.add_bos = true;
return conv;
}

} // namespace

using ConvFactory = Conversation (*)();
Expand Down Expand Up @@ -738,7 +757,9 @@ Conversation Conversation::FromTemplate(const std::string& name) {
{"phi-2", Phi2},
{"qwen", ChatML},
{"stablelm-2", StableLM2},
{"baichuan", ChatML}};
{"baichuan", ChatML},
{"gemma_instruction", GemmaInstruction},
};
auto it = factory.find(name);
if (it == factory.end()) {
LOG(FATAL) << "Unknown conversation template: " << name;
Expand Down
2 changes: 2 additions & 0 deletions python/mlc_chat/chat_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,8 @@ class ChatConfig(ConfigBase): # pylint: disable=too-many-instance-attributes
local_id: Optional[str] = None
conv_template: Optional[str] = None
temperature: Optional[float] = None
presence_penalty: Optional[float] = 0.0
frequency_penalty: Optional[float] = 0.0
repetition_penalty: Optional[float] = None
top_p: Optional[float] = None
mean_gen_len: Optional[int] = None
Expand Down
5 changes: 5 additions & 0 deletions python/mlc_chat/interface/gen_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ class MLCChatConfig: # pylint: disable=too-many-instance-attributes
shift_fill_factor: float = None
# Configuration of text generation
temperature: float = None
presence_penalty: float = None
frequency_penalty: float = None
repetition_penalty: float = None
top_p: float = None
# Conversation template
Expand All @@ -58,6 +60,8 @@ def apply_defaults(self) -> None:
"bos_token_id": 1,
"eos_token_id": 2,
"temperature": 0.7,
"presence_penalty": 0.0,
"frequency_penalty": 0.0,
"repetition_penalty": 1.0,
"top_p": 0.95,
"mean_gen_len": 128,
Expand Down Expand Up @@ -224,4 +228,5 @@ def gen_config( # pylint: disable=too-many-locals,too-many-arguments,too-many-b
"custom", # for web-llm only
"phi-2",
"stablelm-2",
"gemma_instruction",
}
Empty file.
121 changes: 121 additions & 0 deletions python/mlc_chat/model/gemma/gemma_loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
"""
This file specifies how MLC's Gemma parameter maps from other formats, for example HuggingFace
PyTorch, HuggingFace safetensors.
"""

import functools

import numpy as np

from mlc_chat.loader import ExternMapping
from mlc_chat.quantization import Quantization

from .gemma_model import GemmaConfig, GemmaForCausalLM


def huggingface(model_config: GemmaConfig, quantization: Quantization) -> ExternMapping:
"""Returns a parameter mapping that maps from the names of MLC LLM parameters to
the names of HuggingFace PyTorch parameters.
Parameters
----------
model_config : GemmaConfig
The configuration of the Gemma model.
quantization : Quantization
The quantization configuration.
Returns
-------
param_map : ExternMapping
The parameter mapping from MLC to HuggingFace PyTorch.
"""
model = GemmaForCausalLM(model_config)
if quantization is not None:
model.to(quantization.model_dtype)
_, _named_params, _ = model.export_tvm( # type: ignore[misc]
spec=model.get_default_spec(),
allow_extern=True,
)
named_parameters = dict(_named_params)

mapping = ExternMapping()

for i in range(model_config.num_hidden_layers):
# Add QKV in self attention
attn = f"model.layers.{i}.self_attn"
mlc_name = f"{attn}.qkv_proj.weight"
mlc_param = named_parameters[mlc_name]
mapping.add_mapping(
mlc_name,
[
f"{attn}.q_proj.weight",
f"{attn}.k_proj.weight",
f"{attn}.v_proj.weight",
],
functools.partial(
lambda q, k, v, dtype: np.concatenate([q, k, v], axis=0).astype(dtype),
dtype=mlc_param.dtype,
),
)
# Add gates in MLP
mlp = f"model.layers.{i}.mlp"
mlc_name = f"{mlp}.gate_up_proj.weight"
mlc_param = named_parameters[mlc_name]
mapping.add_mapping(
mlc_name,
[
f"{mlp}.gate_proj.weight",
f"{mlp}.up_proj.weight",
],
functools.partial(
lambda gate, up, dtype: np.concatenate([gate, up], axis=0).astype(dtype),
dtype=mlc_param.dtype,
),
)
# Modify RMS layernorm weights, since Gemma model adds 1 to the weights
# We add 1 to the weights here for efficiency purpose
mlc_name = f"model.layers.{i}.input_layernorm.weight"
mlc_param = named_parameters[mlc_name]
mapping.add_mapping(
mlc_name,
[mlc_name],
functools.partial(
lambda x, dtype: (x + 1).astype(dtype),
dtype=named_parameters[mlc_name].dtype,
),
)

mlc_name = f"model.layers.{i}.post_attention_layernorm.weight"
mlc_param = named_parameters[mlc_name]
mapping.add_mapping(
mlc_name,
[mlc_name],
functools.partial(
lambda x, dtype: (x + 1).astype(dtype),
dtype=named_parameters[mlc_name].dtype,
),
)

mlc_name = "model.norm.weight"
mlc_param = named_parameters[mlc_name]
mapping.add_mapping(
mlc_name,
[mlc_name],
functools.partial(
lambda x, dtype: (x + 1).astype(dtype),
dtype=named_parameters[mlc_name].dtype,
),
)

for mlc_name, mlc_param in named_parameters.items():
if mlc_name not in mapping.param_map:
mapping.add_mapping(
mlc_name,
[mlc_name],
functools.partial(
lambda x, dtype: x.astype(dtype),
dtype=mlc_param.dtype,
),
)
return mapping
Loading

0 comments on commit 5cc8e8a

Please sign in to comment.