Skip to content

Commit f77536f

Browse files
FL33TW00Dpcuenca1duoJosh Newnhamjunpeiz
committed
Mistral work
Co-authored-by: Pedro Cuenca <pedro@huggingface.co> Co-authored-by: Yuduo Wu <yuduo@apple.com> Co-authored-by: Josh Newnham <jnewnham@apple.com> Co-authored-by: Junpei Zhou <junpei_zhou@apple.com>
2 parents e72d032 + 5778901 commit f77536f

25 files changed

+1936
-943
lines changed

Examples/Mistral7B/README.md

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
### Export Mistral 7B Instruct v0.3
2+
3+
```shell
4+
✗ python export.py
5+
6+
Loading checkpoint shards: 100%|███████████████████████████| 3/3 [00:12<00:00, 4.11s/it]
7+
Converting PyTorch Frontend ==> MIL Ops: 100%|███| 5575/5575 [00:02<00:00, 2440.66 ops/s]
8+
Running MIL frontend_pytorch pipeline: 100%|██████████| 5/5 [00:00<00:00, 7.12 passes/s]
9+
Running MIL default pipeline: 100%|█████████████████| 79/79 [02:36<00:00, 1.98s/ passes]
10+
Running MIL backend_mlprogram pipeline: 100%|███████| 12/12 [00:00<00:00, 22.90 passes/s]
11+
Running compression: 100%|███████████████████████████| 296/296 [03:04<00:00, 1.60 ops/s]
12+
...
13+
```
14+
15+
### Generate Text
16+
17+
```shell
18+
✗ swift run transformers "Best recommendations for a place to visit in Paris in August 2024:" --max-length 128 StatefulMistral7BInstructInt4.mlpackage
19+
20+
Best recommendations for a place to visit in Paris in August 2024:
21+
22+
1. Palace of Versailles: This iconic palace is a must-visit. It's a short train ride from Paris and offers a glimpse into the opulence of the French monarchy.
23+
24+
2. Eiffel Tower: No trip to Paris is complete without a visit to the Eiffel Tower. You can take an elevator ride to the top for a stunning view of the city.
25+
26+
3. Louvre Museum: Home to thousands of works of art, including the Mona Lisa and the Winged Victory of Samothrace, the Louvre is a cultural treasure.
27+
```

Examples/Mistral7B/export.py

Lines changed: 215 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,215 @@
1+
import logging
2+
import os
3+
import warnings
4+
from typing import List, Optional, Tuple
5+
6+
import coremltools as ct
7+
import numpy as np
8+
import torch
9+
from transformers.cache_utils import Cache
10+
from transformers.models.mistral.modeling_mistral import (
11+
MISTRAL_ATTENTION_CLASSES,
12+
MistralAttention,
13+
MistralConfig,
14+
MistralForCausalLM,
15+
apply_rotary_pos_emb,
16+
repeat_kv,
17+
)
18+
19+
warnings.filterwarnings("ignore")
20+
logging.getLogger("coremltools").setLevel(logging.ERROR)
21+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
22+
23+
# https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.3
24+
MODEL_ID: str = "mistralai/Mistral-7B-Instruct-v0.3"
25+
METADATA_TOKENIZER: str = "co.huggingface.exporters.name"
26+
27+
28+
class SliceUpdateKeyValueCache(Cache):
29+
def __init__(
30+
self,
31+
shape: Tuple[int, ...],
32+
device="cpu",
33+
dtype=torch.float32,
34+
) -> None:
35+
"""KV cache of shape (#layers, batch_size, #kv_heads, context_size, head_dim)."""
36+
super().__init__()
37+
self.past_seen_tokens: int = 0
38+
self.k_cache: torch.Tensor = torch.zeros(shape, dtype=dtype, device=device)
39+
self.v_cache: torch.Tensor = torch.zeros(shape, dtype=dtype, device=device)
40+
41+
def update(
42+
self,
43+
k_state: torch.Tensor,
44+
v_state: torch.Tensor,
45+
layer_idx: int,
46+
slice_indices: torch.LongTensor,
47+
) -> Tuple[torch.Tensor, torch.Tensor]:
48+
"""
49+
Update key/value cache tensors for slice [slice_indices[0], slice_indices[1]).
50+
Return slice of key/value cache tensors from [0, slice_indices[1]).
51+
"""
52+
if len(slice_indices) != 2:
53+
raise ValueError(f"Expect tuple of integers [start, end), got {slice_indices=}.")
54+
begin, end = slice_indices
55+
self.k_cache[layer_idx, :, : k_state.shape[1], begin:end, :] = k_state
56+
self.v_cache[layer_idx, :, : v_state.shape[1], begin:end, :] = v_state
57+
k_cache: torch.Tensor = self.k_cache[layer_idx, :, :, :end, :]
58+
v_cache: torch.Tensor = self.v_cache[layer_idx, :, :, :end, :]
59+
return k_cache, v_cache
60+
61+
def get_seq_length(self, _: int | None = 0) -> int:
62+
"""Get the sequence length of the cache."""
63+
return self.past_seen_tokens
64+
65+
66+
class SliceUpdateMistralAttention(MistralAttention):
67+
def __init__(self, config: MistralConfig, layer_idx: Optional[int] = None):
68+
super().__init__(config=config, layer_idx=layer_idx)
69+
70+
@torch.no_grad()
71+
def forward(
72+
self,
73+
hidden_states: torch.Tensor,
74+
attention_mask: torch.Tensor,
75+
position_ids: Optional[torch.LongTensor] = None,
76+
past_key_value: Optional[Cache] = None,
77+
**kwargs,
78+
) -> Tuple[torch.Tensor | None, ...]:
79+
bsz, q_len, _ = hidden_states.size()
80+
81+
query_states = self.q_proj(hidden_states)
82+
key_states = self.k_proj(hidden_states)
83+
value_states = self.v_proj(hidden_states)
84+
85+
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
86+
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(
87+
1, 2
88+
)
89+
value_states = value_states.view(
90+
bsz, q_len, self.num_key_value_heads, self.head_dim
91+
).transpose(1, 2)
92+
93+
cos, sin = self.rotary_emb(value_states, position_ids)
94+
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
95+
96+
# Slice update key/value cache
97+
end_step = attention_mask.shape[-1]
98+
key_states, value_states = past_key_value.update(
99+
key_states,
100+
value_states,
101+
self.layer_idx,
102+
slice_indices=(end_step - q_len, end_step),
103+
)
104+
105+
key_states = repeat_kv(key_states, self.num_key_value_groups)
106+
value_states = repeat_kv(value_states, self.num_key_value_groups)
107+
108+
attn_output = torch.nn.functional.scaled_dot_product_attention(
109+
query_states,
110+
key_states,
111+
value_states,
112+
attn_mask=attention_mask,
113+
)
114+
115+
attn_output = attn_output.transpose(1, 2).contiguous()
116+
attn_output = attn_output.view(bsz, q_len, self.hidden_size)
117+
attn_output = self.o_proj(attn_output)
118+
return attn_output, None, None
119+
120+
121+
class StatefulMistralForCausalLM(torch.nn.Module):
122+
def __init__(self, model_path: str, max_context_size: int = 2048, batch_size: int = 1) -> None:
123+
super().__init__()
124+
125+
# Custom attention implementation for stateful slice update key/value cache, override
126+
# "sdpa" to compliance with transformers.modeling_utils._autoset_attn_implementation
127+
MISTRAL_ATTENTION_CLASSES["sdpa"] = SliceUpdateMistralAttention
128+
self.model = MistralForCausalLM.from_pretrained(model_path)
129+
130+
# Register KV cache buffers to be recognized as Core ML states
131+
config: MistralConfig = self.model.config
132+
self.kv_cache_shape: Tuple[int, ...] = (
133+
config.num_hidden_layers,
134+
batch_size,
135+
config.num_key_value_heads,
136+
max_context_size,
137+
config.hidden_size // config.num_attention_heads,
138+
)
139+
self.kv_cache = SliceUpdateKeyValueCache(shape=self.kv_cache_shape)
140+
self.register_buffer("keyCache", self.kv_cache.k_cache)
141+
self.register_buffer("valueCache", self.kv_cache.v_cache)
142+
143+
@torch.no_grad()
144+
def forward(
145+
self,
146+
input_ids: torch.LongTensor,
147+
causal_mask: torch.Tensor,
148+
) -> torch.Tensor:
149+
# Compute past seen tokens used for updating key/value cache slices
150+
self.kv_cache.past_seen_tokens = causal_mask.shape[-1] - input_ids.shape[-1]
151+
return self.model(
152+
input_ids,
153+
attention_mask=causal_mask,
154+
past_key_values=self.kv_cache,
155+
use_cache=True,
156+
).logits
157+
158+
159+
def export() -> None:
160+
# Construct model from transformers and trace to TorchScript
161+
max_context_size: int = 2048
162+
torch_model = StatefulMistralForCausalLM(MODEL_ID, max_context_size=max_context_size)
163+
torch_model.eval()
164+
input_ids: torch.Tensor = torch.zeros((1, 2), dtype=torch.int32)
165+
causal_mask: torch.Tensor = torch.zeros((1, 1, 2, 5), dtype=torch.float32)
166+
traced_model = torch.jit.trace(torch_model, [input_ids, causal_mask])
167+
168+
# Convert traced TorchScript to Core ML format
169+
query_length = ct.RangeDim(lower_bound=1, upper_bound=max_context_size, default=1)
170+
end_step_dim = ct.RangeDim(lower_bound=1, upper_bound=max_context_size, default=1)
171+
inputs: List[ct.TensorType] = [
172+
ct.TensorType(shape=(1, query_length), dtype=np.int32, name="inputIds"),
173+
ct.TensorType(
174+
shape=(1, 1, query_length, end_step_dim),
175+
dtype=np.float16,
176+
name="causalMask",
177+
),
178+
]
179+
outputs: List[ct.TensorType] = [ct.TensorType(dtype=np.float16, name="logits")]
180+
states: List[ct.StateType] = [
181+
ct.StateType(
182+
wrapped_type=ct.TensorType(shape=torch_model.kv_cache_shape, dtype=np.float16),
183+
name="keyCache",
184+
),
185+
ct.StateType(
186+
wrapped_type=ct.TensorType(shape=torch_model.kv_cache_shape, dtype=np.float16),
187+
name="valueCache",
188+
),
189+
]
190+
191+
# Convert model with FP16 precision
192+
mlmodel_fp16: ct.MLModel = ct.convert(
193+
traced_model,
194+
inputs=inputs,
195+
outputs=outputs,
196+
states=states,
197+
minimum_deployment_target=ct.target.iOS18,
198+
skip_model_load=True,
199+
)
200+
201+
# Block-wise quantize model weights to int4
202+
op_config = ct.optimize.coreml.OpLinearQuantizerConfig(
203+
mode="linear_symmetric",
204+
dtype="int4",
205+
granularity="per_block",
206+
block_size=32,
207+
)
208+
config = ct.optimize.coreml.OptimizationConfig(global_config=op_config)
209+
mlmodel_int4 = ct.optimize.coreml.linear_quantize_weights(mlmodel_fp16, config=config)
210+
mlmodel_int4._spec.description.metadata.userDefined.update({METADATA_TOKENIZER: MODEL_ID})
211+
mlmodel_int4.save("StatefulMistral7BInstructInt4.mlpackage")
212+
213+
214+
if __name__ == "__main__":
215+
export()

Examples/Mistral7B/generate.py

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
import argparse
2+
from typing import Dict, Generator, List, Tuple
3+
4+
import numpy as np
5+
from coremltools.models import MLModel
6+
from transformers import AutoTokenizer
7+
8+
from export import METADATA_TOKENIZER
9+
10+
11+
def load(model_path: str) -> Tuple[MLModel, AutoTokenizer]:
12+
"""Load a Core ML model and corresponding tokenizer."""
13+
model: MLModel = MLModel(model_path)
14+
description = model.get_spec().description
15+
if METADATA_TOKENIZER not in description.metadata.userDefined:
16+
raise ValueError("Model metadata does not contain tokenizer path.")
17+
tokenizer_path: str = description.metadata.userDefined[METADATA_TOKENIZER]
18+
tokenizer: AutoTokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
19+
return model, tokenizer
20+
21+
22+
def get_next_token(model: MLModel, prompt_tokens: np.ndarray) -> Generator[int, None, None]:
23+
"""Generate a sequence of tokens with naive greedy decoding."""
24+
25+
def sample(logits: np.ndarray) -> int:
26+
"""Perform greedy decoding on the logits array to get the next token."""
27+
return int(np.argmax(logits[0][-1], axis=-1))
28+
29+
def inference(model: MLModel, input_ids: np.ndarray, num_past_tokens: int) -> np.ndarray:
30+
"""Perform inference with the given model and input data."""
31+
causal_mask: np.ndarray = np.triu(
32+
np.full(
33+
(1, 1, input_ids.shape[-1], num_past_tokens + input_ids.shape[-1]),
34+
fill_value=-np.inf if num_past_tokens == 0 else 0,
35+
),
36+
k=1,
37+
).astype(np.float16)
38+
outputs: Dict[str, np.ndarray] = model.predict(
39+
data={"inputIds": input_ids, "causalMask": causal_mask},
40+
state=kv_cache_state,
41+
)
42+
return outputs["logits"]
43+
44+
kv_cache_state = model.make_state()
45+
logits: np.ndarray = inference(model, input_ids=prompt_tokens, num_past_tokens=0)
46+
token: int = sample(logits=logits)
47+
num_past_tokens: int = prompt_tokens.shape[-1]
48+
49+
while True:
50+
yield token
51+
logits: np.ndarray = inference(
52+
model,
53+
input_ids=np.array([[token]], dtype=np.int32),
54+
num_past_tokens=num_past_tokens,
55+
)
56+
token: int = sample(logits=logits)
57+
num_past_tokens += 1
58+
59+
60+
def generate(
61+
model: MLModel,
62+
prompt: str,
63+
tokenizer: AutoTokenizer,
64+
max_new_tokens: int,
65+
) -> str:
66+
prompt_tokens: np.ndarray = tokenizer(prompt, return_tensors="np").input_ids
67+
extend_tokens: List[int] = []
68+
for i, token in enumerate(get_next_token(model, prompt_tokens=prompt_tokens.astype(np.int32))):
69+
if token == tokenizer.eos_token_id or i == max_new_tokens:
70+
break
71+
extend_tokens.append(token)
72+
return tokenizer.decode(prompt_tokens[0].tolist() + extend_tokens)
73+
74+
75+
if __name__ == "__main__":
76+
parser = argparse.ArgumentParser()
77+
parser.add_argument("model_path", type=str)
78+
parser.add_argument("--prompt", type=str, default="Hello")
79+
parser.add_argument("--max_new_tokens", type=int, default=128)
80+
args = parser.parse_args()
81+
model, tokenizer = load(args.model_path)
82+
extend_text: str = generate(
83+
model,
84+
prompt=args.prompt,
85+
tokenizer=tokenizer,
86+
max_new_tokens=args.max_new_tokens,
87+
)
88+
print(extend_text)
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
coremltools==8.0b1
2+
numpy==1.26.4
3+
torch==2.3.1
4+
tqdm==4.66.4
5+
transformers==4.42.3
6+
sentencepiece==0.2.0

Package.swift

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ import PackageDescription
55

66
let package = Package(
77
name: "swift-transformers",
8-
platforms: [.iOS(.v16), .macOS(.v13)],
8+
platforms: [.iOS("18.0"), .macOS("15.0")],
99
products: [
1010
.library(name: "Transformers", targets: ["Tokenizers", "Generation", "Models"]),
1111
.executable(name: "transformers", targets: ["TransformersCLI"]),
@@ -23,13 +23,11 @@ let package = Package(
2323
.executableTarget(name: "HubCLI", dependencies: ["Hub", .product(name: "ArgumentParser", package: "swift-argument-parser")]),
2424
.target(name: "Hub", resources: [.process("FallbackConfigs")]),
2525
.target(name: "Tokenizers", dependencies: ["Hub"]),
26-
.target(name: "TensorUtils"),
27-
.target(name: "Generation", dependencies: ["Tokenizers", "TensorUtils"]),
28-
.target(name: "Models", dependencies: ["Tokenizers", "Generation", "TensorUtils"]),
26+
.target(name: "Generation", dependencies: ["Tokenizers"]),
27+
.target(name: "Models", dependencies: ["Tokenizers", "Generation"]),
2928
.testTarget(name: "TokenizersTests", dependencies: ["Tokenizers", "Models", "Hub"], resources: [.process("Resources"), .process("Vocabs")]),
3029
.testTarget(name: "HubTests", dependencies: ["Hub"]),
3130
.testTarget(name: "PreTokenizerTests", dependencies: ["Tokenizers", "Hub"]),
32-
.testTarget(name: "TensorUtilsTests", dependencies: ["TensorUtils"]),
3331
.testTarget(name: "NormalizerTests", dependencies: ["Tokenizers", "Hub"]),
3432
.testTarget(name: "PostProcessorTests", dependencies: ["Tokenizers", "Hub"])
3533
]

README.md

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
1-
# `swift-transformers`
1+
# `swift-transformers` - preview edition
22
[![Unit Tests](https://github.com/huggingface/swift-transformers/actions/workflows/unit-tests.yml/badge.svg)](https://github.com/huggingface/swift-transformers/actions/workflows/unit-tests.yml)
33
[![](https://img.shields.io/endpoint?url=https%3A%2F%2Fswiftpackageindex.com%2Fapi%2Fpackages%2Fhuggingface%2Fswift-transformers%2Fbadge%3Ftype%3Dswift-versions)](https://swiftpackageindex.com/huggingface/swift-transformers)
44
[![](https://img.shields.io/endpoint?url=https%3A%2F%2Fswiftpackageindex.com%2Fapi%2Fpackages%2Fhuggingface%2Fswift-transformers%2Fbadge%3Ftype%3Dplatforms)](https://swiftpackageindex.com/huggingface/swift-transformers)
55

6+
This preview edition of `swift-transformers` features support for cutting edge features released at WWDC 2024, such as Stateful models & Mistral 7B support.
7+
68
This is a collection of utilities to help adopt language models in Swift apps. It tries to follow the Python `transformers` API and abstractions whenever possible, but it also aims to provide an idiomatic Swift interface and does not assume prior familiarity with [`transformers`](https://github.com/huggingface/transformers) or [`tokenizers`](https://github.com/huggingface/tokenizers).
79

810

@@ -57,7 +59,7 @@ To use `swift-transformers` with SwiftPM, you can add this to your `Package.swif
5759

5860
```swift
5961
dependencies: [
60-
.package(url: "https://github.com/huggingface/swift-transformers", from: "0.1.5")
62+
.package(url: "https://github.com/huggingface/swift-transformers", branch: "preview")
6163
]
6264
```
6365

0 commit comments

Comments
 (0)