forked from mit-han-lab/nunchaku
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathconvert_ckpt.py
103 lines (90 loc) · 4.4 KB
/
convert_ckpt.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import argparse
import os
import torch
from safetensors.torch import save_file
from tqdm import tqdm
def get_args() -> argparse.Namespace:
parser = argparse.ArgumentParser()
parser.add_argument("input_path", type=str, help="Path to checkpoint file")
parser.add_argument(
"-o", "--output-path", type=str, help="Path to save the output checkpoint file", default="output.safetensors"
)
args = parser.parse_args()
return args
def swap_scale_shift(weight: torch.Tensor) -> torch.Tensor:
shift, scale = weight.chunk(2, dim=0)
new_weight = torch.cat([scale, shift], dim=0)
return new_weight
def main():
args = get_args()
original_state_dict = torch.load(args.input_path, map_location="cpu")
new_state_dict = {
"transformer.x_embedder.weight": original_state_dict["img_in"]["weight"],
"transformer.x_embedder.bias": original_state_dict["img_in"]["bias"],
"transformer.norm_out.linear.weight": swap_scale_shift(
original_state_dict["final_layer"]["adaLN_modulation.1.weight"]
),
"transformer.norm_out.linear.bias": swap_scale_shift(
original_state_dict["final_layer"]["adaLN_modulation.1.bias"]
),
"transformer.proj_out.weight": original_state_dict["final_layer"]["linear.weight"],
"transformer.proj_out.bias": original_state_dict["final_layer"]["linear.bias"],
}
original_state_dict.pop("img_in")
original_state_dict.pop("final_layer")
original_lora_state_dict = original_state_dict["lora"]
for k, v in tqdm(original_lora_state_dict.items()):
if "double_blocks" in k:
new_k = k.replace("double_blocks", "transformer.transformer_blocks").replace(".default", "")
if "qkv" in new_k:
for i, p in enumerate(["q", "k", "v"]):
if "lora_A" in new_k:
# Copy the tensor
new_k2 = new_k.replace("img_attn.qkv", f"attn.to_{p}")
new_k2 = new_k2.replace("txt_attn.qkv", f"attn.add_{p}_proj")
new_state_dict[new_k2] = v.clone()
else:
assert "lora_B" in new_k
assert v.shape[0] % 3 == 0
chunk_size = v.shape[0] // 3
new_k2 = new_k.replace("img_attn.qkv", f"attn.to_{p}")
new_k2 = new_k2.replace("txt_attn.qkv", f"attn.add_{p}_proj")
new_state_dict[new_k2] = v[i * chunk_size : (i + 1) * chunk_size]
else:
new_k = new_k.replace("img_mod.lin", "norm1.linear")
new_k = new_k.replace("txt_mod.lin", "norm1_context.linear")
new_k = new_k.replace("img_mlp.0", "ff.net.0.proj")
new_k = new_k.replace("img_mlp.2", "ff.net.2")
new_state_dict[new_k] = v
else:
assert "single_blocks" in k
new_k = k.replace("single_blocks", "transformer.single_transformer_blocks").replace(".default", "")
if "linear1" in k:
start = 0
for i, p in enumerate(["q", "k", "v", "i"]):
if "lora_A" in new_k:
if p == "i":
new_k2 = new_k.replace("linear1", "proj_mlp")
else:
new_k2 = new_k.replace("linear1", f"attn.to_{p}")
new_state_dict[new_k2] = v.clone()
else:
if p == "i":
new_k2 = new_k.replace("linear1", "proj_mlp")
else:
new_k2 = new_k.replace("linear1", f"attn.to_{p}")
chunk_size = 12288 if p == "i" else 3072
new_state_dict[new_k2] = v[start : start + chunk_size]
start += chunk_size
elif "linear2" in k:
new_k = new_k.replace("linear2", "proj_out")
new_k = new_k.replace("modulation_lin", ".norm.linear")
new_state_dict[new_k] = v
else:
assert "modulation.lin" in k
new_k = new_k.replace("modulation.lin", "norm.linear")
new_state_dict[new_k] = v
os.makedirs(os.path.dirname(os.path.abspath(args.output_path)), exist_ok=True)
save_file(new_state_dict, args.output_path)
if __name__ == "__main__":
main()