diff --git a/dance/data/base.py b/dance/data/base.py index 995a0b6f..6a073b5c 100644 --- a/dance/data/base.py +++ b/dance/data/base.py @@ -5,7 +5,7 @@ import torch from dance import logger -from dance.typing import Any, Dict, FeatType, List, Optional, Sequence, Tuple +from dance.typing import Any, Dict, FeatType, List, Literal, Optional, Sequence, Tuple class BaseData(ABC): @@ -122,6 +122,9 @@ def val_idx(self) -> Sequence[int]: def test_idx(self) -> Sequence[int]: return self.get_split_idx("test", error_on_miss=False) + def shape(self) -> Tuple[int, int]: + return self.data.shape + def copy(self): return deepcopy(self) @@ -156,6 +159,53 @@ def get_split_idx(self, split_name: str, error_on_miss: bool = False): else: return None + def get_feature(self, *, return_type: FeatType = "numpy", channel: Optional[str] = None, + channel_type: Literal["obs", "var"] = "obs", layer: Optional[str] = None, + mod: Optional[str] = None): # yapf: disable + # Pick modality + if mod is None: + data = self.data + elif not hasattr(self.data, "mod"): + raise AttributeError("`mod` option is only available when using multimodality data.") + elif mod not in self.mod: + raise KeyError(f"Unknown modality {mod!r}, available options are {sorted(self.mod)}") + else: + data = self.data.mod[mod] + + # Pick channels - obsm or varm + if channel_type == "obs": + channels = data.obsm + elif channel_type == "var": + channels = data.varm + else: + raise ValueError(f"Unknown channel type {channel_type!r}") + + # Pick specific channl + if (channel is not None) and (layer is not None): + raise ValueError(f"Cannot specify feature layer ({layer!r}) and channel ({channel!r}) simmultaneously.") + elif channel is not None: + feature = channels[channel] + elif layer is not None: + feature = data.layers[layer].X + else: + feature = data.X + + if return_type == "default": + return feature + + # Transform features to other data types + if hasattr(feature, "toarray"): # convert sparse array to dense numpy array + feature = feature.toarray() + elif hasattr(feature, "to_numpy"): # convert dataframe to numpy array + feature = feature.to_numpy() + + if return_type == "torch": + feature = torch.from_numpy(feature) + elif return_type != "numpy": + raise ValueError(f"Unknown return_type {return_type!r}") + + return feature + def _get_data(self, name: str, split_name: Optional[str], return_type: FeatType = "numpy", **kwargs): out = getattr(self, name) diff --git a/dance/datasets/singlemodality.py b/dance/datasets/singlemodality.py index 9d5cfb28..d5d4793f 100644 --- a/dance/datasets/singlemodality.py +++ b/dance/datasets/singlemodality.py @@ -1,4 +1,3 @@ -import argparse import glob import os import os.path as osp @@ -14,9 +13,8 @@ from torch.utils.data import Dataset from dance.data import download_file, download_unzip -from dance.transforms.preprocess import (get_map_dict, load_actinn_data, load_annotation_data_internal, - load_annotation_test_data, load_imputation_data_internal, load_svm_data, - prepare_data_celltypist, splitCommonAnnData) +from dance.transforms.preprocess import (get_map_dict, load_actinn_data, load_annotation_data, + load_imputation_data_internal, load_svm_data, splitCommonAnnData) @dataclass @@ -314,44 +312,14 @@ def is_singlecellnet_complete(self): def load_data(self): # Load data from existing h5ad files, or download files and load data. if self.data_type == "scdeepsort" or self.data_type == "scdeepsort_exp": - if self.is_complete(): - pass - else: + if not self.is_complete(): if self.data_type == "scdeepsort": self.download_all_data() if self.data_type == "scdeepsort_exp": self.download_benchmark_data() assert self.is_complete() - ( - self.num_cells, - self.num_genes, - self.num_labels, - self.graph, - self.train_ids, - self.test_ids, - self.labels, - ) = load_annotation_data_internal(self.params) - - if self.params.score: - ( - self.total_cell_test, - self.num_genes_test, - self.num_labels_test, - self.id2label_test, - self.test_dict, - self.test_label_dict, - self.time_used_test, - ) = load_annotation_test_data(self.params) - else: - ( - self.total_cell_test, - self.num_genes_test, - self.num_labels_test, - self.id2label_test, - self.test_dict, - self.time_used_test, - ) = load_annotation_test_data(self.params) + return load_annotation_data(self.params) if self.data_type == "svm": if self.is_complete(): diff --git a/dance/modules/single_modality/cell_type_annotation/scdeepsort.py b/dance/modules/single_modality/cell_type_annotation/scdeepsort.py index 75fe9c0e..3b82e7ef 100644 --- a/dance/modules/single_modality/cell_type_annotation/scdeepsort.py +++ b/dance/modules/single_modality/cell_type_annotation/scdeepsort.py @@ -8,14 +8,14 @@ """ import os import time +from copy import deepcopy from pathlib import Path import dgl.function as fn -import numpy as np -import pandas as pd import torch import torch.nn as nn from dgl.dataloading import DataLoader, NeighborSampler +from sklearn.metrics import accuracy_score DEBUG = os.environ.get("DANCE_DEBUG") @@ -77,18 +77,18 @@ def message_func(self, edges): """ number_of_edges = edges.src["h"].shape[0] - indices = np.expand_dims(np.array([self.gene_num + 1] * number_of_edges, dtype=np.int32), axis=1) - src_id, dst_id = edges.src["id"].cpu().numpy(), edges.dst["id"].cpu().numpy() - indices = np.where((src_id >= 0) & (dst_id < 0), src_id, indices) # gene->cell - indices = np.where((dst_id >= 0) & (src_id < 0), dst_id, indices) # cell->gene - indices = np.where((dst_id >= 0) & (src_id >= 0), self.gene_num, indices) # gene-gene + src_id, dst_id = edges.src["id"], edges.dst["id"] + indices = (self.gene_num + 1) * torch.ones(number_of_edges, dtype=torch.long, device=src_id.device) + indices = torch.where((src_id >= 0) & (dst_id < 0), src_id, indices) # gene->cell + indices = torch.where((dst_id >= 0) & (src_id < 0), dst_id, indices) # cell->gene + indices = torch.where((dst_id >= 0) & (src_id >= 0), self.gene_num, indices) # gene-gene if DEBUG: print( f"{((src_id >= 0) & (dst_id < 0)).sum():>10,} (geen->cell), " f"{((src_id < 0) & (dst_id >= 0)).sum():>10,} (cell->gene), " f"{((src_id >= 0) & (dst_id >= 0)).sum():>10,} (self-gene), " f"{((src_id < 0) & (dst_id < 0)).sum():>10,} (self-cell), ", ) - h = edges.src["h"] * self.alpha[indices.squeeze()] + h = edges.src["h"] * self.alpha[indices] return {"m": h * edges.data["weight"]} def forward(self, block, h): @@ -165,41 +165,54 @@ def forward(self, blocks, x): class ScDeepSort: - """The ScDeepSort cell-type annotation model. + """The ScDeepSort cell-type annotation model.""" - Parameters - ---------- - params : argparse.Namespace - A Namespace contains arguments of Scdeepsort. See parser documnetation for more details. + def __init__(self, dim_in: int, dim_out: int, dim_hid: int, num_layers: int, species: str, tissue: str, *, + dropout: int = 0, batch_size: int = 500, device: str = "cpu"): + """Initialize the scDeepSort object. - """ + Parameters + ---------- + dim_in + Input dimension, i.e., the number of PCA used for cell and gene features. + dim_out + Output dimension, i.e., the number of possible cell-types. + dim_hid + Hidden dimension. + num_layers + Number of convolution layers. + species + Species name (only used for determining the read/write path). + tissue + Tissue name (only used for determining the read/write path). + dropout + Drop-out rate. + batch_size + Batch size. + device + Computation device, e.g., 'cpu', 'cuda'. + + """ + self.dense_dim = dim_in + self.hidden_dim = dim_hid + self.n_layers = num_layers + self.dropout = dropout + self.species = species + self.tissue = tissue + self.batch_size = batch_size + self.device = device - def __init__(self, params): - self.params = params self.postfix = time.strftime("%d_%m_%Y") + "_" + time.strftime("%H:%M:%S") self.prj_path = Path(__file__).resolve().parents[4] - # TODO: change the prefix from `example` to `saved_models` - self.save_path = (self.prj_path / "example" / "single_modality" / "cell_type_annotation" / "pretrained" / - self.params.species / "models") + self.save_path = (self.prj_path / "saved_models" / "single_modality" / "cell_type_annotation" / "pretrained" / + self.species / "models") if not self.save_path.exists(): self.save_path.mkdir(parents=True) - self.device = torch.device("cpu" if self.params.gpu == -1 else f"cuda:{params.gpu}") - - # Define the variables during training - self.num_cells = None - self.num_genes = None - self.num_labels = None - self.graph = None - self.train_ids = None - self.test_ids = None - self.labels = None - - # Define the variables during prediction - self.id2label = None - self.test_dict = None - - def fit(self, num_cells, num_genes, num_labels, graph, train_ids, test_ids, labels): + + self.num_labels = dim_out + + def fit(self, graph, labels, epochs=300, lr=1e-3, weight_decay=0, val_ratio=0.2): """Train scDeepsort model. Parameters @@ -220,64 +233,76 @@ def fit(self, num_cells, num_genes, num_labels, graph, train_ids, test_ids, labe Node (cell, gene) labels, -1 for genes. """ - self.num_cells = num_cells - self.num_genes = num_genes - self.num_labels = num_labels - - self.train_ids = train_ids.to(self.device) - self.test_ids = test_ids.to(self.device) - self.labels = labels.to(self.device) - self.graph = graph.to(self.device) - self.graph.ndata["label"] = self.labels - - self.model = GNN(self.params.dense_dim, self.num_labels, self.params.hidden_dim, self.params.n_layers, - self.num_genes, activation=nn.ReLU(), dropout=self.params.dropout).to(self.device) + gene_mask = graph.ndata["id"] != -1 + cell_mask = graph.ndata["id"] == -1 + num_genes = gene_mask.sum() + num_cells = cell_mask.sum() + + perm = torch.randperm(num_cells) + num_genes + num_val = int(num_cells * val_ratio) + val_idx = perm[:num_val].to(self.device) + train_idx = perm[num_val:].to(self.device) + + full_labels = -torch.ones(num_genes + num_cells, dtype=torch.long) + full_labels[-num_cells:] = labels + graph = graph.to(self.device) + graph.ndata["label"] = full_labels.to(self.device) + + self.model = GNN(self.dense_dim, self.num_labels, self.hidden_dim, self.n_layers, num_genes, + activation=nn.ReLU(), dropout=self.dropout).to(self.device) print(self.model) - self.sampler = NeighborSampler(fanouts=[-1] * self.params.n_layers, edge_dir="in") - self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.params.lr, - weight_decay=self.params.weight_decay) + self.sampler = NeighborSampler(fanouts=[-1] * self.n_layers, edge_dir="in") + self.optimizer = torch.optim.Adam(self.model.parameters(), lr=lr, weight_decay=weight_decay) self.loss_fn = nn.CrossEntropyLoss(reduction="sum") - if self.params.num_neighbors == 0: - self.num_neighbors = self.num_cells + self.num_genes - else: - self.num_neighbors = self.params.num_neighbors - - print(f"Train Number: {len(self.train_ids)}, Test Number: {len(self.test_ids)}") - max_test_acc, _train_acc, _epoch = 0, 0, 0 - for epoch in range(self.params.n_epochs): - loss = self.cal_loss() - train_acc = self.evaluate(self.train_ids)[-1] - test_correct, test_unsure, test_acc = self.evaluate(self.test_ids) - if max_test_acc <= test_acc: - final_test_correct_num = test_correct - final_test_unsure_num = test_unsure + + print(f"Train Number: {len(train_idx)}, Val Number: {len(val_idx)}") + max_val_acc, _train_acc, _epoch = 0, 0, 0 + best_state_dict = None + for epoch in range(epochs): + loss = self.cal_loss(graph, train_idx) + train_acc = self.evaluate(graph, train_idx)[-1] + val_correct, val_unsure, val_acc = self.evaluate(graph, val_idx) + if max_val_acc <= val_acc: + final_val_correct_num = val_correct + final_val_unsure_num = val_unsure _train_acc = train_acc _epoch = epoch - max_test_acc = test_acc + max_val_acc = val_acc self.save_model() - print(f">>>>Epoch {epoch:04d}: Train Acc {train_acc:.4f}, Loss {loss / len(self.train_ids):.4f}, " - f"Test correct {test_correct}, Test unsure {test_unsure}, Test Acc {test_acc:.4f}") + best_state_dict = deepcopy(self.model.state_dict()) + print(f">>>>Epoch {epoch:04d}: Train Acc {train_acc:.4f}, Loss {loss / len(train_idx):.4f}, " + f"Val correct {val_correct}, Val unsure {val_unsure}, Val Acc {val_acc:.4f}") - print(f"---{self.params.species} {self.params.tissue} Best test result:---") - print(f"Epoch {_epoch:04d}, Train Acc {_train_acc:.4f}, Test Correct Num {final_test_correct_num}, " - f"Test Total Num {len(self.test_ids)}, Test Unsure Num {final_test_unsure_num}, " - f"Test Acc {final_test_correct_num / len(self.test_ids):.4f}") + if best_state_dict is not None: + self.model.load_state_dict(best_state_dict) - def cal_loss(self): + print(f"---{self.species} {self.tissue} Best val result:---") + print(f"Epoch {_epoch:04d}, Train Acc {_train_acc:.4f}, Val Correct Num {final_val_correct_num}, " + f"Val Total Num {len(val_idx)}, Val Unsure Num {final_val_unsure_num}, " + f"Val Acc {final_val_correct_num / len(val_idx):.4f}") + + def cal_loss(self, graph, idx): """Calculate loss. + Parameters + ---------- + graph + Input cell-gene graph object. + idx + 1-D tensor containing the indexes of the cell nodes to calculate the loss. + Returns ------- float - Loss function value. + Averaged loss over all batches. """ self.model.train() - total_loss = 0 + total_loss = total_size = 0 - dataloader = DataLoader(graph=self.graph, indices=self.train_ids, graph_sampler=self.sampler, - batch_size=self.params.batch_size, shuffle=True) + dataloader = DataLoader(graph=graph, indices=idx, graph_sampler=self.sampler, batch_size=self.batch_size, + shuffle=True) for _, _, blocks in dataloader: blocks = [b.to(self.device) for b in blocks] input_features = blocks[0].srcdata["features"] @@ -289,18 +314,19 @@ def cal_loss(self): loss.backward() self.optimizer.step() - total_loss += loss.item() + total_size += (size := blocks[-1].num_dst_nodes()) + total_loss += loss.item() * size - return total_loss + return total_loss / total_size @torch.no_grad() - def evaluate(self, ids): - """Evaluate the trained scDeepsort model. + def evaluate(self, graph, idx, unsure_rate: float = 2.0): + """Evaluate the model on certain cell nodes. Parameters ---------- - ids : Tensor - A 1-D tensor containing node IDs to be evaluated on. + idx + 1-D tensor containing the indexes of the cell nodes to be evaluated. Returns ------- @@ -309,10 +335,10 @@ def evaluate(self, ids): """ self.model.eval() - total_correct, total_unsure = 0, 0 + total_correct = total_unsure = 0 - dataloader = DataLoader(graph=self.graph, indices=ids, graph_sampler=self.sampler, - batch_size=self.params.batch_size, shuffle=True) + dataloader = DataLoader(graph=graph, indices=idx, graph_sampler=self.sampler, batch_size=self.batch_size, + shuffle=True) for _, _, blocks in dataloader: blocks = [b.to(self.device) for b in blocks] input_features = blocks[0].srcdata["features"] @@ -321,156 +347,96 @@ def evaluate(self, ids): for pred, label in zip(output_predictions.cpu(), output_labels.cpu()): max_prob = pred.max().item() - if max_prob < self.params.unsure_rate / self.num_labels: + if max_prob < unsure_rate / self.num_labels: total_unsure += 1 elif pred.argmax().item() == label: total_correct += 1 - return total_correct, total_unsure, total_correct / len(ids) + return total_correct, total_unsure, total_correct / len(idx) def save_model(self): """Save the model at the save_path.""" state = {"model": self.model.state_dict(), "optimizer": self.optimizer.state_dict()} - torch.save(state, self.save_path / f"{self.params.species}-{self.params.tissue}.pt") + torch.save(state, self.save_path / f"{self.species}-{self.tissue}.pt") def load_model(self): """Load the model from the model path.""" - filename = f"{self.params.species}-{self.params.tissue}.pt" - model_path = self.prj_path / "pretrained" / self.params.species / "models" / filename + filename = f"{self.species}-{self.tissue}.pt" + model_path = self.prj_path / "pretrained" / self.species / "models" / filename state = torch.load(model_path, map_location=self.device) self.model.load_state_dict(state["model"]) @torch.no_grad() - def inference(self, num): + def predict_proba(self, graph): """Perform inference on a test dataset. Parameters ---------- - num : int - Test dataset number. + graph + Input cell-gene graph to be predicted. Returns ------- - list - Predicted labels. + np.ndarray + 2-D array of predicted probabilities of the cell-types, where rows are cells and columns are cell-types. """ self.model.eval() - unsure_threshold = self.params.unsure_rate / self.num_labels - - graph = self.test_dict["graph"][num].to(self.device) - test_indices = self.test_dict["nid"][num].to(self.device) - new_logits = torch.zeros((graph.number_of_nodes(), self.num_labels)) + cell_mask = graph.ndata["id"] == -1 + idx = torch.where(cell_mask)[0].to(self.device) + graph = graph.to(self.device) - dataloader = DataLoader(graph=graph, indices=test_indices, graph_sampler=self.sampler, - batch_size=self.params.batch_size, shuffle=True) + logits = torch.zeros(graph.number_of_nodes(), self.num_labels) + dataloader = DataLoader(graph=graph, indices=idx, graph_sampler=self.sampler, batch_size=self.batch_size) for _, output_nodes, blocks in dataloader: - blocks = [b.to(torch.device(self.device)) for b in blocks] + blocks = [b.to(self.device) for b in blocks] input_features = blocks[0].srcdata["features"] - new_logits[output_nodes] = self.model(blocks, input_features).detach().cpu() + logits[output_nodes] = self.model(blocks, input_features).detach().cpu() - new_logits = new_logits[self.test_dict["mask"][num]] - new_logits = nn.functional.softmax(new_logits, dim=1).numpy() - predict_label = [] - for pred in new_logits: - pred_label = self.id2label[pred.argmax().item()] - predict_label.append("unsure" if pred.max().item() < unsure_threshold else pred_label) - return predict_label + pred_prob = nn.functional.softmax(logits[cell_mask], dim=-1).numpy() + return pred_prob - def predict(self, id2label, test_dict): + def predict(self, graph, unsure_rate: float = 2.0): """Perform prediction on all test datasets. Parameters ---------- - id2label : np.ndarray - Id to label diction. - test_dict : dict - The test dictionary. + graph + Input cell-gene grahp to be predicted. + unsure_rate + Determine the threshold of the maximum predicted probability under which the predictions are considered + uncertain. Returns ------- - dict - A dictionary where the keys are the test dataset IDs and the values are the corresponding predictions. """ - self.id2label = id2label - self.test_dict = test_dict - return {num: self.inference(num) for num in self.params.test_dataset} + pred_prob = self.predict_proba(graph) - @torch.no_grad() - def score(self, predicted_labels, true_labels): + pred = pred_prob.argmax(1) + unsure = pred_prob.max(1) < unsure_rate / self.num_labels + + return pred, unsure + + def score(self, pred, true): """Compute model performance on test datasets based on accuracy. Parameters ---------- - predicted_labels : dict - A dictionary where the keys are test dataset IDs and the values are the predicted labels. - true_labels : dict - A dictionary where the keys are test dataset IDs and the values are the true labels of the cells. Each - element, i.e., the label, can be either a specific value (e.g., string or intger) or a set of values, - allowing multiple mappings. + pred + Predicted cell-labels as a 1-d numpy array. + true + True cell-labels (could contain multiple cell-type per cell). Returns ------- - dict - A diction of correct prediction numbers, total samples, unsured prediction numbers, and accuracy. - - """ - output_score = {} - for num in set(predicted_labels) & set(true_labels): - total_num = len(predicted_labels[num]) - unsure_num = correct = 0 - for pred, true in zip(predicted_labels[num], true_labels[num]): - if pred == "unsure": - unsure_num += 1 - elif pred == true or pred in true: # either a single mapping or multiple mappings - correct += 1 - - output_score[num] = { - "Total number of predictions": total_num, - "Number of correct predictions": correct, - "Number of unsure predictions": unsure_num, - "Accuracy": correct / total_num, - } - - return output_score - - def save_pred(self, num, pred): - """Save predictions for a particular test dataset. - - Parameters - ---------- - num : int - Test file number. - pred : list np.array or dataframe - Predicted labels. + float + Accuracy score of the prediction """ - label_map = pd.read_excel("./map/celltype2subtype.xlsx", sheet_name=self.params.species, header=0, - names=["species", "old_type", "new_type", "new_subtype"]) - label_map = label_map.fillna("N/A", inplace=False) - oldtype2newtype = {} - oldtype2newsubtype = {} - for _, old_type, new_type, new_subtype in label_map.itertuples(index=False): - oldtype2newtype[old_type] = new_type - oldtype2newsubtype[old_type] = new_subtype - - save_path = self.prj_path / self.params.save_dir - if not save_path.exists(): - save_path.mkdir() - if self.params.score: - df = pd.DataFrame({ - "index": self.test_dict["origin_id"][num], - "original label": self.test_dict["label"][num], - "cell_type": [oldtype2newtype.get(p, p) for p in pred], - "cell_subtype": [oldtype2newsubtype.get(p, p) for p in pred] - }) + if true.max() == 1: + num_samples = true.shape[0] + return (true[range(num_samples), pred.ravel()]).sum() / num_samples else: - df = pd.DataFrame({ - "index": self.test_dict["origin_id"][num], - "cell_type": [oldtype2newtype.get(p, p) for p in pred], - "cell_subtype": [oldtype2newsubtype.get(p, p) for p in pred] - }) - df.to_csv(save_path / (self.params.species + f"_{self.params.tissue}_{num}.csv"), index=False) - print(f"output has been stored in {self.params.species}_{self.params.tissue}_{num}.csv") + return accuracy_score(pred, true) diff --git a/dance/transforms/base.py b/dance/transforms/base.py index 17a7531c..33c611ab 100644 --- a/dance/transforms/base.py +++ b/dance/transforms/base.py @@ -22,6 +22,7 @@ def __init__(self, out: Optional[str] = None, log_level: LogLevel = "WARNING"): self.logger = logger.getChild(self.name) self.logger.setLevel(log_level) + self.log_level = log_level @abstractmethod def __call__(self, data: Any) -> Any: diff --git a/dance/transforms/cell_feature.py b/dance/transforms/cell_feature.py index 040454be..5179f50e 100644 --- a/dance/transforms/cell_feature.py +++ b/dance/transforms/cell_feature.py @@ -42,5 +42,6 @@ def __call__(self, data: BaseData) -> BaseData: x = data.get_x() cell_feat = normalize(x, mode="normalize", axis=1) @ gene_feat data.data.obsm[self.out] = cell_feat.astype(np.float32) + data.data.varm[self.out] = gene_feat.astype(np.float32) return data diff --git a/dance/transforms/graph.py b/dance/transforms/graph.py new file mode 100644 index 00000000..9312beb3 --- /dev/null +++ b/dance/transforms/graph.py @@ -0,0 +1,81 @@ +import dgl +import numpy as np +import torch + +from dance.transforms.base import BaseTransform +from dance.transforms.cell_feature import WeightedGenePCA +from dance.typing import LogLevel, Optional + + +class CellGeneGraph(BaseTransform): + + def __init__(self, cell_feature_channel: str, gene_feature_channel: Optional[str] = None, *, + layer: Optional[str] = None, mod: Optional[str] = None, **kwargs): + super().__init__(**kwargs) + + self.cell_feature_channel = cell_feature_channel + self.gene_feature_channel = gene_feature_channel or cell_feature_channel + self.layer = layer + self.mod = mod + + def __call__(self, data): + feat = data.get_feature(return_type="default", layer=self.layer, mod=self.mod) + num_cells, num_feats = feat.shape + + row, col = np.nonzero(feat) + edata = np.array(feat[row, col])[:, None] + self.logger.info(f"Number of nonzero entries: {edata.size:,}") + self.logger.info(f"Nonzero rate = {edata.size / num_cells / num_feats:.1%}") + + row = row + num_feats # offset by feature nodes + col, row = np.hstack((col, row)), np.hstack((row, col)) # convert to undirected + edata = np.vstack((edata, edata)) + + # Convert to tensors + col = torch.LongTensor(col) + row = torch.LongTensor(row) + edata = torch.FloatTensor(edata) + + # Initialize cell-gene graph + g = dgl.graph((row, col)) + g.edata["weight"] = edata + g.ndata["id"] = torch.hstack((torch.arange(num_feats, dtype=torch.int32), + -torch.ones(num_cells, dtype=torch.int32))) # yapf: disable + + # Normalize edges and add self-loop + in_deg = g.in_degrees() + for i in range(g.number_of_nodes()): + src, dst, eidx = g.in_edges(i, form="all") + if src.shape[0] > 0: + edge_w = g.edata["weight"][eidx] + g.edata["weight"][eidx] = in_deg[i] * edge_w / edge_w.sum() + g.add_edges(g.nodes(), g.nodes(), {"weight": torch.ones(g.number_of_nodes())[:, None]}) + + gene_feature = data.get_feature(return_type="torch", channel=self.gene_feature_channel, mod=self.mod, + channel_type="var") + cell_feature = data.get_feature(return_type="torch", channel=self.cell_feature_channel, mod=self.mod, + channel_type="obs") + g.ndata["features"] = torch.vstack((gene_feature, cell_feature)) + + data.data.uns[self.out] = g + + return data + + +class PCACellGeneGraph(BaseTransform): + + def __init__(self, n_components: int = 400, split_name: Optional[str] = None, *, layer: Optional[str] = None, + mod: Optional[str] = None, log_level: LogLevel = "WARNING"): + super().__init__(log_level=log_level) + + self.n_components = n_components + self.split_name = split_name + + self.layer = layer + self.mod = mod + + def __call__(self, data): + WeightedGenePCA(self.n_components, self.split_name, log_level=self.log_level)(data) + CellGeneGraph(cell_feature_channel="WeightedGenePCA", layer=self.layer, mod=self.mod, + log_level=self.log_level)(data) + return data diff --git a/dance/transforms/preprocess.py b/dance/transforms/preprocess.py index e8169b63..e6ed8d56 100644 --- a/dance/transforms/preprocess.py +++ b/dance/transforms/preprocess.py @@ -1,4 +1,5 @@ import collections +import itertools import os import pprint import random @@ -369,180 +370,6 @@ def get_id_to_label(cell_statistics_path): return id2label -def load_annotation_test_data(params): - random_seed = params.random_seed - dense_dim = params.dense_dim - test = params.test_dataset - tissue = params.tissue - - proj_path = Path(params.proj_path) - species_data_path = proj_path / 'pretrained' / params.species - statistics_path = species_data_path / 'statistics' - - if params.score: - map_path = proj_path / 'map' / params.species - map_dict = get_map_dict(map_path, tissue) - - if not statistics_path.exists(): - statistics_path.mkdir() - - gene_statistics_path = statistics_path / (tissue + '_genes.txt') # train+test gene - cell_statistics_path = statistics_path / (tissue + '_cell_type.txt') # train labels - - # generate gene statistics file - id2gene = get_id_to_gene(gene_statistics_path) - # generate cell label statistics file - id2label = np.array(get_id_to_label(cell_statistics_path), dtype=str) - - test_num = 0 - # prepare unified genes - gene2id = {gene: idx for idx, gene in enumerate(id2gene)} - num_genes = len(id2gene) - # prepare unified labels - num_labels = len(id2label) - print(f"The build graph contains {num_genes} gene nodes with {num_labels} labels supported.") - - test_graph_dict = dict() # test-graph dict - if params.score: - test_label_dict = dict() # test label dict - test_index_dict = dict() # test feature indices in all features - test_mask_dict = dict() - test_nid_dict = dict() - test_cell_origin_id_dict = dict() - - ids = torch.arange(num_genes, dtype=torch.int32).unsqueeze(-1) - - # ================================================== - # add all genes as nodes - - for num in test: - test_graph_dict[num] = dgl.DGLGraph() - test_graph_dict[num].add_nodes(num_genes, {'id': ids}) - # ==================================================== - - matrices = [] - - support_data = proj_path / 'pretrained' / f'{params.species}' / 'graphs' / f'{params.species}_{tissue}_data.npz' - support_num = 0 - info = load_npz(support_data) - print(f"load {support_data.name}") - row_idx, gene_idx = np.nonzero(info > 0) - non_zeros = info.data - cell_num = info.shape[0] - support_num += cell_num - matrices.append(info) - ids = torch.tensor([-1] * cell_num, dtype=torch.int32).unsqueeze(-1) - total_cell = support_num - - for n in test: # training cell also in test graph - cell_idx = row_idx + test_graph_dict[n].number_of_nodes() - test_graph_dict[n].add_nodes(cell_num, {'id': ids}) - test_graph_dict[n].add_edges(cell_idx, gene_idx, - {'weight': torch.tensor(non_zeros, dtype=torch.float32).unsqueeze(1)}) - test_graph_dict[n].add_edges(gene_idx, cell_idx, - {'weight': torch.tensor(non_zeros, dtype=torch.float32).unsqueeze(1)}) - - for num in test: - data_path = proj_path / params.test_dir / params.species / f'{params.species}_{tissue}{num}_data.{params.filetype}' - if params.score: - type_path = proj_path / params.test_dir / params.species / f'{params.species}_{tissue}{num}_celltype.csv' - # load celltype file then update labels accordingly - cell2type = pd.read_csv(type_path, index_col=0) - cell2type.columns = ['cell', 'type'] - cell2type['type'] = cell2type['type'].map(str.strip) - test_label_dict[num] = list(map(map_dict[num].get, cell2type['type'].tolist())) - - # load data file then update graph - if params.filetype == 'csv': - df = pd.read_csv(data_path, index_col=0) # (gene, cell) - elif params.filetype == 'gz': - df = pd.read_csv(data_path, compression='gzip', index_col=0) - else: - print(f'Not supported type for {data_path}. Please verify your data file') - - test_cell_origin_id_dict[num] = list(df.columns) - df = df.transpose(copy=True) # (cell, gene) - - df = df.rename(columns=gene2id) - # filter out useless columns if exists (when using gene intersection) - col = [c for c in df.columns if c in gene2id.values()] - df = df[col] - - print( - f'{params.species}_{tissue}{num}_data.{params.filetype} -> Nonzero Ratio: {df.fillna(0).astype(bool).sum().sum() / df.size * 100:.2f}%' - ) - tic = time.time() - print(f'Begin to cumulate time of training/testing ...') - # maintain inter-datasets index for graph and RNA-seq values - arr = df.to_numpy() - row_idx, col_idx = np.nonzero(arr > params.threshold) # intra-dataset index - non_zeros = arr[(row_idx, col_idx)] # non-zero values - # inter-dataset index - cell_idx = row_idx + test_graph_dict[num].number_of_nodes() - gene_idx = df.columns[col_idx].astype(int).tolist() # gene_index - info_shape = (len(df), num_genes) - info = csr_matrix((non_zeros, (row_idx, gene_idx)), shape=info_shape) - matrices.append(info) - - # test_nodes_index_dict[num] = list(range(graph.number_of_nodes(), graph.number_of_nodes() + len(df))) - ids = torch.tensor([-1] * len(df), dtype=torch.int32).unsqueeze(-1) - test_index_dict[num] = list( - range(num_genes + support_num + test_num, num_genes + support_num + test_num + len(df))) - test_nid_dict[num] = list( - range(test_graph_dict[num].number_of_nodes(), test_graph_dict[num].number_of_nodes() + len(df))) - test_num += len(df) - test_graph_dict[num].add_nodes(len(df), {'id': ids}) - # for the test cells, only gene-cell edges are in the test graph - test_graph_dict[num].add_edges(gene_idx, cell_idx, - {'weight': torch.tensor(non_zeros, dtype=torch.float32).unsqueeze(1)}) - - print(f'Added {len(df)} nodes and {len(cell_idx)} edges.') - total_cell += num - - support_index = list(range(num_genes + support_num)) - # 2. create features - sparse_feat = vstack(matrices).toarray() # cell-wise (cell, gene) - # transpose to gene-wise - gene_pca = PCA(dense_dim, random_state=random_seed).fit(sparse_feat[:support_num].T) - gene_feat = gene_pca.transform(sparse_feat[:support_num].T) - gene_evr = sum(gene_pca.explained_variance_ratio_) * 100 - print(f'[PCA] Gene EVR: {gene_evr:.2f} %.') - - # do normalization - sparse_feat = sparse_feat / (np.sum(sparse_feat, axis=1, keepdims=True) + 1e-6) - # use weighted gene_feat as cell_feat - cell_feat = sparse_feat.dot(gene_feat) - gene_feat = torch.from_numpy(gene_feat) # use shared storage - cell_feat = torch.from_numpy(cell_feat) - - features = torch.cat([gene_feat, cell_feat], dim=0).type(torch.float) - for num in test: - test_graph_dict[num].ndata['features'] = features[support_index + test_index_dict[num]] - - for num in test: - test_mask_dict[num] = torch.zeros(test_graph_dict[num].number_of_nodes(), dtype=torch.bool) - test_mask_dict[num][test_nid_dict[num]] = 1 - test_nid_dict[num] = torch.tensor(test_nid_dict[num], dtype=torch.int64) - # normalize weight & add self-loop - normalize_weight(test_graph_dict[num]) - test_graph_dict[num].add_edges( - test_graph_dict[num].nodes(), test_graph_dict[num].nodes(), - {'weight': torch.ones(test_graph_dict[num].number_of_nodes(), dtype=torch.float).unsqueeze(1)}) - - test_dict = { - 'graph': test_graph_dict, - 'nid': test_nid_dict, - 'mask': test_mask_dict, - 'origin_id': test_cell_origin_id_dict - } - time_used = time.time() - tic - - if params.score: - return total_cell, num_genes, num_labels, id2label, test_dict, test_label_dict, time_used - else: - return total_cell, num_genes, num_labels, id2label, test_dict, time_used - - def get_id_2_gene(species_data_path, species, tissue, filetype): data_path = species_data_path data_files = list(data_path.glob(f'{species}_{tissue}*_data.{filetype}')) @@ -585,25 +412,21 @@ def save_statistics(statistics_path, id2label, id2gene, tissue): f.write(label + '\r\n') -def load_annotation_data_internal(params): - random_seed = params.random_seed - dense_dim = params.dense_dim +def load_annotation_data(params): species = params.species tissue = params.tissue + test = params.test_dataset proj_path = Path(params.proj_path) - species_data_path = proj_path / 'train' / species - graph_path = proj_path / 'pretrained' / species / 'graphs' - statistics_path = proj_path / 'pretrained' / species / 'statistics' + species_data_path = proj_path / "train" / species + statistics_path = proj_path / "pretrained" / species / "statistics" + + map_path = proj_path / "map" / params.species + map_dict = get_map_dict(map_path, tissue) if not species_data_path.exists(): raise NotImplementedError - if not statistics_path.exists(): - statistics_path.mkdir(parents=True) - if not graph_path.exists(): - graph_path.mkdir(parents=True) - # generate gene statistics file id2gene = get_id_2_gene(species_data_path, species, tissue, filetype=params.filetype) # generate cell label statistics file @@ -619,36 +442,34 @@ def load_annotation_data_internal(params): num_labels = len(id2label) label2id = {label: idx for idx, label in enumerate(id2label)} save_statistics(statistics_path, id2label, id2gene, tissue) - print(f"The build graph contains {num_genes} genes with {num_labels} labels supported.") - - graph = dgl.DGLGraph() - gene_ids = torch.arange(num_genes, dtype=torch.int32).unsqueeze(-1) - graph.add_nodes(num_genes, {'id': gene_ids}) + print(f"Number of genes: {num_genes:,}, number of labels: {num_labels:,}") all_labels = [] - matrices = [] - num_cells = 0 + dfs = [] + data_ids = [] + train_size = 0 data_path = species_data_path - data_files = data_path.glob(f'*{params.species}_{tissue}*_data.{params.filetype}') + data_files = data_path.glob(f"*{params.species}_{tissue}*_data.{params.filetype}") for data_file in data_files: - number = ''.join(list(filter(str.isdigit, data_file.name))) - type_file = species_data_path / f'{params.species}_{tissue}{number}_celltype.csv' + data_id = "".join(list(filter(str.isdigit, data_file.name))) + type_file = species_data_path / f"{params.species}_{tissue}{data_id}_celltype.csv" + data_ids.append(data_id) # load celltype file then update labels accordingly cell2type = pd.read_csv(type_file, index_col=0) - cell2type.columns = ['cell', 'type'] - cell2type['type'] = cell2type['type'].map(str.strip) - cell2type['id'] = cell2type['type'].map(label2id) + cell2type.columns = ["cell", "type"] + cell2type["type"] = cell2type["type"].map(str.strip) + cell2type["id"] = cell2type["type"].map(label2id) # filter out cells not in label-text - filter_cell = np.where(~pd.isnull(cell2type['id']))[0] + filter_cell = np.where(~pd.isnull(cell2type["id"]))[0] cell2type = cell2type.iloc[filter_cell] - assert not cell2type['id'].isnull().any(), 'something wrong about celltype file.' - all_labels += cell2type['id'].tolist() + assert not cell2type["id"].isnull().any(), "something wrong about celltype file." + all_labels += [{i} for i in cell2type["type"].tolist()] if params.filetype not in ["csv", "gz"]: - print(f'Not supported type for {data_path}. Please verify your data file') + print(f"Not supported type for {data_path}. Please verify your data file") continue # load data file then update graph @@ -656,76 +477,49 @@ def load_annotation_data_internal(params): # filter out cells not in label-text df = df.iloc[filter_cell] - assert cell2type['cell'].tolist() == df.index.tolist() - df = df.rename(columns=gene2id) + assert cell2type["cell"].tolist() == df.index.tolist() # filter out useless columns if exists (when using gene intersection) - col = [c for c in df.columns if c in gene2id.values()] + col = [c for c in df.columns if c in gene2id] df = df[col] + dfs.append(df) + train_size += df.shape[0] - print(f"{params.species}_{tissue}{num}_data.{params.filetype} -> " + print(f"{params.species}_{tissue}{data_id}_data.{params.filetype} -> " f"Nonzero Ratio: {df.fillna(0).astype(bool).sum().sum() / df.size * 100:.2f}%") - # maintain inter-datasets index for graph and RNA-seq values - arr = df.to_numpy() - row_idx, col_idx = np.nonzero(arr > params.threshold) # intra-dataset index - non_zeros = arr[(row_idx, col_idx)] # non-zero values - cell_idx = row_idx + graph.number_of_nodes() # cell_index - gene_idx = df.columns[col_idx].astype(int).tolist() # gene_index - info_shape = (len(df), num_genes) - info = csr_matrix((non_zeros, (row_idx, gene_idx)), shape=info_shape) - matrices.append(info) - - num_cells += len(df) - - ids = torch.tensor([-1] * len(df), dtype=torch.int32).unsqueeze(-1) - graph.add_nodes(len(df), {'id': ids}) - graph.add_edges(cell_idx, gene_idx, {'weight': torch.tensor(non_zeros, dtype=torch.float32).unsqueeze(1)}) - graph.add_edges(gene_idx, cell_idx, {'weight': torch.tensor(non_zeros, dtype=torch.float32).unsqueeze(1)}) - - print(f'Added {len(df)} nodes and {len(cell_idx)} edges.') - print(f'#Nodes in Graph: {graph.number_of_nodes()}, #Edges: {graph.number_of_edges()}.') + for data_id in test: + data_path = proj_path / params.test_dir / params.species / f"{params.species}_{tissue}{data_id}_data.{params.filetype}" + type_path = proj_path / params.test_dir / params.species / f"{params.species}_{tissue}{data_id}_celltype.csv" + data_ids.append(data_id) - assert len(all_labels) == num_cells - - save_npz(graph_path / f'{params.species}_{tissue}_data', vstack(matrices)) - - # 2. create features - sparse_feat = vstack(matrices).toarray() # cell-wise (cell, gene) - assert sparse_feat.shape[0] == num_cells - # transpose to gene-wise - gene_pca = PCA(dense_dim, random_state=random_seed).fit(sparse_feat.T) - gene_feat = gene_pca.transform(sparse_feat.T) - gene_evr = sum(gene_pca.explained_variance_ratio_) * 100 - print(f'[PCA] Gene EVR: {gene_evr:.2f} %.') + # load celltype file then update labels accordingly + cell2type = pd.read_csv(type_path, index_col=0) + cell2type.columns = ["cell", "type"] + cell2type["type"] = cell2type["type"].map(str.strip) + all_labels += list(map(map_dict[data_id].get, cell2type["type"].tolist())) - print('------Train label statistics------') + if params.filetype not in ["csv", "gz"]: + print(f"Not supported type for {data_path}. Please verify your data file") + continue - for i, label in enumerate(id2label, start=1): - print(f"#{i} [{label}]: {label_statistics[label]}") + df = pd.read_csv(data_path, index_col=0).transpose(copy=True) # (cell, gene) + col = [c for c in df.columns if c in gene2id] + df = df[col] + dfs.append(df) - # do normalization - sparse_feat = sparse_feat / (np.sum(sparse_feat, axis=1, keepdims=True) + 1e-6) - # use weighted gene_feat as cell_feat - cell_feat = sparse_feat.dot(gene_feat) - gene_feat = torch.from_numpy(gene_feat) # use shared storage - cell_feat = torch.from_numpy(cell_feat) - graph.ndata['features'] = torch.cat([gene_feat, cell_feat], dim=0).type(torch.float) - labels = torch.tensor([-1] * num_genes + all_labels, dtype=torch.long) # [gene_num+train_num] - # split train set and test set - per = np.random.permutation(range(num_genes, num_genes + num_cells)) - test_ids = torch.tensor(per[:int(num_cells // ((1 - params.test_rate) / params.test_rate + 1))]) - train_ids = torch.tensor(per[int(num_cells // ((1 - params.test_rate) / params.test_rate + 1)):]) - # normalize weight + print(f"{params.species}_{tissue}{data_id}_data.{params.filetype} -> " + f"Nonzero Ratio: {df.fillna(0).astype(bool).sum().sum() / df.size * 100:.2f}%") - # normalize weight - normalize_weight(graph) + df_combined = pd.concat(dfs).fillna(0) + adata = anndata.AnnData(df_combined, dtype=np.float32) + adata.obs_names_make_unique() - # add self-loop - graph.add_edges(graph.nodes(), graph.nodes(), - {'weight': torch.ones(graph.number_of_nodes(), dtype=torch.float).unsqueeze(1)}) + data_ids_list = list(itertools.chain.from_iterable([i] * j.shape[0] for i, j in zip(data_ids, dfs))) + ids_df = pd.DataFrame(data_ids_list, index=adata.obs_names) + adata.obs["data_id"] = ids_df - return num_cells, num_genes, num_labels, graph, train_ids, test_ids, labels + return adata, all_labels, id2label, train_size ######################################## diff --git a/examples/single_modality/cell_type_annotation/scdeepsort.py b/examples/single_modality/cell_type_annotation/scdeepsort.py index 04cd4700..8db442fb 100644 --- a/examples/single_modality/cell_type_annotation/scdeepsort.py +++ b/examples/single_modality/cell_type_annotation/scdeepsort.py @@ -1,19 +1,20 @@ import argparse -import random from pprint import pprint -import numpy as np import torch +from dance.data import Data from dance.datasets.singlemodality import CellTypeDataset from dance.modules.single_modality.cell_type_annotation.scdeepsort import ScDeepSort +from dance.transforms.graph import PCACellGeneGraph +from dance.utils.preprocess import cell_label_to_df if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--random_seed", type=int, default=10) parser.add_argument("--data_type", type=str, default="scdeepsort_exp") parser.add_argument("--dropout", type=float, default=0.1, help="dropout probability") - parser.add_argument("--gpu", type=int, default=-1, help="GPU id, -1 for cpu") + parser.add_argument("--device", type=str, default="cpu", help="Computation device") parser.add_argument("--filetype", default="csv", type=str, choices=["csv", "gz"], help="data file type, csv or gz") parser.add_argument("--lr", type=float, default=1e-3, help="learning rate") parser.add_argument("--weight_decay", type=float, default=5e-4, help="Weight for L2 loss") @@ -34,7 +35,6 @@ parser.add_argument("--test", dest="evaluate", action="store_false") parser.set_defaults(evaluate=True) parser.add_argument("--test_dir", default="test", type=str, help="test directory") - parser.add_argument("--save_dir", default="result", type=str, help="save directory") parser.add_argument("--test_rate", type=float, default=0.2) parser.add_argument("--test_dataset", nargs="+", type=int, default=[1759], help="Testing dataset IDs") params = parser.parse_args() @@ -42,21 +42,37 @@ dataloader = CellTypeDataset(data_type="scdeepsort_exp", random_seed=params.random_seed, dense_dim=params.dense_dim, test_dataset=params.test_dataset, species=params.species, tissue=params.tissue, - gpu=params.gpu, evaluate=params.evaluate, test_dir=params.test_dir, - filetype=params.filetype, threshold=params.threshold, exclude_rate=params.exclude_rate, + evaluate=params.evaluate, test_dir=params.test_dir, filetype=params.filetype, + threshold=params.threshold, exclude_rate=params.exclude_rate, test_rate=params.test_rate, score=True) - dataloader = dataloader.load_data() - random.seed(params.random_seed) - np.random.seed(params.random_seed) - torch.manual_seed(params.random_seed) - torch.cuda.manual_seed(params.random_seed) - trainer = ScDeepSort(params) - trainer.fit(dataloader.num_cells, dataloader.num_genes, dataloader.num_labels, dataloader.graph, - dataloader.train_ids, dataloader.test_ids, dataloader.labels) - prediction_labels = trainer.predict(dataloader.id2label_test, dataloader.test_dict) - evaluation_scores = trainer.score(prediction_labels, dataloader.test_label_dict) - pprint(evaluation_scores) + adata, cell_labels, idx_to_label, train_size = dataloader.load_data() + adata.obsm["cell_type"] = cell_label_to_df(cell_labels, idx_to_label, index=adata.obs.index) + data = Data(adata, train_size=train_size) + PCACellGeneGraph(n_components=params.dense_dim, split_name="train", log_level="INFO")(data) + data.set_config(label_channel="cell_type") + + y_train = data.get_y(split_name="train", return_type="torch").argmax(1) + y_test = data.get_y(split_name="test", return_type="torch") + num_labels = y_test.shape[1] + + # TODO: make api for the following blcok? + g = data.data.uns["CellGeneGraph"] + num_genes = data.num_features + gene_ids = torch.arange(num_genes) + train_cell_ids = torch.LongTensor(data.train_idx) + num_genes + test_cell_ids = torch.LongTensor(data.test_idx) + num_genes + g_train = g.subgraph(torch.concat((gene_ids, train_cell_ids))) + g_test = g.subgraph(torch.concat((gene_ids, test_cell_ids))) + + model = ScDeepSort(params.dense_dim, num_labels, params.hidden_dim, params.n_layers, params.species, params.tissue, + dropout=params.dropout, batch_size=params.batch_size, device=params.device) + model.fit(g_train, y_train, epochs=params.n_epochs, lr=params.lr, weight_decay=params.weight_decay, + val_ratio=params.test_rate) + + pred, unsure = model.predict(g_test) + score = model.score(pred, y_test) + print(f"{score=}") """To reproduce the benchmarking results, please run the following commands: python scdeepsort.py --data_type scdeepsort --tissue Brain --test_data 2695