-
Notifications
You must be signed in to change notification settings - Fork 245
/
Copy pathquant_utils.py
145 lines (119 loc) · 5.63 KB
/
quant_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.
import torch
from torch import Tensor
# https://github.com/TimDettmers/bitsandbytes/blob/dada530149212d64d4b69534716202659ef37ec8/bitsandbytes/functional.py#L339-L391
# NOTE: zero padding is removed so this function can work with 4-bit qmap
def create_dynamic_map(signed=True, max_exponent_bits=7, total_bits=8):
"""
Creates the dynamic quantiztion map.
The dynamic data type is made up of a dynamic exponent and
fraction. As the exponent increase from 0 to -7 the number
of bits available for the fraction shrinks.
This is a generalization of the dynamic type where a certain
number of the bits and be reserved for the linear quantization
region (the fraction). n determines the maximum number of
exponent bits.
For more details see
(8-Bit Approximations for Parallelism in Deep Learning)[https://arxiv.org/abs/1511.04561]
"""
data = []
# these are additional items that come from the case
# where all the exponent bits are zero and no
# indicator bit is present
non_sign_bits = total_bits - (1 if signed else 1)
additional_items = 2 ** (non_sign_bits - max_exponent_bits) - 1
for i in range(max_exponent_bits):
fraction_items = int(
2 ** (i + non_sign_bits - max_exponent_bits) + 1
if signed
else 2 ** (i + non_sign_bits - max_exponent_bits + 1) + 1,
)
boundaries = torch.linspace(0.1, 1, fraction_items)
means = (boundaries[:-1] + boundaries[1:]) / 2.0
data += ((10 ** (-(max_exponent_bits - 1) + i)) * means).tolist()
if signed:
data += (-(10 ** (-(max_exponent_bits - 1) + i)) * means).tolist()
if additional_items > 0:
boundaries = torch.linspace(0.1, 1, additional_items + 1)
means = (boundaries[:-1] + boundaries[1:]) / 2.0
data += ((10 ** (-(max_exponent_bits - 1) + i)) * means).tolist()
if signed:
data += (-(10 ** (-(max_exponent_bits - 1) + i)) * means).tolist()
data.append(0)
data.append(1.0)
assert len(data) == 2**total_bits
data.sort()
return data
def scale_tensor(input: Tensor, block_size: int):
"""Scale tensor so that max(abs(input)) = 1"""
shape = input.shape
# section 2.1 from https://arxiv.org/abs/2110.02861
input = input.view(-1, block_size)
scale = input.abs().amax(-1).clip(1e-12)
input = input / scale.view(-1, 1)
return input.view(shape), scale
def quantize_8bit_with_qmap(input: Tensor, qmap: Tensor):
# GPU-friendly binary search
# https://blog.demofox.org/2017/06/20/simd-gpu-friendly-branchless-binary-search/
codes = torch.where(input >= qmap[128], 128, 0)
codes += torch.where(input >= qmap[codes + 64], 64, 0)
codes += torch.where(input >= qmap[codes + 32], 32, 0)
codes += torch.where(input >= qmap[codes + 16], 16, 0)
codes += torch.where(input >= qmap[codes + 8], 8, 0)
codes += torch.where(input >= qmap[codes + 4], 4, 0)
codes += torch.where(input >= qmap[codes + 2], 2, 0)
codes += torch.where(input >= qmap[codes + 1], 1, 0)
# rounding
codes_up = (codes + 1).clip(max=255)
val_down = qmap[codes]
val_up = qmap[codes_up]
residual = input - val_down
codes = torch.where(residual >= (val_up - val_down) * 0.5, codes_up, codes)
return codes.to(torch.uint8)
def quantize_4bit_with_qmap(input: Tensor, qmap: Tensor):
# GPU-friendly binary search
# https://blog.demofox.org/2017/06/20/simd-gpu-friendly-branchless-binary-search/
codes = torch.where(input >= qmap[8], 8, 0)
codes += torch.where(input >= qmap[codes + 4], 4, 0)
codes += torch.where(input >= qmap[codes + 2], 2, 0)
codes += torch.where(input >= qmap[codes + 1], 1, 0)
# rounding
codes_up = (codes + 1).clip(max=15)
val_down = qmap[codes]
val_up = qmap[codes_up]
residual = input - val_down
codes = torch.where(residual >= (val_up - val_down) * 0.5, codes_up, codes)
return codes.to(torch.uint8)
def dequant_with_qmap(codes: Tensor, qmap: Tensor, scale: Tensor):
# torch.compile() cannot use uint8 as index
out = qmap[codes.int()].view(scale.shape[0], -1) * scale.view(-1, 1)
return out.view(codes.shape)
def _fp32_to_bf16_sr(x_f32: Tensor) -> Tensor:
# For an FP32 number [a31, ..., a16, a15, ..., a0] to be converted to BF16
# - Round towards zero: [a31, ..., a16, 0, ..., 0]
# - Round away from zero: [a31, ..., a16+1, 0, ..., 0]
# (since the value can be negative, we use round towards/away from zero instead of round up/down)
#
# For stochastic rounding, we round away from zero with the probability of
# [a15, ..., a0] / 2^16, where the bit pattern [a15, ..., a0] is interpreted as uint16
#
# we have to use int32 since most arithmetic ops are not implemented for uint32/int16/uint16
rand_16bit = torch.randint(
0, 1 << 16, x_f32.shape, device=x_f32.device, dtype=torch.int32
)
x_f32_bits = x_f32.view(torch.int32)
x_fraction = x_f32_bits & 0xFFFF # lower 16 bits
x_bf16_towards_zero = x_f32_bits & 0xFFFF0000 # upper 16 bits
x_f32_bits = torch.where(
rand_16bit < x_fraction, # this is True with the probability of p_fraction
x_bf16_towards_zero
+ 0x10000, # this might overflow, which will result in UB due to signed integer
x_bf16_towards_zero,
)
# alternative, slightly faster
# x_f32_bits = (x_f32_bits + rand_16bit) & 0xFFFF0000
return x_f32_bits.view(torch.float32).bfloat16()