Skip to content

Commit

Permalink
Upload model and trainers
Browse files Browse the repository at this point in the history
  • Loading branch information
emirceyani committed Jun 8, 2021
1 parent 66cc74d commit 3ecf59a
Show file tree
Hide file tree
Showing 6 changed files with 914 additions and 0 deletions.
115 changes: 115 additions & 0 deletions model/gat_readout.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
import torch
import torch.nn as nn
import torch.nn.functional as F


class GraphAttentionLayer(nn.Module):
"""
Simple GAT layer, similar to https://arxiv.org/abs/1710.10903
"""

def __init__(self, in_features, out_features, dropout, alpha, concat=True):
super(GraphAttentionLayer, self).__init__()
self.dropout = dropout
self.in_features = in_features
self.out_features = out_features
self.alpha = alpha
self.concat = concat

self.W = nn.Parameter(torch.empty(size=(in_features, out_features)))
nn.init.xavier_uniform_(self.W.data, gain=1.414)
self.a = nn.Parameter(torch.empty(size=(2 * out_features, 1)))
nn.init.xavier_uniform_(self.a.data, gain=1.414)

self.leakyrelu = nn.LeakyReLU(self.alpha)

def forward(self, h, adj):
Wh = torch.mm(h, self.W) # h.shape: (N, in_features), Wh.shape: (N, out_features)
a_input = self._prepare_attentional_mechanism_input(Wh)
e = self.leakyrelu(torch.matmul(a_input, self.a).squeeze(2))

zero_vec = -9e15 * torch.ones_like(e)
attention = torch.where(adj > 0, e, zero_vec)
attention = F.softmax(attention, dim=1)
attention = F.dropout(attention, self.dropout, training=self.training)
h_prime = torch.matmul(attention, Wh)

if self.concat:
return F.elu(h_prime)
else:
return h_prime

def _prepare_attentional_mechanism_input(self, Wh):
N = Wh.size()[0]
Wh_repeated_in_chunks = Wh.repeat_interleave(N, dim=0)
Wh_repeated_alternating = Wh.repeat(N, 1)
all_combinations_matrix = torch.cat([Wh_repeated_in_chunks, Wh_repeated_alternating], dim=1)

return all_combinations_matrix.view(N, N, 2 * self.out_features)

def __repr__(self):
return self.__class__.__name__ + ' (' + str(self.in_features) + ' -> ' + str(self.out_features) + ')'


class GAT(nn.Module):
def __init__(self, feat_dim, hidden_dim1, hidden_dim2, dropout, alpha, nheads):
"""Dense version of GAT."""
super(GAT, self).__init__()
self.dropout = dropout

self.attentions = nn.ModuleList(
[GraphAttentionLayer(feat_dim, hidden_dim1, dropout=dropout, alpha=alpha, concat=True) for _ in range(nheads)])
self.out_att = GraphAttentionLayer(hidden_dim1 * nheads, hidden_dim2, dropout=dropout, alpha=alpha, concat=False)

def forward(self, x, adj):
x = F.dropout(x, self.dropout, training=self.training)
x = torch.cat([att(x, adj) for att in self.attentions], dim=1)
x = F.dropout(x, self.dropout, training=self.training)
embeddings = F.elu(self.out_att(x, adj)) # Node embeddings

return embeddings


class Readout(nn.Module):
"""
This module learns a single graph level representation for a molecule given GNN generated node embeddings
"""

def __init__(self, attr_dim, embedding_dim, hidden_dim, output_dim, num_cats):
super(Readout, self).__init__()
self.attr_dim = attr_dim
self.hidden_dim = hidden_dim
self.output_dim = output_dim
self.num_cats = num_cats

self.layer1 = nn.Linear(attr_dim + embedding_dim, hidden_dim)
self.layer2 = nn.Linear(hidden_dim, output_dim)
self.output = nn.Linear(output_dim, num_cats)
self.act = nn.ReLU()

def forward(self, node_features, node_embeddings):
combined_rep = torch.cat((node_features, node_embeddings),
dim=1) # Concat initial node attributed with embeddings from sage
hidden_rep = self.act(self.layer1(combined_rep))
graph_rep = self.act(self.layer2(hidden_rep)) # Generate final graph level embedding

logits = torch.mean(self.output(graph_rep), dim=0) # Generated logits for multilabel classification

return logits


class GatMoleculeNet(nn.Module):
"""
Network that consolidates GAT + Readout into a single nn.Module
"""

def __init__(self, feat_dim, gat_hidden_dim1, node_embedding_dim, gat_dropout, gat_alpha, gat_nheads, readout_hidden_dim, graph_embedding_dim,
num_categories):
super(GatMoleculeNet, self).__init__()
self.gat = GAT(feat_dim, gat_hidden_dim1, node_embedding_dim, gat_dropout, gat_alpha, gat_nheads)
self.readout = Readout(feat_dim, node_embedding_dim, readout_hidden_dim, graph_embedding_dim, num_categories)

def forward(self, adj_matrix, feature_matrix):
node_embeddings = self.gat(feature_matrix, adj_matrix)
logits = self.readout(feature_matrix, node_embeddings)
return logits
89 changes: 89 additions & 0 deletions model/sage_readout.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
import torch
import torch.nn as nn


class GraphSage(nn.Module):
"""
GraphSAGE model (https://arxiv.org/abs/1706.02216) to learn the role of atoms in the molecules inductively.
Transforms input features into a fixed length embedding in a vector space. The embedding captures the role.
"""

def __init__(self, feat_dim, hidden_dim1, hidden_dim2, dropout):
super(GraphSage, self).__init__()

self.feat_dim = feat_dim
self.hidden_dim1 = hidden_dim1
self.hidden_dim2 = hidden_dim2

self.layer1 = nn.Linear(2 * feat_dim, hidden_dim1, bias=False)
self.layer2 = nn.Linear(2 * hidden_dim1, hidden_dim2, bias=False)

self.relu = nn.ReLU()
self.dropout = nn.Dropout(p=dropout)

def forward(self, forest, feature_matrix):
feat_0 = feature_matrix[forest[0]] # Of shape torch.Size([|B|, feat_dim])
feat_1 = feature_matrix[forest[1]] # Of shape torch.size(|B|, fanouts[0], feat_dim)

# Depth 1
x = feature_matrix[forest[1]].mean(dim=1) # Of shape torch.size(|B|, feat_dim)
feat_0 = torch.cat((feat_0, x), dim=1) # Of shape torch.size(|B|, 2 * feat_dim)
feat_0 = self.relu(self.layer1(feat_0)) # Of shape torch.size(|B|, hidden_dim1)
feat_0 = self.dropout(feat_0)

# Depth 2
x = feature_matrix[forest[2]].mean(dim=1) # Of shape torch.size(|B|*fanouts[0], feat_dim)
feat_1 = torch.cat((feat_1.reshape(-1, self.feat_dim), x),
dim=1) # Of shape torch.size(|B|*fanouts[0], 2 * feat_dim)
feat_1 = self.relu(self.layer1(feat_1)) # Of shape torch.size(|B|*fanouts[0], hidden_dim1)
feat_1 = self.dropout(feat_1)

# Combine
feat_1 = feat_1.reshape(forest[0].shape[0], -1, self.hidden_dim1).mean(
dim=1) # Of shape torch.size([|B|, hidden_dim_1])
combined = torch.cat((feat_0, feat_1), dim=1) # Of shape torch.Size(|B|, 2 * hidden_dim1)
embeddings = self.relu(self.layer2(combined)) # Of shape torch.Size(|B|, hidden_dim2)

return embeddings


class Readout(nn.Module):
"""
This module learns a single graph level representation for a molecule given GraphSAGE generated embeddings
"""
def __init__(self, attr_dim, embedding_dim, hidden_dim, output_dim, num_cats):
super(Readout, self).__init__()
self.attr_dim = attr_dim
self.hidden_dim = hidden_dim
self.output_dim = output_dim
self.num_cats = num_cats

self.layer1 = nn.Linear(attr_dim + embedding_dim, hidden_dim)
self.layer2 = nn.Linear(hidden_dim, output_dim)
self.output = nn.Linear(output_dim, num_cats)
self.act = nn.ReLU()

def forward(self, node_features, node_embeddings):
combined_rep = torch.cat((node_features, node_embeddings),
dim=1) # Concat initial node attributed with embeddings from sage
hidden_rep = self.act(self.layer1(combined_rep))
graph_rep = self.act(self.layer2(hidden_rep)) # Generate final graph level embedding

logits = torch.mean(self.output(graph_rep), dim=0) # Generated logits for multilabel classification

return logits


class SageMoleculeNet(nn.Module):
"""
Network that consolidates Sage + Readout into a single nn.Module
"""
def __init__(self, feat_dim, sage_hidden_dim1, node_embedding_dim, sage_dropout, readout_hidden_dim, graph_embedding_dim, num_categories):
super(SageMoleculeNet, self).__init__()
self.sage = GraphSage(feat_dim, sage_hidden_dim1, node_embedding_dim, sage_dropout)
self.readout = Readout(feat_dim, node_embedding_dim, readout_hidden_dim, graph_embedding_dim, num_categories)

def forward(self, forest, feature_matrix):
node_embeddings = self.sage(forest, feature_matrix)
logits = self.readout(feature_matrix, node_embeddings)
return logits
184 changes: 184 additions & 0 deletions training/gat_readout_trainer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
import logging

import numpy as np
import torch
import wandb
from sklearn.metrics import roc_auc_score, precision_recall_curve, auc
from torch import trace, inverse
from scipy.linalg import fractional_matrix_power

from FedML.fedml_core.trainer.model_trainer import ModelTrainer


# Trainer for MoleculeNet. The evaluation metric is ROC-AUC

class GatMoleculeNetTrainer(ModelTrainer):

def get_model_params(self):
return self.model.cpu().state_dict()

def set_model_params(self, model_parameters):
logging.info("set_model_params")
self.model.load_state_dict(model_parameters)

def train(self, train_data, device, args):
model = self.model
if args.is_mtl:
self.model.omega_corr.to(device)
model.to(device)
model.train()

test_data = None
try:
test_data = self.test_data
except:
pass

criterion = torch.nn.BCEWithLogitsLoss(reduction='none')
if args.client_optimizer == "sgd":
optimizer = torch.optim.SGD(model.parameters(), lr=args.lr)
# optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, weight_decay = args.wd)
else:
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
# optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay = args.wd)


max_test_score = 0
best_model_params = {}
eps = 1e-5
for epoch in range(args.epochs):
for mol_idxs, (adj_matrix, feature_matrix, label, mask , cli_mask) in enumerate(train_data):
# Pass on molecules that have no labels
mask = mask.to(device=device, dtype=torch.float32, non_blocking=True)
cli_mask = cli_mask.to(device=device, dtype=torch.float32, non_blocking=True) if cli_mask is not None else None
mask = mask * cli_mask if cli_mask is not None else mask
if torch.all(mask == 0).item():
continue

optimizer.zero_grad()

adj_matrix = adj_matrix.to(device=device, dtype=torch.float32, non_blocking=True)
feature_matrix = feature_matrix.to(device=device, dtype=torch.float32, non_blocking=True)
label = label.to(device=device, dtype=torch.float32, non_blocking=True)

logits = model(adj_matrix, feature_matrix)
clf_loss = criterion(logits, label) * mask
clf_loss = clf_loss.sum() / mask.sum()

if args.is_mtl:
W = self.model.readout.output.weight
lhs = torch.mm(W.t() , torch.inverse(self.model.omega_corr + eps * torch.eye(self.model.omega_corr.shape[0],
dtype=self.model.omega_corr.dtype, device=self.model.omega_corr.device)))
trace_in = torch.mm(lhs, W)
loss = clf_loss + self.task_reg * torch.trace(trace_in) + args.wd * 0.5 * torch.norm(W)**2
loss.backward()
optimizer.step()
with torch.no_grad():
W = self.model.readout.output.weight
mul = torch.mm(W, W.t())
fr_pow = fractional_matrix_power(mul.cpu(), 1/2)
self.model.omega_corr = torch.nn.Parameter(torch.Tensor(fr_pow / np.trace(fr_pow)).to(device))
self.model.omega_corr.requires_grad=False
else:
clf_loss.backward()
optimizer.step()

if ((mol_idxs + 1) % args.frequency_of_the_test == 0) or (mol_idxs == len(train_data) - 1):
if test_data is not None:
test_score, _ = self.test(self.test_data, device, args)
print('Epoch = {}, Iter = {}/{}: Test Score = {}'.format(epoch, mol_idxs + 1, len(train_data), test_score))
if test_score > max_test_score:
max_test_score = test_score
best_model_params = {k: v.cpu() for k, v in model.state_dict().items()}
print('Current best = {}'.format(max_test_score))

self.task_reg *= args.task_reg_decay

return max_test_score, best_model_params



def test(self, test_data, device, args):
logging.info("----------test--------")
model = self.model
model.eval()
model.to(device)

with torch.no_grad():
y_pred = []
y_true = []
masks = []
for mol_idx, (adj_matrix, feature_matrix, label, mask , _) in enumerate(test_data):
adj_matrix = adj_matrix.to(device=device, dtype=torch.float32, non_blocking=True)
feature_matrix = feature_matrix.to(device=device, dtype=torch.float32, non_blocking=True)

logits = model(adj_matrix, feature_matrix)

y_pred.append(logits.cpu().numpy())
y_true.append(label.cpu().numpy())
masks.append(mask.numpy())

y_pred = np.array(y_pred)
y_true = np.array(y_true)
masks = np.array(masks)

results = []
for label in range(masks.shape[1]):
valid_idxs = np.nonzero(masks[:, label])
truth = y_true[valid_idxs, label].flatten()
pred = y_pred[valid_idxs, label].flatten()

if np.all(truth == 0.0) or np.all(truth == 1.0):
results.append(float('nan'))
else:
if args.metric == 'prc-auc':
precision, recall, _ = precision_recall_curve(truth, pred)
score = auc(recall, precision)
else:
score = roc_auc_score(truth, pred)

results.append(score)

score = np.nanmean(results)

return score, model

def test_on_the_server(self, train_data_local_dict, test_data_local_dict, device, args=None) -> bool:
logging.info("----------test_on_the_server--------")

model_list, score_list = [], []
for client_idx in test_data_local_dict.keys():
test_data = test_data_local_dict[client_idx]
score, model = self.test(test_data, device, args)
for idx in range(len(model_list)):
self._compare_models(model, model_list[idx])
model_list.append(model)
score_list.append(score)
if args.dataset != "pcba":
logging.info('Client {}, Test ROC-AUC score = {}'.format(client_idx, score))
wandb.log({"Client {} Test/ROC-AUC".format(client_idx): score})
else:
logging.info('Client {}, Test PRC-AUC score = {}'.format(client_idx, score))
wandb.log({"Client {} Test/PRC-AUC".format(client_idx): score})
avg_score = np.mean(np.array(score_list))
if args.dataset != "pcba":
logging.info('Test ROC-AUC Score = {}'.format(avg_score))
wandb.log({"Test/ROC-AUC": avg_score})
else:
logging.info('Test PRC-AUC Score = {}'.format(avg_score))
wandb.log({"Test/PRC-AUC": avg_score})
return True

def _compare_models(self, model_1, model_2):
models_differ = 0
for key_item_1, key_item_2 in zip(model_1.state_dict().items(), model_2.state_dict().items()):
if torch.equal(key_item_1[1], key_item_2[1]):
pass
else:
models_differ += 1
if key_item_1[0] == key_item_2[0]:
logging.info('Mismtach found at', key_item_1[0])
else:
raise Exception
if models_differ == 0:
logging.info('Models match perfectly! :)')
Loading

0 comments on commit 3ecf59a

Please sign in to comment.