|
| 1 | +import logging |
| 2 | +import os |
| 3 | +import warnings |
| 4 | +from typing import List, Optional, Tuple |
| 5 | + |
| 6 | +import coremltools as ct |
| 7 | +import numpy as np |
| 8 | +import torch |
| 9 | +from transformers.cache_utils import Cache |
| 10 | +from transformers.models.mistral.modeling_mistral import ( |
| 11 | + MISTRAL_ATTENTION_CLASSES, |
| 12 | + MistralAttention, |
| 13 | + MistralConfig, |
| 14 | + MistralForCausalLM, |
| 15 | + apply_rotary_pos_emb, |
| 16 | + repeat_kv, |
| 17 | +) |
| 18 | + |
| 19 | +warnings.filterwarnings("ignore") |
| 20 | +logging.getLogger("coremltools").setLevel(logging.ERROR) |
| 21 | +os.environ["TOKENIZERS_PARALLELISM"] = "false" |
| 22 | + |
| 23 | +# https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.3 |
| 24 | +MODEL_ID: str = "mistralai/Mistral-7B-Instruct-v0.3" |
| 25 | +METADATA_TOKENIZER: str = "co.huggingface.exporters.name" |
| 26 | + |
| 27 | + |
| 28 | +class SliceUpdateKeyValueCache(Cache): |
| 29 | + def __init__( |
| 30 | + self, |
| 31 | + shape: Tuple[int, ...], |
| 32 | + device="cpu", |
| 33 | + dtype=torch.float32, |
| 34 | + ) -> None: |
| 35 | + """KV cache of shape (#layers, batch_size, #kv_heads, context_size, head_dim).""" |
| 36 | + super().__init__() |
| 37 | + self.past_seen_tokens: int = 0 |
| 38 | + self.k_cache: torch.Tensor = torch.zeros(shape, dtype=dtype, device=device) |
| 39 | + self.v_cache: torch.Tensor = torch.zeros(shape, dtype=dtype, device=device) |
| 40 | + |
| 41 | + def update( |
| 42 | + self, |
| 43 | + k_state: torch.Tensor, |
| 44 | + v_state: torch.Tensor, |
| 45 | + layer_idx: int, |
| 46 | + slice_indices: torch.LongTensor, |
| 47 | + ) -> Tuple[torch.Tensor, torch.Tensor]: |
| 48 | + """ |
| 49 | + Update key/value cache tensors for slice [slice_indices[0], slice_indices[1]). |
| 50 | + Return slice of key/value cache tensors from [0, slice_indices[1]). |
| 51 | + """ |
| 52 | + if len(slice_indices) != 2: |
| 53 | + raise ValueError(f"Expect tuple of integers [start, end), got {slice_indices=}.") |
| 54 | + begin, end = slice_indices |
| 55 | + self.k_cache[layer_idx, :, : k_state.shape[1], begin:end, :] = k_state |
| 56 | + self.v_cache[layer_idx, :, : v_state.shape[1], begin:end, :] = v_state |
| 57 | + k_cache: torch.Tensor = self.k_cache[layer_idx, :, :, :end, :] |
| 58 | + v_cache: torch.Tensor = self.v_cache[layer_idx, :, :, :end, :] |
| 59 | + return k_cache, v_cache |
| 60 | + |
| 61 | + def get_seq_length(self, _: int | None = 0) -> int: |
| 62 | + """Get the sequence length of the cache.""" |
| 63 | + return self.past_seen_tokens |
| 64 | + |
| 65 | + |
| 66 | +class SliceUpdateMistralAttention(MistralAttention): |
| 67 | + def __init__(self, config: MistralConfig, layer_idx: Optional[int] = None): |
| 68 | + super().__init__(config=config, layer_idx=layer_idx) |
| 69 | + |
| 70 | + @torch.no_grad() |
| 71 | + def forward( |
| 72 | + self, |
| 73 | + hidden_states: torch.Tensor, |
| 74 | + attention_mask: torch.Tensor, |
| 75 | + position_ids: Optional[torch.LongTensor] = None, |
| 76 | + past_key_value: Optional[Cache] = None, |
| 77 | + **kwargs, |
| 78 | + ) -> Tuple[torch.Tensor | None, ...]: |
| 79 | + bsz, q_len, _ = hidden_states.size() |
| 80 | + |
| 81 | + query_states = self.q_proj(hidden_states) |
| 82 | + key_states = self.k_proj(hidden_states) |
| 83 | + value_states = self.v_proj(hidden_states) |
| 84 | + |
| 85 | + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) |
| 86 | + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose( |
| 87 | + 1, 2 |
| 88 | + ) |
| 89 | + value_states = value_states.view( |
| 90 | + bsz, q_len, self.num_key_value_heads, self.head_dim |
| 91 | + ).transpose(1, 2) |
| 92 | + |
| 93 | + cos, sin = self.rotary_emb(value_states, position_ids) |
| 94 | + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) |
| 95 | + |
| 96 | + # Slice update key/value cache |
| 97 | + end_step = attention_mask.shape[-1] |
| 98 | + key_states, value_states = past_key_value.update( |
| 99 | + key_states, |
| 100 | + value_states, |
| 101 | + self.layer_idx, |
| 102 | + slice_indices=(end_step - q_len, end_step), |
| 103 | + ) |
| 104 | + |
| 105 | + key_states = repeat_kv(key_states, self.num_key_value_groups) |
| 106 | + value_states = repeat_kv(value_states, self.num_key_value_groups) |
| 107 | + |
| 108 | + attn_output = torch.nn.functional.scaled_dot_product_attention( |
| 109 | + query_states, |
| 110 | + key_states, |
| 111 | + value_states, |
| 112 | + attn_mask=attention_mask, |
| 113 | + ) |
| 114 | + |
| 115 | + attn_output = attn_output.transpose(1, 2).contiguous() |
| 116 | + attn_output = attn_output.view(bsz, q_len, self.hidden_size) |
| 117 | + attn_output = self.o_proj(attn_output) |
| 118 | + return attn_output, None, None |
| 119 | + |
| 120 | + |
| 121 | +class StatefulMistralForCausalLM(torch.nn.Module): |
| 122 | + def __init__(self, model_path: str, max_context_size: int = 2048, batch_size: int = 1) -> None: |
| 123 | + super().__init__() |
| 124 | + |
| 125 | + # Custom attention implementation for stateful slice update key/value cache, override |
| 126 | + # "sdpa" to compliance with transformers.modeling_utils._autoset_attn_implementation |
| 127 | + MISTRAL_ATTENTION_CLASSES["sdpa"] = SliceUpdateMistralAttention |
| 128 | + self.model = MistralForCausalLM.from_pretrained(model_path) |
| 129 | + |
| 130 | + # Register KV cache buffers to be recognized as Core ML states |
| 131 | + config: MistralConfig = self.model.config |
| 132 | + self.kv_cache_shape: Tuple[int, ...] = ( |
| 133 | + config.num_hidden_layers, |
| 134 | + batch_size, |
| 135 | + config.num_key_value_heads, |
| 136 | + max_context_size, |
| 137 | + config.hidden_size // config.num_attention_heads, |
| 138 | + ) |
| 139 | + self.kv_cache = SliceUpdateKeyValueCache(shape=self.kv_cache_shape) |
| 140 | + self.register_buffer("keyCache", self.kv_cache.k_cache) |
| 141 | + self.register_buffer("valueCache", self.kv_cache.v_cache) |
| 142 | + |
| 143 | + @torch.no_grad() |
| 144 | + def forward( |
| 145 | + self, |
| 146 | + input_ids: torch.LongTensor, |
| 147 | + causal_mask: torch.Tensor, |
| 148 | + ) -> torch.Tensor: |
| 149 | + # Compute past seen tokens used for updating key/value cache slices |
| 150 | + self.kv_cache.past_seen_tokens = causal_mask.shape[-1] - input_ids.shape[-1] |
| 151 | + return self.model( |
| 152 | + input_ids, |
| 153 | + attention_mask=causal_mask, |
| 154 | + past_key_values=self.kv_cache, |
| 155 | + use_cache=True, |
| 156 | + ).logits |
| 157 | + |
| 158 | + |
| 159 | +def export() -> None: |
| 160 | + # Construct model from transformers and trace to TorchScript |
| 161 | + max_context_size: int = 2048 |
| 162 | + torch_model = StatefulMistralForCausalLM(MODEL_ID, max_context_size=max_context_size) |
| 163 | + torch_model.eval() |
| 164 | + input_ids: torch.Tensor = torch.zeros((1, 2), dtype=torch.int32) |
| 165 | + causal_mask: torch.Tensor = torch.zeros((1, 1, 2, 5), dtype=torch.float32) |
| 166 | + traced_model = torch.jit.trace(torch_model, [input_ids, causal_mask]) |
| 167 | + |
| 168 | + # Convert traced TorchScript to Core ML format |
| 169 | + query_length = ct.RangeDim(lower_bound=1, upper_bound=max_context_size, default=1) |
| 170 | + end_step_dim = ct.RangeDim(lower_bound=1, upper_bound=max_context_size, default=1) |
| 171 | + inputs: List[ct.TensorType] = [ |
| 172 | + ct.TensorType(shape=(1, query_length), dtype=np.int32, name="inputIds"), |
| 173 | + ct.TensorType( |
| 174 | + shape=(1, 1, query_length, end_step_dim), |
| 175 | + dtype=np.float16, |
| 176 | + name="causalMask", |
| 177 | + ), |
| 178 | + ] |
| 179 | + outputs: List[ct.TensorType] = [ct.TensorType(dtype=np.float16, name="logits")] |
| 180 | + states: List[ct.StateType] = [ |
| 181 | + ct.StateType( |
| 182 | + wrapped_type=ct.TensorType(shape=torch_model.kv_cache_shape, dtype=np.float16), |
| 183 | + name="keyCache", |
| 184 | + ), |
| 185 | + ct.StateType( |
| 186 | + wrapped_type=ct.TensorType(shape=torch_model.kv_cache_shape, dtype=np.float16), |
| 187 | + name="valueCache", |
| 188 | + ), |
| 189 | + ] |
| 190 | + |
| 191 | + # Convert model with FP16 precision |
| 192 | + mlmodel_fp16: ct.MLModel = ct.convert( |
| 193 | + traced_model, |
| 194 | + inputs=inputs, |
| 195 | + outputs=outputs, |
| 196 | + states=states, |
| 197 | + minimum_deployment_target=ct.target.iOS18, |
| 198 | + skip_model_load=True, |
| 199 | + ) |
| 200 | + |
| 201 | + # Block-wise quantize model weights to int4 |
| 202 | + op_config = ct.optimize.coreml.OpLinearQuantizerConfig( |
| 203 | + mode="linear_symmetric", |
| 204 | + dtype="int4", |
| 205 | + granularity="per_block", |
| 206 | + block_size=32, |
| 207 | + ) |
| 208 | + config = ct.optimize.coreml.OptimizationConfig(global_config=op_config) |
| 209 | + mlmodel_int4 = ct.optimize.coreml.linear_quantize_weights(mlmodel_fp16, config=config) |
| 210 | + mlmodel_int4._spec.description.metadata.userDefined.update({METADATA_TOKENIZER: MODEL_ID}) |
| 211 | + mlmodel_int4.save("StatefulMistral7BInstructInt4.mlpackage") |
| 212 | + |
| 213 | + |
| 214 | +if __name__ == "__main__": |
| 215 | + export() |
0 commit comments