Skip to content

Commit

Permalink
propagagte g_idx with perm
Browse files Browse the repository at this point in the history
  • Loading branch information
horheynm committed Jul 11, 2024
1 parent bc08e8d commit f203537
Showing 1 changed file with 3 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -111,10 +111,9 @@ def compress(
self.H = self.H[perm][:, perm]
invperm = torch.argsort(perm)

g_idx = torch.Tensor([i // group_size for i in range(self.columns)]).to(
device=invperm.device
)
g_idx = g_idx[invperm]
g_idx = torch.Tensor(
[perm[i] // group_size for i in range(self.columns)]
).to(device=invperm.device)
self.layer.weight_g_idx.data = g_idx

Losses = torch.zeros(self.rows, device=self.dev)
Expand Down

0 comments on commit f203537

Please sign in to comment.