Skip to content

Commit

Permalink
fix GAAN
Browse files Browse the repository at this point in the history
  • Loading branch information
kayzliu committed Feb 1, 2024
1 parent bed2c9c commit a9a98b2
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 18 deletions.
28 changes: 14 additions & 14 deletions pygod/detector/gaan.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import warnings
import torch.nn.functional as F
from torch_geometric.nn import MLP
from torch_geometric.utils import to_dense_adj

from ..nn import GAANBase
from . import DeepDetector
Expand Down Expand Up @@ -110,7 +111,7 @@ def __init__(self,
epoch=100,
gpu=-1,
batch_size=0,
num_neigh=0,
num_neigh=-1,
weight=0.5,
verbose=0,
save_emb=False,
Expand All @@ -120,16 +121,16 @@ def __init__(self,
self.noise_dim = noise_dim
self.weight = weight

if num_neigh != 0:
warnings.warn('MLP in GAAN does not use neighbor information.')
num_neigh = 0
# self.num_layers is 1 for sample one hop neighbors
# In GAAN, self.model_layers is for model layers
self.model_layers = num_layers

if backbone is not None:
warnings.warn('GAAN can only use MLP as the backbone.')

super(GAAN, self).__init__(
hid_dim=hid_dim,
num_layers=num_layers,
num_layers=1,
dropout=dropout,
weight_decay=weight_decay,
act=act,
Expand All @@ -155,7 +156,7 @@ def init_model(self, **kwargs):
return GAANBase(in_dim=self.in_dim,
noise_dim=self.noise_dim,
hid_dim=self.hid_dim,
num_layers=self.num_layers,
num_layers=self.model_layers,
dropout=self.dropout,
act=self.act,
**kwargs).to(self.device)
Expand All @@ -171,21 +172,20 @@ def forward_model(self, data):

x_, a, a_ = self.model(x, noise)

loss_g = self.model.loss_func_g(a_[edge_index[0], edge_index[1]])
loss_g = self.model.loss_func_g(a_[edge_index])
self.opt_in.zero_grad()
loss_g.backward()
self.opt_in.step()

self.epoch_loss_in += loss_g.item() * batch_size

loss = self.model.loss_func_ed(a[edge_index[0], edge_index[1]],
a_[edge_index[0], edge_index[
1]].detach())
loss = self.model.loss_func_ed(a[edge_index],
a_[edge_index].detach())

score = self.model.score_func(x=x,
x_=x_,
s=s[:, node_idx],
s_=a,
score = self.model.score_func(x=x[:batch_size],
x_=x_[:batch_size],
s=s[:batch_size, node_idx],
s_=a[:batch_size],
weight=self.weight,
pos_weight_s=1,
bce_s=True)
Expand Down
1 change: 0 additions & 1 deletion pygod/nn/gaan.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,6 @@ def forward(self, x, noise):

@staticmethod
def loss_func_g(a_):

loss_g = F.binary_cross_entropy(a_, torch.ones_like(a_))
return loss_g

Expand Down
3 changes: 0 additions & 3 deletions pygod/test/test_gaan.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,3 @@ def test_sample(self):
def test_params(self):
with assert_warns(UserWarning):
GAAN(backbone=GIN)

with assert_warns(UserWarning):
GAAN(num_neigh=2)

0 comments on commit a9a98b2

Please sign in to comment.