Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions Examples/Mistral7B/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
### Export Mistral 7B Instruct v0.3

```shell
✗ uv run export.py

Loading checkpoint shards: 100%|███████████████████████████| 3/3 [00:12<00:00, 4.11s/it]
Converting PyTorch Frontend ==> MIL Ops: 100%|███| 5575/5575 [00:02<00:00, 2440.66 ops/s]
Running MIL frontend_pytorch pipeline: 100%|██████████| 5/5 [00:00<00:00, 7.12 passes/s]
Running MIL default pipeline: 100%|█████████████████| 79/79 [02:36<00:00, 1.98s/ passes]
Running MIL backend_mlprogram pipeline: 100%|███████| 12/12 [00:00<00:00, 22.90 passes/s]
Running compression: 100%|███████████████████████████| 296/296 [03:04<00:00, 1.60 ops/s]
...
```

### Generate Text

```shell
✗ swift run transformers-cli "Best recommendations for a place to visit in Paris in August 2024:" --max-length 128 StatefulMistral7BInstructInt4.mlpackage

Best recommendations for a place to visit in Paris in August 2024:

1. Palace of Versailles: This iconic palace is a must-visit. It's a short train ride from Paris and offers a glimpse into the opulence of the French monarchy.

2. Eiffel Tower: No trip to Paris is complete without a visit to the Eiffel Tower. You can take an elevator ride to the top for a stunning view of the city.

3. Louvre Museum: Home to thousands of works of art, including the Mona Lisa and the Winged Victory of Samothrace, the Louvre is a cultural treasure.
```
230 changes: 230 additions & 0 deletions Examples/Mistral7B/export.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,230 @@
# /// script
# requires-python = ">=3.10"
# dependencies = [
# "coremltools",
# "numpy",
# "sentencepiece",
# "torch",
# "tqdm",
# "transformers",
# ]
# ///
import logging
import os
import warnings
from typing import List, Optional, Tuple

import coremltools as ct
import numpy as np
import torch
from transformers.cache_utils import Cache
from transformers.models.mistral.modeling_mistral import (
MISTRAL_ATTENTION_CLASSES,
MistralAttention,
MistralConfig,
MistralForCausalLM,
apply_rotary_pos_emb,
repeat_kv,
)

warnings.filterwarnings("ignore")
logging.getLogger("coremltools").setLevel(logging.ERROR)
os.environ["TOKENIZERS_PARALLELISM"] = "false"

# https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.3
MODEL_ID: str = "mistralai/Mistral-7B-Instruct-v0.3"
METADATA_TOKENIZER: str = "co.huggingface.exporters.name"


class SliceUpdateKeyValueCache(Cache):
def __init__(
self,
shape: Tuple[int, ...],
device="cpu",
dtype=torch.float32,
) -> None:
"""KV cache of shape (#layers, batch_size, #kv_heads, context_size, head_dim)."""
super().__init__()
self.past_seen_tokens: int = 0
self.k_cache: torch.Tensor = torch.zeros(shape, dtype=dtype, device=device)
self.v_cache: torch.Tensor = torch.zeros(shape, dtype=dtype, device=device)

def update(
self,
k_state: torch.Tensor,
v_state: torch.Tensor,
layer_idx: int,
slice_indices: torch.LongTensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Update key/value cache tensors for slice [slice_indices[0], slice_indices[1]).
Return slice of key/value cache tensors from [0, slice_indices[1]).
"""
if len(slice_indices) != 2:
raise ValueError(f"Expect tuple of integers [start, end), got {slice_indices=}.")
begin, end = slice_indices
self.k_cache[layer_idx, :, : k_state.shape[1], begin:end, :] = k_state
self.v_cache[layer_idx, :, : v_state.shape[1], begin:end, :] = v_state
k_cache: torch.Tensor = self.k_cache[layer_idx, :, :, :end, :]
v_cache: torch.Tensor = self.v_cache[layer_idx, :, :, :end, :]
return k_cache, v_cache

def get_seq_length(self, _: int | None = 0) -> int:
"""Get the sequence length of the cache."""
return self.past_seen_tokens


class SliceUpdateMistralAttention(MistralAttention):
def __init__(self, config: MistralConfig, layer_idx: Optional[int] = None):
super().__init__(config=config, layer_idx=layer_idx)

@torch.no_grad()
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
**kwargs,
) -> Tuple[torch.Tensor | None, ...]:
bsz, q_len, _ = hidden_states.size()

query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)

query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(
1, 2
)
value_states = value_states.view(
bsz, q_len, self.num_key_value_heads, self.head_dim
).transpose(1, 2)

cos, sin = self.rotary_emb(value_states, position_ids)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)

# Slice update key/value cache
end_step = attention_mask.shape[-1]
key_states, value_states = past_key_value.update(
key_states,
value_states,
self.layer_idx,
slice_indices=(end_step - q_len, end_step),
)

key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)

attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states,
key_states,
value_states,
attn_mask=attention_mask,
)

attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.view(bsz, q_len, self.hidden_size)
attn_output = self.o_proj(attn_output)
return attn_output, None, None


class StatefulMistralForCausalLM(torch.nn.Module):
def __init__(self, model_path: str, max_context_size: int = 2048, batch_size: int = 1) -> None:
super().__init__()

# Custom attention implementation for stateful slice update key/value cache, override
# "sdpa" to compliance with transformers.modeling_utils._autoset_attn_implementation
MISTRAL_ATTENTION_CLASSES["sdpa"] = SliceUpdateMistralAttention
self.model = MistralForCausalLM.from_pretrained(model_path)

# Register KV cache buffers to be recognized as Core ML states
config: MistralConfig = self.model.config
self.kv_cache_shape: Tuple[int, ...] = (
config.num_hidden_layers,
batch_size,
config.num_key_value_heads,
max_context_size,
config.hidden_size // config.num_attention_heads,
)
self.kv_cache = SliceUpdateKeyValueCache(shape=self.kv_cache_shape)
self.register_buffer("keyCache", self.kv_cache.k_cache)
self.register_buffer("valueCache", self.kv_cache.v_cache)

@torch.no_grad()
def forward(
self,
input_ids: torch.LongTensor,
causal_mask: torch.Tensor,
) -> torch.Tensor:
# Compute past seen tokens used for updating key/value cache slices
self.kv_cache.past_seen_tokens = causal_mask.shape[-1] - input_ids.shape[-1]
return self.model(
input_ids,
attention_mask=causal_mask,
past_key_values=self.kv_cache,
use_cache=True,
).logits


def export() -> None:
# Construct model from transformers and trace to TorchScript
max_context_size: int = 2048
torch_model = StatefulMistralForCausalLM(MODEL_ID, max_context_size=max_context_size)
torch_model.eval()
input_ids: torch.Tensor = torch.zeros((1, 2), dtype=torch.int32)
causal_mask: torch.Tensor = torch.zeros((1, 1, 2, 5), dtype=torch.float32)
traced_model = torch.jit.trace(torch_model, [input_ids, causal_mask])
kv_cache_shape = torch_model.kv_cache_shape
del torch_model

# Convert traced TorchScript to Core ML format
query_length = ct.RangeDim(lower_bound=1, upper_bound=max_context_size, default=1)
end_step_dim = ct.RangeDim(lower_bound=1, upper_bound=max_context_size, default=1)
inputs: List[ct.TensorType] = [
ct.TensorType(shape=(1, query_length), dtype=np.int32, name="inputIds"),
ct.TensorType(
shape=(1, 1, query_length, end_step_dim),
dtype=np.float16,
name="causalMask",
),
]
outputs: List[ct.TensorType] = [ct.TensorType(dtype=np.float16, name="logits")]
states: List[ct.StateType] = [
ct.StateType(
wrapped_type=ct.TensorType(shape=kv_cache_shape, dtype=np.float16),
name="keyCache",
),
ct.StateType(
wrapped_type=ct.TensorType(shape=kv_cache_shape, dtype=np.float16),
name="valueCache",
),
]

# Convert model with FP16 precision
mlmodel_fp16: ct.MLModel = ct.convert(
traced_model,
inputs=inputs,
outputs=outputs,
states=states,
minimum_deployment_target=ct.target.iOS18,
skip_model_load=True,
)
del traced_model

# Block-wise quantize model weights to int4
op_config = ct.optimize.coreml.OpLinearQuantizerConfig(
mode="linear_symmetric",
dtype="int4",
granularity="per_block",
block_size=32,
)
config = ct.optimize.coreml.OptimizationConfig(global_config=op_config)
mlmodel_int4 = ct.optimize.coreml.linear_quantize_weights(mlmodel_fp16, config=config)
mlmodel_int4._spec.description.metadata.userDefined.update({METADATA_TOKENIZER: MODEL_ID})
del mlmodel_fp16
mlmodel_int4.save("StatefulMistral7BInstructInt4.mlpackage")


if __name__ == "__main__":
export()
99 changes: 99 additions & 0 deletions Examples/Mistral7B/generate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
# /// script
# requires-python = ">=3.10"
# dependencies = [
# "coremltools",
# "numpy",
# "sentencepiece",
# "torch",
# "tqdm",
# "transformers",
# ]
# ///
import argparse
from typing import Dict, Generator, List, Tuple

import numpy as np
from coremltools.models import MLModel
from transformers import AutoTokenizer

from export import METADATA_TOKENIZER


def load(model_path: str) -> Tuple[MLModel, AutoTokenizer]:
"""Load a Core ML model and corresponding tokenizer."""
model: MLModel = MLModel(model_path)
description = model.get_spec().description
if METADATA_TOKENIZER not in description.metadata.userDefined:
raise ValueError("Model metadata does not contain tokenizer path.")
tokenizer_path: str = description.metadata.userDefined[METADATA_TOKENIZER]
tokenizer: AutoTokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
return model, tokenizer


def get_next_token(model: MLModel, prompt_tokens: np.ndarray) -> Generator[int, None, None]:
"""Generate a sequence of tokens with naive greedy decoding."""

def sample(logits: np.ndarray) -> int:
"""Perform greedy decoding on the logits array to get the next token."""
return int(np.argmax(logits[0][-1], axis=-1))

def inference(model: MLModel, input_ids: np.ndarray, num_past_tokens: int) -> np.ndarray:
"""Perform inference with the given model and input data."""
causal_mask: np.ndarray = np.triu(
np.full(
(1, 1, input_ids.shape[-1], num_past_tokens + input_ids.shape[-1]),
fill_value=-np.inf if num_past_tokens == 0 else 0,
),
k=1,
).astype(np.float16)
outputs: Dict[str, np.ndarray] = model.predict(
data={"inputIds": input_ids, "causalMask": causal_mask},
state=kv_cache_state,
)
return outputs["logits"]

kv_cache_state = model.make_state()
logits: np.ndarray = inference(model, input_ids=prompt_tokens, num_past_tokens=0)
token: int = sample(logits=logits)
num_past_tokens: int = prompt_tokens.shape[-1]

while True:
yield token
logits: np.ndarray = inference(
model,
input_ids=np.array([[token]], dtype=np.int32),
num_past_tokens=num_past_tokens,
)
token: int = sample(logits=logits)
num_past_tokens += 1


def generate(
model: MLModel,
prompt: str,
tokenizer: AutoTokenizer,
max_new_tokens: int,
) -> str:
prompt_tokens: np.ndarray = tokenizer(prompt, return_tensors="np").input_ids
extend_tokens: List[int] = []
for i, token in enumerate(get_next_token(model, prompt_tokens=prompt_tokens.astype(np.int32))):
if token == tokenizer.eos_token_id or i == max_new_tokens:
break
extend_tokens.append(token)
return tokenizer.decode(prompt_tokens[0].tolist() + extend_tokens)


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("model_path", type=str)
parser.add_argument("--prompt", type=str, default="Hello")
parser.add_argument("--max_new_tokens", type=int, default=128)
args = parser.parse_args()
model, tokenizer = load(args.model_path)
extend_text: str = generate(
model,
prompt=args.prompt,
tokenizer=tokenizer,
max_new_tokens=args.max_new_tokens,
)
print(extend_text)
6 changes: 6 additions & 0 deletions Examples/Mistral7B/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
coremltools
numpy
torch
tqdm
transformers
sentencepiece
24 changes: 24 additions & 0 deletions Examples/transformers-cli/Package.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
// swift-tools-version: 6.2
// The swift-tools-version declares the minimum version of Swift required to build this package.

import PackageDescription

let package = Package(
name: "transformers-cli",
platforms: [.iOS(.v18), .macOS(.v15)],
dependencies: [
.package(path: "../.."),
// If you copy this manifest as a template, use the following line instead
//.package(url: "https://github.com/huggingface/swift-transformers", from: "1.0.0"),
.package(url: "https://github.com/apple/swift-argument-parser", from: "1.3.0"),
],
targets: [
.executableTarget(
name: "transformers-cli",
dependencies: [
.product(name: "Transformers", package: "swift-transformers"),
.product(name: "ArgumentParser", package: "swift-argument-parser"),
]
)
]
)
Loading