-
Notifications
You must be signed in to change notification settings - Fork 8
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
66cc74d
commit 3ecf59a
Showing
6 changed files
with
914 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,115 @@ | ||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
|
||
|
||
class GraphAttentionLayer(nn.Module): | ||
""" | ||
Simple GAT layer, similar to https://arxiv.org/abs/1710.10903 | ||
""" | ||
|
||
def __init__(self, in_features, out_features, dropout, alpha, concat=True): | ||
super(GraphAttentionLayer, self).__init__() | ||
self.dropout = dropout | ||
self.in_features = in_features | ||
self.out_features = out_features | ||
self.alpha = alpha | ||
self.concat = concat | ||
|
||
self.W = nn.Parameter(torch.empty(size=(in_features, out_features))) | ||
nn.init.xavier_uniform_(self.W.data, gain=1.414) | ||
self.a = nn.Parameter(torch.empty(size=(2 * out_features, 1))) | ||
nn.init.xavier_uniform_(self.a.data, gain=1.414) | ||
|
||
self.leakyrelu = nn.LeakyReLU(self.alpha) | ||
|
||
def forward(self, h, adj): | ||
Wh = torch.mm(h, self.W) # h.shape: (N, in_features), Wh.shape: (N, out_features) | ||
a_input = self._prepare_attentional_mechanism_input(Wh) | ||
e = self.leakyrelu(torch.matmul(a_input, self.a).squeeze(2)) | ||
|
||
zero_vec = -9e15 * torch.ones_like(e) | ||
attention = torch.where(adj > 0, e, zero_vec) | ||
attention = F.softmax(attention, dim=1) | ||
attention = F.dropout(attention, self.dropout, training=self.training) | ||
h_prime = torch.matmul(attention, Wh) | ||
|
||
if self.concat: | ||
return F.elu(h_prime) | ||
else: | ||
return h_prime | ||
|
||
def _prepare_attentional_mechanism_input(self, Wh): | ||
N = Wh.size()[0] | ||
Wh_repeated_in_chunks = Wh.repeat_interleave(N, dim=0) | ||
Wh_repeated_alternating = Wh.repeat(N, 1) | ||
all_combinations_matrix = torch.cat([Wh_repeated_in_chunks, Wh_repeated_alternating], dim=1) | ||
|
||
return all_combinations_matrix.view(N, N, 2 * self.out_features) | ||
|
||
def __repr__(self): | ||
return self.__class__.__name__ + ' (' + str(self.in_features) + ' -> ' + str(self.out_features) + ')' | ||
|
||
|
||
class GAT(nn.Module): | ||
def __init__(self, feat_dim, hidden_dim1, hidden_dim2, dropout, alpha, nheads): | ||
"""Dense version of GAT.""" | ||
super(GAT, self).__init__() | ||
self.dropout = dropout | ||
|
||
self.attentions = nn.ModuleList( | ||
[GraphAttentionLayer(feat_dim, hidden_dim1, dropout=dropout, alpha=alpha, concat=True) for _ in range(nheads)]) | ||
self.out_att = GraphAttentionLayer(hidden_dim1 * nheads, hidden_dim2, dropout=dropout, alpha=alpha, concat=False) | ||
|
||
def forward(self, x, adj): | ||
x = F.dropout(x, self.dropout, training=self.training) | ||
x = torch.cat([att(x, adj) for att in self.attentions], dim=1) | ||
x = F.dropout(x, self.dropout, training=self.training) | ||
embeddings = F.elu(self.out_att(x, adj)) # Node embeddings | ||
|
||
return embeddings | ||
|
||
|
||
class Readout(nn.Module): | ||
""" | ||
This module learns a single graph level representation for a molecule given GNN generated node embeddings | ||
""" | ||
|
||
def __init__(self, attr_dim, embedding_dim, hidden_dim, output_dim, num_cats): | ||
super(Readout, self).__init__() | ||
self.attr_dim = attr_dim | ||
self.hidden_dim = hidden_dim | ||
self.output_dim = output_dim | ||
self.num_cats = num_cats | ||
|
||
self.layer1 = nn.Linear(attr_dim + embedding_dim, hidden_dim) | ||
self.layer2 = nn.Linear(hidden_dim, output_dim) | ||
self.output = nn.Linear(output_dim, num_cats) | ||
self.act = nn.ReLU() | ||
|
||
def forward(self, node_features, node_embeddings): | ||
combined_rep = torch.cat((node_features, node_embeddings), | ||
dim=1) # Concat initial node attributed with embeddings from sage | ||
hidden_rep = self.act(self.layer1(combined_rep)) | ||
graph_rep = self.act(self.layer2(hidden_rep)) # Generate final graph level embedding | ||
|
||
logits = torch.mean(self.output(graph_rep), dim=0) # Generated logits for multilabel classification | ||
|
||
return logits | ||
|
||
|
||
class GatMoleculeNet(nn.Module): | ||
""" | ||
Network that consolidates GAT + Readout into a single nn.Module | ||
""" | ||
|
||
def __init__(self, feat_dim, gat_hidden_dim1, node_embedding_dim, gat_dropout, gat_alpha, gat_nheads, readout_hidden_dim, graph_embedding_dim, | ||
num_categories): | ||
super(GatMoleculeNet, self).__init__() | ||
self.gat = GAT(feat_dim, gat_hidden_dim1, node_embedding_dim, gat_dropout, gat_alpha, gat_nheads) | ||
self.readout = Readout(feat_dim, node_embedding_dim, readout_hidden_dim, graph_embedding_dim, num_categories) | ||
|
||
def forward(self, adj_matrix, feature_matrix): | ||
node_embeddings = self.gat(feature_matrix, adj_matrix) | ||
logits = self.readout(feature_matrix, node_embeddings) | ||
return logits |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,89 @@ | ||
import torch | ||
import torch.nn as nn | ||
|
||
|
||
class GraphSage(nn.Module): | ||
""" | ||
GraphSAGE model (https://arxiv.org/abs/1706.02216) to learn the role of atoms in the molecules inductively. | ||
Transforms input features into a fixed length embedding in a vector space. The embedding captures the role. | ||
""" | ||
|
||
def __init__(self, feat_dim, hidden_dim1, hidden_dim2, dropout): | ||
super(GraphSage, self).__init__() | ||
|
||
self.feat_dim = feat_dim | ||
self.hidden_dim1 = hidden_dim1 | ||
self.hidden_dim2 = hidden_dim2 | ||
|
||
self.layer1 = nn.Linear(2 * feat_dim, hidden_dim1, bias=False) | ||
self.layer2 = nn.Linear(2 * hidden_dim1, hidden_dim2, bias=False) | ||
|
||
self.relu = nn.ReLU() | ||
self.dropout = nn.Dropout(p=dropout) | ||
|
||
def forward(self, forest, feature_matrix): | ||
feat_0 = feature_matrix[forest[0]] # Of shape torch.Size([|B|, feat_dim]) | ||
feat_1 = feature_matrix[forest[1]] # Of shape torch.size(|B|, fanouts[0], feat_dim) | ||
|
||
# Depth 1 | ||
x = feature_matrix[forest[1]].mean(dim=1) # Of shape torch.size(|B|, feat_dim) | ||
feat_0 = torch.cat((feat_0, x), dim=1) # Of shape torch.size(|B|, 2 * feat_dim) | ||
feat_0 = self.relu(self.layer1(feat_0)) # Of shape torch.size(|B|, hidden_dim1) | ||
feat_0 = self.dropout(feat_0) | ||
|
||
# Depth 2 | ||
x = feature_matrix[forest[2]].mean(dim=1) # Of shape torch.size(|B|*fanouts[0], feat_dim) | ||
feat_1 = torch.cat((feat_1.reshape(-1, self.feat_dim), x), | ||
dim=1) # Of shape torch.size(|B|*fanouts[0], 2 * feat_dim) | ||
feat_1 = self.relu(self.layer1(feat_1)) # Of shape torch.size(|B|*fanouts[0], hidden_dim1) | ||
feat_1 = self.dropout(feat_1) | ||
|
||
# Combine | ||
feat_1 = feat_1.reshape(forest[0].shape[0], -1, self.hidden_dim1).mean( | ||
dim=1) # Of shape torch.size([|B|, hidden_dim_1]) | ||
combined = torch.cat((feat_0, feat_1), dim=1) # Of shape torch.Size(|B|, 2 * hidden_dim1) | ||
embeddings = self.relu(self.layer2(combined)) # Of shape torch.Size(|B|, hidden_dim2) | ||
|
||
return embeddings | ||
|
||
|
||
class Readout(nn.Module): | ||
""" | ||
This module learns a single graph level representation for a molecule given GraphSAGE generated embeddings | ||
""" | ||
def __init__(self, attr_dim, embedding_dim, hidden_dim, output_dim, num_cats): | ||
super(Readout, self).__init__() | ||
self.attr_dim = attr_dim | ||
self.hidden_dim = hidden_dim | ||
self.output_dim = output_dim | ||
self.num_cats = num_cats | ||
|
||
self.layer1 = nn.Linear(attr_dim + embedding_dim, hidden_dim) | ||
self.layer2 = nn.Linear(hidden_dim, output_dim) | ||
self.output = nn.Linear(output_dim, num_cats) | ||
self.act = nn.ReLU() | ||
|
||
def forward(self, node_features, node_embeddings): | ||
combined_rep = torch.cat((node_features, node_embeddings), | ||
dim=1) # Concat initial node attributed with embeddings from sage | ||
hidden_rep = self.act(self.layer1(combined_rep)) | ||
graph_rep = self.act(self.layer2(hidden_rep)) # Generate final graph level embedding | ||
|
||
logits = torch.mean(self.output(graph_rep), dim=0) # Generated logits for multilabel classification | ||
|
||
return logits | ||
|
||
|
||
class SageMoleculeNet(nn.Module): | ||
""" | ||
Network that consolidates Sage + Readout into a single nn.Module | ||
""" | ||
def __init__(self, feat_dim, sage_hidden_dim1, node_embedding_dim, sage_dropout, readout_hidden_dim, graph_embedding_dim, num_categories): | ||
super(SageMoleculeNet, self).__init__() | ||
self.sage = GraphSage(feat_dim, sage_hidden_dim1, node_embedding_dim, sage_dropout) | ||
self.readout = Readout(feat_dim, node_embedding_dim, readout_hidden_dim, graph_embedding_dim, num_categories) | ||
|
||
def forward(self, forest, feature_matrix): | ||
node_embeddings = self.sage(forest, feature_matrix) | ||
logits = self.readout(feature_matrix, node_embeddings) | ||
return logits |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,184 @@ | ||
import logging | ||
|
||
import numpy as np | ||
import torch | ||
import wandb | ||
from sklearn.metrics import roc_auc_score, precision_recall_curve, auc | ||
from torch import trace, inverse | ||
from scipy.linalg import fractional_matrix_power | ||
|
||
from FedML.fedml_core.trainer.model_trainer import ModelTrainer | ||
|
||
|
||
# Trainer for MoleculeNet. The evaluation metric is ROC-AUC | ||
|
||
class GatMoleculeNetTrainer(ModelTrainer): | ||
|
||
def get_model_params(self): | ||
return self.model.cpu().state_dict() | ||
|
||
def set_model_params(self, model_parameters): | ||
logging.info("set_model_params") | ||
self.model.load_state_dict(model_parameters) | ||
|
||
def train(self, train_data, device, args): | ||
model = self.model | ||
if args.is_mtl: | ||
self.model.omega_corr.to(device) | ||
model.to(device) | ||
model.train() | ||
|
||
test_data = None | ||
try: | ||
test_data = self.test_data | ||
except: | ||
pass | ||
|
||
criterion = torch.nn.BCEWithLogitsLoss(reduction='none') | ||
if args.client_optimizer == "sgd": | ||
optimizer = torch.optim.SGD(model.parameters(), lr=args.lr) | ||
# optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, weight_decay = args.wd) | ||
else: | ||
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) | ||
# optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay = args.wd) | ||
|
||
|
||
max_test_score = 0 | ||
best_model_params = {} | ||
eps = 1e-5 | ||
for epoch in range(args.epochs): | ||
for mol_idxs, (adj_matrix, feature_matrix, label, mask , cli_mask) in enumerate(train_data): | ||
# Pass on molecules that have no labels | ||
mask = mask.to(device=device, dtype=torch.float32, non_blocking=True) | ||
cli_mask = cli_mask.to(device=device, dtype=torch.float32, non_blocking=True) if cli_mask is not None else None | ||
mask = mask * cli_mask if cli_mask is not None else mask | ||
if torch.all(mask == 0).item(): | ||
continue | ||
|
||
optimizer.zero_grad() | ||
|
||
adj_matrix = adj_matrix.to(device=device, dtype=torch.float32, non_blocking=True) | ||
feature_matrix = feature_matrix.to(device=device, dtype=torch.float32, non_blocking=True) | ||
label = label.to(device=device, dtype=torch.float32, non_blocking=True) | ||
|
||
logits = model(adj_matrix, feature_matrix) | ||
clf_loss = criterion(logits, label) * mask | ||
clf_loss = clf_loss.sum() / mask.sum() | ||
|
||
if args.is_mtl: | ||
W = self.model.readout.output.weight | ||
lhs = torch.mm(W.t() , torch.inverse(self.model.omega_corr + eps * torch.eye(self.model.omega_corr.shape[0], | ||
dtype=self.model.omega_corr.dtype, device=self.model.omega_corr.device))) | ||
trace_in = torch.mm(lhs, W) | ||
loss = clf_loss + self.task_reg * torch.trace(trace_in) + args.wd * 0.5 * torch.norm(W)**2 | ||
loss.backward() | ||
optimizer.step() | ||
with torch.no_grad(): | ||
W = self.model.readout.output.weight | ||
mul = torch.mm(W, W.t()) | ||
fr_pow = fractional_matrix_power(mul.cpu(), 1/2) | ||
self.model.omega_corr = torch.nn.Parameter(torch.Tensor(fr_pow / np.trace(fr_pow)).to(device)) | ||
self.model.omega_corr.requires_grad=False | ||
else: | ||
clf_loss.backward() | ||
optimizer.step() | ||
|
||
if ((mol_idxs + 1) % args.frequency_of_the_test == 0) or (mol_idxs == len(train_data) - 1): | ||
if test_data is not None: | ||
test_score, _ = self.test(self.test_data, device, args) | ||
print('Epoch = {}, Iter = {}/{}: Test Score = {}'.format(epoch, mol_idxs + 1, len(train_data), test_score)) | ||
if test_score > max_test_score: | ||
max_test_score = test_score | ||
best_model_params = {k: v.cpu() for k, v in model.state_dict().items()} | ||
print('Current best = {}'.format(max_test_score)) | ||
|
||
self.task_reg *= args.task_reg_decay | ||
|
||
return max_test_score, best_model_params | ||
|
||
|
||
|
||
def test(self, test_data, device, args): | ||
logging.info("----------test--------") | ||
model = self.model | ||
model.eval() | ||
model.to(device) | ||
|
||
with torch.no_grad(): | ||
y_pred = [] | ||
y_true = [] | ||
masks = [] | ||
for mol_idx, (adj_matrix, feature_matrix, label, mask , _) in enumerate(test_data): | ||
adj_matrix = adj_matrix.to(device=device, dtype=torch.float32, non_blocking=True) | ||
feature_matrix = feature_matrix.to(device=device, dtype=torch.float32, non_blocking=True) | ||
|
||
logits = model(adj_matrix, feature_matrix) | ||
|
||
y_pred.append(logits.cpu().numpy()) | ||
y_true.append(label.cpu().numpy()) | ||
masks.append(mask.numpy()) | ||
|
||
y_pred = np.array(y_pred) | ||
y_true = np.array(y_true) | ||
masks = np.array(masks) | ||
|
||
results = [] | ||
for label in range(masks.shape[1]): | ||
valid_idxs = np.nonzero(masks[:, label]) | ||
truth = y_true[valid_idxs, label].flatten() | ||
pred = y_pred[valid_idxs, label].flatten() | ||
|
||
if np.all(truth == 0.0) or np.all(truth == 1.0): | ||
results.append(float('nan')) | ||
else: | ||
if args.metric == 'prc-auc': | ||
precision, recall, _ = precision_recall_curve(truth, pred) | ||
score = auc(recall, precision) | ||
else: | ||
score = roc_auc_score(truth, pred) | ||
|
||
results.append(score) | ||
|
||
score = np.nanmean(results) | ||
|
||
return score, model | ||
|
||
def test_on_the_server(self, train_data_local_dict, test_data_local_dict, device, args=None) -> bool: | ||
logging.info("----------test_on_the_server--------") | ||
|
||
model_list, score_list = [], [] | ||
for client_idx in test_data_local_dict.keys(): | ||
test_data = test_data_local_dict[client_idx] | ||
score, model = self.test(test_data, device, args) | ||
for idx in range(len(model_list)): | ||
self._compare_models(model, model_list[idx]) | ||
model_list.append(model) | ||
score_list.append(score) | ||
if args.dataset != "pcba": | ||
logging.info('Client {}, Test ROC-AUC score = {}'.format(client_idx, score)) | ||
wandb.log({"Client {} Test/ROC-AUC".format(client_idx): score}) | ||
else: | ||
logging.info('Client {}, Test PRC-AUC score = {}'.format(client_idx, score)) | ||
wandb.log({"Client {} Test/PRC-AUC".format(client_idx): score}) | ||
avg_score = np.mean(np.array(score_list)) | ||
if args.dataset != "pcba": | ||
logging.info('Test ROC-AUC Score = {}'.format(avg_score)) | ||
wandb.log({"Test/ROC-AUC": avg_score}) | ||
else: | ||
logging.info('Test PRC-AUC Score = {}'.format(avg_score)) | ||
wandb.log({"Test/PRC-AUC": avg_score}) | ||
return True | ||
|
||
def _compare_models(self, model_1, model_2): | ||
models_differ = 0 | ||
for key_item_1, key_item_2 in zip(model_1.state_dict().items(), model_2.state_dict().items()): | ||
if torch.equal(key_item_1[1], key_item_2[1]): | ||
pass | ||
else: | ||
models_differ += 1 | ||
if key_item_1[0] == key_item_2[0]: | ||
logging.info('Mismtach found at', key_item_1[0]) | ||
else: | ||
raise Exception | ||
if models_differ == 0: | ||
logging.info('Models match perfectly! :)') |
Oops, something went wrong.