From a9a98b21d59468e627e7ac26d67389d0d2add032 Mon Sep 17 00:00:00 2001 From: "Kay (Zekuan) Liu" Date: Wed, 31 Jan 2024 21:35:10 -0600 Subject: [PATCH] fix GAAN --- pygod/detector/gaan.py | 28 ++++++++++++++-------------- pygod/nn/gaan.py | 1 - pygod/test/test_gaan.py | 3 --- 3 files changed, 14 insertions(+), 18 deletions(-) diff --git a/pygod/detector/gaan.py b/pygod/detector/gaan.py index da3aecf3..040bcab5 100644 --- a/pygod/detector/gaan.py +++ b/pygod/detector/gaan.py @@ -7,6 +7,7 @@ import warnings import torch.nn.functional as F from torch_geometric.nn import MLP +from torch_geometric.utils import to_dense_adj from ..nn import GAANBase from . import DeepDetector @@ -110,7 +111,7 @@ def __init__(self, epoch=100, gpu=-1, batch_size=0, - num_neigh=0, + num_neigh=-1, weight=0.5, verbose=0, save_emb=False, @@ -120,16 +121,16 @@ def __init__(self, self.noise_dim = noise_dim self.weight = weight - if num_neigh != 0: - warnings.warn('MLP in GAAN does not use neighbor information.') - num_neigh = 0 + # self.num_layers is 1 for sample one hop neighbors + # In GAAN, self.model_layers is for model layers + self.model_layers = num_layers if backbone is not None: warnings.warn('GAAN can only use MLP as the backbone.') super(GAAN, self).__init__( hid_dim=hid_dim, - num_layers=num_layers, + num_layers=1, dropout=dropout, weight_decay=weight_decay, act=act, @@ -155,7 +156,7 @@ def init_model(self, **kwargs): return GAANBase(in_dim=self.in_dim, noise_dim=self.noise_dim, hid_dim=self.hid_dim, - num_layers=self.num_layers, + num_layers=self.model_layers, dropout=self.dropout, act=self.act, **kwargs).to(self.device) @@ -171,21 +172,20 @@ def forward_model(self, data): x_, a, a_ = self.model(x, noise) - loss_g = self.model.loss_func_g(a_[edge_index[0], edge_index[1]]) + loss_g = self.model.loss_func_g(a_[edge_index]) self.opt_in.zero_grad() loss_g.backward() self.opt_in.step() self.epoch_loss_in += loss_g.item() * batch_size - loss = self.model.loss_func_ed(a[edge_index[0], edge_index[1]], - a_[edge_index[0], edge_index[ - 1]].detach()) + loss = self.model.loss_func_ed(a[edge_index], + a_[edge_index].detach()) - score = self.model.score_func(x=x, - x_=x_, - s=s[:, node_idx], - s_=a, + score = self.model.score_func(x=x[:batch_size], + x_=x_[:batch_size], + s=s[:batch_size, node_idx], + s_=a[:batch_size], weight=self.weight, pos_weight_s=1, bce_s=True) diff --git a/pygod/nn/gaan.py b/pygod/nn/gaan.py index 126e4a02..89680f07 100644 --- a/pygod/nn/gaan.py +++ b/pygod/nn/gaan.py @@ -109,7 +109,6 @@ def forward(self, x, noise): @staticmethod def loss_func_g(a_): - loss_g = F.binary_cross_entropy(a_, torch.ones_like(a_)) return loss_g diff --git a/pygod/test/test_gaan.py b/pygod/test/test_gaan.py index 09284201..7a7725ff 100644 --- a/pygod/test/test_gaan.py +++ b/pygod/test/test_gaan.py @@ -119,6 +119,3 @@ def test_sample(self): def test_params(self): with assert_warns(UserWarning): GAAN(backbone=GIN) - - with assert_warns(UserWarning): - GAAN(num_neigh=2)