forked from clevercool/SQuant
-
Notifications
You must be signed in to change notification settings - Fork 0
/
quant_affine.py
143 lines (124 loc) · 4.95 KB
/
quant_affine.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
# *
# @file Different utility functions
# Copyright (c) Yaohui Cai, Zhewei Yao, Zhen Dong, Amir Gholami
# All rights reserved.
# This file is part of ZeroQ repository.
#
# ZeroQ is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# ZeroQ is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with ZeroQ repository. If not, see <http://www.gnu.org/licenses/>.
# *
import math
import numpy as np
from torch.autograd import Function, Variable
import torch
def clamp(input, min, max, inplace=False):
"""
Clamp tensor input to (min, max).
input: input tensor to be clamped
"""
if inplace:
input.clamp_(min, max)
return input
return torch.clamp(input, min, max)
def linear_quantize(input, scale, zero_point, inplace=False):
"""
Quantize single-precision input tensor to integers with the given scaling factor and zeropoint.
input: single-precision input tensor to be quantized
scale: scaling factor for quantization
zero_pint: shift for quantization
"""
# reshape scale and zeropoint for convolutional weights and activation
if len(input.shape) == 4:
scale = scale.view(-1, 1, 1, 1)
zero_point = zero_point.view(-1, 1, 1, 1)
# reshape scale and zeropoint for linear weights
elif len(input.shape) == 2:
scale = scale.view(-1, 1)
zero_point = zero_point.view(-1, 1)
# mapping single-precision input to integer values with the given scale and zeropoint
if inplace:
input.mul_(scale).sub_(zero_point).round_()
return input
return (scale * input - zero_point)
def linear_dequantize(input, scale, zero_point, inplace=False):
"""
Map integer input tensor to fixed point float point with given scaling factor and zeropoint.
input: integer input tensor to be mapped
scale: scaling factor for quantization
zero_pint: shift for quantization
"""
# reshape scale and zeropoint for convolutional weights and activation
if len(input.shape) == 4:
scale = scale.view(-1, 1, 1, 1)
zero_point = zero_point.view(-1, 1, 1, 1)
# reshape scale and zeropoint for linear weights
elif len(input.shape) == 2:
scale = scale.view(-1, 1)
zero_point = zero_point.view(-1, 1)
# mapping integer input to fixed point float point value with given scaling factor and zeropoint
if inplace:
input.add_(zero_point).div_(scale)
return input
return (input + zero_point) / scale
def asymmetric_linear_quantization_params(num_bits,
saturation_min,
saturation_max,
integral_zero_point=True,
signed=True):
"""
Compute the scaling factor and zeropoint with the given quantization range.
saturation_min: lower bound for quantization range
saturation_max: upper bound for quantization range
"""
n = 2 ** num_bits - 1
scale = n / torch.clamp((saturation_max - saturation_min), min=1e-8)
zero_point = scale * saturation_min
if integral_zero_point:
if isinstance(zero_point, torch.Tensor):
zero_point = zero_point.round()
else:
zero_point = float(round(zero_point))
if signed:
zero_point += 2**(num_bits - 1)
return scale, zero_point
class AsymmetricQuantFunction(Function):
"""
Class to quantize the given floating-point values with given range and bit-setting.
Currently only support inference, but not support back-propagation.
"""
@staticmethod
def forward(ctx, x, k, x_min, x_max):
"""
x: single-precision value to be quantized
k: bit-setting for x
x_min: lower bound for quantization range
x_max=None
"""
signed = x_min < 0
scale, zero_point = asymmetric_linear_quantization_params(
k, x_min, x_max, integral_zero_point = True, signed=signed )
# print(zero_point)
new_quant_x = linear_quantize(x, scale, zero_point, inplace=False).round()
n = 2**(k - 1)
if signed:
new_quant_x = torch.clamp(new_quant_x, -n, n - 1)
else:
new_quant_x = torch.clamp(new_quant_x, 0, 2 * n - 1)
quant_x = linear_dequantize(new_quant_x,
scale,
zero_point,
inplace=False)
return torch.autograd.Variable(quant_x)
@staticmethod
def backward(ctx, grad_output):
raise NotImplementedError