-
Notifications
You must be signed in to change notification settings - Fork 17
/
DSQConv.py
153 lines (127 loc) · 6.74 KB
/
DSQConv.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
import torch
import torch.nn as nn
import torch.nn.functional as F
class RoundWithGradient(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
delta = torch.max(x) - torch.min(x)
x = (x/delta + 0.5)
return x.round() * 2 - 1
@staticmethod
def backward(ctx, g):
return g
class DSQConv(nn.Conv2d):
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True,
momentum = 0.1,
num_bit = 8, QInput = True, bSetQ = True):
super(DSQConv, self).__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias)
self.num_bit = num_bit
self.quan_input = QInput
self.bit_range = 2**self.num_bit -1
self.is_quan = bSetQ
self.momentum = momentum
if self.is_quan:
# using int32 max/min as init and backprogation to optimization
# Weight
self.uW = nn.Parameter(data = torch.tensor(2 **31 - 1).float())
self.lW = nn.Parameter(data = torch.tensor((-1) * (2**32)).float())
self.register_buffer('running_uw', torch.tensor([self.uW.data])) # init with uw
self.register_buffer('running_lw', torch.tensor([self.lW.data])) # init with lw
self.alphaW = nn.Parameter(data = torch.tensor(0.2).float())
# Bias
if self.bias is not None:
self.uB = nn.Parameter(data = torch.tensor(2 **31 - 1).float())
self.lB = nn.Parameter(data = torch.tensor((-1) * (2**32)).float())
self.register_buffer('running_uB', torch.tensor([self.uB.data]))# init with ub
self.register_buffer('running_lB', torch.tensor([self.lB.data]))# init with lb
self.alphaB = nn.Parameter(data = torch.tensor(0.2).float())
# Activation input
if self.quan_input:
self.uA = nn.Parameter(data = torch.tensor(2 **31 - 1).float())
self.lA = nn.Parameter(data = torch.tensor((-1) * (2**32)).float())
self.register_buffer('running_uA', torch.tensor([self.uA.data])) # init with uA
self.register_buffer('running_lA', torch.tensor([self.lA.data])) # init with lA
self.alphaA = nn.Parameter(data = torch.tensor(0.2).float())
def clipping(self, x, upper, lower):
# clip lower
x = x + F.relu(lower - x)
# clip upper
x = x - F.relu(x - upper)
return x
def phi_function(self, x, mi, alpha, delta):
# alpha should less than 2 or log will be None
# alpha = alpha.clamp(None, 2)
alpha = torch.where(alpha >= 2.0, torch.tensor([2.0]).cuda(), alpha)
s = 1/(1-alpha)
k = (2/alpha - 1).log() * (1/delta)
x = (((x - mi) *k ).tanh()) * s
return x
def sgn(self, x):
x = RoundWithGradient.apply(x)
return x
def dequantize(self, x, lower_bound, delta, interval):
# save mem
x = ((x+1)/2 + interval) * delta + lower_bound
return x
def forward(self, x):
if self.is_quan:
# Weight Part
# moving average
if self.training:
cur_running_lw = self.running_lw.mul(1-self.momentum).add((self.momentum) * self.lW)
cur_running_uw = self.running_uw.mul(1-self.momentum).add((self.momentum) * self.uW)
else:
cur_running_lw = self.running_lw
cur_running_uw = self.running_uw
Qweight = self.clipping(self.weight, cur_running_uw, cur_running_lw)
cur_max = torch.max(Qweight)
cur_min = torch.min(Qweight)
delta = (cur_max - cur_min)/(self.bit_range)
interval = (Qweight - cur_min) //delta
mi = (interval + 0.5) * delta + cur_min
Qweight = self.phi_function(Qweight, mi, self.alphaW, delta)
Qweight = self.sgn(Qweight)
Qweight = self.dequantize(Qweight, cur_min, delta, interval)
Qbias = self.bias
# Bias
if self.bias is not None:
# self.running_lB.mul_(1-self.momentum).add_((self.momentum) * self.lB)
# self.running_uB.mul_(1-self.momentum).add_((self.momentum) * self.uB)
if self.training:
cur_running_lB = self.running_lB.mul(1-self.momentum).add((self.momentum) * self.lB)
cur_running_uB = self.running_uB.mul(1-self.momentum).add((self.momentum) * self.uB)
else:
cur_running_lB = self.running_lB
cur_running_uB = self.running_uB
Qbias = self.clipping(self.bias, cur_running_uB, cur_running_lB)
cur_max = torch.max(Qbias)
cur_min = torch.min(Qbias)
delta = (cur_max - cur_min)/(self.bit_range)
interval = (Qbias - cur_min) //delta
mi = (interval + 0.5) * delta + cur_min
Qbias = self.phi_function(Qbias, mi, self.alphaB, delta)
Qbias = self.sgn(Qbias)
Qbias = self.dequantize(Qbias, cur_min, delta, interval)
# Input(Activation)
Qactivation = x
if self.quan_input:
if self.training:
cur_running_lA = self.running_lA.mul(1-self.momentum).add((self.momentum) * self.lA)
cur_running_uA = self.running_uA.mul(1-self.momentum).add((self.momentum) * self.uA)
else:
cur_running_lA = self.running_lA
cur_running_uA = self.running_uA
Qactivation = self.clipping(x, cur_running_uA, cur_running_lA)
cur_max = torch.max(Qactivation)
cur_min = torch.min(Qactivation)
delta = (cur_max - cur_min)/(self.bit_range)
interval = (Qactivation - cur_min) //delta
mi = (interval + 0.5) * delta + cur_min
Qactivation = self.phi_function(Qactivation, mi, self.alphaA, delta)
Qactivation = self.sgn(Qactivation)
Qactivation = self.dequantize(Qactivation, cur_min, delta, interval)
output = F.conv2d(Qactivation, Qweight, Qbias, self.stride, self.padding, self.dilation, self.groups)
else:
output = F.conv2d(x, self.weight, self.bias, self.stride,
self.padding, self.dilation, self.groups)
return output