From f203537e4dd7d1d5a507a1489b4c3ea06cf416fc Mon Sep 17 00:00:00 2001 From: George Ohashi Date: Thu, 11 Jul 2024 20:33:20 +0000 Subject: [PATCH] propagagte g_idx with perm --- .../modifiers/quantization/gptq/utils/gptq_wrapper.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_wrapper.py b/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_wrapper.py index f7bea26d8..c47969893 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_wrapper.py +++ b/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_wrapper.py @@ -111,10 +111,9 @@ def compress( self.H = self.H[perm][:, perm] invperm = torch.argsort(perm) - g_idx = torch.Tensor([i // group_size for i in range(self.columns)]).to( - device=invperm.device - ) - g_idx = g_idx[invperm] + g_idx = torch.Tensor( + [perm[i] // group_size for i in range(self.columns)] + ).to(device=invperm.device) self.layer.weight_g_idx.data = g_idx Losses = torch.zeros(self.rows, device=self.dev)