diff --git a/model/gat_readout.py b/model/gat_readout.py new file mode 100644 index 0000000..9314264 --- /dev/null +++ b/model/gat_readout.py @@ -0,0 +1,115 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class GraphAttentionLayer(nn.Module): + """ + Simple GAT layer, similar to https://arxiv.org/abs/1710.10903 + """ + + def __init__(self, in_features, out_features, dropout, alpha, concat=True): + super(GraphAttentionLayer, self).__init__() + self.dropout = dropout + self.in_features = in_features + self.out_features = out_features + self.alpha = alpha + self.concat = concat + + self.W = nn.Parameter(torch.empty(size=(in_features, out_features))) + nn.init.xavier_uniform_(self.W.data, gain=1.414) + self.a = nn.Parameter(torch.empty(size=(2 * out_features, 1))) + nn.init.xavier_uniform_(self.a.data, gain=1.414) + + self.leakyrelu = nn.LeakyReLU(self.alpha) + + def forward(self, h, adj): + Wh = torch.mm(h, self.W) # h.shape: (N, in_features), Wh.shape: (N, out_features) + a_input = self._prepare_attentional_mechanism_input(Wh) + e = self.leakyrelu(torch.matmul(a_input, self.a).squeeze(2)) + + zero_vec = -9e15 * torch.ones_like(e) + attention = torch.where(adj > 0, e, zero_vec) + attention = F.softmax(attention, dim=1) + attention = F.dropout(attention, self.dropout, training=self.training) + h_prime = torch.matmul(attention, Wh) + + if self.concat: + return F.elu(h_prime) + else: + return h_prime + + def _prepare_attentional_mechanism_input(self, Wh): + N = Wh.size()[0] + Wh_repeated_in_chunks = Wh.repeat_interleave(N, dim=0) + Wh_repeated_alternating = Wh.repeat(N, 1) + all_combinations_matrix = torch.cat([Wh_repeated_in_chunks, Wh_repeated_alternating], dim=1) + + return all_combinations_matrix.view(N, N, 2 * self.out_features) + + def __repr__(self): + return self.__class__.__name__ + ' (' + str(self.in_features) + ' -> ' + str(self.out_features) + ')' + + +class GAT(nn.Module): + def __init__(self, feat_dim, hidden_dim1, hidden_dim2, dropout, alpha, nheads): + """Dense version of GAT.""" + super(GAT, self).__init__() + self.dropout = dropout + + self.attentions = nn.ModuleList( + [GraphAttentionLayer(feat_dim, hidden_dim1, dropout=dropout, alpha=alpha, concat=True) for _ in range(nheads)]) + self.out_att = GraphAttentionLayer(hidden_dim1 * nheads, hidden_dim2, dropout=dropout, alpha=alpha, concat=False) + + def forward(self, x, adj): + x = F.dropout(x, self.dropout, training=self.training) + x = torch.cat([att(x, adj) for att in self.attentions], dim=1) + x = F.dropout(x, self.dropout, training=self.training) + embeddings = F.elu(self.out_att(x, adj)) # Node embeddings + + return embeddings + + +class Readout(nn.Module): + """ + This module learns a single graph level representation for a molecule given GNN generated node embeddings + """ + + def __init__(self, attr_dim, embedding_dim, hidden_dim, output_dim, num_cats): + super(Readout, self).__init__() + self.attr_dim = attr_dim + self.hidden_dim = hidden_dim + self.output_dim = output_dim + self.num_cats = num_cats + + self.layer1 = nn.Linear(attr_dim + embedding_dim, hidden_dim) + self.layer2 = nn.Linear(hidden_dim, output_dim) + self.output = nn.Linear(output_dim, num_cats) + self.act = nn.ReLU() + + def forward(self, node_features, node_embeddings): + combined_rep = torch.cat((node_features, node_embeddings), + dim=1) # Concat initial node attributed with embeddings from sage + hidden_rep = self.act(self.layer1(combined_rep)) + graph_rep = self.act(self.layer2(hidden_rep)) # Generate final graph level embedding + + logits = torch.mean(self.output(graph_rep), dim=0) # Generated logits for multilabel classification + + return logits + + +class GatMoleculeNet(nn.Module): + """ + Network that consolidates GAT + Readout into a single nn.Module + """ + + def __init__(self, feat_dim, gat_hidden_dim1, node_embedding_dim, gat_dropout, gat_alpha, gat_nheads, readout_hidden_dim, graph_embedding_dim, + num_categories): + super(GatMoleculeNet, self).__init__() + self.gat = GAT(feat_dim, gat_hidden_dim1, node_embedding_dim, gat_dropout, gat_alpha, gat_nheads) + self.readout = Readout(feat_dim, node_embedding_dim, readout_hidden_dim, graph_embedding_dim, num_categories) + + def forward(self, adj_matrix, feature_matrix): + node_embeddings = self.gat(feature_matrix, adj_matrix) + logits = self.readout(feature_matrix, node_embeddings) + return logits diff --git a/model/sage_readout.py b/model/sage_readout.py new file mode 100644 index 0000000..9c60faa --- /dev/null +++ b/model/sage_readout.py @@ -0,0 +1,89 @@ +import torch +import torch.nn as nn + + +class GraphSage(nn.Module): + """ + GraphSAGE model (https://arxiv.org/abs/1706.02216) to learn the role of atoms in the molecules inductively. + Transforms input features into a fixed length embedding in a vector space. The embedding captures the role. + """ + + def __init__(self, feat_dim, hidden_dim1, hidden_dim2, dropout): + super(GraphSage, self).__init__() + + self.feat_dim = feat_dim + self.hidden_dim1 = hidden_dim1 + self.hidden_dim2 = hidden_dim2 + + self.layer1 = nn.Linear(2 * feat_dim, hidden_dim1, bias=False) + self.layer2 = nn.Linear(2 * hidden_dim1, hidden_dim2, bias=False) + + self.relu = nn.ReLU() + self.dropout = nn.Dropout(p=dropout) + + def forward(self, forest, feature_matrix): + feat_0 = feature_matrix[forest[0]] # Of shape torch.Size([|B|, feat_dim]) + feat_1 = feature_matrix[forest[1]] # Of shape torch.size(|B|, fanouts[0], feat_dim) + + # Depth 1 + x = feature_matrix[forest[1]].mean(dim=1) # Of shape torch.size(|B|, feat_dim) + feat_0 = torch.cat((feat_0, x), dim=1) # Of shape torch.size(|B|, 2 * feat_dim) + feat_0 = self.relu(self.layer1(feat_0)) # Of shape torch.size(|B|, hidden_dim1) + feat_0 = self.dropout(feat_0) + + # Depth 2 + x = feature_matrix[forest[2]].mean(dim=1) # Of shape torch.size(|B|*fanouts[0], feat_dim) + feat_1 = torch.cat((feat_1.reshape(-1, self.feat_dim), x), + dim=1) # Of shape torch.size(|B|*fanouts[0], 2 * feat_dim) + feat_1 = self.relu(self.layer1(feat_1)) # Of shape torch.size(|B|*fanouts[0], hidden_dim1) + feat_1 = self.dropout(feat_1) + + # Combine + feat_1 = feat_1.reshape(forest[0].shape[0], -1, self.hidden_dim1).mean( + dim=1) # Of shape torch.size([|B|, hidden_dim_1]) + combined = torch.cat((feat_0, feat_1), dim=1) # Of shape torch.Size(|B|, 2 * hidden_dim1) + embeddings = self.relu(self.layer2(combined)) # Of shape torch.Size(|B|, hidden_dim2) + + return embeddings + + +class Readout(nn.Module): + """ + This module learns a single graph level representation for a molecule given GraphSAGE generated embeddings + """ + def __init__(self, attr_dim, embedding_dim, hidden_dim, output_dim, num_cats): + super(Readout, self).__init__() + self.attr_dim = attr_dim + self.hidden_dim = hidden_dim + self.output_dim = output_dim + self.num_cats = num_cats + + self.layer1 = nn.Linear(attr_dim + embedding_dim, hidden_dim) + self.layer2 = nn.Linear(hidden_dim, output_dim) + self.output = nn.Linear(output_dim, num_cats) + self.act = nn.ReLU() + + def forward(self, node_features, node_embeddings): + combined_rep = torch.cat((node_features, node_embeddings), + dim=1) # Concat initial node attributed with embeddings from sage + hidden_rep = self.act(self.layer1(combined_rep)) + graph_rep = self.act(self.layer2(hidden_rep)) # Generate final graph level embedding + + logits = torch.mean(self.output(graph_rep), dim=0) # Generated logits for multilabel classification + + return logits + + +class SageMoleculeNet(nn.Module): + """ + Network that consolidates Sage + Readout into a single nn.Module + """ + def __init__(self, feat_dim, sage_hidden_dim1, node_embedding_dim, sage_dropout, readout_hidden_dim, graph_embedding_dim, num_categories): + super(SageMoleculeNet, self).__init__() + self.sage = GraphSage(feat_dim, sage_hidden_dim1, node_embedding_dim, sage_dropout) + self.readout = Readout(feat_dim, node_embedding_dim, readout_hidden_dim, graph_embedding_dim, num_categories) + + def forward(self, forest, feature_matrix): + node_embeddings = self.sage(forest, feature_matrix) + logits = self.readout(feature_matrix, node_embeddings) + return logits \ No newline at end of file diff --git a/training/gat_readout_trainer.py b/training/gat_readout_trainer.py new file mode 100644 index 0000000..2de0360 --- /dev/null +++ b/training/gat_readout_trainer.py @@ -0,0 +1,184 @@ +import logging + +import numpy as np +import torch +import wandb +from sklearn.metrics import roc_auc_score, precision_recall_curve, auc +from torch import trace, inverse +from scipy.linalg import fractional_matrix_power + +from FedML.fedml_core.trainer.model_trainer import ModelTrainer + + +# Trainer for MoleculeNet. The evaluation metric is ROC-AUC + +class GatMoleculeNetTrainer(ModelTrainer): + + def get_model_params(self): + return self.model.cpu().state_dict() + + def set_model_params(self, model_parameters): + logging.info("set_model_params") + self.model.load_state_dict(model_parameters) + + def train(self, train_data, device, args): + model = self.model + if args.is_mtl: + self.model.omega_corr.to(device) + model.to(device) + model.train() + + test_data = None + try: + test_data = self.test_data + except: + pass + + criterion = torch.nn.BCEWithLogitsLoss(reduction='none') + if args.client_optimizer == "sgd": + optimizer = torch.optim.SGD(model.parameters(), lr=args.lr) + # optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, weight_decay = args.wd) + else: + optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) + # optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay = args.wd) + + + max_test_score = 0 + best_model_params = {} + eps = 1e-5 + for epoch in range(args.epochs): + for mol_idxs, (adj_matrix, feature_matrix, label, mask , cli_mask) in enumerate(train_data): + # Pass on molecules that have no labels + mask = mask.to(device=device, dtype=torch.float32, non_blocking=True) + cli_mask = cli_mask.to(device=device, dtype=torch.float32, non_blocking=True) if cli_mask is not None else None + mask = mask * cli_mask if cli_mask is not None else mask + if torch.all(mask == 0).item(): + continue + + optimizer.zero_grad() + + adj_matrix = adj_matrix.to(device=device, dtype=torch.float32, non_blocking=True) + feature_matrix = feature_matrix.to(device=device, dtype=torch.float32, non_blocking=True) + label = label.to(device=device, dtype=torch.float32, non_blocking=True) + + logits = model(adj_matrix, feature_matrix) + clf_loss = criterion(logits, label) * mask + clf_loss = clf_loss.sum() / mask.sum() + + if args.is_mtl: + W = self.model.readout.output.weight + lhs = torch.mm(W.t() , torch.inverse(self.model.omega_corr + eps * torch.eye(self.model.omega_corr.shape[0], + dtype=self.model.omega_corr.dtype, device=self.model.omega_corr.device))) + trace_in = torch.mm(lhs, W) + loss = clf_loss + self.task_reg * torch.trace(trace_in) + args.wd * 0.5 * torch.norm(W)**2 + loss.backward() + optimizer.step() + with torch.no_grad(): + W = self.model.readout.output.weight + mul = torch.mm(W, W.t()) + fr_pow = fractional_matrix_power(mul.cpu(), 1/2) + self.model.omega_corr = torch.nn.Parameter(torch.Tensor(fr_pow / np.trace(fr_pow)).to(device)) + self.model.omega_corr.requires_grad=False + else: + clf_loss.backward() + optimizer.step() + + if ((mol_idxs + 1) % args.frequency_of_the_test == 0) or (mol_idxs == len(train_data) - 1): + if test_data is not None: + test_score, _ = self.test(self.test_data, device, args) + print('Epoch = {}, Iter = {}/{}: Test Score = {}'.format(epoch, mol_idxs + 1, len(train_data), test_score)) + if test_score > max_test_score: + max_test_score = test_score + best_model_params = {k: v.cpu() for k, v in model.state_dict().items()} + print('Current best = {}'.format(max_test_score)) + + self.task_reg *= args.task_reg_decay + + return max_test_score, best_model_params + + + + def test(self, test_data, device, args): + logging.info("----------test--------") + model = self.model + model.eval() + model.to(device) + + with torch.no_grad(): + y_pred = [] + y_true = [] + masks = [] + for mol_idx, (adj_matrix, feature_matrix, label, mask , _) in enumerate(test_data): + adj_matrix = adj_matrix.to(device=device, dtype=torch.float32, non_blocking=True) + feature_matrix = feature_matrix.to(device=device, dtype=torch.float32, non_blocking=True) + + logits = model(adj_matrix, feature_matrix) + + y_pred.append(logits.cpu().numpy()) + y_true.append(label.cpu().numpy()) + masks.append(mask.numpy()) + + y_pred = np.array(y_pred) + y_true = np.array(y_true) + masks = np.array(masks) + + results = [] + for label in range(masks.shape[1]): + valid_idxs = np.nonzero(masks[:, label]) + truth = y_true[valid_idxs, label].flatten() + pred = y_pred[valid_idxs, label].flatten() + + if np.all(truth == 0.0) or np.all(truth == 1.0): + results.append(float('nan')) + else: + if args.metric == 'prc-auc': + precision, recall, _ = precision_recall_curve(truth, pred) + score = auc(recall, precision) + else: + score = roc_auc_score(truth, pred) + + results.append(score) + + score = np.nanmean(results) + + return score, model + + def test_on_the_server(self, train_data_local_dict, test_data_local_dict, device, args=None) -> bool: + logging.info("----------test_on_the_server--------") + + model_list, score_list = [], [] + for client_idx in test_data_local_dict.keys(): + test_data = test_data_local_dict[client_idx] + score, model = self.test(test_data, device, args) + for idx in range(len(model_list)): + self._compare_models(model, model_list[idx]) + model_list.append(model) + score_list.append(score) + if args.dataset != "pcba": + logging.info('Client {}, Test ROC-AUC score = {}'.format(client_idx, score)) + wandb.log({"Client {} Test/ROC-AUC".format(client_idx): score}) + else: + logging.info('Client {}, Test PRC-AUC score = {}'.format(client_idx, score)) + wandb.log({"Client {} Test/PRC-AUC".format(client_idx): score}) + avg_score = np.mean(np.array(score_list)) + if args.dataset != "pcba": + logging.info('Test ROC-AUC Score = {}'.format(avg_score)) + wandb.log({"Test/ROC-AUC": avg_score}) + else: + logging.info('Test PRC-AUC Score = {}'.format(avg_score)) + wandb.log({"Test/PRC-AUC": avg_score}) + return True + + def _compare_models(self, model_1, model_2): + models_differ = 0 + for key_item_1, key_item_2 in zip(model_1.state_dict().items(), model_2.state_dict().items()): + if torch.equal(key_item_1[1], key_item_2[1]): + pass + else: + models_differ += 1 + if key_item_1[0] == key_item_2[0]: + logging.info('Mismtach found at', key_item_1[0]) + else: + raise Exception + if models_differ == 0: + logging.info('Models match perfectly! :)') diff --git a/training/gat_readout_trainer_regression.py b/training/gat_readout_trainer_regression.py new file mode 100644 index 0000000..733903e --- /dev/null +++ b/training/gat_readout_trainer_regression.py @@ -0,0 +1,167 @@ +import logging + +import numpy as np +import torch +from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score +import wandb +from FedML.fedml_core.trainer.model_trainer import ModelTrainer +from scipy.linalg import fractional_matrix_power + + +# Trainer for MoleculeNet. The evaluation metric is ROC-AUC + +class GatMoleculeNetTrainer(ModelTrainer): + + def get_model_params(self): + return self.model.cpu().state_dict() + + def set_model_params(self, model_parameters): + logging.info("set_model_params") + self.model.load_state_dict(model_parameters) + + def train(self, train_data, device, args): + model = self.model + if args.is_mtl: + self.model.omega_corr.to(device) + model.to(device) + model.train() + + test_data = None + try: + test_data = self.test_data + except: + pass + + criterion = torch.nn.MSELoss(reduction='none') if args.dataset != 'qm9' else torch.nn.L1Loss(reduction='none') + if args.client_optimizer == "sgd": + optimizer = torch.optim.SGD(model.parameters(), lr=args.lr) + # optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, weight_decay = args.wd) + else: + optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) + # optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay = args.wd) + + + min_score = np.Inf if args.metric != 'r2' else -np.Inf + best_model_params = {} + # print('Training on {}'.format(torch.cuda.get_device_name())) + for epoch in range(args.epochs): + avg_loss = 0 + count = 0 + for mol_idxs, (adj_matrix, feature_matrix, label, _, cli_mask) in enumerate(train_data): + optimizer.zero_grad() + if cli_mask is not None: + cli_mask = cli_mask.to(device=device, dtype=torch.float32, non_blocking=True) + if torch.all(cli_mask == 0).item(): + continue + + adj_matrix = adj_matrix.to(device=device, dtype=torch.float32, non_blocking=True) + feature_matrix = feature_matrix.to(device=device, dtype=torch.float32, non_blocking=True) + label = label.to(device=device, dtype=torch.float32, non_blocking=True) + + logits = model(adj_matrix, feature_matrix) + pred_loss = criterion(logits, label) * cli_mask if cli_mask is not None else criterion(logits, label) + pred_loss = pred_loss.sum() / cli_mask.sum() if cli_mask is not None else pred_loss.mean() + + if args.is_mtl: + W = self.model.readout.output.weight + lhs = torch.mm(W.t() , torch.inverse(self.model.omega_corr) ) + trace_in = torch.mm(lhs, W) + loss = pred_loss + self.task_reg * torch.trace(trace_in) + args.wd * 0.5 * torch.norm(W)**2 + loss.backward() + optimizer.step() + with torch.no_grad(): + W = self.model.readout.output.weight + mul = torch.mm(W, W.t()) + fr_pow = fractional_matrix_power(mul.cpu(), 1/2) + self.model.omega_corr = torch.nn.Parameter(torch.Tensor(fr_pow / np.trace(fr_pow)).to(device)) + self.model.omega_corr.requires_grad=False + else: + pred_loss.backward() + optimizer.step() + + + if ((mol_idxs + 1) % args.frequency_of_the_test == 0) or (mol_idxs == len(train_data) - 1): + if test_data is not None: + test_score, _ = self.test(self.test_data, device, args) + if args.metric != 'r2': + print('Epoch = {}, Iter = {}/{}: Test {} = {}'.format(epoch, mol_idxs + 1, len(train_data), args.metric.upper(),test_score)) + if test_score < min_score: + min_score = test_score + best_model_params = {k: v.cpu() for k, v in model.state_dict().items()} + print('Current best {}= {}'.format(args.metric.upper(),min_score)) + else: + print('Epoch = {}, Iter = {}/{}: Test R2 = {}'.format(epoch, mol_idxs + 1, len(train_data), test_score)) + if test_score > min_score: + min_score = test_score + best_model_params = {k: v.cpu() for k, v in model.state_dict().items()} + print('Current best R2= {}'.format(min_score)) + + self.task_reg *= args.task_reg_decay + + return min_score, best_model_params + + def test(self, test_data, device, args): + logging.info("----------test--------") + model = self.model + model.eval() + model.to(device) + + with torch.no_grad(): + y_pred = [] + y_true = [] + for mol_idx, (adj_matrix, feature_matrix, label, _, _) in enumerate(test_data): + adj_matrix = adj_matrix.to(device=device, dtype=torch.float32, non_blocking=True) + feature_matrix = feature_matrix.to(device=device, dtype=torch.float32, non_blocking=True) + label = label.to(device=device, dtype=torch.float32, non_blocking=True) + logits = model(adj_matrix, feature_matrix) + y_pred.append(logits.cpu().numpy()) + y_true.append(label.cpu().numpy()) + + # logging.info(y_true) + # logging.info(y_pred) + if args.metric == 'rmse': + score = mean_squared_error(np.array(y_true), np.array(y_pred), squared=False) + elif args.metric == 'r2': + score = r2_score(np.array(y_true), np.array(y_pred)) + else: + score = mean_absolute_error(np.array(y_true), np.array(y_pred)) + + return score, model + + def test_on_the_server(self, train_data_local_dict, test_data_local_dict, device, args=None) -> bool: + logging.info("----------test_on_the_server--------") + # for client_idx in train_data_local_dict.keys(): + # train_data = train_data_local_dict[client_idx] + # train_score = self.test(train_data, device, args) + # logging.info('Client {}, Train ROC-AUC score = {}'.format(client_idx, train_score)) + + model_list, score_list = [], [] + for client_idx in test_data_local_dict.keys(): + test_data = test_data_local_dict[client_idx] + score, model = self.test(test_data, device, args) + for idx in range(len(model_list)): + self._compare_models(model, model_list[idx]) + model_list.append(model) + score_list.append(score) + logging.info('Client {}, Test {} = {}'.format(client_idx,args.metric.upper(), score)) + wandb.log({"Client {} Test/{}".format(client_idx,args.metric.upper()): score}) + + avg_score = sum(score_list) / len(score_list) + logging.info('Test {} score = {}'.format(args.metric.upper(),avg_score)) + wandb.log({"Test/{}".format(args.metric.upper()): avg_score}) + + return True + + def _compare_models(self, model_1, model_2): + models_differ = 0 + for key_item_1, key_item_2 in zip(model_1.state_dict().items(), model_2.state_dict().items()): + if torch.equal(key_item_1[1], key_item_2[1]): + pass + else: + models_differ += 1 + if (key_item_1[0] == key_item_2[0]): + logging.info('Mismtach found at', key_item_1[0]) + else: + raise Exception + if models_differ == 0: + logging.info('Models match perfectly! :)') diff --git a/training/sage_readout_trainer.py b/training/sage_readout_trainer.py new file mode 100644 index 0000000..a2caf6e --- /dev/null +++ b/training/sage_readout_trainer.py @@ -0,0 +1,187 @@ +import logging + +import numpy as np +import torch +import wandb +from sklearn.metrics import roc_auc_score, precision_recall_curve, auc +from tqdm import tqdm + +from torch import trace, inverse +from scipy.linalg import fractional_matrix_power + +from FedML.fedml_core.trainer.model_trainer import ModelTrainer + + +# Trainer for MoleculeNet. The evaluation metric is ROC-AUC + +class SageMoleculeNetTrainer(ModelTrainer): + + def get_model_params(self): + return self.model.cpu().state_dict() + + def set_model_params(self, model_parameters): + logging.info("set_model_params") + self.model.load_state_dict(model_parameters) + + def train(self, train_data, device, args): + model = self.model + if args.is_mtl: + self.model.omega_corr.to(device) + model.to(device) + model.train() + + test_data = None + try: + test_data = self.test_data + except: + pass + + criterion = torch.nn.BCEWithLogitsLoss(reduction='none') + if args.client_optimizer == "sgd": + # optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, weight_decay = args.wd) + optimizer = torch.optim.SGD(model.parameters(), lr=args.lr) + + else: + optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) + # optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay = args.wd) + + max_test_score = 0 + eps = 1e-7 + best_model_params = {} + for epoch in range(args.epochs): + for mol_idxs, (forest, feature_matrix, label, mask, cli_mask) in enumerate(train_data): + # Pass on molecules that have no labels + mask = mask.to(device=device, dtype=torch.float32, non_blocking=True) + cli_mask = cli_mask.to(device=device, dtype=torch.float32, non_blocking=True) if cli_mask is not None else None + mask = mask * cli_mask if cli_mask is not None else mask + if torch.all(mask == 0).item(): + continue + + optimizer.zero_grad() + + forest = [level.to(device=device, dtype=torch.long, non_blocking=True) for level in forest] + feature_matrix = feature_matrix.to(device=device, dtype=torch.float32, non_blocking=True) + label = label.to(device=device, dtype=torch.float32, non_blocking=True) + + logits = model(forest, feature_matrix) + clf_loss = criterion(logits, label) * mask + clf_loss = clf_loss.sum() / mask.sum() + + if args.is_mtl: + W = self.model.readout.output.weight + lhs = torch.mm(W.t() , torch.inverse(self.model.omega_corr + eps * torch.eye(self.model.omega_corr.shape[0], + dtype=self.model.omega_corr.dtype, device=self.model.omega_corr.device))) + trace_in = torch.mm(lhs, W) + loss = clf_loss + self.task_reg * torch.trace(trace_in) + args.wd * 0.5 * torch.norm(W)**2 + loss.backward() + optimizer.step() + with torch.no_grad(): + W = self.model.readout.output.weight + mul = torch.mm(W, W.t()) + fr_pow = fractional_matrix_power(mul.cpu(), 1/2) + self.model.omega_corr = torch.nn.Parameter(torch.Tensor(fr_pow / np.trace(fr_pow)).to(device)) + self.model.omega_corr.requires_grad=False + + else: + clf_loss.backward() + optimizer.step() + + if ((mol_idxs + 1) % args.frequency_of_the_test == 0) or (mol_idxs == len(train_data) - 1): + if test_data is not None: + test_score, _ = self.test(self.test_data, device, args) + print('Epoch = {}, Iter = {}/{}: Test Score = {}'.format(epoch, mol_idxs + 1, len(train_data), test_score)) + if test_score > max_test_score: + max_test_score = test_score + best_model_params = {k: v.cpu() for k, v in model.state_dict().items()} + print('Current best = {}'.format(max_test_score)) + + self.task_reg *= args.task_reg_decay + + return max_test_score, best_model_params + + def test(self, test_data, device, args): + logging.info("----------test--------") + model = self.model + model.eval() + model.to(device) + + with torch.no_grad(): + y_pred = [] + y_true = [] + masks = [] + + for mol_idx, (forest, feature_matrix, label, mask, _) in enumerate(test_data): + + forest = [level.to(device=device, dtype=torch.long, non_blocking=True) for level in forest] + feature_matrix = feature_matrix.to(device=device, dtype=torch.float32, non_blocking=True) + + logits = model(forest, feature_matrix) + + y_pred.append(logits.cpu().numpy()) + y_true.append(label.numpy()) + masks.append(mask.numpy()) + + y_pred = np.array(y_pred) + y_true = np.array(y_true) + masks = np.array(masks) + + results = [] + for label in range(masks.shape[1]): + valid_idxs = np.nonzero(masks[:, label]) + truth = y_true[valid_idxs, label].flatten() + pred = y_pred[valid_idxs, label].flatten() + + if np.all(truth == 0.0) or np.all(truth == 1.0): + results.append(float('nan')) + else: + if args.metric == 'prc-auc': + precision, recall, _ = precision_recall_curve(truth, pred) + score = auc(recall, precision) + else: + score = roc_auc_score(truth, pred) + + results.append(score) + + score = np.nanmean(results) + return score, model + + def test_on_the_server(self, train_data_local_dict, test_data_local_dict, device, args=None) -> bool: + logging.info("----------test_on_the_server--------") + + model_list, score_list = [], [] + for client_idx in test_data_local_dict.keys(): + + test_data = test_data_local_dict[client_idx] + score, model = self.test(test_data, device, args) + for idx in range(len(model_list)): + self._compare_models(model, model_list[idx]) + model_list.append(model) + score_list.append(score) + if args.dataset != "pcba": + logging.info('Client {}, Test ROC-AUC score = {}'.format(client_idx, score)) + wandb.log({"Client {} Test/ROC-AUC".format(client_idx): score}) + else: + logging.info('Client {}, Test PRC-AUC score = {}'.format(client_idx, score)) + wandb.log({"Client {} Test/PRC-AUC".format(client_idx): score}) + avg_score = np.mean(np.array(score_list)) + if args.dataset != "pcba": + logging.info('Test ROC-AUC Score = {}'.format(avg_score)) + wandb.log({"Test/ROC-AUC": avg_score}) + else: + logging.info('Test PRC-AUC Score = {}'.format(avg_score)) + wandb.log({"Test/PRC-AUC": avg_score}) + return True + + def _compare_models(self, model_1, model_2): + models_differ = 0 + for key_item_1, key_item_2 in zip(model_1.state_dict().items(), model_2.state_dict().items()): + if torch.equal(key_item_1[1], key_item_2[1]): + pass + else: + models_differ += 1 + if key_item_1[0] == key_item_2[0]: + logging.info('Mismtach found at', key_item_1[0]) + else: + raise Exception + if models_differ == 0: + logging.info('Models match perfectly! :)') diff --git a/training/sage_readout_trainer_regression.py b/training/sage_readout_trainer_regression.py new file mode 100644 index 0000000..2c7a0f2 --- /dev/null +++ b/training/sage_readout_trainer_regression.py @@ -0,0 +1,172 @@ +import logging + +import numpy as np +import torch +from torch._C import dtype +import wandb +from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score + +from torch import trace, inverse +from scipy.linalg import fractional_matrix_power + +from FedML.fedml_core.trainer.model_trainer import ModelTrainer + + +# Trainer for MoleculeNet. The evaluation metrics are RMSE, R2, and MAE + +class SageMoleculeNetTrainer(ModelTrainer): + + def get_model_params(self): + return self.model.cpu().state_dict() + + def set_model_params(self, model_parameters): + logging.info("set_model_params") + self.model.load_state_dict(model_parameters) + + def train(self, train_data, device, args): + model = self.model + if args.is_mtl: + self.model.omega_corr.to(device) + model.to(device) + model.train() + + test_data = None + try: + test_data = self.test_data + except: + pass + + criterion = torch.nn.MSELoss(reduction='none') if args.dataset != 'qm9' else torch.nn.L1Loss(reduction='none') + if args.client_optimizer == "sgd": + optimizer = torch.optim.SGD(model.parameters(), lr=args.lr) + # optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, weight_decay = args.wd) + else: + optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) + # optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay = args.wd) + + + min_score = np.Inf if args.metric != 'r2' else -np.Inf + + best_model_params = {} + # print('Training on {}'.format(torch.cuda.get_device_name())) + for epoch in range(args.epochs): + for mol_idxs, (forest, feature_matrix, label, _, cli_mask) in enumerate(train_data): + optimizer.zero_grad() + if cli_mask is not None: + cli_mask = cli_mask.to(device=device, dtype=torch.float32, non_blocking=True) + if torch.all(cli_mask == 0).item(): + continue + + forest = [level.to(device=device, dtype=torch.long, non_blocking=True) for level in forest] + feature_matrix = feature_matrix.to(device=device, dtype=torch.float32, non_blocking=True) + label = label.to(device=device, dtype=torch.float32, non_blocking=True) + + + logits = model(forest, feature_matrix) + pred_loss = criterion(logits, label) * cli_mask if cli_mask is not None else criterion(logits, label) + pred_loss = pred_loss.sum() / cli_mask.sum() if cli_mask is not None else pred_loss.mean() + + + if args.is_mtl: + W = self.model.readout.output.weight + lhs = torch.mm(W.t() , torch.inverse(self.model.omega_corr) ) + trace_in = torch.mm(lhs, W) + loss = pred_loss + self.task_reg * torch.trace(trace_in) + args.wd * 0.5 * torch.norm(W)**2 + loss.backward() + optimizer.step() + with torch.no_grad(): + W = self.model.readout.output.weight + mul = torch.mm(W, W.t()) + fr_pow = fractional_matrix_power(mul.cpu(), 1/2) + self.model.omega_corr = torch.nn.Parameter(torch.Tensor(fr_pow / np.trace(fr_pow)).to(device)) + self.model.omega_corr.requires_grad=False + else: + pred_loss.backward() + optimizer.step() + + if ((mol_idxs + 1) % args.frequency_of_the_test == 0) or (mol_idxs == len(train_data) - 1): + if test_data is not None: + test_score, _ = self.test(self.test_data, device, args) + if args.metric != 'r2': + print('Epoch = {}, Iter = {}/{}: Test {} = {}'.format(epoch, mol_idxs + 1, len(train_data), args.metric.upper(),test_score)) + if test_score < min_score: + min_score = test_score + best_model_params = {k: v.cpu() for k, v in model.state_dict().items()} + print('Current best {}= {}'.format(args.metric.upper(),min_score)) + else: + print('Epoch = {}, Iter = {}/{}: Test R2 = {}'.format(epoch, mol_idxs + 1, len(train_data), test_score)) + if test_score > min_score: + min_score = test_score + best_model_params = {k: v.cpu() for k, v in model.state_dict().items()} + print('Current best R2= {}'.format(min_score)) + + self.task_reg *= args.task_reg_decay + + return min_score, best_model_params + + def test(self, test_data, device, args): + logging.info("----------test--------") + model = self.model + model.eval() + model.to(device) + + with torch.no_grad(): + y_pred = [] + y_true = [] + for mol_idx, (forest, feature_matrix, label, _, _) in enumerate(test_data): + forest = [level.to(device=device, dtype=torch.long, non_blocking=True) for level in forest] + feature_matrix = feature_matrix.to(device=device, dtype=torch.float32, non_blocking=True) + label = label.to(device=device, dtype=torch.float32, non_blocking=True) + logits = model(forest, feature_matrix) + y_pred.append(logits.cpu().numpy()) + y_true.append(label.cpu().numpy()) + + # logging.info(y_true) + # logging.info(y_pred) + if args.metric == 'rmse': + score = mean_squared_error(np.array(y_true), np.array(y_pred), squared=False) + elif args.metric == 'r2': + score = r2_score(np.array(y_true), np.array(y_pred)) + else: + score = mean_absolute_error(np.array(y_true), np.array(y_pred)) + + return score, model + + def test_on_the_server(self, train_data_local_dict, test_data_local_dict, device, args=None) -> bool: + logging.info("----------test_on_the_server--------") + # for client_idx in train_data_local_dict.keys(): + # train_data = train_data_local_dict[client_idx] + # train_score = self.test(train_data, device, args) + # logging.info('Client {}, Train ROC-AUC score = {}'.format(client_idx, train_score)) + + model_list, score_list = [], [] + for client_idx in test_data_local_dict.keys(): + test_data = test_data_local_dict[client_idx] + score, model = self.test(test_data, device, args) + for idx in range(len(model_list)): + self._compare_models(model, model_list[idx]) + model_list.append(model) + score_list.append(score) + logging.info('Client {}, Test {} = {}'.format(client_idx,args.metric.upper(), score)) + wandb.log({"Client {} Test/{}".format(client_idx,args.metric.upper()): score}) + + + avg_score = sum(score_list) / len(score_list) + logging.info('Test {} score = {}'.format(args.metric.upper(),avg_score)) + wandb.log({"Test/{}".format(args.metric.upper()): avg_score}) + + return True + + def _compare_models(self, model_1, model_2): + models_differ = 0 + for key_item_1, key_item_2 in zip(model_1.state_dict().items(), model_2.state_dict().items()): + if torch.equal(key_item_1[1], key_item_2[1]): + pass + else: + models_differ += 1 + if (key_item_1[0] == key_item_2[0]): + logging.info('Mismatch found at', key_item_1[0]) + else: + raise Exception + if models_differ == 0: + logging.info('Models match perfectly! :)')