-
Notifications
You must be signed in to change notification settings - Fork 9
/
Copy pathweights.py
58 lines (49 loc) · 1.53 KB
/
weights.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
from typing import List, NamedTuple
import mlx
import mlx.nn
import mlx.core as mx
from pathlib import Path
class LayerWeights(NamedTuple):
wq: mx.array
wk: mx.array
wv: mx.array
wo: mx.array
w1: mx.array
w2: mx.array
w3: mx.array
ffn_norm: mx.array
attention_norm: mx.array
class XfmrWeights(NamedTuple):
tok_embeddings: mx.array
norm: mx.array
output: mx.array
layer_weights: List[LayerWeights]
def load_weights(ckpt_dir: Path, n_layers: int = 16):
"""
MLX will use metal gpu by default
"""
w = {}
layer_weights = []
for file in ckpt_dir.glob("*.npy"):
name = '.'.join(str(file).split('/')[-1].split('.')[:-1])
weight = mx.load(str(file))
w[name] = weight
for i in range(n_layers):
layer_weights.append(LayerWeights(
wq=w[f'layers.{i}.attention.wq.weight'],
wk=w[f'layers.{i}.attention.wk.weight'],
wv=w[f'layers.{i}.attention.wv.weight'],
wo=w[f'layers.{i}.attention.wo.weight'],
w1=w[f'layers.{i}.feed_forward.w1.weight'],
w2=w[f'layers.{i}.feed_forward.w2.weight'],
w3=w[f'layers.{i}.feed_forward.w3.weight'],
ffn_norm=w[f'layers.{i}.ffn_norm.weight'],
attention_norm=w[f'layers.{i}.attention_norm.weight'],
))
xfmr_weights = XfmrWeights(
tok_embeddings=w['tok_embeddings.weight'],
norm=w['norm.weight'],
output=w['output.weight'],
layer_weights=layer_weights
)
return xfmr_weights