-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathprune.py
65 lines (56 loc) · 2.66 KB
/
prune.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import torch
from torch import nn
import numpy as np
def prune_linear(model, all_linear_data, threshold):
pruning_map = all_linear_data.query(f"SmallestAverageMagnitude <= {threshold}")
pruning_map = pruning_map.assign(
PruneBias=pruning_map["AverageMagnitude"] <= threshold
)
target_layers = pruning_map["Layer"].unique()
for name, m in model.named_modules():
if isinstance(m, nn.Linear) and (name in target_layers):
df = pruning_map.loc[pruning_map["Layer"] == name]
for _, row in df.iterrows():
neuron_type = row["NeuronType"]
col_index = row["ColumnIndex"]
# Prune the weights
if neuron_type == "linear_input":
# Zero out the corresponding row
m.weight.data[:, col_index] = 0
elif neuron_type == "linear_output":
# Zero out the corresponding column
m.weight.data[col_index, :] = 0
if row["PruneBias"]:
m.bias.data[col_index] = 0
# Return the pruned model
return model
def prune_activations(model, all_activation_data, threshold, mapping):
# for opt model 1.3 mapping=['activation_fn', 'fc1']
# for phi-1.3 mapping=['act', 'fc1']
pruning_map = all_activation_data.query(f"AverageMagnitude <= {threshold}")
target_layers = [
s.replace(mapping[0], mapping[1]) for s in pruning_map["Layer"].unique()
]
for name, m in model.named_modules():
if isinstance(m, nn.Linear) and (name in target_layers):
df = pruning_map.loc[
pruning_map["Layer"] == name.replace(mapping[1], mapping[0])
]
for _, row in df.iterrows():
if row["AverageMagnitude"] <= threshold:
# Zero out the corresponding column of weights
m.weight.data[row["ColumnIndex"], :] = 0
# Zero out the corresponding column in bias
m.bias.data[row["ColumnIndex"]] = 0
return model
def prune_embedding(model, normalized_counts, special_tokens, threshold, names):
# Note this will slightly change perplexity as lm_head is automatically handled
normalized_counts_s = normalized_counts
normalized_counts_s[list(special_tokens)] = np.inf
pruning_map = torch.where(normalized_counts_s <= threshold)[0]
for name, m in model.named_modules():
if isinstance(m, nn.Embedding) and (name in names):
m.weight.data[pruning_map, :] = 0
if isinstance(m, nn.Linear) and (name in names):
m.weight.data[pruning_map, :] = 0
return model