diff --git a/pygod/detector/adone.py b/pygod/detector/adone.py index 1f56989..94d247f 100644 --- a/pygod/detector/adone.py +++ b/pygod/detector/adone.py @@ -198,27 +198,31 @@ def forward_model(self, data): s = data.s.to(self.device) edge_index = data.edge_index.to(self.device) - x_, s_, h_a, h_s, dna, dns, dis_a, dis_s = self.model(x, s, edge_index) - - loss_d = - torch.mean(torch.log(1 - dis_a[:batch_size]) - + torch.log(dis_s[:batch_size])) - - loss_g, oa, os, oc = self.model.loss_func(x[:batch_size], - x_[:batch_size], - s[:batch_size], - s_[:batch_size], - h_a[:batch_size], - h_s[:batch_size], - dna[:batch_size], - dns[:batch_size], - dis_a[:batch_size], - dis_s[:batch_size]) + x_, s_, h_a, h_s, dna, dns = self.model(x, s, edge_index) + + loss_d = self.model.loss_func_d(h_a[:batch_size].detach(), + h_s[:batch_size].detach()) + + self.opt_in.zero_grad() + loss_d.backward() + self.opt_in.step() + + self.epoch_loss_in += loss_d.item() * batch_size + + loss_g, oa, os, oc = self.model.loss_func_g(x[:batch_size], + x_[:batch_size], + s[:batch_size], + s_[:batch_size], + h_a[:batch_size], + h_s[:batch_size], + dna[:batch_size], + dns[:batch_size]) self.attribute_score_[node_idx[:batch_size]] = oa.detach().cpu() self.structural_score_[node_idx[:batch_size]] = os.detach().cpu() self.combined_score_[node_idx[:batch_size]] = oc.detach().cpu() - return (loss_g, loss_d), ((oa + os + oc) / 3).detach().cpu() + return loss_g, ((oa + os + oc) / 3).detach().cpu() def decision_function(self, data, label=None): if data is not None: diff --git a/pygod/detector/base.py b/pygod/detector/base.py index 778ec19..7c056cf 100644 --- a/pygod/detector/base.py +++ b/pygod/detector/base.py @@ -439,39 +439,31 @@ def fit(self, data, label=None): self.model = self.init_model(**self.kwargs) if self.compile_model: self.model = compile(self.model) - if self.gan: - opt_g = torch.optim.Adam(self.model.generator.parameters(), - lr=self.lr, - weight_decay=self.weight_decay) - opt_d = torch.optim.Adam(self.model.discriminator.parameters(), - lr=self.lr, - weight_decay=self.weight_decay) - else: + if not self.gan: optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr, weight_decay=self.weight_decay) + else: + self.opt_in = torch.optim.Adam(self.model.inner.parameters(), + lr=self.lr, + weight_decay=self.weight_decay) + optimizer = torch.optim.Adam(self.model.outer.parameters(), + lr=self.lr, + weight_decay=self.weight_decay) self.model.train() self.decision_score_ = torch.zeros(data.x.shape[0]) for epoch in range(self.epoch): start_time = time.time() + epoch_loss = 0 if self.gan: - epoch_loss_g = 0 - epoch_loss_d = 0 - else: - epoch_loss = 0 + self.epoch_loss_in = 0 for sampled_data in loader: batch_size = sampled_data.batch_size node_idx = sampled_data.n_id loss, score = self.forward_model(sampled_data) - - if self.gan: - epoch_loss_g += loss[0].item() * batch_size - epoch_loss_d += loss[1].item() * batch_size - else: - epoch_loss += loss.item() * batch_size - + epoch_loss += loss.item() * batch_size if self.save_emb: if type(self.emb) is tuple: self.emb[0][node_idx[:batch_size]] = \ @@ -483,23 +475,13 @@ def fit(self, data, label=None): self.model.emb[:batch_size].cpu() self.decision_score_[node_idx[:batch_size]] = score - if self.gan: - opt_g.zero_grad() - loss[0].backward() - opt_g.step() - opt_d.zero_grad() - loss[0].backward() - opt_d.step() - else: - optimizer.zero_grad() - loss.backward() - optimizer.step() + optimizer.zero_grad() + loss.backward() + optimizer.step() + loss_value = epoch_loss / data.x.shape[0] if self.gan: - loss_value = (self.epoch_loss_g / data.x.shape[0], - self.epoch_loss_d / data.x.shape[0]) - else: - loss_value = epoch_loss / data.x.shape[0] + loss_value = (self.epoch_loss_in / data.x.shape[0], loss_value) logger(epoch=epoch, loss=loss_value, score=self.decision_score_, @@ -527,11 +509,7 @@ def decision_function(self, data, label=None): else: self.emb = torch.zeros(data.x.shape[0], self.hid_dim) start_time = time.time() - if self.gan: - test_loss_g = 0 - test_loss_d = 0 - else: - test_loss = 0 + test_loss = 0 for sampled_data in loader: loss, score = self.forward_model(sampled_data) batch_size = sampled_data.batch_size @@ -546,19 +524,12 @@ def decision_function(self, data, label=None): self.emb[node_idx[:batch_size]] = \ self.model.emb[:batch_size].cpu() - if self.gan: - test_loss_g += loss[0].item() * batch_size - test_loss_d = loss[1].item() * batch_size - else: - test_loss = loss.item() * batch_size - + test_loss = loss.item() * batch_size outlier_score[node_idx[:batch_size]] = score + loss_value = test_loss / data.x.shape[0] if self.gan: - loss_value = (test_loss_g / data.x.shape[0], - test_loss_d / data.x.shape[0]) - else: - loss_value = test_loss / data.x.shape[0] + loss_value = (self.epoch_loss_in / data.x.shape[0], loss_value) logger(loss=loss_value, score=outlier_score, diff --git a/pygod/detector/gaan.py b/pygod/detector/gaan.py index cc9fa91..da3aecf 100644 --- a/pygod/detector/gaan.py +++ b/pygod/detector/gaan.py @@ -161,6 +161,7 @@ def init_model(self, **kwargs): **kwargs).to(self.device) def forward_model(self, data): + batch_size = data.batch_size node_idx = data.n_id x = data.x.to(self.device) s = data.s.to(self.device) @@ -171,9 +172,15 @@ def forward_model(self, data): x_, a, a_ = self.model(x, noise) loss_g = self.model.loss_func_g(a_[edge_index[0], edge_index[1]]) - loss_d = self.model.loss_func_ed(a[edge_index[0], edge_index[1]], - a_[edge_index[0], edge_index[ - 1]].detach()) + self.opt_in.zero_grad() + loss_g.backward() + self.opt_in.step() + + self.epoch_loss_in += loss_g.item() * batch_size + + loss = self.model.loss_func_ed(a[edge_index[0], edge_index[1]], + a_[edge_index[0], edge_index[ + 1]].detach()) score = self.model.score_func(x=x, x_=x_, @@ -183,4 +190,4 @@ def forward_model(self, data): pos_weight_s=1, bce_s=True) - return (loss_g, loss_d), score.detach().cpu() + return loss, score.detach().cpu() diff --git a/pygod/nn/adone.py b/pygod/nn/adone.py index 178e99c..2c529e8 100644 --- a/pygod/nn/adone.py +++ b/pygod/nn/adone.py @@ -92,6 +92,8 @@ def __init__(self, dropout=dropout, act=torch.tanh) self.emb = None + self.inner = self.discriminator + self.outer = self.generator def forward(self, x, s, edge_index): """ @@ -126,15 +128,13 @@ def forward(self, x, s, edge_index): Structure discriminator score. """ x_, s_, h_a, h_s, dna, dns = self.generator(x, s, edge_index) - dis_a = torch.sigmoid(self.discriminator(h_a)) - dis_s = torch.sigmoid(self.discriminator(h_s)) self.emb = (h_a, h_s) - return x_, s_, h_a, h_s, dna, dns, dis_a, dis_s + return x_, s_, h_a, h_s, dna, dns - def loss_func(self, x, x_, s, s_, h_a, h_s, dna, dns, dis_a, dis_s): + def loss_func_g(self, x, x_, s, s_, h_a, h_s, dna, dns): """ - Loss function for AdONE. + Generator loss function for AdONE. Parameters ---------- @@ -154,10 +154,6 @@ def loss_func(self, x, x_, s, s_, h_a, h_s, dna, dns, dis_a, dis_s): Attribute neighbor distance. dns : torch.Tensor Structure neighbor distance. - dis_a : torch.Tensor - Attribute discriminator score. - dis_s : torch.Tensor - Structure discriminator score. Returns ------- @@ -198,6 +194,8 @@ def loss_func(self, x, x_, s, s_, h_a, h_s, dna, dns, dis_a, dis_s): # equation 3 loss_hom_s = torch.mean(torch.log(torch.pow(os, -1)) * dns) + dis_a = torch.sigmoid(self.discriminator(h_a)) + dis_s = torch.sigmoid(self.discriminator(h_s)) # equation 12 loss_alg = torch.mean(torch.log(torch.pow(oc, -1)) * (torch.log(1 - dis_a) + torch.log(dis_s))) @@ -211,6 +209,28 @@ def loss_func(self, x, x_, s, s_, h_a, h_s, dna, dns, dis_a, dis_s): return loss, oa, os, oc + def loss_func_d(self, h_a, h_s): + """ + Discriminator loss function for AdONE. + + Parameters + ---------- + h_a : torch.Tensor + Attribute hidden embeddings. + h_s : torch.Tensor + Structure hidden embeddings. + + Returns + ------- + loss : torch.Tensor + Loss value. + """ + # equation 11 + dis_a = torch.sigmoid(self.discriminator(h_a)) + dis_s = torch.sigmoid(self.discriminator(h_s)) + loss = - torch.mean(torch.log(1 - dis_a) + torch.log(dis_s)) + return loss + @staticmethod def process_graph(data): """ diff --git a/pygod/nn/gaan.py b/pygod/nn/gaan.py index dba9449..126e4a0 100644 --- a/pygod/nn/gaan.py +++ b/pygod/nn/gaan.py @@ -74,6 +74,9 @@ def __init__(self, self.emb = None self.score_func = double_recon_loss + self.inner = self.generator + self.outer = self.discriminator + def forward(self, x, noise): """ Forward computation.