-
-
Notifications
You must be signed in to change notification settings - Fork 45
/
bin_to_st.py
54 lines (45 loc) · 1.96 KB
/
bin_to_st.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import os
import torch
from safetensors.torch import save_file
ckpt = "path_to/pytorch_model.bin"
vista_bin = torch.load(ckpt, map_location="cpu") # only contains model weights
for k in list(vista_bin.keys()): # merge LoRA weights (if exist) for inference
if "adapter_down" in k:
if "q_adapter_down" in k:
up_k = k.replace("q_adapter_down", "q_adapter_up")
pretrain_k = k.replace("q_adapter_down", "to_q")
elif "k_adapter_down" in k:
up_k = k.replace("k_adapter_down", "k_adapter_up")
pretrain_k = k.replace("k_adapter_down", "to_k")
elif "v_adapter_down" in k:
up_k = k.replace("v_adapter_down", "v_adapter_up")
pretrain_k = k.replace("v_adapter_down", "to_v")
else:
up_k = k.replace("out_adapter_down", "out_adapter_up")
if "model_ema" in k:
pretrain_k = k.replace("out_adapter_down", "to_out0")
else:
pretrain_k = k.replace("out_adapter_down", "to_out.0")
lora_weights = vista_bin[up_k] @ vista_bin[k]
del vista_bin[k]
del vista_bin[up_k]
vista_bin[pretrain_k] = vista_bin[pretrain_k] + lora_weights
for k in list(vista_bin.keys()): # remove the prefix
if "_forward_module" in k and "decay" not in k and "num_updates" not in k:
vista_bin[k.replace("_forward_module.", "")] = vista_bin[k]
del vista_bin[k]
for k in list(vista_bin.keys()): # combine EMA weights
if "model_ema" in k:
orig_k = None
for kk in list(vista_bin.keys()):
if "model_ema" not in kk and k[10:] == kk[6:].replace(".", ""):
orig_k = kk
assert orig_k is not None
vista_bin[orig_k] = vista_bin[k]
del vista_bin[k]
print("Replace", orig_k, "with", k)
vista_st = dict()
for k in list(vista_bin.keys()):
vista_st[k] = vista_bin[k]
os.makedirs("ckpts", exist_ok=True)
save_file(vista_st, "ckpts/vista.safetensors")