diff --git a/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_wrapper.py b/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_wrapper.py index c47969893..ebd4f2c97 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_wrapper.py +++ b/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_wrapper.py @@ -100,6 +100,24 @@ def compress( W[:, dead] = 0 g_idx = None + # if hasattr(self.layer, "quantization_scheme"): + # quant_scheme = self.layer.quantization_scheme + # actorder = quant_scheme.weights.actorder + + # if actorder: + # group_size = quant_scheme.weights.group_size + # perm = torch.argsort(torch.diag(self.H), descending=True) + # # W = W[:, perm] + # # self.H = self.H[perm][:, perm] + # invperm = torch.argsort(perm) + + # # g_idx = torch.Tensor( + # # [perm[i] // group_size for i in range(self.columns)] + # # ).to(device=invperm.device) + # g_idx = torch.Tensor( + # [i // group_size for i in range(self.columns)] + # ).to(device=invperm.device) + # self.layer.weight_g_idx.data = g_idx if hasattr(self.layer, "quantization_scheme"): quant_scheme = self.layer.quantization_scheme actorder = quant_scheme.weights.actorder @@ -114,8 +132,11 @@ def compress( g_idx = torch.Tensor( [perm[i] // group_size for i in range(self.columns)] ).to(device=invperm.device) + # g_idx = torch.Tensor( + # [i // group_size for i in range(self.columns)] + # ).to(device=invperm.device) self.layer.weight_g_idx.data = g_idx - + Losses = torch.zeros(self.rows, device=self.dev) damp = percdamp * torch.mean(torch.diag(self.H)) @@ -200,24 +221,22 @@ def compress( altered_qargs = copy(quant_scheme.weights) altered_qargs.strategy = QuantizationStrategy.CHANNEL - # apply g_idx if g_idx is not None: - # scale and zp already transformed by group_size - # extract first index of group_size - indices_to_extract = torch.arange( - 0, g_idx.shape[0], group_size + q = fake_quantize( + q, + scale[:, int(g_idx[column_idx])], + zero_point[:, int(g_idx[column_idx])], + altered_qargs, + ) + + else: + + q = fake_quantize( + q, + scale[:, input_dim_group], + zero_point[:, input_dim_group], + altered_qargs, ) - grouped_indicies = g_idx[indices_to_extract].int() - - scale = scale[:, grouped_indicies] - zero_point = zero_point[:, grouped_indicies] - - q = fake_quantize( - q, - scale[:, input_dim_group], - zero_point[:, input_dim_group], - altered_qargs, - ) Q1[:, i] = q Losses1[:, i] = (w - q) ** 2 / d**2 @@ -244,6 +263,7 @@ def compress( if actorder: W = W[:, invperm] + self.H = self.H[perm][:, perm] if isinstance(self.layer, transformers.Conv1D): W = W.t()