diff --git a/pygod/nn/ocgnn.py b/pygod/nn/ocgnn.py index 38ced96..ae18624 100644 --- a/pygod/nn/ocgnn.py +++ b/pygod/nn/ocgnn.py @@ -108,6 +108,12 @@ def loss_func(self, emb): score : torch.Tensor Outlier scores of shape :math:`N` with gradients. """ + if self.warmup > 0: + with torch.no_grad(): + self.warmup -= 1 + self.c = torch.mean(emb, 0) + self.c[(abs(self.c) < self.eps) & (self.c < 0)] = -self.eps + self.c[(abs(self.c) < self.eps) & (self.c > 0)] = self.eps dist = torch.sum(torch.pow(emb - self.c, 2), 1) score = dist - self.r ** 2 @@ -115,10 +121,6 @@ def loss_func(self, emb): if self.warmup > 0: with torch.no_grad(): - self.warmup -= 1 self.r = torch.quantile(torch.sqrt(dist), 1 - self.beta) - self.c = torch.mean(emb, 0) - self.c[(abs(self.c) < self.eps) & (self.c < 0)] = -self.eps - self.c[(abs(self.c) < self.eps) & (self.c > 0)] = self.eps return loss, score