diff --git a/src/brevitas/graph/gpfq.py b/src/brevitas/graph/gpfq.py index 01668ef2b..efe6bab6a 100644 --- a/src/brevitas/graph/gpfq.py +++ b/src/brevitas/graph/gpfq.py @@ -244,17 +244,18 @@ def single_layer_update(self): weight.shape[0], weight.shape[1], self.float_input.shape[1], device=dev, dtype=dtype) # We don't need full Hessian, we just need the diagonal # Summing over batch dimension - self.H_diag = self.quant_input.transpose(2, 1).square().sum(2) + H_diag = self.quant_input.transpose(2, 1).square().sum(2) permutation_list = [] for group_index in range(self.groups): if self.act_order: # Re-order Hessian_diagonal so that weights associated to # higher magnitude activations are quantized first - perm = torch.argsort(self.H_diag[group_index, :], descending=True) + perm = torch.argsort(H_diag[group_index, :], descending=True) else: # No permutation, permutation tensor is a ordered index perm = torch.tensor(range(weight.shape[-1]), device=dev) permutation_list.append(perm) + del H_diag for t in range(weight.shape[-1]): for group_index in range(self.groups): U[group_index] += torch.matmul(