diff --git a/python/mlc_llm/conversation_template/__init__.py b/python/mlc_llm/conversation_template/__init__.py index fb01a1ef83..6f7a462de2 100644 --- a/python/mlc_llm/conversation_template/__init__.py +++ b/python/mlc_llm/conversation_template/__init__.py @@ -7,6 +7,7 @@ # model preset templates from . import ( + cohere, dolly, gemma, glm, diff --git a/python/mlc_llm/conversation_template/cohere.py b/python/mlc_llm/conversation_template/cohere.py new file mode 100644 index 0000000000..a5c2719ec5 --- /dev/null +++ b/python/mlc_llm/conversation_template/cohere.py @@ -0,0 +1,27 @@ +"""Cohere default templates""" +# pylint: disable=line-too-long + +# Referred from: https://huggingface.co/CohereForAI/aya-23-8B/blob/main/tokenizer_config.json + +from mlc_llm.protocol.conversation_protocol import Conversation, MessagePlaceholders + +from .registry import ConvTemplateRegistry + +# Aya-23 +ConvTemplateRegistry.register_conv_template( + Conversation( + name="aya-23", + system_template=f"<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>{MessagePlaceholders.SYSTEM.value}<|END_OF_TURN_TOKEN|>", + system_message="You are Command-R, a brilliant, sophisticated, AI-assistant trained to assist human users by providing thorough responses.", + roles={ + "user": "<|START_OF_TURN_TOKEN|><|USER_TOKEN|>", + "assistant": "<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>", + }, + seps=["<|END_OF_TURN_TOKEN|>"], + role_content_sep="", + role_empty_sep="", + system_prefix_token_ids=[5], + stop_str=["<|END_OF_TURN_TOKEN|>"], + stop_token_ids=[6, 255001], + ) +) diff --git a/python/mlc_llm/interface/gen_config.py b/python/mlc_llm/interface/gen_config.py index ab63d1bcf9..6addd98b3b 100644 --- a/python/mlc_llm/interface/gen_config.py +++ b/python/mlc_llm/interface/gen_config.py @@ -130,7 +130,7 @@ def gen_config( # pylint: disable=too-many-locals,too-many-arguments,too-many-b prefill_chunk_size=model_config.prefill_chunk_size, attention_sink_size=getattr(model_config, "attention_sink_size", -1), tensor_parallel_shards=model_config.tensor_parallel_shards, - conv_template=conversation, + conv_template=conversation, # type: ignore ) # Step 2. Load `generation_config.json` and `config.json` for text-generation related configs for generation_config_filename in ["generation_config.json", "config.json"]: @@ -299,4 +299,5 @@ def gen_config( # pylint: disable=too-many-locals,too-many-arguments,too-many-b "llava", "hermes2_pro_llama3", "tinyllama_v1_0", + "aya-23", } diff --git a/python/mlc_llm/model/cohere/__init__.py b/python/mlc_llm/model/cohere/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/python/mlc_llm/model/cohere/cohere_loader.py b/python/mlc_llm/model/cohere/cohere_loader.py new file mode 100644 index 0000000000..6fa19513f3 --- /dev/null +++ b/python/mlc_llm/model/cohere/cohere_loader.py @@ -0,0 +1,172 @@ +""" +This file specifies how MLC's Cohere parameter maps from other formats, for example HuggingFace +PyTorch, HuggingFace safetensors. +""" + +import functools + +import numpy as np + +from mlc_llm.loader import ExternMapping +from mlc_llm.quantization import Quantization + +from .cohere_model import CohereConfig, CohereForCausalLM +from .cohere_quantization import awq_quant + + +def huggingface(model_config: CohereConfig, quantization: Quantization) -> ExternMapping: + """Returns a parameter mapping that maps from the names of MLC LLM parameters to + the names of HuggingFace PyTorch parameters. + + Parameters + ---------- + model_config : CohereConfig + The configuration of the Cohere model. + + quantization : Quantization + The quantization configuration. + + Returns + ------- + param_map : ExternMapping + The parameter mapping from MLC to HuggingFace PyTorch. + """ + model = CohereForCausalLM(model_config) + if quantization is not None: + model.to(quantization.model_dtype) + _, _named_params, _ = model.export_tvm( # type: ignore[misc] + spec=model.get_default_spec(), + allow_extern=True, + ) + named_parameters = dict(_named_params) + + mapping = ExternMapping() + + def _add(mlc_name, hf_name): + mapping.add_mapping( + mlc_name, + [hf_name], + functools.partial( + lambda x, dtype: x.astype(dtype), + dtype=named_parameters[mlc_name].dtype, + ), + ) + + for i in range(model_config.num_hidden_layers): + # Add QKV in self attention + attn = f"model.layers.{i}.self_attn" + mlc_name = f"{attn}.qkv_proj.weight" + mlc_param = named_parameters[mlc_name] + _add(f"{attn}.out_proj.weight", f"{attn}.o_proj.weight") + mapping.add_mapping( + mlc_name, + [ + f"{attn}.q_proj.weight", + f"{attn}.k_proj.weight", + f"{attn}.v_proj.weight", + ], + functools.partial( + lambda q, k, v, dtype: np.concatenate([q, k, v], axis=0).astype(dtype), + dtype=mlc_param.dtype, + ), + ) + # Add gates in MLP + mlp = f"model.layers.{i}.mlp" + _add(f"{mlp}.up_proj.weight", f"{mlp}.up_proj.weight") + _add(f"{mlp}.gate_proj.weight", f"{mlp}.gate_proj.weight") + _add(f"{mlp}.down_proj.weight", f"{mlp}.down_proj.weight") + # inv_freq is not used in the model + # mapping.add_unused(f"{attn}.rotary_emb.inv_freq") + + for mlc_name, mlc_param in named_parameters.items(): + if mlc_name not in mapping.param_map: + mapping.add_mapping( + mlc_name, + [mlc_name], + functools.partial( + lambda x, dtype: x.astype(dtype), + dtype=mlc_param.dtype, + ), + ) + + return mapping + + +# https://huggingface.co/alijawad07/aya-23-8B-AWQ-GEMM/tree/main +def awq(model_config: CohereConfig, quantization: Quantization) -> ExternMapping: + """Returns a parameter mapping that maps from the names of MLC LLM parameters to + the names of AWQ parameters. + Parameters + ---------- + model_config : CohereConfig + The configuration of the Cohere model. + + quantization : Quantization + The quantization configuration. + + Returns + ------- + param_map : ExternMapping + The parameter mapping from MLC to AWQ. + """ + model, _ = awq_quant(model_config, quantization) + _, _named_params, _ = model.export_tvm( # type: ignore[misc] + spec=model.get_default_spec(), # type: ignore[attr-defined] + allow_extern=True, + ) + named_parameters = dict(_named_params) + + mapping = ExternMapping() + + def _add(mlc_name, hf_name): + mapping.add_mapping( + mlc_name, + [hf_name], + functools.partial( + lambda x, dtype: x.astype(dtype), + dtype=named_parameters[mlc_name].dtype, + ), + ) + + for i in range(model_config.num_hidden_layers): + # Add QKV in self attention + attn = f"model.layers.{i}.self_attn" + for quantize_suffix in ["qweight", "qzeros", "scales"]: + mlc_name = f"{attn}.qkv_proj.{quantize_suffix}" + assert mlc_name in named_parameters + mlc_param = named_parameters[mlc_name] + mapping.add_mapping( + mlc_name, + [ + f"{attn}.q_proj.{quantize_suffix}", + f"{attn}.k_proj.{quantize_suffix}", + f"{attn}.v_proj.{quantize_suffix}", + ], + functools.partial( + lambda q, k, v, dtype: np.concatenate( + [q, k, v], + axis=1, # AWQ GEMM would transpose the weight + ).astype(dtype), + dtype=mlc_param.dtype, + ), + ) + _add(f"{attn}.out_proj.{quantize_suffix}", f"{attn}.o_proj.{quantize_suffix}") + + # Concat gate and up in MLP + mlp = f"model.layers.{i}.mlp" + for quantize_suffix in ["qweight", "qzeros", "scales"]: + _add(f"{mlp}.up_proj.{quantize_suffix}", f"{mlp}.up_proj.{quantize_suffix}") + _add(f"{mlp}.gate_proj.{quantize_suffix}", f"{mlp}.gate_proj.{quantize_suffix}") + _add(f"{mlp}.down_proj.{quantize_suffix}", f"{mlp}.down_proj.{quantize_suffix}") + + # inv_freq is not used in the model + # mapping.add_unused(f"{attn}.rotary_emb.inv_freq") + + for mlc_name, mlc_param in named_parameters.items(): + if mlc_name not in mapping.param_map: + mapping.add_mapping( + mlc_name, + [mlc_name], + functools.partial(lambda x, dtype: x.astype(dtype), dtype=mlc_param.dtype), + ) + return mapping diff --git a/python/mlc_llm/model/cohere/cohere_model.py b/python/mlc_llm/model/cohere/cohere_model.py new file mode 100644 index 0000000000..180c60ba13 --- /dev/null +++ b/python/mlc_llm/model/cohere/cohere_model.py @@ -0,0 +1,404 @@ +""" +Implementation for Aya23 architecture +TODO: add docstring +""" + +import dataclasses +from typing import Any, Dict, Optional + +from tvm import te, tir +from tvm.relax.frontend import nn +from tvm.relax.frontend.nn import Tensor, op + +from mlc_llm import op as op_ext +from mlc_llm.nn import PagedKVCache, RopeMode +from mlc_llm.support import logging +from mlc_llm.support import tensor_parallel as tp +from mlc_llm.support.config import ConfigBase +from mlc_llm.support.style import bold + +logger = logging.getLogger(__name__) + + +@dataclasses.dataclass +class CohereConfig(ConfigBase): # pylint: disable=too-many-instance-attributes + """Configuration of the Cohere Aya-23 model""" + + model_type: str # cohere + hidden_size: int + vocab_size: int + num_hidden_layers: int + num_attention_heads: int + num_key_value_heads: int + intermediate_size: int + layer_norm_eps: float + position_embedding_base: int = 0 + context_window_size: int = 0 + prefill_chunk_size: int = 0 + head_dim: int = 0 + tensor_parallel_shards: int = 1 + max_batch_size: int = 1 + kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict) + + def __post_init__(self): + if self.position_embedding_base == 0: + if "rope_theta" in self.kwargs: + self.position_embedding_base = self.kwargs["rope_theta"] + else: + self.position_embedding_base = 10000 + + if self.context_window_size == 0: + for name in ["max_position_embeddings", "max_sequence_length"]: + if name in self.kwargs: + self.context_window_size = self.kwargs.pop(name) + logger.info( + "%s not found in config.json. Falling back to %s (%d)", + bold("context_window_size"), + bold(name), + self.context_window_size, + ) + break + + if self.prefill_chunk_size == 0: + logger.info( + "%s defaults to %d", + bold("prefill_chunk_size"), + min(self.context_window_size, 2048), + ) + self.prefill_chunk_size = min(self.context_window_size, 2048) + elif self.prefill_chunk_size > self.context_window_size: + logger.info( + "Overriding %s from %d to %d", + bold("prefill_chunk_size"), + self.prefill_chunk_size, + min(self.context_window_size, 2048), + ) + self.prefill_chunk_size = min(self.context_window_size, 2048) + + if self.num_key_value_heads == 0 or self.num_key_value_heads is None: + self.num_key_value_heads = self.num_attention_heads + if self.head_dim == 0: + self.head_dim = self.hidden_size // self.num_attention_heads + assert ( + self.head_dim * self.num_attention_heads == self.hidden_size + ), "head_dim * num_attention_heads != hidden_size" + assert ( + self.num_attention_heads % self.num_key_value_heads == 0 + ), "num_attention_heads % num_key_value_heads != 0" + + +# pylint: disable=invalid-name,missing-docstring + + +class CohereMLP(nn.Module): + def __init__(self, config: CohereConfig): + super().__init__() + if config.intermediate_size % config.tensor_parallel_shards != 0: + raise ValueError( + f"Cannot split MLP intermediate size {config.intermediate_size} " + f"evenly to {config.tensor_parallel_shards} GPUs." + ) + + self.intermediate_size = config.intermediate_size // config.tensor_parallel_shards + self.gate_proj = nn.Linear(config.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(config.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, config.hidden_size, bias=False) + + def forward(self, x): + down_proj = self.down_proj(op.silu(self.gate_proj(x)) * self.up_proj(x)) + return down_proj + + +# pylint: disable=invalid-name,missing-docstring + + +class CohereAttention(nn.Module): + def __init__(self, config: CohereConfig): + self.num_q_heads = config.num_attention_heads // config.tensor_parallel_shards + assert config.num_attention_heads % config.tensor_parallel_shards == 0, ( + f"num_attention_heads({config.num_attention_heads}) " + "must be divisible by tensor_parallel_shards" + ) + self.num_key_value_heads = config.num_key_value_heads // config.tensor_parallel_shards + assert config.num_key_value_heads % config.tensor_parallel_shards == 0, ( + f"num_attention_heads({config.num_key_value_heads}) " + "must be divisible by tensor_parallel_shards" + ) + self.head_dim = config.head_dim + + self.qkv_proj = nn.Linear( + in_features=config.hidden_size, + out_features=(self.num_q_heads + 2 * self.num_key_value_heads) * self.head_dim, + bias=False, + ) + self.out_proj = nn.Linear(self.num_q_heads * self.head_dim, config.hidden_size, bias=False) + + def forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: int): + d, h_q, h_kv = self.head_dim, self.num_q_heads, self.num_key_value_heads + b, s, _ = hidden_states.shape + # QKV Projection + qkv = self.qkv_proj(hidden_states) + qkv = op.reshape(qkv, (b, s, h_q + h_kv + h_kv, d)) + # Attention + output = op.reshape( + paged_kv_cache.attention_with_fused_qkv(layer_id, qkv, self.num_q_heads), + (b, s, h_q * d), + ) + return self.out_proj(output) + + +class CohereDecoderLayer(nn.Module): + def __init__(self, config: CohereConfig): + super().__init__() + self.self_attn = CohereAttention(config) + self.mlp = CohereMLP(config) + self.input_layernorm = CohereNorm(config.hidden_size, eps=config.layer_norm_eps) + + def _set_tp(): + def _set(layer, hint): + layer.weight.attrs["shard_strategy"] = hint + + hd = config.head_dim + q = self.self_attn.num_q_heads * hd + k = self.self_attn.num_key_value_heads * hd + v = self.self_attn.num_key_value_heads * hd + i = self.mlp.intermediate_size + _set(self.self_attn.qkv_proj, tp.ShardSingleDim("_shard_qkv", segs=[q, k, v], dim=0)) + _set(self.self_attn.out_proj, tp.ShardSingleDim("_shard_o", dim=1)) + _set(self.mlp.gate_proj, tp.ShardSingleDim("_shard_mlp_gate", segs=[i, i], dim=0)) + _set(self.mlp.up_proj, tp.ShardSingleDim("_shard_mlp_up", segs=[i, i], dim=0)) + _set(self.mlp.down_proj, tp.ShardSingleDim("_shard_mlp_down", dim=1)) + + self.tensor_parallel_shards = config.tensor_parallel_shards + _set_tp() + + def forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: int): + hidden_ln = self.input_layernorm(hidden_states) + attn = self.self_attn(hidden_ln, paged_kv_cache, layer_id) + mlp = self.mlp(hidden_ln) + hidden_states = self._apply_parallel_residual(attn, residual=hidden_states) # type: ignore + hidden_states = self._apply_parallel_residual(mlp, residual=hidden_states) # type: ignore + return hidden_states + + def _apply_parallel_residual(self, mlp_out, residual): + if self.tensor_parallel_shards > 1: + return op.ccl_allreduce(mlp_out + residual / self.tensor_parallel_shards, "sum") + return mlp_out + residual + + +class CohereNorm(nn.Module): + def __init__( + self, normalized_shape: int, eps: float = 1e-5, dtype: Optional[str] = None + ) -> None: + super().__init__() + self.normalized_shape = normalized_shape + self.eps = eps + self.weight = nn.Parameter((normalized_shape,), dtype=dtype) + + def forward(self, x: Tensor) -> Tensor: + return op.layer_norm( + x, + normalized_shape=self.normalized_shape, + weight=self.weight, + bias=None, + eps=self.eps, + ) + + +class CohereEmbedding(nn.Embedding): + def lm_head_forward(self, x: nn.Tensor): + """The lm_head forwarding, which transposes the weight and multiplies + with the input tensor. + """ + weight = nn.op.permute_dims(self.weight) + return nn.op.matmul(x, weight, out_dtype="float32") + + +class CohereModel(nn.Module): + def __init__(self, config: CohereConfig): + assert config.hidden_size % config.num_attention_heads == 0 + self.embed_tokens = CohereEmbedding("vocab_size", config.hidden_size) + self.layers = nn.ModuleList( + [CohereDecoderLayer(config) for _ in range(config.num_hidden_layers)] + ) + self.norm = CohereNorm(config.hidden_size, eps=config.layer_norm_eps) + + def forward(self, input_embed: Tensor, paged_kv_cache: PagedKVCache): + hidden_states = input_embed + for layer_id, layer in enumerate(self.layers): + hidden_states = layer(hidden_states, paged_kv_cache, layer_id) + hidden_states = self.norm(hidden_states) + return hidden_states + + +class CohereForCausalLM(nn.Module): + # pylint: disable=too-many-instance-attributes + def __init__(self, config: CohereConfig) -> None: + super().__init__() + self.model = CohereModel(config) + self.num_hidden_layers = config.num_hidden_layers + self.num_attention_heads = config.num_attention_heads + self.num_key_value_heads = config.num_key_value_heads + self.head_dim = config.head_dim + self.hidden_size = config.hidden_size + self.vocab_size = config.vocab_size + self.rope_theta = config.position_embedding_base + self.tensor_parallel_shards = config.tensor_parallel_shards + self.dtype = "float32" + + def to(self, dtype: Optional[str] = None): + super().to(dtype=dtype) + if dtype is not None: + self.dtype = dtype + + def batch_forward( + self, + input_embeds: Tensor, + paged_kv_cache: PagedKVCache, + logit_positions: Optional[Tensor] = None, + ): + op_ext.configure() + + hidden_states = self.model(input_embeds, paged_kv_cache) + if logit_positions is not None: + hidden_states = op.take(hidden_states, logit_positions, axis=1) + lm_logits = self.model.embed_tokens.lm_head_forward(hidden_states) + if lm_logits.dtype != "float32": + lm_logits = lm_logits.astype("float32") + return lm_logits + + def prefill(self, input_embed: Tensor, paged_kv_cache: PagedKVCache): + op_ext.configure() + + def _index(x: te.Tensor): + b, s, d = x.shape # type: ignore + return te.compute((b, 1, d), lambda i, _, k: x[i, s - 1, k], name="index") + + hidden_states = self.model(input_embed, paged_kv_cache) + hidden_states = op.tensor_expr_op(_index, name_hint="index", args=[hidden_states]) + # logits = self.lm_head(hidden_states) + logits = self.model.embed_tokens.lm_head_forward(hidden_states) # type: ignore + + if logits.dtype != "float32": + logits = logits.astype("float32") + + return logits, paged_kv_cache + + def decode(self, input_embed: Tensor, paged_kv_cache: PagedKVCache): + op_ext.configure() + + hidden_states = self.model(input_embed, paged_kv_cache) + logits = self.model.embed_tokens.lm_head_forward(hidden_states) + if logits.dtype != "float32": + logits = logits.astype("float32") + return logits, paged_kv_cache + + def batch_prefill( + self, input_embeds: Tensor, logit_positions: Tensor, paged_kv_cache: PagedKVCache + ): + if self.tensor_parallel_shards > 1: + logit_positions = op.ccl_broadcast_from_worker0(logit_positions) # type: ignore + logits = self.batch_forward(input_embeds, paged_kv_cache, logit_positions) + return logits, paged_kv_cache + + def batch_decode(self, input_embeds: Tensor, paged_kv_cache: PagedKVCache): + logits = self.batch_forward(input_embeds, paged_kv_cache) + return logits, paged_kv_cache + + def batch_verify(self, input_embeds: Tensor, paged_kv_cache: PagedKVCache): + logits = self.batch_forward(input_embeds, paged_kv_cache) + return logits, paged_kv_cache + + def embed(self, input_ids: Tensor): + if self.tensor_parallel_shards > 1: + input_ids = op.ccl_broadcast_from_worker0(input_ids) # type: ignore + embeds = self.model.embed_tokens(input_ids) + return embeds + + def create_paged_kv_cache( # pylint: disable=too-many-arguments + self, + max_batch_size: tir.Var, + max_total_seq_len: tir.Var, + prefill_chunk_size: tir.Var, + page_size: tir.Var, + support_sliding_window: tir.Var, + ) -> PagedKVCache: + return PagedKVCache.create_generic( + max_batch_size=max_batch_size, + max_total_seq_len=max_total_seq_len, + prefill_chunk_size=prefill_chunk_size, + page_size=page_size, + support_sliding_window=support_sliding_window, + num_hidden_layers=self.num_hidden_layers, + num_attention_heads=self.num_attention_heads // self.tensor_parallel_shards, + num_key_value_heads=self.num_key_value_heads // self.tensor_parallel_shards, + head_dim=self.head_dim, + rope_mode=RopeMode.NORMAL, + rope_scale=1, + rope_theta=self.rope_theta, + dtype=self.dtype, + ) + + def get_default_spec(self): + mod_spec = { + "embed": { + "input_ids": nn.spec.Tensor(["seq_len"], "int32"), + "$": { + "param_mode": "packed", + "effect_mode": "none", + }, + }, + "prefill": { + "input_embed": nn.spec.Tensor([1, "seq_len", self.hidden_size], self.dtype), + "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "$": { + "param_mode": "packed", + "effect_mode": "none", + }, + }, + "decode": { + "input_embed": nn.spec.Tensor([1, 1, self.hidden_size], self.dtype), + "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "$": { + "param_mode": "packed", + "effect_mode": "none", + }, + }, + "batch_prefill": { + "input_embeds": nn.spec.Tensor([1, "seq_len", self.hidden_size], self.dtype), + "logit_positions": nn.spec.Tensor(["batch_size"], "int32"), + "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "$": { + "param_mode": "packed", + "effect_mode": "none", + }, + }, + "batch_decode": { + "input_embeds": nn.spec.Tensor(["batch_size", 1, self.hidden_size], self.dtype), + "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "$": { + "param_mode": "packed", + "effect_mode": "none", + }, + }, + "batch_verify": { + "input_embeds": nn.spec.Tensor([1, "seq_len", self.hidden_size], self.dtype), + "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "$": { + "param_mode": "packed", + "effect_mode": "none", + }, + }, + "create_paged_kv_cache": { + "max_batch_size": int, + "max_total_seq_len": int, + "prefill_chunk_size": int, + "page_size": int, + "support_sliding_window": int, + "$": { + "param_mode": "none", + "effect_mode": "none", + }, + }, + } + return nn.spec.ModuleSpec.from_raw(mod_spec, self) # type: ignore diff --git a/python/mlc_llm/model/cohere/cohere_quantization.py b/python/mlc_llm/model/cohere/cohere_quantization.py new file mode 100644 index 0000000000..ded84e1151 --- /dev/null +++ b/python/mlc_llm/model/cohere/cohere_quantization.py @@ -0,0 +1,71 @@ +"""This file specifies how MLC's Cohere parameters are quantized using group quantization +or other formats.""" + +from typing import Tuple + +from tvm.relax.frontend import nn + +from mlc_llm.loader import QuantizeMapping +from mlc_llm.quantization import AWQQuantize, FTQuantize, GroupQuantize, NoQuantize + +from .cohere_model import CohereConfig, CohereForCausalLM + + +def group_quant( + model_config: CohereConfig, + quantization: GroupQuantize, +) -> Tuple[nn.Module, QuantizeMapping]: + """Quantize a Cohere-architecture model using group quantization.""" + model: nn.Module = CohereForCausalLM(model_config) + model.to(quantization.model_dtype) + quant_map = QuantizeMapping({}, {}) + quantization.tensor_parallel_shards = model_config.tensor_parallel_shards + model = quantization.quantize_model( + model, + quant_map, + "", + ) + return model, quant_map + + +def ft_quant( + model_config: CohereConfig, + quantization: FTQuantize, +) -> Tuple[nn.Module, QuantizeMapping]: + """Quantize a Cohere-architecture model using FasterTransformer quantization.""" + model: nn.Module = CohereForCausalLM(model_config) + model.to(quantization.model_dtype) + quant_map = QuantizeMapping({}, {}) + model = quantization.quantize_model( + model, + quant_map, + "", + ) + return model, quant_map + + +def awq_quant( + model_config: CohereConfig, + quantization: AWQQuantize, +) -> Tuple[nn.Module, QuantizeMapping]: + """Quantize a cohere-Aya model using Activation-aware Weight Quantization(AWQ).""" + model: nn.Module = CohereForCausalLM(model_config) + model.to(quantization.model_dtype) + quant_map = QuantizeMapping({}, {}) + model = quantization.quantize_model( + model, + quant_map, + "", + ) + return model, quant_map + + +def no_quant( + model_config: CohereConfig, + quantization: NoQuantize, +) -> Tuple[nn.Module, QuantizeMapping]: + """Quantize a Cohere model without quantization.""" + model: nn.Module = CohereForCausalLM(model_config) + model.to(quantization.model_dtype) + quant_map = QuantizeMapping({}, {}) + return model, quant_map diff --git a/python/mlc_llm/model/llama/llama_loader.py b/python/mlc_llm/model/llama/llama_loader.py index c166609b4c..1c52b16504 100644 --- a/python/mlc_llm/model/llama/llama_loader.py +++ b/python/mlc_llm/model/llama/llama_loader.py @@ -10,7 +10,7 @@ from mlc_llm.loader import ExternMapping from mlc_llm.quantization import Quantization -from .llama_model import LlamaConfig, LlamaForCasualLM +from .llama_model import LlamaConfig, LlamaForCausalLM from .llama_quantization import awq_quant @@ -31,7 +31,7 @@ def huggingface(model_config: LlamaConfig, quantization: Quantization) -> Extern param_map : ExternMapping The parameter mapping from MLC to HuggingFace PyTorch. """ - model = LlamaForCasualLM(model_config) + model = LlamaForCausalLM(model_config) if quantization is not None: model.to(quantization.model_dtype) _, _named_params, _ = model.export_tvm( # type: ignore[misc] diff --git a/python/mlc_llm/model/llama/llama_model.py b/python/mlc_llm/model/llama/llama_model.py index d395292b9a..9e786c90d5 100644 --- a/python/mlc_llm/model/llama/llama_model.py +++ b/python/mlc_llm/model/llama/llama_model.py @@ -218,7 +218,7 @@ def forward(self, input_embed: Tensor, paged_kv_cache: PagedKVCache): return hidden_states -class LlamaForCasualLM(nn.Module): # pylint: disable=too-many-instance-attributes +class LlamaForCausalLM(nn.Module): # pylint: disable=too-many-instance-attributes def __init__(self, config: LlamaConfig): self.model = LlamaModel(config) self.tie_word_embeddings = config.tie_word_embeddings diff --git a/python/mlc_llm/model/llama/llama_quantization.py b/python/mlc_llm/model/llama/llama_quantization.py index 26b6e0e728..3d2ab8c8d1 100644 --- a/python/mlc_llm/model/llama/llama_quantization.py +++ b/python/mlc_llm/model/llama/llama_quantization.py @@ -14,7 +14,7 @@ PerTensorQuantize, ) -from .llama_model import LlamaConfig, LlamaForCasualLM +from .llama_model import LlamaConfig, LlamaForCausalLM def group_quant( @@ -22,7 +22,7 @@ def group_quant( quantization: GroupQuantize, ) -> Tuple[nn.Module, QuantizeMapping]: """Quantize a Llama-architecture model using group quantization.""" - model: nn.Module = LlamaForCasualLM(model_config) + model: nn.Module = LlamaForCausalLM(model_config) model.to(quantization.model_dtype) quant_map = QuantizeMapping({}, {}) quantization.tensor_parallel_shards = model_config.tensor_parallel_shards @@ -39,7 +39,7 @@ def ft_quant( quantization: FTQuantize, ) -> Tuple[nn.Module, QuantizeMapping]: """Quantize a Llama-architecture model using FasterTransformer quantization.""" - model: nn.Module = LlamaForCasualLM(model_config) + model: nn.Module = LlamaForCausalLM(model_config) model.to(quantization.model_dtype) quant_map = QuantizeMapping({}, {}) model = quantization.quantize_model( @@ -55,7 +55,7 @@ def awq_quant( quantization: AWQQuantize, ) -> Tuple[nn.Module, QuantizeMapping]: """Quantize a Llama-architecture model using Activation-aware Weight Quantization(AWQ).""" - model: nn.Module = LlamaForCasualLM(model_config) + model: nn.Module = LlamaForCausalLM(model_config) model.to(quantization.model_dtype) quant_map = QuantizeMapping({}, {}) model = quantization.quantize_model( @@ -71,7 +71,7 @@ def no_quant( quantization: NoQuantize, ) -> Tuple[nn.Module, QuantizeMapping]: """Quantize a Llama2 model without quantization.""" - model: nn.Module = LlamaForCasualLM(model_config) + model: nn.Module = LlamaForCausalLM(model_config) model.to(quantization.model_dtype) quant_map = QuantizeMapping({}, {}) return model, quant_map @@ -82,7 +82,7 @@ def per_tensor_quant( quantization: PerTensorQuantize, ) -> Tuple[nn.Module, QuantizeMapping]: """Quantize a Llama-architecture model using per-tensor quantization.""" - model: nn.Module = LlamaForCasualLM(model_config) + model: nn.Module = LlamaForCausalLM(model_config) model.to(quantization.model_dtype) quant_map = QuantizeMapping({}, {}) model = quantization.quantize_model( diff --git a/python/mlc_llm/model/llava/llava_model.py b/python/mlc_llm/model/llava/llava_model.py index e667ef8ed4..c1126f7158 100644 --- a/python/mlc_llm/model/llava/llava_model.py +++ b/python/mlc_llm/model/llava/llava_model.py @@ -19,14 +19,14 @@ from mlc_llm.nn import PagedKVCache, RopeMode from ...support.config import ConfigBase -from ..llama.llama_model import LlamaConfig, LlamaForCasualLM +from ..llama.llama_model import LlamaConfig, LlamaForCausalLM from ..mistral.mistral_model import MistralConfig, MistralForCasualLM logger = logging.getLogger(__name__) CONFIG_MAP = {"LlamaForCausalLM": LlamaConfig, "MistralForCausalLM": MistralConfig} -ARCHITECTURE_MAP = {"LlamaForCausalLM": LlamaForCasualLM, "MistralForCausalLM": MistralForCasualLM} +ARCHITECTURE_MAP = {"LlamaForCausalLM": LlamaForCausalLM, "MistralForCausalLM": MistralForCasualLM} @dataclasses.dataclass diff --git a/python/mlc_llm/model/mixtral/mixtral_model.py b/python/mlc_llm/model/mixtral/mixtral_model.py index aedc566aa7..647f73e246 100644 --- a/python/mlc_llm/model/mixtral/mixtral_model.py +++ b/python/mlc_llm/model/mixtral/mixtral_model.py @@ -10,7 +10,7 @@ from mlc_llm.model.llama.llama_model import ( LlamaAttention, LlamaConfig, - LlamaForCasualLM, + LlamaForCausalLM, LlamaModel, ) from mlc_llm.nn import PagedKVCache @@ -176,7 +176,7 @@ def __init__(self, config: MixtralConfig): ) -class MixtralForCasualLM(LlamaForCasualLM): +class MixtralForCasualLM(LlamaForCausalLM): """Same as LlamaForCausalLM.""" def __init__(self, config: MixtralConfig): diff --git a/python/mlc_llm/model/model.py b/python/mlc_llm/model/model.py index 550e075a92..7c0a71362c 100644 --- a/python/mlc_llm/model/model.py +++ b/python/mlc_llm/model/model.py @@ -11,6 +11,7 @@ from .baichuan import baichuan_loader, baichuan_model, baichuan_quantization from .bert import bert_loader, bert_model, bert_quantization from .chatglm3 import chatglm3_loader, chatglm3_model, chatglm3_quantization +from .cohere import cohere_loader, cohere_model, cohere_quantization from .eagle import eagle_loader, eagle_model, eagle_quantization from .gemma import gemma_loader, gemma_model, gemma_quantization from .gemma2 import gemma2_loader, gemma2_model, gemma2_quantization @@ -81,7 +82,7 @@ class Model: MODELS: Dict[str, Model] = { "llama": Model( name="llama", - model=llama_model.LlamaForCasualLM, + model=llama_model.LlamaForCausalLM, config=llama_model.LlamaConfig, source={ "huggingface-torch": llama_loader.huggingface, @@ -472,4 +473,18 @@ class Model: "ft-quant": starcoder2_quantization.ft_quant, }, ), + "cohere": Model( + name="cohere", + model=cohere_model.CohereForCausalLM, + config=cohere_model.CohereConfig, + source={ + "huggingface-torch": cohere_loader.huggingface, + "huggingface-safetensor": cohere_loader.huggingface, + }, + quantize={ + "no-quant": cohere_quantization.no_quant, + "group-quant": cohere_quantization.group_quant, + "ft-quant": cohere_quantization.ft_quant, + }, + ), } diff --git a/python/mlc_llm/model/model_preset.py b/python/mlc_llm/model/model_preset.py index 90fe11d60a..c8c17d77ab 100644 --- a/python/mlc_llm/model/model_preset.py +++ b/python/mlc_llm/model/model_preset.py @@ -2,6 +2,8 @@ from typing import Any, Dict # pylint: disable=too-many-lines +# pylint: disable=too-many-lines + MODEL_PRESETS: Dict[str, Any] = { "llama2_7b": { "architectures": ["LlamaForCausalLM"], @@ -1263,4 +1265,29 @@ "use_cache": True, "vocab_size": 49152, }, + "aya-23": { + "architectures": ["CohereForCausalLM"], + "attention_bias": False, + "attention_dropout": 0.0, + "bos_token_id": 5, + "eos_token_id": 255001, + "hidden_act": "silu", + "hidden_size": 4096, + "initializer_range": 0.02, + "intermediate_size": 14336, + "layer_norm_eps": 1e-05, + "logit_scale": 0.0625, + "max_position_embeddings": 8192, + "model_type": "cohere", + "num_attention_heads": 32, + "num_hidden_layers": 32, + "num_key_value_heads": 8, + "pad_token_id": 0, + "rope_theta": 10000, + "torch_dtype": "float16", + "transformers_version": "4.40.0.dev0", + "use_cache": True, + "use_qk_norm": False, + "vocab_size": 256000, + }, }