Skip to content

Commit c0c04af

Browse files
committed
manually rebase
Signed-off-by: xin3he <xin3.he@intel.com>
1 parent 3622eab commit c0c04af

File tree

29 files changed

+1111
-1235
lines changed

29 files changed

+1111
-1235
lines changed

examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/llm/run_clm_no_trainer.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -230,8 +230,8 @@ def get_user_model():
230230

231231
# 3.x api
232232
if args.approach == 'weight_only':
233-
from neural_compressor.torch import RTNWeightQuantConfig, GPTQConfig, quantize
234-
from neural_compressor.torch.utils.utility import get_double_quant_config
233+
from neural_compressor.torch.quantization import RTNConfig, GPTQConfig, quantize
234+
from neural_compressor.torch.utils import get_double_quant_config
235235
weight_sym = True if args.woq_scheme == "sym" else False
236236
double_quant_config_dict = get_double_quant_config(args.double_quant_type, weight_sym=weight_sym)
237237

@@ -243,9 +243,9 @@ def get_user_model():
243243
"enable_mse_search": args.woq_enable_mse_search,
244244
}
245245
)
246-
quant_config = RTNWeightQuantConfig.from_dict(double_quant_config_dict)
246+
quant_config = RTNConfig.from_dict(double_quant_config_dict)
247247
else:
248-
quant_config = RTNWeightQuantConfig(
248+
quant_config = RTNConfig(
249249
weight_dtype=args.woq_dtype,
250250
weight_bits=args.woq_bits,
251251
weight_group_size=args.woq_group_size,
@@ -257,7 +257,7 @@ def get_user_model():
257257
double_quant_sym=args.double_quant_sym,
258258
double_quant_group_size=args.double_quant_group_size,
259259
)
260-
quant_config.set_local("lm_head", RTNWeightQuantConfig(weight_dtype="fp32"))
260+
quant_config.set_local("lm_head", RTNConfig(weight_dtype="fp32"))
261261
user_model = quantize(
262262
model=user_model, quant_config=quant_config
263263
)

neural_compressor/adaptor/torch_utils/autoround/model_wrapper.py

Lines changed: 22 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -127,23 +127,20 @@ def __init__(
127127
dtype=self.float_type,
128128
).to(device),
129129
)
130-
self.scales = self.scales.T
131130
self.register_buffer(
132131
"qweight",
133132
torch.zeros(
134133
(math.ceil(in_features / self.n_pack), out_features),
135134
dtype=self.compression_dtype,
136135
).to(device),
137136
)
138-
self.qweight = self.qweight.T
139137
self.register_buffer(
140138
"qzeros",
141139
torch.zeros(
142140
(math.ceil(self.in_features / self.groupsize), math.ceil(self.out_features / self.n_pack)),
143141
dtype=self.compression_dtype,
144142
).to(device),
145143
)
146-
self.qzeros = self.qzeros.T
147144
self.register_buffer("bias", torch.zeros(self.out_features, dtype=self.float_type).to(device))
148145
else:
149146
self.compression_dtype = compression_dtype
@@ -193,6 +190,10 @@ def __init__(
193190
self.bias = None
194191

195192
def pack(self, int_weight, scale, zp, bias):
193+
if self.use_optimum_format:
194+
self.scales = self.scales.t_().contiguous()
195+
self.qweight = self.qweight.t_().contiguous()
196+
self.qzeros = self.qzeros.t_().contiguous()
196197
int_weight = int_weight.to(self.device)
197198
if self.use_optimum_format and zp is None:
198199
# to avoid overflow
@@ -206,8 +207,8 @@ def pack(self, int_weight, scale, zp, bias):
206207
assert scale.shape == self.scales.shape, "Scale shape is mismatched."
207208
self.scales = scale.type(self.float_type).to(self.device)
208209
if not self.use_optimum_format and self.compression_dim == 0:
209-
int_weight = int_weight.T
210-
self.qweight = self.qweight.T
210+
int_weight = int_weight.t_().contiguous()
211+
self.qweight = self.qweight.t_().contiguous()
211212
origin_shape = int_weight.shape
212213
target_shape = self.qweight.shape
213214
assert origin_shape[0] == target_shape[0], "output channels mismatch, please check."
@@ -223,15 +224,15 @@ def pack(self, int_weight, scale, zp, bias):
223224
tmp[:, e] = tmp[:, e] << (self.bits * e)
224225
self.qweight[:, j] |= tmp[:, e]
225226
if not self.use_optimum_format and self.compression_dim == 0:
226-
self.qweight = self.qweight.T
227+
self.qweight = self.qweight.t_().contiguous()
227228

228229
if zp is not None:
229230
zp = zp.to(self.device)
230231
if self.use_optimum_format:
231232
zp -= 1
232233
if self.use_optimum_format or self.compression_dim == 0:
233-
zp = zp.T
234-
self.qzeros = self.qzeros.T
234+
zp = zp.t_().contiguous()
235+
self.qzeros = self.qzeros.t_().contiguous()
235236
assert hasattr(self, "qzeros"), "zp is not set when initializing."
236237
target_shape = self.qzeros.shape
237238
for j in range(target_shape[1]):
@@ -243,16 +244,16 @@ def pack(self, int_weight, scale, zp, bias):
243244
tmp[:, e] = tmp[:, e] << (self.bits * e)
244245
self.qzeros[:, j] |= tmp[:, e]
245246
if self.use_optimum_format or self.compression_dim == 0:
246-
self.qzeros = self.qzeros.T
247+
self.qzeros = self.qzeros.t_().contiguous()
247248
if self.use_optimum_format:
248-
self.scales = self.scales.T
249-
self.qweight = self.qweight.T
250-
self.qzeros = self.qzeros.T
249+
self.scales = self.scales.t_().contiguous()
250+
self.qweight = self.qweight.t_().contiguous()
251+
self.qzeros = self.qzeros.t_().contiguous()
251252

252253
def recover(self):
253254
logger.debug(f"Recovering {self} weight")
254-
scales = self.scales.T if self.use_optimum_format else self.scales
255-
qweight = self.qweight.T if self.use_optimum_format else self.qweight
255+
scales = self.scales.t_().contiguous() if self.use_optimum_format else self.scales
256+
qweight = self.qweight.t_().contiguous() if self.use_optimum_format else self.qweight
256257

257258
device = scales.device
258259
fp32_weight = torch.zeros(self.out_features, self.in_features, dtype=self.float_type).to(device)
@@ -264,8 +265,8 @@ def recover(self):
264265
# unpack weight
265266
weight = torch.zeros(self.out_features, self.in_features, dtype=weight_dtype).to(device)
266267
if not self.use_optimum_format and self.compression_dim == 0:
267-
weight = weight.T
268-
qweight = qweight.T
268+
weight = weight.t_().contiguous()
269+
qweight = qweight.t_().contiguous()
269270
origin_shape = weight.shape
270271
target_shape = qweight.shape
271272
for j in range(target_shape[1]):
@@ -280,7 +281,7 @@ def recover(self):
280281
tmp &= mask # remove sign bit
281282
weight[:, index] = tmp.type(weight_dtype)
282283
if not self.use_optimum_format and self.compression_dim == 0:
283-
weight = weight.T
284+
weight = weight.t_().contiguous()
284285
if "int" not in self.dtype:
285286
new_weight = torch.zeros(self.out_features, self.in_features).to(device)
286287
for k, v in self.int2float_mapping.items():
@@ -290,10 +291,10 @@ def recover(self):
290291
if hasattr(self, "qzeros"):
291292
zp_dtype = self.compression_dtype # to avoid overflow when weight-zp
292293
zp = torch.zeros(scales.shape, dtype=zp_dtype).to(device)
293-
qzeros = self.qzeros.T if self.use_optimum_format else self.qzeros
294+
qzeros = self.qzeros.t_().contiguous() if self.use_optimum_format else self.qzeros
294295
if self.use_optimum_format or self.compression_dim == 0:
295-
zp = zp.T
296-
qzeros = qzeros.T
296+
zp = zp.t_().contiguous()
297+
qzeros = qzeros.t_().contiguous()
297298
origin_shape = zp.shape
298299
target_shape = qzeros.shape
299300
for j in range(target_shape[1]):
@@ -307,7 +308,7 @@ def recover(self):
307308
tmp &= mask
308309
zp[:, index] = tmp.type(zp_dtype)
309310
if self.use_optimum_format or self.compression_dim == 0:
310-
zp = zp.T
311+
zp = zp.t_().contiguous()
311312
if self.use_optimum_format:
312313
# zp -= 1 may cause zp == -1, after recover it becomes 2**self.bits - 1
313314
zp += 1

neural_compressor/adaptor/torch_utils/model_wrapper.py

Lines changed: 21 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -327,9 +327,9 @@ def __init__(
327327

328328
def pack(self, int_weight, scale, zp, bias, g_idx=None):
329329
if self.use_optimum_format:
330-
self.scales = self.scales.T
331-
self.qweight = self.qweight.T
332-
self.qzeros = self.qzeros.T
330+
self.scales = self.scales.t_().contiguous()
331+
self.qweight = self.qweight.t_().contiguous()
332+
self.qzeros = self.qzeros.t_().contiguous()
333333
int_weight = int_weight.to(self.device)
334334
if self.use_optimum_format and zp is None:
335335
# to avoid overflow
@@ -350,8 +350,8 @@ def pack(self, int_weight, scale, zp, bias, g_idx=None):
350350
assert scale.shape == self.scales.shape, "Scale shape is mismatched."
351351
self.scales = scale.type(self.float_type).to(self.device)
352352
if not self.use_optimum_format and self.compression_dim == 0:
353-
int_weight = int_weight.T
354-
self.qweight = self.qweight.T
353+
int_weight = int_weight.t_().contiguous()
354+
self.qweight = self.qweight.t_().contiguous()
355355
origin_shape = int_weight.shape
356356
target_shape = self.qweight.shape
357357
assert origin_shape[0] == target_shape[0], "output channels mismatch, please check."
@@ -367,15 +367,15 @@ def pack(self, int_weight, scale, zp, bias, g_idx=None):
367367
tmp[:, e] = tmp[:, e] << (self.bits * e)
368368
self.qweight[:, j] |= tmp[:, e]
369369
if not self.use_optimum_format and self.compression_dim == 0:
370-
self.qweight = self.qweight.T
370+
self.qweight = self.qweight.t_().contiguous()
371371

372372
if zp is not None:
373373
zp = zp.to(self.device)
374374
if self.use_optimum_format:
375375
zp -= 1
376376
if self.use_optimum_format or self.compression_dim == 0:
377-
zp = zp.T
378-
self.qzeros = self.qzeros.T
377+
zp = zp.t_().contiguous()
378+
self.qzeros = self.qzeros.t_().contiguous()
379379
assert hasattr(self, "qzeros"), "zp is not set when initializing."
380380
target_shape = self.qzeros.shape
381381
for j in range(target_shape[1]):
@@ -387,16 +387,16 @@ def pack(self, int_weight, scale, zp, bias, g_idx=None):
387387
tmp[:, e] = tmp[:, e] << (self.bits * e)
388388
self.qzeros[:, j] |= tmp[:, e]
389389
if self.use_optimum_format or self.compression_dim == 0:
390-
self.qzeros = self.qzeros.T
390+
self.qzeros = self.qzeros.t_().contiguous()
391391
if self.use_optimum_format:
392-
self.scales = self.scales.T
393-
self.qweight = self.qweight.T
394-
self.qzeros = self.qzeros.T
392+
self.scales = self.scales.t_().contiguous()
393+
self.qweight = self.qweight.t_().contiguous()
394+
self.qzeros = self.qzeros.t_().contiguous()
395395

396396
def recover(self):
397397
logger.debug(f"Recovering {self} weight")
398-
scales = self.scales.T if self.use_optimum_format else self.scales
399-
qweight = self.qweight.T if self.use_optimum_format else self.qweight
398+
scales = self.scales.t_().contiguous() if self.use_optimum_format else self.scales
399+
qweight = self.qweight.t_().contiguous() if self.use_optimum_format else self.qweight
400400

401401
device = scales.device
402402
fp32_weight = torch.zeros(self.out_features, self.in_features, dtype=self.float_type).to(device)
@@ -411,8 +411,8 @@ def recover(self):
411411
# unpack weight
412412
weight = torch.zeros(self.out_features, self.in_features, dtype=weight_dtype).to(device)
413413
if not self.use_optimum_format and self.compression_dim == 0:
414-
weight = weight.T
415-
qweight = qweight.T
414+
weight = weight.t_().contiguous()
415+
qweight = qweight.t_().contiguous()
416416
origin_shape = weight.shape
417417
target_shape = qweight.shape
418418
for j in range(target_shape[1]):
@@ -427,7 +427,7 @@ def recover(self):
427427
tmp &= mask # remove sign bit
428428
weight[:, index] = tmp.type(weight_dtype)
429429
if not self.use_optimum_format and self.compression_dim == 0:
430-
weight = weight.T
430+
weight = weight.t_().contiguous()
431431
if "int" not in self.dtype:
432432
new_weight = torch.zeros(self.out_features, self.in_features).to(device)
433433
for k, v in self.int2float_mapping.items():
@@ -437,10 +437,10 @@ def recover(self):
437437
if hasattr(self, "qzeros"):
438438
zp_dtype = self.compression_dtype # to avoid overflow when weight-zp
439439
zp = torch.zeros(scales.shape, dtype=zp_dtype).to(device)
440-
qzeros = self.qzeros.T if self.use_optimum_format else self.qzeros
440+
qzeros = self.qzeros.t_().contiguous() if self.use_optimum_format else self.qzeros
441441
if self.use_optimum_format or self.compression_dim == 0:
442-
zp = zp.T
443-
qzeros = qzeros.T
442+
zp = zp.t_().contiguous()
443+
qzeros = qzeros.t_().contiguous()
444444
origin_shape = zp.shape
445445
target_shape = qzeros.shape
446446
for j in range(target_shape[1]):
@@ -454,7 +454,7 @@ def recover(self):
454454
tmp &= mask
455455
zp[:, index] = tmp.type(zp_dtype)
456456
if self.use_optimum_format or self.compression_dim == 0:
457-
zp = zp.T
457+
zp = zp.t_().contiguous()
458458
if self.use_optimum_format:
459459
# zp -= 1 may cause zp == -1, after recover it becomes 2**self.bits - 1
460460
zp += 1

neural_compressor/adaptor/torch_utils/weight_only.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -429,7 +429,7 @@ def rtn_quantize(
429429
if num_bits <= 0:
430430
logger.info(f"Skip {name}")
431431
continue
432-
weight = m.weight.T if group_dim == 0 else m.weight
432+
weight = m.weight.t_().contiguous() if group_dim == 0 else m.weight
433433
if enable_mse_search:
434434
quantile = search_clip(m, num_bits, group_size, scheme, data_type, enable_full_range)
435435
if return_int:
@@ -447,8 +447,8 @@ def rtn_quantize(
447447
)
448448
if group_dim == 0:
449449
weight.transpose_(0, 1)
450-
scale = scale.T if group_dim == 0 else scale
451-
zp = zp.T if group_dim == 0 and zp is not None else zp
450+
scale = scale.t_().contiguous() if group_dim == 0 else scale
451+
zp = zp.t_().contiguous() if group_dim == 0 and zp is not None else zp
452452
new_module = WeightOnlyLinear(
453453
m.in_features,
454454
m.out_features,
@@ -651,18 +651,18 @@ def quant_weight_w_scale(weight, scale, zp, group_size=-1):
651651
if zp is not None:
652652
zp = zp.to(device)
653653
if group_size == -1:
654-
return torch.round(weight / scale) if zp is None else torch.round(weight / scale + zp)
654+
return weight.div_(scale).round_() if zp is None else weight.div_(scale).add_(zp).round_()
655655
int_weight = torch.zeros(weight.shape).to(device)
656656
leng = weight.shape[1] // group_size
657657
tail_flag = False if weight.shape[1] % group_size == 0 else True
658658
for i in range(leng):
659-
int_weight_tmp = weight[:, i * group_size : (i + 1) * group_size] / scale[:, i].unsqueeze(1)
659+
int_weight_tmp = weight[:, i * group_size : (i + 1) * group_size].div_(scale[:, i].unsqueeze(1))
660660
if zp is not None:
661-
int_weight_tmp += zp[:, i].unsqueeze(1)
662-
int_weight[:, i * group_size : (i + 1) * group_size] = torch.round(int_weight_tmp)
661+
int_weight_tmp.add_(zp[:, i].unsqueeze(1))
662+
int_weight[:, i * group_size : (i + 1) * group_size].copy_(int_weight_tmp.round_())
663663
if tail_flag:
664-
int_weight_tmp = weight[:, leng * group_size :] / scale[:, -1].unsqueeze(1)
664+
int_weight_tmp = weight[:, leng * group_size :].div_(scale[:, -1].unsqueeze(1))
665665
if zp is not None:
666-
int_weight_tmp += zp[:, -1].unsqueeze(1)
667-
int_weight[:, leng * group_size :] = torch.round(int_weight_tmp)
666+
int_weight_tmp.add_(zp[:, -1].unsqueeze(1))
667+
int_weight[:, leng * group_size :].copy_(int_weight_tmp.round_())
668668
return int_weight

neural_compressor/common/base_config.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,7 @@ def set_local(self, operator_name: str, config: BaseConfig) -> BaseConfig:
180180
self.local_config[operator_name] = config
181181
return self
182182

183-
def to_dict(self, params_list=[], operator2str=None):
183+
def to_dict(self):
184184
result = {}
185185
global_config = self.get_params_dict()
186186
if bool(self.local_config):
@@ -200,7 +200,7 @@ def get_params_dict(self):
200200
return result
201201

202202
@classmethod
203-
def from_dict(cls, config_dict, str2operator=None):
203+
def from_dict(cls, config_dict):
204204
"""Construct config from a dict.
205205
206206
Args:

neural_compressor/common/utility.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
# config name
2828
BASE_CONFIG = "base_config"
2929
COMPOSABLE_CONFIG = "composable_config"
30-
RTN_WEIGHT_ONLY_QUANT = "rtn_weight_only_quant"
30+
RTN = "rtn"
3131
STATIC_QUANT = "static_quant"
3232
GPTQ = "gptq"
3333
FP8_QUANT = "fp8_quant"

neural_compressor/tensorflow/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def register_algo(name):
3535
3636
Usage example:
3737
@register_algo(name=example_algo)
38-
def example_algo(model: torch.nn.Module, quant_config: RTNWeightQuantConfig) -> torch.nn.Module:
38+
def example_algo(model: torch.nn.Module, quant_config: RTNConfig) -> torch.nn.Module:
3939
...
4040
Args:
4141
name (str): The name under which the algorithm function will be registered.

neural_compressor/torch/__init__.py

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -11,17 +11,3 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
15-
from neural_compressor.torch.utils.utility import register_algo
16-
from neural_compressor.torch.algorithms import rtn_quantize_entry, gptq_quantize_entry
17-
18-
from neural_compressor.torch.quantization import (
19-
quantize,
20-
RTNWeightQuantConfig,
21-
get_default_rtn_config,
22-
GPTQConfig,
23-
get_default_gptq_config,
24-
)
25-
26-
from neural_compressor.common.base_tuning import TuningConfig
27-
from neural_compressor.torch.autotune import autotune, get_default_tune_config

neural_compressor/torch/algorithms/__init__.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,5 +13,7 @@
1313
# limitations under the License.
1414

1515

16-
from neural_compressor.torch.algorithms.weight_only_algos import rtn_quantize_entry
17-
from neural_compressor.torch.algorithms.weight_only_algos import gptq_quantize_entry
16+
from .weight_only import (
17+
rtn_quantize,
18+
gptq_quantize,
19+
)
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
# Demo of algorithm usage w/o INC

0 commit comments

Comments
 (0)