-
Notifications
You must be signed in to change notification settings - Fork 2
/
ash.py
120 lines (86 loc) · 2.78 KB
/
ash.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import os
import numpy as np
import torch
import torch.nn.functional as F
def get_msp_score(logits):
scores = np.max(F.softmax(logits, dim=1).detach().cpu().numpy(), axis=1)
return scores
def get_energy_score(logits):
scores = torch.logsumexp(logits.data.cpu(), dim=1).numpy()
return scores
def get_score(logits, method):
if method == "msp":
return get_msp_score(logits)
if method == "energy":
return get_energy_score(logits)
exit('Unsupported scoring method')
def ash_b(x, percentile=65):
assert x.dim() == 4
assert 0 <= percentile <= 100
b, c, h, w = x.shape
# calculate the sum of the input per sample
s1 = x.sum(dim=[1, 2, 3])
n = x.shape[1:].numel()
k = n - int(np.round(n * percentile / 100.0))
t = x.view((b, c * h * w))
v, i = torch.topk(t, k, dim=1)
fill = s1 / k
fill = fill.unsqueeze(dim=1).expand(v.shape)
t.zero_().scatter_(dim=1, index=i, src=fill)
return x
def ash_p(x, percentile=65):
assert x.dim() == 4
assert 0 <= percentile <= 100
b, c, h, w = x.shape
n = x.shape[1:].numel()
k = n - int(np.round(n * percentile / 100.0))
t = x.view((b, c * h * w))
v, i = torch.topk(t, k, dim=1)
t.zero_().scatter_(dim=1, index=i, src=v)
return x
def ash_s(x, percentile=65):
assert x.dim() == 4
assert 0 <= percentile <= 100
b, c, h, w = x.shape
# calculate the sum of the input per sample
s1 = x.sum(dim=[1, 2, 3])
n = x.shape[1:].numel()
k = n - int(np.round(n * percentile / 100.0))
t = x.view((b, c * h * w))
v, i = torch.topk(t, k, dim=1)
t.zero_().scatter_(dim=1, index=i, src=v)
# calculate new sum of the input per sample after pruning
s2 = x.sum(dim=[1, 2, 3])
# apply sharpening
scale = s1 / s2
x = x * torch.exp(scale[:, None, None, None])
return x
def ash_rand(x, percentile=65, r1=0, r2=10):
assert x.dim() == 4
assert 0 <= percentile <= 100
b, c, h, w = x.shape
n = x.shape[1:].numel()
k = n - int(np.round(n * percentile / 100.0))
t = x.view((b, c * h * w))
v, i = torch.topk(t, k, dim=1)
v = v.uniform_(r1, r2)
t.zero_().scatter_(dim=1, index=i, src=v)
return x
def react(x, threshold):
x = x.clip(max=threshold)
return x
def react_and_ash(x, clip_threshold, pruning_percentile):
x = x.clip(max=clip_threshold)
x = ash_s(x, pruning_percentile)
return x
def apply_ash(x, method):
if method.startswith('react_and_ash@'):
[fn, t, p] = method.split('@')
return eval(fn)(x, float(t), int(p))
if method.startswith('react@'):
[fn, t] = method.split('@')
return eval(fn)(x, float(t))
if method.startswith('ash'):
[fn, p] = method.split('@')
return eval(fn)(x, int(p))
return x