diff --git a/dance/modules/base.py b/dance/modules/base.py index 62d7bebd..ad022960 100644 --- a/dance/modules/base.py +++ b/dance/modules/base.py @@ -1,5 +1,13 @@ +import os from abc import ABC, abstractmethod, abstractstaticmethod +from contextlib import contextmanager +from functools import partialmethod +from operator import attrgetter +from time import time +import torch + +from dance import logger from dance.data import Data from dance.transforms.base import BaseTransform from dance.typing import Any, Mapping, Optional, Tuple, Union @@ -47,6 +55,88 @@ def score(self, x, y, score_func: Optional[Union[str, Mapping[Any, float]]] = No return (score, y_pred) if return_pred else score +class BasePretrain(ABC): + + @property + def is_pretrained(self) -> bool: + return getattr(self, "_is_pretrained", False) + + def _pretrain(self, *args, force_pretrain: bool = False, **kwargs): + pt_path = getattr(self, "pretrain_path", None) + if not force_pretrain: + if self.is_pretrained: + logger.info("Skipping pre_train as the model appears to be pretrained already. " + "If you wish to force pre-training, please set 'force_pretrain' to True.") + return + + if pt_path is not None and os.path.isfile(pt_path): + logger.info(f"Loading pre-trained model from {pt_path}") + self.load_pretrained(pt_path) + self._is_pretrained = True + return + + logger.info("Pre-training started") + if pt_path is None: + logger.warning("`pretrain_path` is not set, pre-trained model will not be saved.") + else: + logger.info(f"Pre-trained model will to saved to {pt_path}") + + t = time() + self.pretrain(*args, **kwargs) + elapsed = time() - t + logger.info(f"Pre-training finished (took {elapsed:.2f} seconds)") + self._is_pretrained = True + + if pt_path is not None: + logger.info(f"Saving pre-trained model to {pt_path}") + self.save_pretrained(pt_path) + + def pretrain(self, *args, **kwargs): + ... + + def save_pretrained(self, path, **kwargs): + ... + + def load_pretrained(self, path, **kwargs): + ... + + +class TorchNNPretrain(BasePretrain, ABC): + + def _fix_unfix_modules(self, *module_names: Tuple[str], unfix: bool = False, single: bool = True): + modules = attrgetter(*module_names)(self) + modules = [modules] if single else modules + + for module in modules: + for p in module.parameters(): + p.requires_grad = unfix + + fix_module = partialmethod(_fix_unfix_modules, unfix=False, single=True) + fix_modules = partialmethod(_fix_unfix_modules, unfix=False, single=False) + unfix_module = partialmethod(_fix_unfix_modules, unfix=True, single=True) + unfix_modules = partialmethod(_fix_unfix_modules, unfix=True, single=False) + + @contextmanager + def pretrain_context(self, *module_names: Tuple[str]): + """Unlock module for pretraining and lock once pretraining is done.""" + is_single = len(module_names) == 1 + logger.info(f"Entering pre-training context; unlocking: {module_names}") + self._fix_unfix_modules(*module_names, unfix=True, single=is_single) + try: + yield + finally: + logger.info(f"Exiting pre-training context; locking: {module_names}") + self._fix_unfix_modules(*module_names, unfix=False, single=is_single) + + def save_pretrained(self, path): + torch.save(self.state_dict(), path) + + def load_pretrained(self, path): + device = getattr(self, "device", None) + checkpoint = torch.load(path, map_location=device) + self.load_state_dict(checkpoint) + + class BaseClassificationMethod(BaseMethod): _DEFAULT_METRIC = "acc" diff --git a/dance/modules/single_modality/clustering/__init__.py b/dance/modules/single_modality/clustering/__init__.py index bbb5a126..a2b9c0f1 100644 --- a/dance/modules/single_modality/clustering/__init__.py +++ b/dance/modules/single_modality/clustering/__init__.py @@ -1,7 +1,7 @@ from .graphsc import GraphSC from .scdcc import ScDCC from .scdeepcluster import ScDeepCluster -from .scdsc import SCDSC +from .scdsc import ScDSCModel from .sctag import ScTAG __all__ = [ diff --git a/dance/modules/single_modality/clustering/scdsc.py b/dance/modules/single_modality/clustering/scdsc.py index 9890b185..8406423c 100644 --- a/dance/modules/single_modality/clustering/scdsc.py +++ b/dance/modules/single_modality/clustering/scdsc.py @@ -8,47 +8,104 @@ neural network." Briefings in Bioinformatics 23.2 (2022): bbac018. """ - -import math - import numpy as np +import pandas as pd import scanpy as sc +import scipy.sparse as sp import torch import torch.nn as nn import torch.nn.functional as F -from sklearn import metrics from torch.nn import Linear from torch.nn.parameter import Parameter from torch.optim import Adam -from torch.optim.optimizer import Optimizer from torch.utils.data import DataLoader, TensorDataset +from dance import logger +from dance.modules.base import BaseClusteringMethod, TorchNNPretrain from dance.transforms import AnnDataTransform, Compose, SaveRaw, SetConfig from dance.transforms.graph import NeighborGraph from dance.transforms.preprocess import sparse_mx_to_torch_sparse_tensor -from dance.typing import LogLevel +from dance.typing import Any, LogLevel, Optional, Tuple +from dance.utils import get_device from dance.utils.loss import ZINBLoss -from dance.utils.metrics import cluster_acc -class SCDSCWrapper: +class ScDSC(TorchNNPretrain, BaseClusteringMethod): """scDSC wrapper class. Parameters ---------- - args : argparse.Namespace - a Namespace contains arguments of scDSC. For details of parameters in parser args, please refer to link (parser help document). + pretrain_path + Path of saved autoencoder weights. + sigma + Balance parameter. + n_enc_1 + Output dimension of encoder layer 1. + n_enc_2 + Output dimension of encoder layer 2. + n_enc_3 + Output dimension of encoder layer 3. + n_dec_1 + Output dimension of decoder layer 1. + n_dec_2 + Output dimension of decoder layer 2. + n_dec_3 + Output dimension of decoder layer 3. + n_z1 + Output dimension of hidden layer 1. + n_z2 + Output dimension of hidden layer 2. + n_z3 + Output dimension of hidden layer 3. + n_clusters + Number of clusters. + n_input + Input feature dimension. + v + Parameter of soft assignment. + device + Computing device. """ - def __init__(self, args): + def __init__( + self, + pretrain_path: str, + sigma: float = 1, + n_enc_1: int = 512, + n_enc_2: int = 256, + n_enc_3: int = 256, + n_dec_1: int = 256, + n_dec_2: int = 256, + n_dec_3: int = 512, + n_z1: int = 256, + n_z2: int = 128, + n_z3: int = 32, + n_clusters: int = 100, + n_input: int = 10, + v: float = 1, + device: str = "auto", + ): super().__init__() - self.args = args - self.device = args.device - self.model = SCDSC(args).to(self.device) - self.model_pre = AE(n_enc_1=args.n_enc_1, n_enc_2=args.n_enc_2, n_enc_3=args.n_enc_3, n_dec_1=args.n_dec_1, - n_dec_2=args.n_dec_2, n_dec_3=args.n_dec_3, n_input=args.n_input, n_z1=args.n_z1, - n_z2=args.n_z2, n_z3=args.n_z3).to(self.device) + self.pretrain_path = pretrain_path + self.device = get_device(device) + self.model = ScDSCModel( + sigma=sigma, + n_enc_1=n_enc_1, + n_enc_2=n_enc_2, + n_enc_3=n_enc_3, + n_dec_1=n_dec_1, + n_dec_2=n_dec_2, + n_dec_3=n_dec_3, + n_z1=n_z1, + n_z2=n_z2, + n_z3=n_z3, + n_clusters=n_clusters, + n_input=n_input, + v=v, + device=self.device, + ).to(self.device) + self.fix_module("model.ae") @staticmethod def preprocessing_pipeline(n_top_genes: int = 2000, n_neighbors: int = 50, log_level: LogLevel = "INFO"): @@ -70,8 +127,8 @@ def preprocessing_pipeline(n_top_genes: int = 2000, n_neighbors: int = 50, log_l # Construct k-neighbors graph using the noramlized feature matrix NeighborGraph(n_neighbors=n_neighbors, metric="correlation", channel="X"), SetConfig({ - "feature_channel": [None, None, "n_counts", "NeighborGraph"], - "feature_channel_type": ["X", "raw_X", "obs", "obsp"], + "feature_channel": ["NeighborGraph", None, None, "n_counts"], + "feature_channel_type": ["obsp", "X", "raw_X", "obs"], "label_channel": "Group" }), log_level=log_level, @@ -82,106 +139,109 @@ def target_distribution(self, q): Parameters ---------- - q : - soft label. + q + Soft label. Returns ------- - p : - target distribution. + p + Target distribution. """ p = q**2 / q.sum(0) return (p.t() / p.sum(1)).t() - def pretrain_ae(self, x, batch_size, n_epochs, fname, lr=1e-3): + def pretrain(self, x, batch_size=256, n_epochs=200, lr=1e-3): """Pretrain autoencoder. Parameters ---------- - x : np.ndarray - input features. - batch_size : int - size of batch. - n_epochs : int - number of epochs. - lr : float optional - learning rate. - fname : str - path to save autoencoder weights. - - Returns - ------- - None. + x + Input features. + batch_size + Size of batch. + n_epochs + Number of epochs. + lr + Learning rate. """ - print("Pretrain:") - device = self.device - x_tensor = torch.from_numpy(x) - train_loader = DataLoader(TensorDataset(x_tensor), batch_size, shuffle=True) - model = self.model_pre - optimizer = Adam(model.parameters(), lr=lr) - for epoch in range(n_epochs): - - total_loss = total_size = 0 - for batch_idx, (x_batch, ) in enumerate(train_loader): - x_batch = x_batch.to(device) - x_bar, _, _, _, _, _, _, _ = model(x_batch) - - loss = F.mse_loss(x_bar, x_batch) - optimizer.zero_grad() - loss.backward() - optimizer.step() - - size = x_batch.shape[0] - total_size += size - total_loss += loss.item() * size - - print(f"Pretrain epoch {epoch + 1:4d}, MSE loss:{total_loss / total_size:.8f}") - - torch.save(model.state_dict(), fname) - - def fit(self, x, y, X_raw, n_counts, adj, lr=1e-03, n_epochs=300, bcl=0.1, cl=0.01, rl=1, zl=0.1): + with self.pretrain_context("model.ae"): + x_tensor = torch.from_numpy(x) + train_loader = DataLoader(TensorDataset(x_tensor), batch_size, shuffle=True) + model = self.model.ae + optimizer = Adam(model.parameters(), lr=lr) + for epoch in range(n_epochs): + + total_loss = total_size = 0 + for batch_idx, (x_batch, ) in enumerate(train_loader): + x_batch = x_batch.to(self.device) + x_bar, _, _, _, _, _, _, _ = model(x_batch) + + loss = F.mse_loss(x_bar, x_batch) + optimizer.zero_grad() + loss.backward() + optimizer.step() + + size = x_batch.shape[0] + total_size += size + total_loss += loss.item() * size + + logger.info(f"Pretrain epoch {epoch + 1:4d}, MSE loss:{total_loss / total_size:.8f}") + + def save_pretrained(self, path): + torch.save(self.model.ae.state_dict(), path) + + def load_pretrained(self, path): + checkpoint = torch.load(self.pretrain_path, map_location=self.device) + self.model.ae.load_state_dict(checkpoint) + + def fit( + self, + inputs: Tuple[sp.spmatrix, np.ndarray, np.ndarray, pd.Series], + y: np.ndarray, + lr: float = 1e-03, + n_epochs: int = 300, + bcl: float = 0.1, + cl: float = 0.01, + rl: float = 1, + zl: float = 0.1, + pt_epochs: int = 200, + pt_batch_size: int = 256, + pt_lr: float = 1e-3, + ): """Train model. Parameters ---------- - x : np.ndarray - input features. - y : np.ndarray - labels. - X_raw : - raw input features. - n_counts : list - total counts for each cell. - adj : - adjacency matrix as a sicpy sparse matrix. - lr : float optional - learning rate. - n_epochs : int optional - number of epochs. - bcl : float optional - parameter of binary crossentropy loss. - cl : float optional - parameter of Kullback–Leibler divergence loss. - rl : float optional - parameter of reconstruction loss. - zl : float optional - parameter of ZINB loss. - - Returns - ------- - None. + inputs + A tuple containing (1) the adjacency matrix, (2) the input features, (3) the raw input features, and (4) + the total counts for each cell. + y + Label. + lr + Learning rate. + n_epochs + Number of epochs. + bcl + Parameter of binary crossentropy loss. + cl + Parameter of Kullback–Leibler divergence loss. + rl + Parameter of reconstruction loss. + zl + Parameter of ZINB loss. """ - print("Train:") + adj, x, x_raw, n_counts = inputs + self._pretrain(x, batch_size=pt_batch_size, n_epochs=pt_epochs, lr=pt_lr) + device = self.device model = self.model - optimizer = Adam(model.parameters(), lr=lr) - # optimizer = RAdam(model.parameters(), lr=lr) + optimizer = Adam(filter(lambda x: x.requires_grad, model.parameters()), lr=lr) adj = sparse_mx_to_torch_sparse_tensor(adj).to(device) - X_raw = torch.tensor(X_raw).to(device) + x_raw = torch.tensor(x_raw).to(device) sf = torch.tensor(n_counts / np.median(n_counts)).to(device) data = torch.from_numpy(x).to(device) @@ -203,10 +263,10 @@ def fit(self, x, y, X_raw, n_counts, adj, lr=1e-03, n_epochs=300, bcl=0.1, cl=0. p = self.target_distribution(tmp_q) # calculate ari score for model selection - _, _, ari = self.score(y) + ari = self.score(None, y) aris.append(ari) keys.append(key := f"epoch{epoch}") - print("Epoch %3d, ARI: %.4f, Best ARI: %.4f" % (epoch + 1, ari, max(aris))) + logger.info("Epoch %3d, ARI: %.4f, Best ARI: %.4f", epoch + 1, ari, max(aris)) P[key] = p Q[key] = tmp_q @@ -215,9 +275,9 @@ def fit(self, x, y, X_raw, n_counts, adj, lr=1e-03, n_epochs=300, bcl=0.1, cl=0. x_bar, q, pred, z, meanbatch, dispbatch, pibatch, zinb_loss = model(data, adj) binary_crossentropy_loss = F.binary_cross_entropy(q, p) - ce_loss = F.kl_div(pred.log(), p, reduction='batchmean') + ce_loss = F.kl_div(pred.log(), p, reduction="batchmean") re_loss = F.mse_loss(x_bar, data) - zinb_loss = zinb_loss(X_raw, meanbatch, dispbatch, pibatch, sf) + zinb_loss = zinb_loss(x_raw, meanbatch, dispbatch, pibatch, sf) loss = bcl * binary_crossentropy_loss + cl * ce_loss + rl * re_loss + zl * zinb_loss optimizer.zero_grad() @@ -227,158 +287,166 @@ def fit(self, x, y, X_raw, n_counts, adj, lr=1e-03, n_epochs=300, bcl=0.1, cl=0. index = np.argmax(aris) self.q = Q[keys[index]] - def predict(self): - """Get predictions from the trained model. + def predict_proba(self, x: Optional[Any] = None) -> np.ndarray: + """Get the predicted propabilities for each cell. Parameters ---------- - None. + x + Not used, for compatibility with the BaseClusteringMethod class. Returns ------- - y_pred : np.array - prediction of given clustering method. + pred_prop + Predicted probability for each cell. """ - y_pred = torch.argmax(self.q, dim=1).data.cpu().numpy() - return y_pred + pred_prob = self.q.detach().clone().cpu().numpy() + return pred_prob - def score(self, y): - """Evaluate the trained model. + def predict(self, x: Optional[Any] = None) -> np.ndarray: + """Get predictions from the trained model. Parameters ---------- - y : list - true labels. + x + Not used, for compatibility with the BaseClusteringMethod class. Returns ------- - acc : float - accuracy. - nmi : float - normalized mutual information. - ari : float - adjusted Rand index. + pred + Predicted clustering assignment for each cell. """ - y_pred = torch.argmax(self.q, dim=1).data.cpu().numpy() - acc = np.round(cluster_acc(y, y_pred), 5) - nmi = np.round(metrics.normalized_mutual_info_score(y, y_pred), 5) - ari = np.round(metrics.adjusted_rand_score(y, y_pred), 5) - return acc, nmi, ari + pred = self.predict_proba().argmax(1) + return pred -class SCDSC(nn.Module): +class ScDSCModel(nn.Module): """scDSC class. Parameters ---------- - args : argparse.Namespace - a Namespace contains arguments of GCNAE. For details of parameters in parser args, please refer to link (parser help document). - device : str - computing device. - sigma : float - balance parameter. - pretrain_path : str - path of saved autoencoder weights. - n_enc_1 : int - output dimension of encoder layer 1. - n_enc_2 : int - output dimension of encoder layer 2. - n_enc_3 : int - output dimension of encoder layer 3. - n_dec_1 : int - output dimension of decoder layer 1. - n_dec_2 : int - output dimension of decoder layer 2. - n_dec_3 : int - output dimension of decoder layer 3. - n_z1 : int - output dimension of hidden layer 1. - n_z2 : int - output dimension of hidden layer 2. - n_z3 : int - output dimension of hidden layer 3. - n_clusters : int - number of clusters. - n_input : int - input feature dimension. - v : float - parameter of soft assignment. + sigma + Balance parameter. + n_enc_1 + Output dimension of encoder layer 1. + n_enc_2 + Output dimension of encoder layer 2. + n_enc_3 + Output dimension of encoder layer 3. + n_dec_1 + Output dimension of decoder layer 1. + n_dec_2 + Output dimension of decoder layer 2. + n_dec_3 + Output dimension of decoder layer 3. + n_z1 + Output dimension of hidden layer 1. + n_z2 + Output dimension of hidden layer 2. + n_z3 + Output dimension of hidden layer 3. + n_clusters + Number of clusters. + n_input + Input feature dimension. + v + Parameter of soft assignment. + device + Computing device. """ - def __init__(self, args): + def __init__( + self, + sigma: float = 1, + n_enc_1: int = 512, + n_enc_2: int = 256, + n_enc_3: int = 256, + n_dec_1: int = 256, + n_dec_2: int = 256, + n_dec_3: int = 512, + n_z1: int = 256, + n_z2: int = 128, + n_z3: int = 32, + n_clusters: int = 10, + n_input: int = 100, + v: float = 1, + device: str = "auto", + ): super().__init__() - device = args.device - self.sigma = args.sigma - self.pretrain_path = args.pretrain_path + self.device = get_device(device) + self.sigma = sigma + self.ae = AE( - n_enc_1=args.n_enc_1, - n_enc_2=args.n_enc_2, - n_enc_3=args.n_enc_3, - n_dec_1=args.n_dec_1, - n_dec_2=args.n_dec_2, - n_dec_3=args.n_dec_3, - n_input=args.n_input, - n_z1=args.n_z1, - n_z2=args.n_z2, - n_z3=args.n_z3, + n_enc_1=n_enc_1, + n_enc_2=n_enc_2, + n_enc_3=n_enc_3, + n_dec_1=n_dec_1, + n_dec_2=n_dec_2, + n_dec_3=n_dec_3, + n_input=n_input, + n_z1=n_z1, + n_z2=n_z2, + n_z3=n_z3, ) - self.gnn_1 = GNNLayer(args.n_input, args.n_enc_1) - self.gnn_2 = GNNLayer(args.n_enc_1, args.n_enc_2) - self.gnn_3 = GNNLayer(args.n_enc_2, args.n_enc_3) - self.gnn_4 = GNNLayer(args.n_enc_3, args.n_z1) - self.gnn_5 = GNNLayer(args.n_z1, args.n_z2) - self.gnn_6 = GNNLayer(args.n_z2, args.n_z3) - self.gnn_7 = GNNLayer(args.n_z3, args.n_clusters) + + self.gnn_1 = GNNLayer(n_input, n_enc_1) + self.gnn_2 = GNNLayer(n_enc_1, n_enc_2) + self.gnn_3 = GNNLayer(n_enc_2, n_enc_3) + self.gnn_4 = GNNLayer(n_enc_3, n_z1) + self.gnn_5 = GNNLayer(n_z1, n_z2) + self.gnn_6 = GNNLayer(n_z2, n_z3) + self.gnn_7 = GNNLayer(n_z3, n_clusters) # cluster layer - self.cluster_layer = Parameter(torch.Tensor(args.n_clusters, args.n_z3)) + self.cluster_layer = Parameter(torch.Tensor(n_clusters, n_z3)) torch.nn.init.xavier_normal_(self.cluster_layer.data) - self._dec_mean = nn.Sequential(nn.Linear(args.n_dec_3, args.n_input), MeanAct()) - self._dec_disp = nn.Sequential(nn.Linear(args.n_dec_3, args.n_input), DispAct()) - self._dec_pi = nn.Sequential(nn.Linear(args.n_dec_3, args.n_input), nn.Sigmoid()) + self._dec_mean = nn.Sequential(nn.Linear(n_dec_3, n_input), MeanAct()) + self._dec_disp = nn.Sequential(nn.Linear(n_dec_3, n_input), DispAct()) + self._dec_pi = nn.Sequential(nn.Linear(n_dec_3, n_input), nn.Sigmoid()) # degree - self.v = args.v - self.zinb_loss = ZINBLoss().to(device) + self.v = v + self.zinb_loss = ZINBLoss().to(self.device) + + self.to(self.device) def forward(self, x, adj): """Forward propagation. Parameters ---------- - x : - input features. - adj : - adjacency matrix + x + Input features. + adj + Adjacency matrix Returns ------- x_bar: - reconstructed features. - q : - soft label. + Reconstructed features. + q + Soft label. predict: - prediction given by softmax assignment of embedding of GCN module - z3 : - embedding of autoencoder. - _mean : - data mean from ZINB. - _disp : - data dispersion from ZINB. - _pi : - data dropout probability from ZINB. + Prediction given by softmax assignment of embedding of GCN module + z3 + Embedding of autoencoder. + _mean + Data mean from ZINB. + _disp + Data dispersion from ZINB. + _pi + Data dropout probability from ZINB. zinb_loss: ZINB loss class. """ # DNN Module - self.ae.load_state_dict(torch.load(self.pretrain_path, map_location='cpu')) x_bar, tra1, tra2, tra3, z3, z2, z1, dec_h3 = self.ae(x) - sigma = self.sigma # GCN Module + sigma = self.sigma h = self.gnn_1(x, adj) h = self.gnn_2((1 - sigma) * h + sigma * tra1, adj) h = self.gnn_3((1 - sigma) * h + sigma * tra2, adj) @@ -406,10 +474,10 @@ class GNNLayer(nn.Module): Parameters ---------- - in_features : int - input dimension of GNN layer. - out_features : int - output dimension of GNN layer. + in_features + Input dimension of GNN layer. + out_features + Output dimension of GNN layer. """ @@ -435,26 +503,26 @@ class AE(nn.Module): Parameters ---------- - n_enc_1 : int - output dimension of encoder layer 1. - n_enc_2 : int - output dimension of encoder layer 2. - n_enc_3 : int - output dimension of encoder layer 3. - n_dec_1 : int - output dimension of decoder layer 1. - n_dec_2 : int - output dimension of decoder layer 2. - n_dec_3 : int - output dimension of decoder layer 3. - n_input : int - input feature dimension. - n_z1 : int - output dimension of hidden layer 1. - n_z2 : int - output dimension of hidden layer 2. - n_z3 : int - output dimension of hidden layer 3. + n_enc_1 + Output dimension of encoder layer 1. + n_enc_2 + Output dimension of encoder layer 2. + n_enc_3 + Output dimension of encoder layer 3. + n_dec_1 + Output dimension of decoder layer 1. + n_dec_2 + Output dimension of decoder layer 2. + n_dec_3 + Output dimension of decoder layer 3. + n_input + Input feature dimension. + n_z1 + Output dimension of hidden layer 1. + n_z2 + Output dimension of hidden layer 2. + n_z3 + Output dimension of hidden layer 3. """ @@ -488,27 +556,27 @@ def forward(self, x): Parameters ---------- - x : - input features. + x + Input features. Returns ------- - x_bar: - reconstructed features. - enc_h1: - output of encoder layer 1. - enc_h2: - output of encoder layer 2. - enc_h3: - output of encoder layer 3. - z3 : - output of hidden layer 3. - z2 : - output of hidden layer 2. - z1 : - output of hidden layer 1. - dec_h3 : - output of decoder layer 3. + x_bar + Reconstructed features. + enc_h1 + Output of encoder layer 1. + enc_h2 + Output of encoder layer 2. + enc_h3 + Output of encoder layer 3. + z3 + Output of hidden layer 3. + z2 + Output of hidden layer 2. + z1 + Output of hidden layer 1. + dec_h3 + Output of decoder layer 3. """ enc_h1 = F.relu(self.BN1(self.enc_1(x))) @@ -527,119 +595,6 @@ def forward(self, x): return x_bar, enc_h1, enc_h2, enc_h3, z3, z2, z1, dec_h3 -class RAdam(Optimizer): - """RAdam optimizer class. - - Parameters - ---------- - params : - model parameters. - lr : float optional - learning rate. - betas : tuple optional - coefficients used for computing running averages of gradient and its square. - eps : float optional - term added to the denominator to improve numerical stability. - weight decay : float optional - weight decay (L2 penalty). - degenerated_to_sgd : bool optional - degenerated to SGD or not. - - """ - - def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, degenerated_to_sgd=True): - if not 0.0 <= lr: - raise ValueError("Invalid learning rate: {}".format(lr)) - if not 0.0 <= eps: - raise ValueError("Invalid epsilon value: {}".format(eps)) - if not 0.0 <= betas[0] < 1.0: - raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) - if not 0.0 <= betas[1] < 1.0: - raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) - - self.degenerated_to_sgd = degenerated_to_sgd - if isinstance(params, (list, tuple)) and len(params) > 0 and isinstance(params[0], dict): - for param in params: - if 'betas' in param and (param['betas'][0] != betas[0] or param['betas'][1] != betas[1]): - param['buffer'] = [[None, None, None] for _ in range(10)] - defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, - buffer=[[None, None, None] for _ in range(10)]) - super().__init__(params, defaults) - - def __setstate__(self, state): - super().__setstate__(state) - - def step(self, closure=None): - - loss = None - if closure is not None: - loss = closure() - - for group in self.param_groups: - - for p in group['params']: - if p.grad is None: - continue - grad = p.grad.data.float() - if grad.is_sparse: - raise RuntimeError('RAdam does not support sparse gradients') - - p_data_fp32 = p.data.float() - - state = self.state[p] - - if len(state) == 0: - state['step'] = 0 - state['exp_avg'] = torch.zeros_like(p_data_fp32) - state['exp_avg_sq'] = torch.zeros_like(p_data_fp32) - else: - state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32) - state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32) - - exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] - beta1, beta2 = group['betas'] - - exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) - exp_avg.mul_(beta1).add_(1 - beta1, grad) - - state['step'] += 1 - buffered = group['buffer'][int(state['step'] % 10)] - if state['step'] == buffered[0]: - N_sma, step_size = buffered[1], buffered[2] - else: - buffered[0] = state['step'] - beta2_t = beta2**state['step'] - N_sma_max = 2 / (1 - beta2) - 1 - N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t) - buffered[1] = N_sma - - # more conservative since it's an approximated value - if N_sma >= 5: - step_size = math.sqrt( - (1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / - (N_sma_max - 2)) / (1 - beta1**state['step']) - elif self.degenerated_to_sgd: - step_size = 1.0 / (1 - beta1**state['step']) - else: - step_size = -1 - buffered[2] = step_size - - # more conservative since it's an approximated value - if N_sma >= 5: - if group['weight_decay'] != 0: - p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32) - denom = exp_avg_sq.sqrt().add_(group['eps']) - p_data_fp32.addcdiv_(-step_size * group['lr'], exp_avg, denom) - p.data.copy_(p_data_fp32) - elif step_size > 0: - if group['weight_decay'] != 0: - p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32) - p_data_fp32.add_(-step_size * group['lr'], exp_avg) - p.data.copy_(p_data_fp32) - - return loss - - class MeanAct(nn.Module): """Mean activation class.""" diff --git a/dance/modules/single_modality/clustering/sctag.py b/dance/modules/single_modality/clustering/sctag.py index 8c45ab87..3e612874 100644 --- a/dance/modules/single_modality/clustering/sctag.py +++ b/dance/modules/single_modality/clustering/sctag.py @@ -9,8 +9,6 @@ doi:10.1609/aaai.v36i4.20392. """ -import os - import dgl import numpy as np import scanpy as sc @@ -23,14 +21,14 @@ from torch.nn import Parameter from dance import logger -from dance.modules.base import BaseClusteringMethod +from dance.modules.base import BaseClusteringMethod, TorchNNPretrain from dance.transforms import AnnDataTransform, CellPCA, Compose, SaveRaw, SetConfig from dance.transforms.graph import NeighborGraph from dance.typing import Any, LogLevel, Optional, Tuple from dance.utils.loss import ZINBLoss, dist_loss -class ScTAG(nn.Module, BaseClusteringMethod): +class ScTAG(nn.Module, TorchNNPretrain, BaseClusteringMethod): """The scTAG clustering model. Parameters @@ -108,10 +106,6 @@ def init_model(self, adj: np.ndarray, x: np.ndarray): self.zinb_loss = ZINBLoss().to(self.device) self.to(self.device) - @property - def is_pretrained(self) -> bool: - return self._is_pretrained - @property def in_dim(self) -> int: if self._in_dim is None: @@ -183,7 +177,7 @@ def forward(self, g, x_input): return adj_out, z, q, _mean, _disp, _pi - def pre_train( + def pretrain( self, adj, x, @@ -233,21 +227,6 @@ def pre_train( or even the pre-trained model file is available to load. """ - pt_path = self.pretrain_save_path - if not force_pretrain: - if self.is_pretrained: - logger.info("Skipping pre_train as the model appears to be pretrained already. " - "If you wish to force pre-training, please set 'force_pretrain' to True.") - return - - if pt_path is not None and os.path.isfile(pt_path): - logger.info(f"Loading pre-trained model from {pt_path}") - checkpoint = torch.load(pt_path) - # TODO: change device? - self.load_state_dict(checkpoint) - self._is_pretrained = True - return - x = torch.Tensor(x).to(self.device) x_raw = torch.Tensor(x_raw).to(self.device) scale_factor = torch.tensor(n_counts / np.median(n_counts)).to(self.device) @@ -278,13 +257,6 @@ def pre_train( loss.backward() optimizer.step() - if pt_path is not None: - logger.info(f"Saving pre-trained model to {pt_path}") - torch.save(self.state_dict(), pt_path) - - logger.info("Pre-training done") - self._is_pretrained = True - def fit( self, inputs: Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray], @@ -334,7 +306,7 @@ def fit( """ adj, x, x_raw, n_counts = inputs self.init_model(adj, x) - self.pre_train(adj, x, x_raw, n_counts, epochs=pretrain_epochs, info_step=info_step, lr=lr, w_a=w_a, w_x=w_x, + self._pretrain(adj, x, x_raw, n_counts, epochs=pretrain_epochs, info_step=info_step, lr=lr, w_a=w_a, w_x=w_x, w_d=w_d, min_dist=min_dist, max_dist=max_dist, force_pretrain=force_pretrain) x = torch.Tensor(x).to(self.device) diff --git a/examples/single_modality/clustering/scdsc.py b/examples/single_modality/clustering/scdsc.py index 35a83445..24ab8295 100644 --- a/examples/single_modality/clustering/scdsc.py +++ b/examples/single_modality/clustering/scdsc.py @@ -1,18 +1,14 @@ import argparse -import os -from argparse import Namespace -from time import time from dance.data import Data from dance.datasets.singlemodality import ClusteringDataset -from dance.modules.single_modality.clustering.scdsc import SCDSCWrapper +from dance.modules.single_modality.clustering.scdsc import ScDSC from dance.utils import set_seed # for repeatability set_seed(42) if __name__ == "__main__": - time_start = time() parser = argparse.ArgumentParser() # model_para = [n_enc_1(n_dec_3), n_enc_2(n_dec_2), n_enc_3(n_dec_1)] @@ -36,6 +32,7 @@ parser.add_argument("--n_dec_3", default=model_para[0], type=int) parser.add_argument("--topk", type=int, default=50) parser.add_argument("--lr", type=float, default=1e-2) + parser.add_argument("--pretrain_lr", type=float, default=1e-3) parser.add_argument("--pretrain_epochs", type=int, default=200) parser.add_argument("--n_epochs", type=int, default=1000) parser.add_argument("--n_z1", default=Cluster_para[0], type=int) @@ -43,7 +40,7 @@ parser.add_argument("--n_z3", default=Cluster_para[2], type=int) parser.add_argument("--n_input", type=int, default=Cluster_para[4]) parser.add_argument("--n_clusters", type=int, default=Cluster_para[5]) - parser.add_argument("--device", type=str, default="cuda") + parser.add_argument("--device", type=str, default="auto") parser.add_argument("--v", type=int, default=1) parser.add_argument("--nb_genes", type=int, default=2000) parser.add_argument("--binary_crossentropy_loss", type=float, default=Balance_para[0]) @@ -53,43 +50,30 @@ parser.add_argument("--sigma", type=float, default=Balance_para[4]) args = parser.parse_args() - # File = [gene_expresion data file, Graph file, h5 file, pretrain_path] - File = [ - os.path.join("data", args.name), - None, - os.path.join("data", f"{args.name}.h5"), - os.path.join("model", f"{args.name}_pre.pkl"), - ] - args.pretrain_path = File[3] - if not os.path.exists("./graph/"): - os.makedirs("./graph/") - if not os.path.exists("./model/"): - os.makedirs("./model/") - adata, labels = ClusteringDataset("./data", args.name).load_data() adata.obsm["Group"] = labels data = Data(adata, train_size="all") - preprocessing_pipeline = SCDSCWrapper.preprocessing_pipeline(n_top_genes=args.nb_genes, n_neighbors=args.topk) + preprocessing_pipeline = ScDSC.preprocessing_pipeline(n_top_genes=args.nb_genes, n_neighbors=args.topk) preprocessing_pipeline(data) - (x, x_raw, n_counts, adj), y = data.get_data(return_type="default") - args.n_input = x.shape[1] + # inputs: adj, x, x_raw, n_counts + inputs, y = data.get_data(return_type="default") + args.n_input = inputs[1].shape[1] - # Pretrain AE - model = SCDSCWrapper(Namespace(**vars(args))) - if not os.path.exists(args.pretrain_path): - model.pretrain_ae(x, args.batch_size, args.pretrain_epochs, args.pretrain_path) + model = ScDSC(pretrain_path=f"{args.name}_scdcs_pre.pkl", sigma=args.sigma, n_enc_1=args.n_enc_1, + n_enc_2=args.n_enc_2, n_enc_3=args.n_enc_3, n_dec_1=args.n_dec_1, n_dec_2=args.n_dec_2, + n_dec_3=args.n_dec_3, n_z1=args.n_z1, n_z2=args.n_z2, n_z3=args.n_z3, n_clusters=args.n_clusters, + n_input=args.n_input, v=args.v, device=args.device) - # Train scDSC - model.fit(x, y, x_raw, n_counts, adj, lr=args.lr, n_epochs=args.n_epochs, bcl=args.binary_crossentropy_loss, - cl=args.ce_loss, rl=args.re_loss, zl=args.zinb_loss) - print(f"Running Time:{int(time() - time_start)} seconds") + # Build and train model + model.fit(inputs, y, lr=args.lr, n_epochs=args.n_epochs, bcl=args.binary_crossentropy_loss, cl=args.ce_loss, + rl=args.re_loss, zl=args.zinb_loss, pt_epochs=args.pretrain_epochs, pt_batch_size=args.batch_size, + pt_lr=args.pretrain_lr) - y_pred = model.predict() - print(f"Prediction (first ten): {y_pred[:10]}") - acc, nmi, ari = model.score(y) - print("ACC: {:.4f}, NMI: {:.4f}, ARI: {:.4f}".format(acc, nmi, ari)) + # Evaluate model predictions + score = model.score(None, y) + print(f"{score=:.4f}") """Reproduction information 10X PBMC: python scdsc.py --name 10X_PBMC --method cosine --topk 30 --v 7 --binary_crossentropy_loss 0.75 --ce_loss 0.5 --re_loss 0.1 --zinb_loss 2.5 --sigma 0.4 diff --git a/examples/single_modality/clustering/sctag.py b/examples/single_modality/clustering/sctag.py index 9725b8aa..354a90ca 100644 --- a/examples/single_modality/clustering/sctag.py +++ b/examples/single_modality/clustering/sctag.py @@ -56,7 +56,7 @@ inputs, y = data.get_train_data() n_clusters = len(np.unique(y)) - # Build model & training + # Build and train model model = ScTAG(n_clusters=n_clusters, k=args.k, hidden_dim=args.hidden_dim, latent_dim=args.latent_dim, dec_dim=args.dec_dim, dropout=args.dropout, device=args.device, alpha=args.alpha, pretrain_save_path=args.pretrain_file)