Skip to content

Commit

Permalink
fix gan
Browse files Browse the repository at this point in the history
  • Loading branch information
kayzliu committed Jan 31, 2024
1 parent fd2a662 commit 0ebcf5a
Show file tree
Hide file tree
Showing 5 changed files with 83 additions and 78 deletions.
36 changes: 20 additions & 16 deletions pygod/detector/adone.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,27 +198,31 @@ def forward_model(self, data):
s = data.s.to(self.device)
edge_index = data.edge_index.to(self.device)

x_, s_, h_a, h_s, dna, dns, dis_a, dis_s = self.model(x, s, edge_index)

loss_d = - torch.mean(torch.log(1 - dis_a[:batch_size])
+ torch.log(dis_s[:batch_size]))

loss_g, oa, os, oc = self.model.loss_func(x[:batch_size],
x_[:batch_size],
s[:batch_size],
s_[:batch_size],
h_a[:batch_size],
h_s[:batch_size],
dna[:batch_size],
dns[:batch_size],
dis_a[:batch_size],
dis_s[:batch_size])
x_, s_, h_a, h_s, dna, dns = self.model(x, s, edge_index)

loss_d = self.model.loss_func_d(h_a[:batch_size].detach(),
h_s[:batch_size].detach())

self.opt_in.zero_grad()
loss_d.backward()
self.opt_in.step()

self.epoch_loss_in += loss_d.item() * batch_size

loss_g, oa, os, oc = self.model.loss_func_g(x[:batch_size],
x_[:batch_size],
s[:batch_size],
s_[:batch_size],
h_a[:batch_size],
h_s[:batch_size],
dna[:batch_size],
dns[:batch_size])

self.attribute_score_[node_idx[:batch_size]] = oa.detach().cpu()
self.structural_score_[node_idx[:batch_size]] = os.detach().cpu()
self.combined_score_[node_idx[:batch_size]] = oc.detach().cpu()

return (loss_g, loss_d), ((oa + os + oc) / 3).detach().cpu()
return loss_g, ((oa + os + oc) / 3).detach().cpu()

def decision_function(self, data, label=None):
if data is not None:
Expand Down
69 changes: 20 additions & 49 deletions pygod/detector/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,39 +439,31 @@ def fit(self, data, label=None):
self.model = self.init_model(**self.kwargs)
if self.compile_model:
self.model = compile(self.model)
if self.gan:
opt_g = torch.optim.Adam(self.model.generator.parameters(),
lr=self.lr,
weight_decay=self.weight_decay)
opt_d = torch.optim.Adam(self.model.discriminator.parameters(),
lr=self.lr,
weight_decay=self.weight_decay)
else:
if not self.gan:
optimizer = torch.optim.Adam(self.model.parameters(),
lr=self.lr,
weight_decay=self.weight_decay)
else:
self.opt_in = torch.optim.Adam(self.model.inner.parameters(),
lr=self.lr,
weight_decay=self.weight_decay)
optimizer = torch.optim.Adam(self.model.outer.parameters(),
lr=self.lr,
weight_decay=self.weight_decay)

self.model.train()
self.decision_score_ = torch.zeros(data.x.shape[0])
for epoch in range(self.epoch):
start_time = time.time()
epoch_loss = 0
if self.gan:
epoch_loss_g = 0
epoch_loss_d = 0
else:
epoch_loss = 0
self.epoch_loss_in = 0
for sampled_data in loader:
batch_size = sampled_data.batch_size
node_idx = sampled_data.n_id

loss, score = self.forward_model(sampled_data)

if self.gan:
epoch_loss_g += loss[0].item() * batch_size
epoch_loss_d += loss[1].item() * batch_size
else:
epoch_loss += loss.item() * batch_size

epoch_loss += loss.item() * batch_size
if self.save_emb:
if type(self.emb) is tuple:
self.emb[0][node_idx[:batch_size]] = \
Expand All @@ -483,23 +475,13 @@ def fit(self, data, label=None):
self.model.emb[:batch_size].cpu()
self.decision_score_[node_idx[:batch_size]] = score

if self.gan:
opt_g.zero_grad()
loss[0].backward()
opt_g.step()
opt_d.zero_grad()
loss[0].backward()
opt_d.step()
else:
optimizer.zero_grad()
loss.backward()
optimizer.step()
optimizer.zero_grad()
loss.backward()
optimizer.step()

loss_value = epoch_loss / data.x.shape[0]
if self.gan:
loss_value = (self.epoch_loss_g / data.x.shape[0],
self.epoch_loss_d / data.x.shape[0])
else:
loss_value = epoch_loss / data.x.shape[0]
loss_value = (self.epoch_loss_in / data.x.shape[0], loss_value)
logger(epoch=epoch,
loss=loss_value,
score=self.decision_score_,
Expand Down Expand Up @@ -527,11 +509,7 @@ def decision_function(self, data, label=None):
else:
self.emb = torch.zeros(data.x.shape[0], self.hid_dim)
start_time = time.time()
if self.gan:
test_loss_g = 0
test_loss_d = 0
else:
test_loss = 0
test_loss = 0
for sampled_data in loader:
loss, score = self.forward_model(sampled_data)
batch_size = sampled_data.batch_size
Expand All @@ -546,19 +524,12 @@ def decision_function(self, data, label=None):
self.emb[node_idx[:batch_size]] = \
self.model.emb[:batch_size].cpu()

if self.gan:
test_loss_g += loss[0].item() * batch_size
test_loss_d = loss[1].item() * batch_size
else:
test_loss = loss.item() * batch_size

test_loss = loss.item() * batch_size
outlier_score[node_idx[:batch_size]] = score

loss_value = test_loss / data.x.shape[0]
if self.gan:
loss_value = (test_loss_g / data.x.shape[0],
test_loss_d / data.x.shape[0])
else:
loss_value = test_loss / data.x.shape[0]
loss_value = (self.epoch_loss_in / data.x.shape[0], loss_value)

logger(loss=loss_value,
score=outlier_score,
Expand Down
15 changes: 11 additions & 4 deletions pygod/detector/gaan.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,7 @@ def init_model(self, **kwargs):
**kwargs).to(self.device)

def forward_model(self, data):
batch_size = data.batch_size
node_idx = data.n_id
x = data.x.to(self.device)
s = data.s.to(self.device)
Expand All @@ -171,9 +172,15 @@ def forward_model(self, data):
x_, a, a_ = self.model(x, noise)

loss_g = self.model.loss_func_g(a_[edge_index[0], edge_index[1]])
loss_d = self.model.loss_func_ed(a[edge_index[0], edge_index[1]],
a_[edge_index[0], edge_index[
1]].detach())
self.opt_in.zero_grad()
loss_g.backward()
self.opt_in.step()

self.epoch_loss_in += loss_g.item() * batch_size

loss = self.model.loss_func_ed(a[edge_index[0], edge_index[1]],
a_[edge_index[0], edge_index[
1]].detach())

score = self.model.score_func(x=x,
x_=x_,
Expand All @@ -183,4 +190,4 @@ def forward_model(self, data):
pos_weight_s=1,
bce_s=True)

return (loss_g, loss_d), score.detach().cpu()
return loss, score.detach().cpu()
38 changes: 29 additions & 9 deletions pygod/nn/adone.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,8 @@ def __init__(self,
dropout=dropout,
act=torch.tanh)
self.emb = None
self.inner = self.discriminator
self.outer = self.generator

def forward(self, x, s, edge_index):
"""
Expand Down Expand Up @@ -126,15 +128,13 @@ def forward(self, x, s, edge_index):
Structure discriminator score.
"""
x_, s_, h_a, h_s, dna, dns = self.generator(x, s, edge_index)
dis_a = torch.sigmoid(self.discriminator(h_a))
dis_s = torch.sigmoid(self.discriminator(h_s))
self.emb = (h_a, h_s)

return x_, s_, h_a, h_s, dna, dns, dis_a, dis_s
return x_, s_, h_a, h_s, dna, dns

def loss_func(self, x, x_, s, s_, h_a, h_s, dna, dns, dis_a, dis_s):
def loss_func_g(self, x, x_, s, s_, h_a, h_s, dna, dns):
"""
Loss function for AdONE.
Generator loss function for AdONE.
Parameters
----------
Expand All @@ -154,10 +154,6 @@ def loss_func(self, x, x_, s, s_, h_a, h_s, dna, dns, dis_a, dis_s):
Attribute neighbor distance.
dns : torch.Tensor
Structure neighbor distance.
dis_a : torch.Tensor
Attribute discriminator score.
dis_s : torch.Tensor
Structure discriminator score.
Returns
-------
Expand Down Expand Up @@ -198,6 +194,8 @@ def loss_func(self, x, x_, s, s_, h_a, h_s, dna, dns, dis_a, dis_s):
# equation 3
loss_hom_s = torch.mean(torch.log(torch.pow(os, -1)) * dns)

dis_a = torch.sigmoid(self.discriminator(h_a))
dis_s = torch.sigmoid(self.discriminator(h_s))
# equation 12
loss_alg = torch.mean(torch.log(torch.pow(oc, -1))
* (torch.log(1 - dis_a) + torch.log(dis_s)))
Expand All @@ -211,6 +209,28 @@ def loss_func(self, x, x_, s, s_, h_a, h_s, dna, dns, dis_a, dis_s):

return loss, oa, os, oc

def loss_func_d(self, h_a, h_s):
"""
Discriminator loss function for AdONE.
Parameters
----------
h_a : torch.Tensor
Attribute hidden embeddings.
h_s : torch.Tensor
Structure hidden embeddings.
Returns
-------
loss : torch.Tensor
Loss value.
"""
# equation 11
dis_a = torch.sigmoid(self.discriminator(h_a))
dis_s = torch.sigmoid(self.discriminator(h_s))
loss = - torch.mean(torch.log(1 - dis_a) + torch.log(dis_s))
return loss

@staticmethod
def process_graph(data):
"""
Expand Down
3 changes: 3 additions & 0 deletions pygod/nn/gaan.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,9 @@ def __init__(self,
self.emb = None
self.score_func = double_recon_loss

self.inner = self.generator
self.outer = self.discriminator

def forward(self, x, noise):
"""
Forward computation.
Expand Down

0 comments on commit 0ebcf5a

Please sign in to comment.