Skip to content

Commit

Permalink
update actinn example script with new data object (#63)
Browse files Browse the repository at this point in the history
* use logger instead of print

* update actinn example script to use the new data object
  • Loading branch information
RemyLau authored Nov 22, 2022
1 parent acc5986 commit 0e70e4a
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 55 deletions.
9 changes: 1 addition & 8 deletions dance/datasets/singlemodality.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,14 +367,7 @@ def load_data(self):

if self.data_type == "actinn":
self.download_benchmark_data()
(
self.train_set,
self.train_label,
self.test_set,
self.test_label,
self.barcode,
self.label_to_type_dict,
) = load_actinn_data(self.train_set, self.train_label, self.test_set, self.test_label)
return load_actinn_data(self.train_set, self.train_label, self.test_set, self.test_label)

if self.data_type == "celltypist":
self.download_benchmark_data(download_pretrained=False)
Expand Down
28 changes: 13 additions & 15 deletions dance/modules/single_modality/cell_type_annotation/actinn.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,15 +145,15 @@ def fit(self, x_train, y_train, seed=None):
Parameters
----------
x_train : torch.Tensor
training data (genes by cells).
training data (cells by genes).
y_train : torch.Tensor
training labels (cell-types by cells).
training labels (cells by cell-types).
seed : int, optional
Random seed, if set to None, then random.
"""
x_train = x_train.T.clone().detach().float().to(self.device) # cells by genes
y_train = torch.where(y_train.T)[1].to(self.device) # cells
x_train = x_train.clone().detach().float().to(self.device) # cells by genes
y_train = torch.where(y_train)[1].to(self.device) # cells

# Initialize weights, optimizer, and scheduler
self.initialize_parameters(seed)
Expand Down Expand Up @@ -182,35 +182,33 @@ def predict(self, x):
Parameters
----------
x : torch.Tensor
Gene expression input features (genes by cells).
Gene expression input features (cells by genes).
Returns
-------
torch.Tensor
Predicted cell-label indices.
"""
x = x.T.clone().detach().to(self.device)
x = x.clone().detach().to(self.device)
z = self.forward(x)
prediction = torch.argmax(z, dim=-1)
return prediction

def score(self, x, y):
def score(self, pred, true):
"""Model performance score measured by accuracy.
Parameters
----------
x : torch.Tensor
Gene expression input features (genes by cells).
y : torch.Tensor
One-hot encoded ground truth labels (cell-types by cells).
pred : torch.Tensor
Gene expression input features (cells by genes).
true : torch.Tensor
Encoded ground truth cell type labels (cells by cell-types).
Returns
-------
float
Prediction accuracy
Prediction accuracy.
"""
pred = self.predict(x).detach().cpu()
label = torch.where(y.T)[1]
return (pred == label).detach().float().mean().tolist()
return true[range(pred.shape[0]), pred.squeeze(-1)].detach().mean().item()
22 changes: 13 additions & 9 deletions dance/transforms/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -968,26 +968,30 @@ def load_actinn_data(train_data_paths: List[str], train_label_paths: List[str],
train_set, test_set = scale_sets([train_set, test_set], normalize=normalize)
type_to_label_dict_out = type_to_label_dict(train_label.iloc[:, 1])
label_to_type_dict = {v: k for k, v in type_to_label_dict_out.items()}
print(f"Cell Types in training set:")
pprint.pprint(type_to_label_dict_out)
logger.info("Cell Types in training set:\n%s", pprint.pformat(type_to_label_dict_out))
train_label = convert_type_to_label(train_label.iloc[:, 1], type_to_label_dict_out)
train_label = one_hot_matrix(train_label, nt)
print(f"# Trainng cells: {train_label.shape[1]:,}")
logger.info(f"# Trainng cells: {train_label.shape[1]:,}")

total_test_cells = test_label.shape[0]
indicator = test_label.iloc[:, 1].isin(type_to_label_dict_out)
test_label = test_label[indicator]
barcode = test_label.iloc[:, 0].tolist()
test_set = test_set[:, indicator]
test_label = convert_type_to_label(test_label.iloc[:, 1], type_to_label_dict_out)
test_label = one_hot_matrix(test_label, nt)
print(f"# Testing cells {test_label.shape[1]:,} (original number of cells = {total_test_cells:,})")
logger.info(f"# Testing cells {test_label.shape[1]:,} (original number of cells = {total_test_cells:,})")

# Convert to train_set and test_set to tensor
train_set = torch.from_numpy(train_set)
test_set = torch.from_numpy(test_set)
x_adata = AnnData(sp.csr_matrix(np.hstack((train_set, test_set)).T), dtype=np.float32)
train_size = train_set.shape[1]
tot_size = x_adata.shape[0]

return train_set, train_label, test_set, test_label, barcode, label_to_type_dict
labels = [set() for _ in range(tot_size)]
for i, j in zip(*np.where(np.hstack((train_label, test_label)).T)):
labels[i].add(label_to_type_dict[j])

idx_to_label = list(map(label_to_type_dict.get, range(len(label_to_type_dict))))

return x_adata, labels, idx_to_label, train_size


#######################################################
Expand Down
39 changes: 16 additions & 23 deletions examples/single_modality/cell_type_annotation/actinn.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@
import numpy as np
import pandas as pd

from dance.data import Data
from dance.datasets.singlemodality import CellTypeDataset
from dance.modules.single_modality.cell_type_annotation.actinn import ACTINN
from dance.utils.preprocess import cell_label_to_adata

if __name__ == "__main__":
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
Expand Down Expand Up @@ -43,33 +45,24 @@

dataloader = CellTypeDataset(data_type="actinn", train_set=train_data_paths, train_label=train_label_paths,
test_set=test_data_path, test_label=test_label_path)
dataloader = dataloader.load_data()
barcode = dataloader.barcode
train_set = dataloader.train_set
train_label = dataloader.train_label
test_set = dataloader.test_set
test_label = dataloader.test_label

# Initialize and train model
num_genes, num_train_samples = train_set.shape
num_cell_types = train_label.shape[0]
print(f"{num_train_samples=:,}, {num_genes=:,}, {num_cell_types=:,}")
model = ACTINN(num_genes, num_cell_types, hidden_dims=args.hidden_dims, lr=args.learning_rate, device=args.device,
num_epochs=args.num_epochs, batch_size=args.batch_size, print_cost=args.print_cost, lambd=args.lambd)
x_adata, cell_labels, idx_to_label, train_size = dataloader.load_data()
y_adata = cell_label_to_adata(cell_labels, idx_to_label, obs=x_adata.obs)
data = Data(x_adata, y_adata, train_size=train_size)

scores = []
for k in range(args.runs):
model.fit(train_set, train_label, seed=args.seed + k)
test_predict = model.predict(test_set)
model = ACTINN(input_dim=data.num_features, output_dim=len(idx_to_label), hidden_dims=args.hidden_dims,
lr=args.learning_rate, device=args.device, num_epochs=args.num_epochs, batch_size=args.batch_size,
print_cost=args.print_cost, lambd=args.lambd)

predicted_label = []
for i in range(len(test_predict)):
predicted_label.append(dataloader.label_to_type_dict[test_predict[i].item()])
predicted_label = pd.DataFrame({"cellname": barcode, "celltype": predicted_label})
# predicted_label.to_csv("predicted_label.txt", sep="\t", index=False)
x_train, y_train = data.get_train_data(return_type="torch")
x_test, y_test = data.get_test_data(return_type="torch")

scores.append(model.score(test_set, test_label))
print(f"Run {k + 1:>2d}/{args.runs:>2d}: {scores[-1]}")
scores = []
for k in range(args.runs):
model.fit(x_train, y_train, seed=args.seed + k)
pred = model.predict(x_test)
scores.append(score := model.score(pred, y_test))
print(f"{score}")
print(f"Score: {np.mean(scores):04.3f} +/- {np.std(scores):04.3f}")
"""To reproduce ACTINN benchmarks, please refer to command lines belows:
Expand Down

0 comments on commit 0e70e4a

Please sign in to comment.