Skip to content

Commit

Permalink
Add files via upload
Browse files Browse the repository at this point in the history
GNN network file updated.
  • Loading branch information
engrodawood authored Jun 5, 2023
1 parent 786197a commit 2fc7977
Showing 1 changed file with 9 additions and 24 deletions.
33 changes: 9 additions & 24 deletions workspace/model/gnn.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from utils import *
from torch.nn import BatchNorm1d
from torch.nn import Sequential, Linear, ReLU, Tanh, LeakyReLU, ELU, SELU, GELU,Sigmoid
from torch_geometric.nn import GINConv, EdgeConv,PNAConv, DynamicEdgeConv, global_add_pool, global_mean_pool, global_max_pool
from tqdm import tqdm
import torch.nn as nn
from utils import *
# %% Graph Neural Network


Expand Down Expand Up @@ -61,36 +61,28 @@ def __init__(self, dim_features, dim_target, layers=[16, 16, 8], pooling='max',
self.first_h = Sequential(Linear(dim_features, out_emb_dim), BatchNorm1d(out_emb_dim),GELU())
self.linears.append(
Sequential(
#Updated latest
Linear(out_emb_dim, self.featd), BatchNorm1d(self.featd),ELU(),
Linear(self.featd, dim_target), BatchNorm1d(dim_target),#ELU(),
Linear(out_emb_dim, dim_target), BatchNorm1d(dim_target),ELU(),
)
)

else:
input_emb_dim = self.embeddings_dim[layer-1]
self.linears.append(
Sequential(
# Linear(out_emb_dim, dim_target), BatchNorm1d(dim_target),ELU(),
# Linear(self.featd, dim_target), BatchNorm1d(dim_target),
Linear(out_emb_dim, self.featd), BatchNorm1d(self.featd),ELU(),
Linear(self.featd, dim_target), BatchNorm1d(dim_target),
Linear(out_emb_dim, dim_target), BatchNorm1d(dim_target),ELU(),
)
)
if conv == 'GINConv':
subnet = Sequential(
Linear(input_emb_dim, out_emb_dim), BatchNorm1d(out_emb_dim), ELU(),
#Linear(self.featd, out_emb_dim), BatchNorm1d(out_emb_dim)#, ELU(),
Linear(input_emb_dim, self.featd), BatchNorm1d(self.featd), ELU(),
Linear(self.featd, out_emb_dim), BatchNorm1d(out_emb_dim)#, ELU(),

)
self.nns.append(subnet)
# Eq. 4.2 eps=100, train_eps=False
# import pdb; pdb.set_trace()
self.convs.append(GINConv(self.nns[-1], **kwargs))
elif conv == 'EdgeConv':
subnet = Sequential(
Linear(2*input_emb_dim,out_emb_dim), BatchNorm1d(out_emb_dim), ELU(),
# Linear(self.featd,out_emb_dim),BatchNorm1d(out_emb_dim)#,ELU(),
)
self.nns.append(subnet)
# DynamicEdgeConv#EdgeConv aggr='mean'
Expand Down Expand Up @@ -118,15 +110,12 @@ def forward(self, data):
import torch.nn.functional as F
for layer in range(self.no_layers):
if layer == 0:
# Uncomment this line and use smaller size for node level features if you are
# facing GPU memory problem.
#x = self.first_h(x)
x = self.first_h(x)
z = self.linears[layer](x)
Z += z
dout = F.dropout(pooling(z, batch),
p=self.dropout, training=self.training)
out += dout
#
else:
x = self.convs[layer-1](x, edge_index)
if not self.gembed:
Expand Down Expand Up @@ -317,8 +306,6 @@ def _pair_train(self, train_loader, optimizer, clipping=None):
if clipping is not None: # Clip gradient before updating weights
torch.nn.utils.clip_grad_norm_(model.parameters(), clipping)
optimizer.step()
# print(pair_count)
#import pdb;pdb.set_trace()
return acc_all, loss_all

def classify_graphs(self, loader):
Expand All @@ -327,15 +314,13 @@ def classify_graphs(self, loader):
R = np.full_like(np.zeros((n_classes, 2)),np.nan)
for i in range(n_classes):
try:
vidx = ~np.isnan(Y[:,i])
R[i] = np.array(
[calc_roc_auc(Y[vidx, i], Z[vidx, i]), calc_pr(Y[vidx, i], Z[vidx, i])])
[calc_roc_auc(Y[:, i], Z[:, i]), calc_pr(Y[:, i], Z[:, i])])
except:
import pdb; pdb.set_trace()
print('Only one class')
# Rf = R[R[:,0]!=-1]
loss = 0
#
# import pdb; pdb.set_trace()
return np.nanmedian(R,0)[0], loss,np.nanmedian(R,0)[1]

def train(self, train_loader, max_epochs=100, optimizer=torch.optim.Adam, scheduler=None, clipping=None,
Expand Down Expand Up @@ -456,4 +441,4 @@ def train(self, train_loader, max_epochs=100, optimizer=torch.optim.Adam, schedu
test_pr = test_pr_at_best_val_acc

Q.reverse()
return Q, train_loss, train_acc, val_loss, np.round(val_acc, 2), test_loss, np.round(test_acc, 2), val_pr, test_pr
return Q, train_loss, train_acc, val_loss, np.round(val_acc, 2), test_loss, np.round(test_acc, 2), val_pr, test_pr

0 comments on commit 2fc7977

Please sign in to comment.