Skip to content

Commit aa33c46

Browse files
committed
use inplace op for weight_only
Signed-off-by: xin3he <xin3.he@intel.com>
1 parent 5ad4fa3 commit aa33c46

File tree

9 files changed

+124
-114
lines changed

9 files changed

+124
-114
lines changed

neural_compressor/adaptor/torch_utils/autoround/model_wrapper.py

Lines changed: 22 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -127,23 +127,20 @@ def __init__(
127127
dtype=self.float_type,
128128
).to(device),
129129
)
130-
self.scales = self.scales.T
131130
self.register_buffer(
132131
"qweight",
133132
torch.zeros(
134133
(math.ceil(in_features / self.n_pack), out_features),
135134
dtype=self.compression_dtype,
136135
).to(device),
137136
)
138-
self.qweight = self.qweight.T
139137
self.register_buffer(
140138
"qzeros",
141139
torch.zeros(
142140
(math.ceil(self.in_features / self.groupsize), math.ceil(self.out_features / self.n_pack)),
143141
dtype=self.compression_dtype,
144142
).to(device),
145143
)
146-
self.qzeros = self.qzeros.T
147144
self.register_buffer("bias", torch.zeros(self.out_features, dtype=self.float_type).to(device))
148145
else:
149146
self.compression_dtype = compression_dtype
@@ -193,6 +190,10 @@ def __init__(
193190
self.bias = None
194191

195192
def pack(self, int_weight, scale, zp, bias):
193+
if self.use_optimum_format:
194+
self.scales = self.scales.t_().contiguous()
195+
self.qweight = self.qweight.t_().contiguous()
196+
self.qzeros = self.qzeros.t_().contiguous()
196197
int_weight = int_weight.to(self.device)
197198
if self.use_optimum_format and zp is None:
198199
# to avoid overflow
@@ -206,8 +207,8 @@ def pack(self, int_weight, scale, zp, bias):
206207
assert scale.shape == self.scales.shape, "Scale shape is mismatched."
207208
self.scales = scale.type(self.float_type).to(self.device)
208209
if not self.use_optimum_format and self.compression_dim == 0:
209-
int_weight = int_weight.T
210-
self.qweight = self.qweight.T
210+
int_weight = int_weight.t_().contiguous()
211+
self.qweight = self.qweight.t_().contiguous()
211212
origin_shape = int_weight.shape
212213
target_shape = self.qweight.shape
213214
assert origin_shape[0] == target_shape[0], "output channels mismatch, please check."
@@ -223,15 +224,15 @@ def pack(self, int_weight, scale, zp, bias):
223224
tmp[:, e] = tmp[:, e] << (self.bits * e)
224225
self.qweight[:, j] |= tmp[:, e]
225226
if not self.use_optimum_format and self.compression_dim == 0:
226-
self.qweight = self.qweight.T
227+
self.qweight = self.qweight.t_().contiguous()
227228

228229
if zp is not None:
229230
zp = zp.to(self.device)
230231
if self.use_optimum_format:
231232
zp -= 1
232233
if self.use_optimum_format or self.compression_dim == 0:
233-
zp = zp.T
234-
self.qzeros = self.qzeros.T
234+
zp = zp.t_().contiguous()
235+
self.qzeros = self.qzeros.t_().contiguous()
235236
assert hasattr(self, "qzeros"), "zp is not set when initializing."
236237
target_shape = self.qzeros.shape
237238
for j in range(target_shape[1]):
@@ -243,16 +244,16 @@ def pack(self, int_weight, scale, zp, bias):
243244
tmp[:, e] = tmp[:, e] << (self.bits * e)
244245
self.qzeros[:, j] |= tmp[:, e]
245246
if self.use_optimum_format or self.compression_dim == 0:
246-
self.qzeros = self.qzeros.T
247+
self.qzeros = self.qzeros.t_().contiguous()
247248
if self.use_optimum_format:
248-
self.scales = self.scales.T
249-
self.qweight = self.qweight.T
250-
self.qzeros = self.qzeros.T
249+
self.scales = self.scales.t_().contiguous()
250+
self.qweight = self.qweight.t_().contiguous()
251+
self.qzeros = self.qzeros.t_().contiguous()
251252

252253
def recover(self):
253254
logger.debug(f"Recovering {self} weight")
254-
scales = self.scales.T if self.use_optimum_format else self.scales
255-
qweight = self.qweight.T if self.use_optimum_format else self.qweight
255+
scales = self.scales.t_().contiguous() if self.use_optimum_format else self.scales
256+
qweight = self.qweight.t_().contiguous() if self.use_optimum_format else self.qweight
256257

257258
device = scales.device
258259
fp32_weight = torch.zeros(self.out_features, self.in_features, dtype=self.float_type).to(device)
@@ -264,8 +265,8 @@ def recover(self):
264265
# unpack weight
265266
weight = torch.zeros(self.out_features, self.in_features, dtype=weight_dtype).to(device)
266267
if not self.use_optimum_format and self.compression_dim == 0:
267-
weight = weight.T
268-
qweight = qweight.T
268+
weight = weight.t_().contiguous()
269+
qweight = qweight.t_().contiguous()
269270
origin_shape = weight.shape
270271
target_shape = qweight.shape
271272
for j in range(target_shape[1]):
@@ -280,7 +281,7 @@ def recover(self):
280281
tmp &= mask # remove sign bit
281282
weight[:, index] = tmp.type(weight_dtype)
282283
if not self.use_optimum_format and self.compression_dim == 0:
283-
weight = weight.T
284+
weight = weight.t_().contiguous()
284285
if "int" not in self.dtype:
285286
new_weight = torch.zeros(self.out_features, self.in_features).to(device)
286287
for k, v in self.int2float_mapping.items():
@@ -290,10 +291,10 @@ def recover(self):
290291
if hasattr(self, "qzeros"):
291292
zp_dtype = self.compression_dtype # to avoid overflow when weight-zp
292293
zp = torch.zeros(scales.shape, dtype=zp_dtype).to(device)
293-
qzeros = self.qzeros.T if self.use_optimum_format else self.qzeros
294+
qzeros = self.qzeros.t_().contiguous() if self.use_optimum_format else self.qzeros
294295
if self.use_optimum_format or self.compression_dim == 0:
295-
zp = zp.T
296-
qzeros = qzeros.T
296+
zp = zp.t_().contiguous()
297+
qzeros = qzeros.t_().contiguous()
297298
origin_shape = zp.shape
298299
target_shape = qzeros.shape
299300
for j in range(target_shape[1]):
@@ -307,7 +308,7 @@ def recover(self):
307308
tmp &= mask
308309
zp[:, index] = tmp.type(zp_dtype)
309310
if self.use_optimum_format or self.compression_dim == 0:
310-
zp = zp.T
311+
zp = zp.t_().contiguous()
311312
if self.use_optimum_format:
312313
# zp -= 1 may cause zp == -1, after recover it becomes 2**self.bits - 1
313314
zp += 1

neural_compressor/adaptor/torch_utils/model_wrapper.py

Lines changed: 21 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -327,9 +327,9 @@ def __init__(
327327

328328
def pack(self, int_weight, scale, zp, bias, g_idx=None):
329329
if self.use_optimum_format:
330-
self.scales = self.scales.T
331-
self.qweight = self.qweight.T
332-
self.qzeros = self.qzeros.T
330+
self.scales = self.scales.t_().contiguous()
331+
self.qweight = self.qweight.t_().contiguous()
332+
self.qzeros = self.qzeros.t_().contiguous()
333333
int_weight = int_weight.to(self.device)
334334
if self.use_optimum_format and zp is None:
335335
# to avoid overflow
@@ -350,8 +350,8 @@ def pack(self, int_weight, scale, zp, bias, g_idx=None):
350350
assert scale.shape == self.scales.shape, "Scale shape is mismatched."
351351
self.scales = scale.type(self.float_type).to(self.device)
352352
if not self.use_optimum_format and self.compression_dim == 0:
353-
int_weight = int_weight.T
354-
self.qweight = self.qweight.T
353+
int_weight = int_weight.t_().contiguous()
354+
self.qweight = self.qweight.t_().contiguous()
355355
origin_shape = int_weight.shape
356356
target_shape = self.qweight.shape
357357
assert origin_shape[0] == target_shape[0], "output channels mismatch, please check."
@@ -367,15 +367,15 @@ def pack(self, int_weight, scale, zp, bias, g_idx=None):
367367
tmp[:, e] = tmp[:, e] << (self.bits * e)
368368
self.qweight[:, j] |= tmp[:, e]
369369
if not self.use_optimum_format and self.compression_dim == 0:
370-
self.qweight = self.qweight.T
370+
self.qweight = self.qweight.t_().contiguous()
371371

372372
if zp is not None:
373373
zp = zp.to(self.device)
374374
if self.use_optimum_format:
375375
zp -= 1
376376
if self.use_optimum_format or self.compression_dim == 0:
377-
zp = zp.T
378-
self.qzeros = self.qzeros.T
377+
zp = zp.t_().contiguous()
378+
self.qzeros = self.qzeros.t_().contiguous()
379379
assert hasattr(self, "qzeros"), "zp is not set when initializing."
380380
target_shape = self.qzeros.shape
381381
for j in range(target_shape[1]):
@@ -387,16 +387,16 @@ def pack(self, int_weight, scale, zp, bias, g_idx=None):
387387
tmp[:, e] = tmp[:, e] << (self.bits * e)
388388
self.qzeros[:, j] |= tmp[:, e]
389389
if self.use_optimum_format or self.compression_dim == 0:
390-
self.qzeros = self.qzeros.T
390+
self.qzeros = self.qzeros.t_().contiguous()
391391
if self.use_optimum_format:
392-
self.scales = self.scales.T
393-
self.qweight = self.qweight.T
394-
self.qzeros = self.qzeros.T
392+
self.scales = self.scales.t_().contiguous()
393+
self.qweight = self.qweight.t_().contiguous()
394+
self.qzeros = self.qzeros.t_().contiguous()
395395

396396
def recover(self):
397397
logger.debug(f"Recovering {self} weight")
398-
scales = self.scales.T if self.use_optimum_format else self.scales
399-
qweight = self.qweight.T if self.use_optimum_format else self.qweight
398+
scales = self.scales.t_().contiguous() if self.use_optimum_format else self.scales
399+
qweight = self.qweight.t_().contiguous() if self.use_optimum_format else self.qweight
400400

401401
device = scales.device
402402
fp32_weight = torch.zeros(self.out_features, self.in_features, dtype=self.float_type).to(device)
@@ -411,8 +411,8 @@ def recover(self):
411411
# unpack weight
412412
weight = torch.zeros(self.out_features, self.in_features, dtype=weight_dtype).to(device)
413413
if not self.use_optimum_format and self.compression_dim == 0:
414-
weight = weight.T
415-
qweight = qweight.T
414+
weight = weight.t_().contiguous()
415+
qweight = qweight.t_().contiguous()
416416
origin_shape = weight.shape
417417
target_shape = qweight.shape
418418
for j in range(target_shape[1]):
@@ -427,7 +427,7 @@ def recover(self):
427427
tmp &= mask # remove sign bit
428428
weight[:, index] = tmp.type(weight_dtype)
429429
if not self.use_optimum_format and self.compression_dim == 0:
430-
weight = weight.T
430+
weight = weight.t_().contiguous()
431431
if "int" not in self.dtype:
432432
new_weight = torch.zeros(self.out_features, self.in_features).to(device)
433433
for k, v in self.int2float_mapping.items():
@@ -437,10 +437,10 @@ def recover(self):
437437
if hasattr(self, "qzeros"):
438438
zp_dtype = self.compression_dtype # to avoid overflow when weight-zp
439439
zp = torch.zeros(scales.shape, dtype=zp_dtype).to(device)
440-
qzeros = self.qzeros.T if self.use_optimum_format else self.qzeros
440+
qzeros = self.qzeros.t_().contiguous() if self.use_optimum_format else self.qzeros
441441
if self.use_optimum_format or self.compression_dim == 0:
442-
zp = zp.T
443-
qzeros = qzeros.T
442+
zp = zp.t_().contiguous()
443+
qzeros = qzeros.t_().contiguous()
444444
origin_shape = zp.shape
445445
target_shape = qzeros.shape
446446
for j in range(target_shape[1]):
@@ -454,7 +454,7 @@ def recover(self):
454454
tmp &= mask
455455
zp[:, index] = tmp.type(zp_dtype)
456456
if self.use_optimum_format or self.compression_dim == 0:
457-
zp = zp.T
457+
zp = zp.t_().contiguous()
458458
if self.use_optimum_format:
459459
# zp -= 1 may cause zp == -1, after recover it becomes 2**self.bits - 1
460460
zp += 1

neural_compressor/adaptor/torch_utils/weight_only.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -427,7 +427,7 @@ def rtn_quantize(
427427
if num_bits <= 0:
428428
logger.info(f"Skip {name}")
429429
continue
430-
weight = m.weight.T if group_dim == 0 else m.weight
430+
weight = m.weight.t_().contiguous() if group_dim == 0 else m.weight
431431
if enable_mse_search:
432432
quantile = search_clip(m, num_bits, group_size, scheme, data_type, enable_full_range)
433433
if return_int:
@@ -445,8 +445,8 @@ def rtn_quantize(
445445
)
446446
if group_dim == 0:
447447
weight.transpose_(0, 1)
448-
scale = scale.T if group_dim == 0 else scale
449-
zp = zp.T if group_dim == 0 and zp is not None else zp
448+
scale = scale.t_().contiguous() if group_dim == 0 else scale
449+
zp = zp.t_().contiguous() if group_dim == 0 and zp is not None else zp
450450
new_module = WeightOnlyLinear(
451451
m.in_features,
452452
m.out_features,
@@ -649,18 +649,18 @@ def quant_weight_w_scale(weight, scale, zp, group_size=-1):
649649
if zp is not None:
650650
zp = zp.to(device)
651651
if group_size == -1:
652-
return torch.round(weight / scale) if zp is None else torch.round(weight / scale + zp)
652+
return weight.div_(scale).round_() if zp is None else weight.div_(scale).add_(zp).round_()
653653
int_weight = torch.zeros(weight.shape).to(device)
654654
leng = weight.shape[1] // group_size
655655
tail_flag = False if weight.shape[1] % group_size == 0 else True
656656
for i in range(leng):
657-
int_weight_tmp = weight[:, i * group_size : (i + 1) * group_size] / scale[:, i].unsqueeze(1)
657+
int_weight_tmp = weight[:, i * group_size : (i + 1) * group_size].div_(scale[:, i].unsqueeze(1))
658658
if zp is not None:
659-
int_weight_tmp += zp[:, i].unsqueeze(1)
660-
int_weight[:, i * group_size : (i + 1) * group_size] = torch.round(int_weight_tmp)
659+
int_weight_tmp.add_(zp[:, i].unsqueeze(1))
660+
int_weight[:, i * group_size : (i + 1) * group_size].copy_(int_weight_tmp.round_())
661661
if tail_flag:
662-
int_weight_tmp = weight[:, leng * group_size :] / scale[:, -1].unsqueeze(1)
662+
int_weight_tmp = weight[:, leng * group_size :].div_(scale[:, -1].unsqueeze(1))
663663
if zp is not None:
664-
int_weight_tmp += zp[:, -1].unsqueeze(1)
665-
int_weight[:, leng * group_size :] = torch.round(int_weight_tmp)
664+
int_weight_tmp.add_(zp[:, -1].unsqueeze(1))
665+
int_weight[:, leng * group_size :].copy_(int_weight_tmp.round_())
666666
return int_weight

neural_compressor/torch/algorithms/weight_only/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,5 +12,6 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
from .utility import *
1516
from .rtn import rtn_quantize
1617
from .gptq import gptq_quantize

neural_compressor/torch/algorithms/weight_only/rtn.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ def rtn_quantize(
128128
continue
129129
logger.debug(f"RTN quantized module:{name, m}")
130130
logger.debug(log_msg)
131-
weight = m.weight.T if group_dim == 0 else m.weight
131+
weight = m.weight.t_().contiguous() if group_dim == 0 else m.weight
132132
if use_mse_search:
133133
quantile = search_clip(m, bits, group_size, scheme, dtype, use_full_range)
134134
if export_compressed_model:
@@ -143,9 +143,9 @@ def rtn_quantize(
143143
full_range=use_full_range,
144144
**double_quant_config,
145145
)
146-
int_weight = int_weight.T if group_dim == 0 else int_weight
147-
scale = scale.T if group_dim == 0 else scale
148-
zp = zp.T if group_dim == 0 and zp is not None else zp
146+
int_weight = int_weight.t_().contiguous() if group_dim == 0 else int_weight
147+
scale = scale.t_().contiguous() if group_dim == 0 else scale
148+
zp = zp.t_().contiguous() if group_dim == 0 and zp is not None else zp
149149
from neural_compressor.torch.quantization.layers import WeightOnlyLinear
150150

151151
new_module = WeightOnlyLinear(
@@ -175,6 +175,6 @@ def rtn_quantize(
175175
full_range=use_full_range,
176176
**double_quant_config,
177177
)
178-
weight = weight.T if group_dim == 0 else weight
178+
weight = weight.t_().contiguous() if group_dim == 0 else weight
179179
m.weight.data.copy_(weight)
180180
return model

0 commit comments

Comments
 (0)