Skip to content

Commit

Permalink
Fix (gpfq): memory management
Browse files Browse the repository at this point in the history
  • Loading branch information
i-colbert committed Oct 8, 2024
1 parent 2088225 commit 10a5c7e
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions src/brevitas/graph/gpfq.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,17 +244,18 @@ def single_layer_update(self):
weight.shape[0], weight.shape[1], self.float_input.shape[1], device=dev, dtype=dtype)
# We don't need full Hessian, we just need the diagonal
# Summing over batch dimension
self.H_diag = self.quant_input.transpose(2, 1).square().sum(2)
H_diag = self.quant_input.transpose(2, 1).square().sum(2)
permutation_list = []
for group_index in range(self.groups):
if self.act_order:
# Re-order Hessian_diagonal so that weights associated to
# higher magnitude activations are quantized first
perm = torch.argsort(self.H_diag[group_index, :], descending=True)
perm = torch.argsort(H_diag[group_index, :], descending=True)
else:
# No permutation, permutation tensor is a ordered index
perm = torch.tensor(range(weight.shape[-1]), device=dev)
permutation_list.append(perm)
del H_diag
for t in range(weight.shape[-1]):
for group_index in range(self.groups):
U[group_index] += torch.matmul(
Expand Down

0 comments on commit 10a5c7e

Please sign in to comment.