diff --git a/dance/modules/spatial/spatial_domain/stagate.py b/dance/modules/spatial/spatial_domain/stagate.py index cade693d..11c8e90f 100644 --- a/dance/modules/spatial/spatial_domain/stagate.py +++ b/dance/modules/spatial/spatial_domain/stagate.py @@ -9,21 +9,18 @@ """ -import matplotlib.pyplot as plt import numpy as np -import pandas as pd import scanpy as sc import scipy.sparse as sp -import sklearn.neighbors import torch import torch.nn as nn import torch.nn.functional as F from sklearn import mixture +from sklearn.metrics.cluster import adjusted_rand_score from torch import Tensor from torch.nn import Parameter from torch_geometric.data import Data from torch_geometric.nn.conv import MessagePassing -from torch_geometric.nn.dense.linear import Linear from torch_geometric.utils import add_self_loops, remove_self_loops, softmax from torch_sparse import SparseTensor, set_diag from tqdm import tqdm @@ -40,62 +37,22 @@ def transfer_pytorch_data(adata, adj): return data -def Stats_Spatial_Net(adata): - Num_edge = adata.uns['Spatial_Net']['Cell1'].shape[0] - Mean_edge = Num_edge / adata.shape[0] - plot_df = pd.value_counts(pd.value_counts(adata.uns['Spatial_Net']['Cell1'])) - plot_df = plot_df / adata.shape[0] - fig, ax = plt.subplots(figsize=[3, 2]) - plt.ylabel('Percentage') - plt.xlabel('') - plt.title('Number of Neighbors (Mean=%.2f)' % Mean_edge) - ax.bar(plot_df.index, plot_df) - - -def mclust_P(adata, num_cluster, used_obsm='STAGATE', modelNames='EEE'): - from sklearn import mixture - g = mixture.GaussianMixture(n_components=num_cluster, covariance_type='tied', warm_start=True, n_init=100, +def mclust(adata, num_cluster, used_obsm="STAGATE", modelNames="EEE"): + g = mixture.GaussianMixture(n_components=num_cluster, covariance_type="tied", warm_start=True, n_init=100, max_iter=300, reg_covar=1.4663143602030552e-04, random_state=36282, tol=0.00022187708009762592) res = g.fit_predict(adata.obsm[used_obsm]) - adata.obs['mclust'] = res + adata.obs["mclust"] = res return adata -''' -def mclust_R(adata, num_cluster, modelNames='EEE', used_obsm='STAGATE', random_seed=2020): - """\ - Clustering using the mclust algorithm. - The parameters are the same as those in the R package mclust. - """ - - np.random.seed(random_seed) - import rpy2.robjects as robjects - robjects.r.library("mclust") - - import rpy2.robjects.numpy2ri - rpy2.robjects.numpy2ri.activate() - r_random_seed = robjects.r['set.seed'] - r_random_seed(random_seed) - rmclust = robjects.r['Mclust'] - - res = rmclust(rpy2.robjects.numpy2ri.numpy2rpy(adata.obsm[used_obsm]), num_cluster, modelNames) - mclust_res = np.array(res[-2]) - - adata.obs['mclust'] = mclust_res - adata.obs['mclust'] = adata.obs['mclust'].astype('int') - adata.obs['mclust'] = adata.obs['mclust'].astype('category') - return adata -''' - - class GATConv(MessagePassing): """Graph attention layer from Graph Attention Network.""" _alpha = None def __init__(self, in_channels, out_channels, heads: int = 1, concat: bool = True, negative_slope: float = 0.2, dropout: float = 0.0, add_self_loops=True, bias=True, **kwargs): - kwargs.setdefault('aggr', 'add') + kwargs.setdefault("aggr", "add") super().__init__(node_dim=0, **kwargs) self.in_channels = in_channels @@ -126,12 +83,12 @@ def forward(self, x, edge_index, size=None, return_attention_weights=None, atten # We first transform the input node features. If a tuple is passed, we # transform source and target node features via separate weights: if isinstance(x, Tensor): - assert x.dim() == 2, "Static graphs not supported in 'GATConv'" + assert x.dim() == 2, "Static graphs not supported in GATConv" # x_src = x_dst = self.lin_src(x).view(-1, H, C) x_src = x_dst = torch.mm(x, self.lin_src).view(-1, H, C) else: # Tuple of source and target node features: x_src, x_dst = x - assert x_src.dim() == 2, "Static graphs not supported in 'GATConv'" + assert x_src.dim() == 2, "Static graphs not supported in GATConv" x_src = self.lin_src(x_src).view(-1, H, C) if x_dst is not None: x_dst = self.lin_dst(x_dst).view(-1, H, C) @@ -142,7 +99,7 @@ def forward(self, x, edge_index, size=None, return_attention_weights=None, atten return x[0].mean(dim=1) # return x[0].view(-1, self.heads * self.out_channels) - if tied_attention == None: + if tied_attention is None: # Next, we compute node-level attention coefficients, both for source # and target nodes (if present): alpha_src = (x_src * self.att_src).sum(dim=-1) @@ -180,7 +137,7 @@ def forward(self, x, edge_index, size=None, return_attention_weights=None, atten if isinstance(edge_index, Tensor): return out, (edge_index, alpha) elif isinstance(edge_index, SparseTensor): - return out, edge_index.set_value(alpha, layout='coo') + return out, edge_index.set_value(alpha, layout="coo") else: return out @@ -196,7 +153,7 @@ def message(self, x_j, alpha_j, alpha_i, index, ptr, size_i): return x_j * alpha.unsqueeze(-1) def __repr__(self): - return '{}({}, {}, heads={})'.format(self.__class__.__name__, self.in_channels, self.out_channels, self.heads) + return "{}({}, {}, heads={})".format(self.__class__.__name__, self.in_channels, self.out_channels, self.heads) class Stagate(torch.nn.Module): @@ -219,21 +176,19 @@ def __init__(self, hidden_dims): self.conv4 = GATConv(num_hidden, in_dim, heads=1, concat=False, dropout=0, add_self_loops=False, bias=False) def forward(self, features, edge_index): - """forward function for training. + """Forward function for training. Parameters ---------- features : - node features. + Node features. edge_index : - adjacent matrix. + Adjacent matrix. Returns ------- - h2 : - the second hidden layer. - h4 : - the forth hidden layer. + Tuple[Tensor, Tensor] + The second and the forth hidden layerx. """ h1 = F.elu(self.conv1(features, edge_index)) @@ -247,56 +202,50 @@ def forward(self, features, edge_index): return h2, h4 # F.log_softmax(x, dim=-1) - def fit(self, adata, graph, n_epochs=1, lr=0.001, key_added='STAGATE', gradient_clipping=5., pre_resolution=0.2, + def fit(self, adata, graph, n_epochs=1, lr=0.001, key_added="STAGATE", gradient_clipping=5., pre_resolution=0.2, weight_decay=0.0001, verbose=True, random_seed=0, save_loss=False, save_reconstrction=False, - device=torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')): - """fit function for training. + device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu")): + """Fit function for training. Parameters ---------- adata : - input data. + Input data. graph : - graph structure. + Graph structure. n_epochs : int optional - number of epochs. + Number of epochs. lr : float optional - learning rate. + Learning rate. key_added : str optional - by default 'STAGATE'. + Default "STAGATE". gradient_clipping : float optional - gradient clipping. + Gradient clipping. pre_resolution : float optional - pre resolution. + Pre-resolution. weight_decay : float optional - weight decay. + Weight decay. verbose : bool optional - verbose, by default to be True. + Verbosity, by default to be True. random_seed : int optional - random seed by default to be 0. + Random seed. save_loss : bool optional - by default to be False. + Whether to save loss or not. save_reconstrction : bool optional - by default to be False. + Whether to save reconstruction or not. device : str optional - to indicate gpu or cpu device. - - Returns - ------- - None. + Computation device. """ adata.X = sp.csr_matrix(adata.X) - if 'highly_variable' in adata.var.columns: - adata_Vars = adata[:, adata.var['highly_variable']] + if "highly_variable" in adata.var.columns: + adata_Vars = adata[:, adata.var["highly_variable"]] else: adata_Vars = adata if verbose: - print('Size of Input: ', adata_Vars.shape) - if 'Spatial_Net' not in adata.uns.keys(): - raise ValueError("Spatial_Net is not existed! Run Cal_Spatial_Net first!") + print("Size of Input: ", adata_Vars.shape) data = transfer_pytorch_data(adata_Vars, graph) @@ -320,39 +269,25 @@ def fit(self, adata, graph, n_epochs=1, lr=0.001, key_added='STAGATE', gradient_ model.eval() z, out = model(data.x, data.edge_index) - STAGATE_rep = z.to('cpu').detach().numpy() + STAGATE_rep = z.to("cpu").detach().numpy() adata.obsm[key_added] = STAGATE_rep if save_loss: - adata.uns['STAGATE_loss'] = loss + adata.uns["STAGATE_loss"] = loss if save_reconstrction: - ReX = out.to('cpu').detach().numpy() + ReX = out.to("cpu").detach().numpy() ReX[ReX < 0] = 0 - adata.layers['STAGATE_ReX'] = ReX + adata.layers["STAGATE_ReX"] = ReX print("post process...") - sc.pp.neighbors(adata, use_rep='STAGATE') + sc.pp.neighbors(adata, use_rep="STAGATE") sc.tl.umap(adata) - #adata = mclust_R(adata, used_obsm='STAGATE', num_cluster=7) - adata = mclust_P(adata, used_obsm='STAGATE', num_cluster=7) + adata = mclust(adata, used_obsm="STAGATE", num_cluster=7) self.adata = adata - def predict(self, ): - """prediction function. - - Parameters - ---------- - - Returns - ------- - self.y_pred : - predicted label. - - """ - data_dropna = self.adata.obs.dropna() - self.y_pred = data_dropna['mclust'] - self.target = data_dropna['ground_truth'] - return data_dropna['mclust'] + def predict(self): + """Prediction function.""" + return self.adata.obs["mclust"].values def score(self, y_true=None): """score function to get score of prediction. @@ -360,15 +295,13 @@ def score(self, y_true=None): Parameters ---------- y_true : - ground truth label. + Ground truth label. Returns ------- - score : float - metric eval score. + float + Adjusted rand index score. """ - from sklearn.metrics.cluster import adjusted_rand_score - score = adjusted_rand_score(self.target, self.y_pred) - print("ARI {}".format(adjusted_rand_score(self.target, self.y_pred))) + score = adjusted_rand_score(y_true, self.predict()) return score diff --git a/dance/transforms/graph/__init__.py b/dance/transforms/graph/__init__.py index 88866a91..96bdf481 100644 --- a/dance/transforms/graph/__init__.py +++ b/dance/transforms/graph/__init__.py @@ -1,7 +1,7 @@ from dance.transforms.graph.cell_feature_graph import CellFeatureGraph, PCACellFeatureGraph from dance.transforms.graph.dstg_graph import DSTGraph from dance.transforms.graph.neighbor_graph import NeighborGraph -from dance.transforms.graph.spatial_graph import SMEGraph, SpaGCNGraph, SpaGCNGraph2D +from dance.transforms.graph.spatial_graph import SMEGraph, SpaGCNGraph, SpaGCNGraph2D, StagateGraph __all__ = [ "CellFeatureGraph", @@ -11,4 +11,5 @@ "SMEGraph", "SpaGCNGraph", "SpaGCNGraph2D", + "StagateGraph", ] # yapf: disable diff --git a/dance/transforms/graph/spatial_graph.py b/dance/transforms/graph/spatial_graph.py index fa08d79b..b0308ee6 100644 --- a/dance/transforms/graph/spatial_graph.py +++ b/dance/transforms/graph/spatial_graph.py @@ -1,6 +1,7 @@ import numpy as np from sklearn.linear_model import LinearRegression from sklearn.metrics import pairwise_distances +from sklearn.neighbors import NearestNeighbors from dance.transforms.base import BaseTransform from dance.typing import Sequence @@ -103,3 +104,45 @@ def __call__(self, data): adj = adj_p * adj_m * adj_g data.data.obsp[self.out] = adj + + +class StagateGraph(BaseTransform): + """STAGATE spatial graph.""" + + _MODELS = ("radius", "knn") + _DISPLAY_ATTRS = ("model_name", "radius", "n_neighbors") + + def __init__(self, model_name: str = "radius", *, radius: float = 1, n_neighbors: int = 5, + channel: str = "spatial_pixel", channel_type: str = "obsm", **kwargs): + """Initialize StagateGraph. + + Parameters + ---------- + model_name + Type of graph to construct. Currently support `radius` and `knn`. See + :class:`~sklearn.neighbors.NearestNeighbors` for more info. + radius + Radius parameter for `radius_neighbors_graph`. + n_neighbors + Number of neighbors for `kneighbors_graph`. + + """ + super().__init__(**kwargs) + + if not isinstance(model_name, str) or (model_name.lower() not in self._MODELS): + raise ValueError(f"Unknown model {model_name!r}, available options are {self._MODELS}") + self.model_name = model_name + self.radius = radius + self.n_neighbors = n_neighbors + self.channel = channel + self.channel_type = channel_type + + def __call__(self, data): + xy_pixel = data.get_feature(return_type="numpy", channel=self.channel, channel_type=self.channel_type) + + if self.model_name.lower() == "radius": + adj = NearestNeighbors(radius=self.radius).fit(xy_pixel).radius_neighbors_graph(xy_pixel) + elif self.model_name.lower() == "knn": + adj = NearestNeighbors(n_neighbors=self.n_neighbors).fit(xy_pixel).kneighbors_graph(xy_pixel) + + data.data.obsp[self.out] = adj diff --git a/dance/transforms/graph_construct.py b/dance/transforms/graph_construct.py index 2fc08ded..55b70d0e 100644 --- a/dance/transforms/graph_construct.py +++ b/dance/transforms/graph_construct.py @@ -1218,58 +1218,3 @@ def edgeList2edgeDict(edgeList, nodesize): adj = nx.adjacency_matrix(nx.from_dict_of_lists(graphdict)) return adj, edgeList - - -############################ -# stagate # -############################ - - -def stagate_construct_graph(adata, rad_cutoff=None, k_cutoff=None, model='Radius', verbose=True): - x_pixel = pd.DataFrame(adata.obs['x_pixel']) - y_pixel = pd.DataFrame(adata.obs['y_pixel']) - coor = pd.concat([y_pixel, x_pixel], axis=1) - coor.index = adata.obs.index - coor.columns = ['imagerow', 'imagecol'] - if model == 'Radius': - nbrs = sklearn.neighbors.NearestNeighbors(radius=rad_cutoff).fit(coor) - distances, indices = nbrs.radius_neighbors(coor, return_distance=True) - KNN_list = [] - for it in range(indices.shape[0]): - KNN_list.append(pd.DataFrame(zip([it] * indices[it].shape[0], indices[it], distances[it]))) - - if model == 'KNN': - nbrs = sklearn.neighbors.NearestNeighbors(n_neighbors=k_cutoff + 1).fit(coor) - distances, indices = nbrs.kneighbors(coor) - KNN_list = [] - for it in range(indices.shape[0]): - KNN_list.append(pd.DataFrame(zip([it] * indices.shape[1], indices[it, :], distances[it, :]))) - - KNN_df = pd.concat(KNN_list) - KNN_df.columns = ['Cell1', 'Cell2', 'Distance'] - - Spatial_Net = KNN_df.copy() - Spatial_Net = Spatial_Net.loc[Spatial_Net['Distance'] > 0, ] - id_cell_trans = dict(zip( - range(coor.shape[0]), - np.array(coor.index), - )) - Spatial_Net['Cell1'] = Spatial_Net['Cell1'].map(id_cell_trans) - Spatial_Net['Cell2'] = Spatial_Net['Cell2'].map(id_cell_trans) - if verbose: - print('The graph contains %d edges, %d cells.' % (Spatial_Net.shape[0], adata.n_obs)) - print('%.4f neighbors per cell on average.' % (Spatial_Net.shape[0] / adata.n_obs)) - - adata.uns['Spatial_Net'] = Spatial_Net - - G_df = adata.uns['Spatial_Net'].copy() - cells = np.array(adata.obs_names) - cells_id_tran = dict(zip(cells, range(cells.shape[0]))) - G_df['Cell1'] = G_df['Cell1'].map(cells_id_tran) - G_df['Cell2'] = G_df['Cell2'].map(cells_id_tran) - - G = sp.coo_matrix((np.ones(G_df.shape[0]), (G_df['Cell1'], G_df['Cell2'])), shape=(adata.n_obs, adata.n_obs)) - G = G + sp.eye(G.shape[0]) - - edgeList = np.nonzero(G) - return edgeList diff --git a/dance/transforms/preprocess.py b/dance/transforms/preprocess.py index 4d36142f..15aa99ef 100644 --- a/dance/transforms/preprocess.py +++ b/dance/transforms/preprocess.py @@ -2483,128 +2483,3 @@ def inspect_data(data): data_dict['test_idx'] = test_idx return data_dict - - -def Cal_Spatial_Net(adata, rad_cutoff=None, k_cutoff=None, model='Radius', verbose=True): - """\ - Construct the spatial neighbor networks. - Parameters - ---------- - adata - AnnData object of scanpy package. - rad_cutoff - radius cutoff when model='Radius' - k_cutoff - The number of nearest neighbors when model='KNN' - model - The network construction model. When model=='Radius', the spot is connected to spots whose distance is less than rad_cutoff. When model=='KNN', the spot is connected to its first k_cutoff nearest neighbors. - - Returns - ------- - The spatial networks are saved in adata.uns['Spatial_Net'] - """ - - assert (model in ['Radius', 'KNN']) - if verbose: - print('------Calculating spatial graph...') - - # coor = pd.DataFrame(adata.obsm['spatial']) - x_pixel = pd.DataFrame(adata.obs['x_pixel']) - y_pixel = pd.DataFrame(adata.obs['y_pixel']) - coor = pd.concat([x_pixel, y_pixel], axis=1) - coor.columns = ['imagerow', 'imagecol'] - - if model == 'Radius': - nbrs = sklearn.neighbors.NearestNeighbors(radius=rad_cutoff).fit(coor) - distances, indices = nbrs.radius_neighbors(coor, return_distance=True) - KNN_list = [] - for it in range(indices.shape[0]): - KNN_list.append(pd.DataFrame(zip([it] * indices[it].shape[0], indices[it], distances[it]))) - - if model == 'KNN': - nbrs = sklearn.neighbors.NearestNeighbors(n_neighbors=k_cutoff + 1).fit(coor) - distances, indices = nbrs.kneighbors(coor) - KNN_list = [] - for it in range(indices.shape[0]): - KNN_list.append(pd.DataFrame(zip([it] * indices.shape[1], indices[it, :], distances[it, :]))) - - KNN_df = pd.concat(KNN_list) - KNN_df.columns = ['Cell1', 'Cell2', 'Distance'] - - Spatial_Net = KNN_df.copy() - Spatial_Net = Spatial_Net.loc[Spatial_Net['Distance'] > 0, ] - id_cell_trans = dict(zip( - range(coor.shape[0]), - np.array(coor.index), - )) - Spatial_Net['Cell1'] = Spatial_Net['Cell1'].map(id_cell_trans) - Spatial_Net['Cell2'] = Spatial_Net['Cell2'].map(id_cell_trans) - if verbose: - print('The graph contains %d edges, %d cells.' % (Spatial_Net.shape[0], adata.n_obs)) - print('%.4f neighbors per cell on average.' % (Spatial_Net.shape[0] / adata.n_obs)) - - adata.uns['Spatial_Net'] = Spatial_Net - - -def Cal_Spatial_Net_3D(adata, rad_cutoff_2D, rad_cutoff_Zaxis, key_section='Section_id', section_order=None, - verbose=True): - """\ - Construct the spatial neighbor networks. - Parameters - ---------- - adata - AnnData object of scanpy package. - rad_cutoff_2D - radius cutoff for 2D SNN construction. - rad_cutoff_Zaxis - radius cutoff for 2D SNN construction for consturcting SNNs between adjacent sections. - key_section - The columns names of section_ID in adata.obs. - section_order - The order of sections. The SNNs between adjacent sections are constructed according to this order. - - Returns - ------- - The 3D spatial networks are saved in adata.uns['Spatial_Net']. - """ - adata.uns['Spatial_Net_2D'] = pd.DataFrame() - adata.uns['Spatial_Net_Zaxis'] = pd.DataFrame() - num_section = np.unique(adata.obs[key_section]).shape[0] - if verbose: - print('Radius used for 2D SNN:', rad_cutoff_2D) - print('Radius used for SNN between sections:', rad_cutoff_Zaxis) - for temp_section in np.unique(adata.obs[key_section]): - if verbose: - print('------Calculating 2D SNN of section ', temp_section) - temp_adata = adata[adata.obs[key_section] == temp_section, ] - Cal_Spatial_Net(temp_adata, rad_cutoff=rad_cutoff_2D, verbose=False) - temp_adata.uns['Spatial_Net']['SNN'] = temp_section - if verbose: - print('This graph contains %d edges, %d cells.' % - (temp_adata.uns['Spatial_Net'].shape[0], temp_adata.n_obs)) - print('%.4f neighbors per cell on average.' % (temp_adata.uns['Spatial_Net'].shape[0] / temp_adata.n_obs)) - adata.uns['Spatial_Net_2D'] = pd.concat([adata.uns['Spatial_Net_2D'], temp_adata.uns['Spatial_Net']]) - for it in range(num_section - 1): - section_1 = section_order[it] - section_2 = section_order[it + 1] - if verbose: - print('------Calculating SNN between adjacent section {} and {}.'.format(section_1, section_2)) - Z_Net_ID = section_1 + '-' + section_2 - temp_adata = adata[adata.obs[key_section].isin([section_1, section_2]), ] - Cal_Spatial_Net(temp_adata, rad_cutoff=rad_cutoff_Zaxis, verbose=False) - spot_section_trans = dict(zip(temp_adata.obs.index, temp_adata.obs[key_section])) - temp_adata.uns['Spatial_Net']['Section_id_1'] = temp_adata.uns['Spatial_Net']['Cell1'].map(spot_section_trans) - temp_adata.uns['Spatial_Net']['Section_id_2'] = temp_adata.uns['Spatial_Net']['Cell2'].map(spot_section_trans) - used_edge = temp_adata.uns['Spatial_Net'].apply(lambda x: x['Section_id_1'] != x['Section_id_2'], axis=1) - temp_adata.uns['Spatial_Net'] = temp_adata.uns['Spatial_Net'].loc[used_edge, ] - temp_adata.uns['Spatial_Net'] = temp_adata.uns['Spatial_Net'].loc[:, ['Cell1', 'Cell2', 'Distance']] - temp_adata.uns['Spatial_Net']['SNN'] = Z_Net_ID - if verbose: - print('This graph contains %d edges, %d cells.' % - (temp_adata.uns['Spatial_Net'].shape[0], temp_adata.n_obs)) - print('%.4f neighbors per cell on average.' % (temp_adata.uns['Spatial_Net'].shape[0] / temp_adata.n_obs)) - adata.uns['Spatial_Net_Zaxis'] = pd.concat([adata.uns['Spatial_Net_Zaxis'], temp_adata.uns['Spatial_Net']]) - adata.uns['Spatial_Net'] = pd.concat([adata.uns['Spatial_Net_2D'], adata.uns['Spatial_Net_Zaxis']]) - if verbose: - print('3D SNN contains %d edges, %d cells.' % (adata.uns['Spatial_Net'].shape[0], adata.n_obs)) - print('%.4f neighbors per cell on average.' % (adata.uns['Spatial_Net'].shape[0] / adata.n_obs)) diff --git a/docs/source/conf.py b/docs/source/conf.py index 10673faf..68d6a784 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -35,6 +35,7 @@ intersphinx_mapping = { 'mudata': ('https://mudata.readthedocs.io/en/stable/', None), 'scanpy': ('https://scanpy.readthedocs.io/en/stable/', None), + 'sklearn': ('https://scikit-learn.org/stable', None), } templates_path = ['_templates'] diff --git a/examples/spatial/spatial_domain/stagate.py b/examples/spatial/spatial_domain/stagate.py index 1f2abba6..b8158023 100644 --- a/examples/spatial/spatial_domain/stagate.py +++ b/examples/spatial/spatial_domain/stagate.py @@ -1,19 +1,20 @@ import argparse +import numpy as np import scanpy as sc +from dance.data import Data from dance.datasets.spatial import SpotDataset from dance.modules.spatial.spatial_domain.stagate import Stagate -from dance.transforms.graph_construct import construct_graph, stagate_construct_graph -from dance.transforms.preprocess import (log1p, normalize, normalize_total, prefilter_cells, prefilter_genes, - prefilter_specialgenes, set_seed) +from dance.transforms import AnnDataTransform +from dance.transforms.graph import StagateGraph +from dance.transforms.preprocess import set_seed -if __name__ == '__main__': +if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument( - "--sample_number", type=str, default="151673", - help="12 samples of human dorsolateral prefrontal cortex dataset supported in the task of spatial domain task.") + parser.add_argument("--sample_number", type=str, default="151673", + help="12 human dorsolateral prefrontal cortex datasets for the spatial domain task.") parser.add_argument("--hidden_dims", type=list, default=[512, 32], help="hidden dimensions") parser.add_argument("--rad_cutoff", type=int, default=150, help="") parser.add_argument("--seed", type=int, default=3, help="") @@ -22,29 +23,32 @@ args = parser.parse_args() set_seed(args.seed) - # from dance.modules.spatial.spatial_domain.stagan import Stagate - # get data - dataset = SpotDataset(args.sample_number, data_dir="../../../data/spot") - ## dataset.data has repeat name , be careful - - # preprocess data - dataset.data.var_names_make_unique() - sc.pp.highly_variable_genes(dataset.data, flavor="seurat_v3", n_top_genes=args.high_variable_genes) - normalize_total(dataset.data) - log1p(dataset.data) + # Load raw data + dataset = SpotDataset(args.sample_number, data_dir="../../../data/spot") + _, adata, _, spatial_pixel, label = dataset.load_data() - dataset.adj = stagate_construct_graph(dataset.data, rad_cutoff=args.rad_cutoff) + # Construct dance data object + adata.var_names_make_unique() + adata.obsm["spatial_pixel"] = spatial_pixel + adata.obsm["label"] = label + data = Data(adata, train_size="all") - hidden_dims = args.hidden_dims + # Data preprocessing pipeline + AnnDataTransform(sc.pp.highly_variable_genes, flavor="seurat_v3", n_top_genes=args.high_variable_genes)(data) + AnnDataTransform(sc.pp.normalize_total, target_sum=1e4)(data) + AnnDataTransform(sc.pp.log1p)(data) - hidden_dims = [args.high_variable_genes] + hidden_dims + # Construct cell graph + StagateGraph("radius", radius=args.rad_cutoff)(data) + data.set_config(feature_channel="StagateGraph", feature_channel_type="obsp", label_channel="label") + adj, y = data.get_data(return_type="default") - model = Stagate(hidden_dims) - model.fit(dataset.data, dataset.adj, n_epochs=args.n_epochs) + model = Stagate([args.high_variable_genes] + args.hidden_dims) + model.fit(data.data, np.nonzero(adj), n_epochs=args.n_epochs) predict = model.predict() - curr_ARI = model.score() - print(curr_ARI) + score = model.score(y.values.ravel()) + print(f"ARI: {score:.4f}") """ To reproduce Stagate on other samples, please refer to command lines belows: NOTE: since the stagate method is unstable, you have to run at least 5 times to get best performance. (same with original Stagate paper)