-
Notifications
You must be signed in to change notification settings - Fork 13
/
tool_trim_learnable_sparsity.py
executable file
·64 lines (51 loc) · 2.74 KB
/
tool_trim_learnable_sparsity.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import torch
import argparse
parser = argparse.ArgumentParser(description='Trim Lana checkpoint')
parser.add_argument('--ckpt_dir', type=str, default='output/checkpoints/llama-mask-only/train_iters_2000/ckpt/iter_0002000', help='Input checkpoint')
args = parser.parse_args()
def trim_ckpt(input, output):
ckpt = torch.load(input, map_location='cpu')
new_encoder_state_dict = {}
mask_options = torch.zeros(1, 6, 4, dtype=torch.float32)
mask_options[:, 0, :].data += torch.tensor([1, 1, 0, 0], dtype=torch.float32)
mask_options[:, 1, :].data += torch.tensor([1, 0, 1, 0], dtype=torch.float32)
mask_options[:, 2, :].data += torch.tensor([1, 0, 0, 1], dtype=torch.float32)
mask_options[:, 3, :].data += torch.tensor([0, 1, 1, 0], dtype=torch.float32)
mask_options[:, 4, :].data += torch.tensor([0, 1, 0, 1], dtype=torch.float32)
mask_options[:, 5, :].data += torch.tensor([0, 0, 1, 1], dtype=torch.float32)
for k,v in ckpt['model']['language_model']['encoder'].items():
if '.diff_mask.gate' in k:
gate = ckpt['model']['language_model']['encoder'][k].float()
runtime_mask = ckpt['model']['language_model']['encoder'][k.replace('diff_mask.gate', 'mask')].float()
winner_mask = mask_options[torch.arange(mask_options.shape[0]), gate.argmax(dim=-1)].view(*runtime_mask.shape)
# set the type of winner mask the same as runtime_mask
winner_mask = winner_mask.type_as(runtime_mask)
new_encoder_state_dict[k.replace('diff_mask.gate', 'mask')] = winner_mask
print("save winner mask:", k.replace('diff_mask.gate', 'mask'))
continue
if '.mask' in k: continue
if '.mask_options' in k: continue
new_encoder_state_dict[k] = v
ckpt['model']['language_model']['encoder'] = new_encoder_state_dict
print(ckpt['model']['language_model']['encoder'].keys())
torch.save(ckpt, output)
import os
import glob
# Create output directory
splited_dir = args.ckpt_dir.split('/')
output_dir = os.path.join('/'.join(splited_dir[:-1]), 'release')
print(f"output_dir: {output_dir}")
os.makedirs(output_dir, exist_ok=True)
# Trim the checkpoints
mp_rank_dirs = glob.glob(os.path.join(args.ckpt_dir, "mp_rank_*"))
for mp_rank_dir in mp_rank_dirs:
ckpt_file = os.path.join(mp_rank_dir, "model_optim_rng.pt")
output_file = ckpt_file.replace(args.ckpt_dir, output_dir)
os.makedirs(os.path.dirname(output_file), exist_ok=True)
print(f"Trim {ckpt_file} to {output_file}")
trim_ckpt(ckpt_file, output_file)
# update the latest iteration to "release"
iteration_file = os.path.join( *splited_dir[:-1], 'latest_checkpointed_iteration.txt')
print(iteration_file)
with open(iteration_file, 'w') as f:
f.write("release")