-
Notifications
You must be signed in to change notification settings - Fork 1.9k
/
Copy pathQuantizedLinear.cs
209 lines (185 loc) · 9.16 KB
/
QuantizedLinear.cs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System;
using Microsoft.ML.GenAI.Core;
using TorchSharp;
using static TorchSharp.torch;
namespace Microsoft.ML.GenAI.Core;
internal class QuantizedLinear : GenAILinear, IQuantizeModule
{
public QuantizedLinear(int inFeatures, int outFeatures, bool hasBias = true, ScalarType dtype = ScalarType.Float32, string? device = null)
: base(inFeatures, outFeatures, hasBias, dtype, device)
{
}
public void Int8()
{
if (this.weight is null)
{
throw new Exception("Weight is not initialized");
}
if (this.weight.device_type != DeviceType.META)
{
// if weight is not on meta device, this means that weight and bias are already loaded
// so we can quantize them in memory
var timer = new System.Diagnostics.Stopwatch();
timer.Start();
// scale and zero point on vector-wise
// scale = 255 / max(weight, axis=1) - min(weight, axis=1)
var scale = 255 / (torch.max(this.weight, 1).values - torch.min(this.weight, 1).values);
// zero point = - scale * min(weight, axis=1) - 128
var zeroPoint = -scale * torch.min(this.weight, 1).values - 128;
// round zero point to nearest integer
zeroPoint = torch.round(zeroPoint).to(torch.int8);
// assert zero point is in range [-128, 127]
//if (torch.any(this.zeroPoint < -128).item<bool>() || torch.any(this.zeroPoint > 127).item<bool>())
//{
// throw new Exception("Zero point is out of range [-128, 127]");
//}
// quantize weight
var eightBitWeight = torch.round(this.weight * scale.view(-1, 1) + zeroPoint.view(-1, 1)).to(torch.int8);
// assert weight is in range [-128, 127]
//if (torch.any(this._8bitWeight < -128).item<bool>() || torch.any(this._8bitWeight > 127).item<bool>())
//{
// throw new Exception("Weight is out of range [-128, 127]");
//}
timer.Stop();
// dispose float32 weight
this.weight.Dispose();
this.weight = null;
this._internal_buffers.Remove("weight");
this.register_buffer("8bit_weight", eightBitWeight);
this.register_buffer("zeroPoint", zeroPoint);
this.register_buffer("scale", scale);
}
else
{
// if weight is on meta device, then we just need to create the placeholder for 8bit_weight, zeroPoint and scale
var eightBitWeight = torch.zeros(this.weight.shape, dtype: torch.int8);
var zeroPoint = torch.zeros(this.weight.shape[0], dtype: torch.int8);
var scale = torch.zeros(this.weight.shape[0], dtype: torch.float32);
this._internal_buffers.Remove("weight");
this.weight = null;
this.register_buffer("8bit_weight", eightBitWeight);
this.register_buffer("zeroPoint", zeroPoint);
this.register_buffer("scale", scale);
}
}
#pragma warning disable MSML_GeneralName // This name should be PascalCased
public override Tensor forward(Tensor input)
#pragma warning restore MSML_GeneralName // This name should be PascalCased
{
if (this._internal_buffers.ContainsKey("weight"))
{
return base.forward(input);
}
else if (this._internal_buffers.ContainsKey("8bit_weight"))
{
// 8bit quantization
using var dispose = torch.NewDisposeScope();
var weight = this.get_buffer("8bit_weight")!.to(ScalarType.Float32);
var zeroPoint = this.get_buffer("zeroPoint")!.to(ScalarType.Float32);
var scale = this.get_buffer("scale")!.to(ScalarType.Float32);
var restoreWeight = (weight - zeroPoint.view(-1, 1)) / scale.view(-1, 1);
// use float32
var result = torch.matmul(input.to(ScalarType.Float32), restoreWeight.T);
if (this.bias is not null)
{
result = result + this.bias.to_type(ScalarType.Float32);
}
//result.Peek("result");
return result.to_type(input.dtype).MoveToOuterDisposeScope();
}
else if (this._internal_buffers.ContainsKey("4bit_weight"))
{
using var dispose = torch.NewDisposeScope();
var weight = this.get_buffer("4bit_weight");
var weightLower = weight! % 16;
var weightUpper = weight! / 16;
weight = torch.cat([weightUpper, weightLower], 0).to(ScalarType.Float32);
weight = weight.view(this._outFeatures, this._inFeatures);
weight -= 8;
var zeroPoint = this.get_buffer("zeroPoint");
var zeroPointLower = zeroPoint! % 16;
var zeroPointUpper = zeroPoint! / 16;
zeroPoint = torch.cat([zeroPointUpper, zeroPointLower], 0).to(ScalarType.Float32);
zeroPoint -= 8;
var scale = this.get_buffer("scale")!.to(ScalarType.Float32);
var restoreWeight = (weight - zeroPoint.view(-1, 1)) / scale.view(-1, 1);
// use float32
var result = torch.matmul(input.to(ScalarType.Float32), restoreWeight.T);
if (this.bias is not null)
{
result = result + this.bias.to_type(ScalarType.Float32);
}
//result.Peek("result");
return result.to_type(input.dtype).MoveToOuterDisposeScope();
}
else
{
throw new Exception("Quantization is not done yet");
}
}
public void Int4()
{
if (this.weight is null)
{
throw new Exception("Weight is not initialized");
}
var placeHolderDim = this._outFeatures / 2 + this._outFeatures % 2;
var fourBitWeightDim = this.weight.size(0) * this.weight.size(1);
var fourBitWeightPlaceHolderDim = Convert.ToInt32(fourBitWeightDim / 2 + fourBitWeightDim % 2);
if (this.weight.device_type != DeviceType.META)
{
using var scope = NewDisposeScope();
var timer = new System.Diagnostics.Stopwatch();
timer.Start();
// scale and zero point on vector-wise
// scale = 15 / max(weight, axis=1) - min(weight, axis=1)
var scale = 15 / (torch.max(this.weight, 1).values - torch.min(this.weight, 1).values);
// zero point = - scale * min(weight, axis=1) - 8
var zeroPoint = -scale * torch.min(this.weight, 1).values - 8;
// round zero point to nearest integer
zeroPoint = torch.round(zeroPoint);
var fourBitWeight = torch.round(this.weight * scale.view(-1, 1) + zeroPoint.view(-1, 1)).to(torch.int8);
zeroPoint = (zeroPoint + 8).to(torch.uint8);
fourBitWeight = (fourBitWeight + 8).view(-1).to(torch.uint8);
// torch doesn't provide int4, so we use int8 as placeholder
// and foreach int8, we save two int4, e.g. 0b1010 -> 0b10, 0b10
var zpPlaceHolder = zeroPoint[..placeHolderDim];
zpPlaceHolder = zpPlaceHolder * 16 + zeroPoint[placeHolderDim..];
// assert zero point is in range [-128, 127]
//if (torch.any(this.zeroPoint < -128).item<bool>() || torch.any(this.zeroPoint > 127).item<bool>())
//{
// throw new Exception("Zero point is out of range [-128, 127]");
//}
// quantize weight
var fourBitWeightPlaceHolder = fourBitWeight[..fourBitWeightPlaceHolderDim];
fourBitWeightPlaceHolder = fourBitWeightPlaceHolder * 16 + fourBitWeight[fourBitWeightPlaceHolderDim..];
// assert weight is in range [-128, 127]
//if (torch.any(this._8bitWeight < -128).item<bool>() || torch.any(this._8bitWeight > 127).item<bool>())
//{
// throw new Exception("Weight is out of range [-128, 127]");
//}
// dispose float32 weight
this.weight.Dispose();
this._internal_buffers.Remove("weight");
this.register_buffer("4bit_weight", fourBitWeightPlaceHolder.MoveToOuterDisposeScope());
this.register_buffer("zeroPoint", zpPlaceHolder.MoveToOuterDisposeScope());
this.register_buffer("scale", scale.MoveToOuterDisposeScope());
timer.Stop();
}
else
{
// if weight is on meta device, then we just need to create the placeholder for 8bit_weight, zeroPoint and scale
var fourBitWeight = torch.zeros(fourBitWeightPlaceHolderDim, dtype: torch.int8);
var zeroPoint = torch.zeros(placeHolderDim, dtype: torch.int8);
var scale = torch.zeros(this.weight.shape[0], dtype: torch.float32);
this._internal_buffers.Remove("weight");
this.weight = null;
this.register_buffer("4bit_weight", fourBitWeight);
this.register_buffer("zeroPoint", zeroPoint);
this.register_buffer("scale", scale);
}
}
}