diff --git a/dance/datasets/singlemodality.py b/dance/datasets/singlemodality.py index 0642ef30..9d5cfb28 100644 --- a/dance/datasets/singlemodality.py +++ b/dance/datasets/singlemodality.py @@ -367,14 +367,7 @@ def load_data(self): if self.data_type == "actinn": self.download_benchmark_data() - ( - self.train_set, - self.train_label, - self.test_set, - self.test_label, - self.barcode, - self.label_to_type_dict, - ) = load_actinn_data(self.train_set, self.train_label, self.test_set, self.test_label) + return load_actinn_data(self.train_set, self.train_label, self.test_set, self.test_label) if self.data_type == "celltypist": self.download_benchmark_data(download_pretrained=False) diff --git a/dance/modules/single_modality/cell_type_annotation/actinn.py b/dance/modules/single_modality/cell_type_annotation/actinn.py index d6d4f85e..6be30a53 100644 --- a/dance/modules/single_modality/cell_type_annotation/actinn.py +++ b/dance/modules/single_modality/cell_type_annotation/actinn.py @@ -145,15 +145,15 @@ def fit(self, x_train, y_train, seed=None): Parameters ---------- x_train : torch.Tensor - training data (genes by cells). + training data (cells by genes). y_train : torch.Tensor - training labels (cell-types by cells). + training labels (cells by cell-types). seed : int, optional Random seed, if set to None, then random. """ - x_train = x_train.T.clone().detach().float().to(self.device) # cells by genes - y_train = torch.where(y_train.T)[1].to(self.device) # cells + x_train = x_train.clone().detach().float().to(self.device) # cells by genes + y_train = torch.where(y_train)[1].to(self.device) # cells # Initialize weights, optimizer, and scheduler self.initialize_parameters(seed) @@ -182,7 +182,7 @@ def predict(self, x): Parameters ---------- x : torch.Tensor - Gene expression input features (genes by cells). + Gene expression input features (cells by genes). Returns ------- @@ -190,27 +190,25 @@ def predict(self, x): Predicted cell-label indices. """ - x = x.T.clone().detach().to(self.device) + x = x.clone().detach().to(self.device) z = self.forward(x) prediction = torch.argmax(z, dim=-1) return prediction - def score(self, x, y): + def score(self, pred, true): """Model performance score measured by accuracy. Parameters ---------- - x : torch.Tensor - Gene expression input features (genes by cells). - y : torch.Tensor - One-hot encoded ground truth labels (cell-types by cells). + pred : torch.Tensor + Gene expression input features (cells by genes). + true : torch.Tensor + Encoded ground truth cell type labels (cells by cell-types). Returns ------- float - Prediction accuracy + Prediction accuracy. """ - pred = self.predict(x).detach().cpu() - label = torch.where(y.T)[1] - return (pred == label).detach().float().mean().tolist() + return true[range(pred.shape[0]), pred.squeeze(-1)].detach().mean().item() diff --git a/dance/transforms/preprocess.py b/dance/transforms/preprocess.py index ae5b3823..dcf7d303 100644 --- a/dance/transforms/preprocess.py +++ b/dance/transforms/preprocess.py @@ -968,26 +968,30 @@ def load_actinn_data(train_data_paths: List[str], train_label_paths: List[str], train_set, test_set = scale_sets([train_set, test_set], normalize=normalize) type_to_label_dict_out = type_to_label_dict(train_label.iloc[:, 1]) label_to_type_dict = {v: k for k, v in type_to_label_dict_out.items()} - print(f"Cell Types in training set:") - pprint.pprint(type_to_label_dict_out) + logger.info("Cell Types in training set:\n%s", pprint.pformat(type_to_label_dict_out)) train_label = convert_type_to_label(train_label.iloc[:, 1], type_to_label_dict_out) train_label = one_hot_matrix(train_label, nt) - print(f"# Trainng cells: {train_label.shape[1]:,}") + logger.info(f"# Trainng cells: {train_label.shape[1]:,}") total_test_cells = test_label.shape[0] indicator = test_label.iloc[:, 1].isin(type_to_label_dict_out) test_label = test_label[indicator] - barcode = test_label.iloc[:, 0].tolist() test_set = test_set[:, indicator] test_label = convert_type_to_label(test_label.iloc[:, 1], type_to_label_dict_out) test_label = one_hot_matrix(test_label, nt) - print(f"# Testing cells {test_label.shape[1]:,} (original number of cells = {total_test_cells:,})") + logger.info(f"# Testing cells {test_label.shape[1]:,} (original number of cells = {total_test_cells:,})") - # Convert to train_set and test_set to tensor - train_set = torch.from_numpy(train_set) - test_set = torch.from_numpy(test_set) + x_adata = AnnData(sp.csr_matrix(np.hstack((train_set, test_set)).T), dtype=np.float32) + train_size = train_set.shape[1] + tot_size = x_adata.shape[0] - return train_set, train_label, test_set, test_label, barcode, label_to_type_dict + labels = [set() for _ in range(tot_size)] + for i, j in zip(*np.where(np.hstack((train_label, test_label)).T)): + labels[i].add(label_to_type_dict[j]) + + idx_to_label = list(map(label_to_type_dict.get, range(len(label_to_type_dict)))) + + return x_adata, labels, idx_to_label, train_size ####################################################### diff --git a/examples/single_modality/cell_type_annotation/actinn.py b/examples/single_modality/cell_type_annotation/actinn.py index db0411bd..fc3df34d 100644 --- a/examples/single_modality/cell_type_annotation/actinn.py +++ b/examples/single_modality/cell_type_annotation/actinn.py @@ -4,8 +4,10 @@ import numpy as np import pandas as pd +from dance.data import Data from dance.datasets.singlemodality import CellTypeDataset from dance.modules.single_modality.cell_type_annotation.actinn import ACTINN +from dance.utils.preprocess import cell_label_to_adata if __name__ == "__main__": parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) @@ -43,33 +45,24 @@ dataloader = CellTypeDataset(data_type="actinn", train_set=train_data_paths, train_label=train_label_paths, test_set=test_data_path, test_label=test_label_path) - dataloader = dataloader.load_data() - barcode = dataloader.barcode - train_set = dataloader.train_set - train_label = dataloader.train_label - test_set = dataloader.test_set - test_label = dataloader.test_label - # Initialize and train model - num_genes, num_train_samples = train_set.shape - num_cell_types = train_label.shape[0] - print(f"{num_train_samples=:,}, {num_genes=:,}, {num_cell_types=:,}") - model = ACTINN(num_genes, num_cell_types, hidden_dims=args.hidden_dims, lr=args.learning_rate, device=args.device, - num_epochs=args.num_epochs, batch_size=args.batch_size, print_cost=args.print_cost, lambd=args.lambd) + x_adata, cell_labels, idx_to_label, train_size = dataloader.load_data() + y_adata = cell_label_to_adata(cell_labels, idx_to_label, obs=x_adata.obs) + data = Data(x_adata, y_adata, train_size=train_size) - scores = [] - for k in range(args.runs): - model.fit(train_set, train_label, seed=args.seed + k) - test_predict = model.predict(test_set) + model = ACTINN(input_dim=data.num_features, output_dim=len(idx_to_label), hidden_dims=args.hidden_dims, + lr=args.learning_rate, device=args.device, num_epochs=args.num_epochs, batch_size=args.batch_size, + print_cost=args.print_cost, lambd=args.lambd) - predicted_label = [] - for i in range(len(test_predict)): - predicted_label.append(dataloader.label_to_type_dict[test_predict[i].item()]) - predicted_label = pd.DataFrame({"cellname": barcode, "celltype": predicted_label}) - # predicted_label.to_csv("predicted_label.txt", sep="\t", index=False) + x_train, y_train = data.get_train_data(return_type="torch") + x_test, y_test = data.get_test_data(return_type="torch") - scores.append(model.score(test_set, test_label)) - print(f"Run {k + 1:>2d}/{args.runs:>2d}: {scores[-1]}") + scores = [] + for k in range(args.runs): + model.fit(x_train, y_train, seed=args.seed + k) + pred = model.predict(x_test) + scores.append(score := model.score(pred, y_test)) + print(f"{score}") print(f"Score: {np.mean(scores):04.3f} +/- {np.std(scores):04.3f}") """To reproduce ACTINN benchmarks, please refer to command lines belows: