Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
horheynm committed Jul 10, 2024
1 parent 3e7b875 commit 778b5b5
Showing 1 changed file with 14 additions and 13 deletions.
27 changes: 14 additions & 13 deletions src/llmcompressor/modifiers/quantization/gptq/utils/gptq_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,24 +98,14 @@ def compress(
dead = torch.diag(self.H) == 0
self.H[dead, dead] = 1
W[:, dead] = 0

Losses = torch.zeros(self.rows, device=self.dev)

damp = percdamp * torch.mean(torch.diag(self.H))
diag = torch.arange(self.columns, device=self.dev)
self.H[diag, diag] += damp
self.H = torch.linalg.cholesky(self.H)
self.H = torch.cholesky_inverse(self.H)
self.H = torch.linalg.cholesky(self.H, upper=True)
Hinv = self.H


g_idx = None
if hasattr(self.layer, "quantization_scheme"):
quant_scheme = self.layer.quantization_scheme
actorder = quant_scheme.weights.actorder
group_size = quant_scheme.weights.group_size


if actorder:
group_size = quant_scheme.weights.group_size
perm = torch.argsort(torch.diag(self.H), descending=True)
W = W[:, perm]
self.H = self.H[perm][:, perm]
Expand All @@ -126,6 +116,16 @@ def compress(
)
g_idx = g_idx[invperm]
self.layer.weight_g_idx.data = g_idx

Losses = torch.zeros(self.rows, device=self.dev)

damp = percdamp * torch.mean(torch.diag(self.H))
diag = torch.arange(self.columns, device=self.dev)
self.H[diag, diag] += damp
self.H = torch.linalg.cholesky(self.H)
self.H = torch.cholesky_inverse(self.H)
self.H = torch.linalg.cholesky(self.H, upper=True)
Hinv = self.H

# See section 3.4 of https://arxiv.org/abs/2203.07259
for i1 in range(0, self.columns, blocksize):
Expand Down Expand Up @@ -209,6 +209,7 @@ def compress(
0, g_idx.shape[0], group_size
)
grouped_indicies = g_idx[indices_to_extract].int()

scale = scale[:, grouped_indicies]
zero_point = zero_point[:, grouped_indicies]

Expand Down

0 comments on commit 778b5b5

Please sign in to comment.