Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update actinn example script with new data object #63

Merged
merged 2 commits into from
Nov 22, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 1 addition & 8 deletions dance/datasets/singlemodality.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,14 +367,7 @@ def load_data(self):

if self.data_type == "actinn":
self.download_benchmark_data()
(
self.train_set,
self.train_label,
self.test_set,
self.test_label,
self.barcode,
self.label_to_type_dict,
) = load_actinn_data(self.train_set, self.train_label, self.test_set, self.test_label)
return load_actinn_data(self.train_set, self.train_label, self.test_set, self.test_label)

if self.data_type == "celltypist":
self.download_benchmark_data(download_pretrained=False)
Expand Down
28 changes: 13 additions & 15 deletions dance/modules/single_modality/cell_type_annotation/actinn.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,15 +145,15 @@ def fit(self, x_train, y_train, seed=None):
Parameters
----------
x_train : torch.Tensor
training data (genes by cells).
training data (cells by genes).
y_train : torch.Tensor
training labels (cell-types by cells).
training labels (cells by cell-types).
seed : int, optional
Random seed, if set to None, then random.

"""
x_train = x_train.T.clone().detach().float().to(self.device) # cells by genes
y_train = torch.where(y_train.T)[1].to(self.device) # cells
x_train = x_train.clone().detach().float().to(self.device) # cells by genes
y_train = torch.where(y_train)[1].to(self.device) # cells

# Initialize weights, optimizer, and scheduler
self.initialize_parameters(seed)
Expand Down Expand Up @@ -182,35 +182,33 @@ def predict(self, x):
Parameters
----------
x : torch.Tensor
Gene expression input features (genes by cells).
Gene expression input features (cells by genes).

Returns
-------
torch.Tensor
Predicted cell-label indices.

"""
x = x.T.clone().detach().to(self.device)
x = x.clone().detach().to(self.device)
z = self.forward(x)
prediction = torch.argmax(z, dim=-1)
return prediction

def score(self, x, y):
def score(self, pred, true):
"""Model performance score measured by accuracy.

Parameters
----------
x : torch.Tensor
Gene expression input features (genes by cells).
y : torch.Tensor
One-hot encoded ground truth labels (cell-types by cells).
pred : torch.Tensor
Gene expression input features (cells by genes).
true : torch.Tensor
Encoded ground truth cell type labels (cells by cell-types).

Returns
-------
float
Prediction accuracy
Prediction accuracy.

"""
pred = self.predict(x).detach().cpu()
label = torch.where(y.T)[1]
return (pred == label).detach().float().mean().tolist()
return true[range(pred.shape[0]), pred.squeeze(-1)].detach().mean().item()
22 changes: 13 additions & 9 deletions dance/transforms/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -968,26 +968,30 @@ def load_actinn_data(train_data_paths: List[str], train_label_paths: List[str],
train_set, test_set = scale_sets([train_set, test_set], normalize=normalize)
type_to_label_dict_out = type_to_label_dict(train_label.iloc[:, 1])
label_to_type_dict = {v: k for k, v in type_to_label_dict_out.items()}
print(f"Cell Types in training set:")
pprint.pprint(type_to_label_dict_out)
logger.info("Cell Types in training set:\n%s", pprint.pformat(type_to_label_dict_out))
train_label = convert_type_to_label(train_label.iloc[:, 1], type_to_label_dict_out)
train_label = one_hot_matrix(train_label, nt)
print(f"# Trainng cells: {train_label.shape[1]:,}")
logger.info(f"# Trainng cells: {train_label.shape[1]:,}")

total_test_cells = test_label.shape[0]
indicator = test_label.iloc[:, 1].isin(type_to_label_dict_out)
test_label = test_label[indicator]
barcode = test_label.iloc[:, 0].tolist()
test_set = test_set[:, indicator]
test_label = convert_type_to_label(test_label.iloc[:, 1], type_to_label_dict_out)
test_label = one_hot_matrix(test_label, nt)
print(f"# Testing cells {test_label.shape[1]:,} (original number of cells = {total_test_cells:,})")
logger.info(f"# Testing cells {test_label.shape[1]:,} (original number of cells = {total_test_cells:,})")

# Convert to train_set and test_set to tensor
train_set = torch.from_numpy(train_set)
test_set = torch.from_numpy(test_set)
x_adata = AnnData(sp.csr_matrix(np.hstack((train_set, test_set)).T), dtype=np.float32)
train_size = train_set.shape[1]
tot_size = x_adata.shape[0]

return train_set, train_label, test_set, test_label, barcode, label_to_type_dict
labels = [set() for _ in range(tot_size)]
for i, j in zip(*np.where(np.hstack((train_label, test_label)).T)):
labels[i].add(label_to_type_dict[j])

idx_to_label = list(map(label_to_type_dict.get, range(len(label_to_type_dict))))

return x_adata, labels, idx_to_label, train_size


#######################################################
Expand Down
39 changes: 16 additions & 23 deletions examples/single_modality/cell_type_annotation/actinn.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@
import numpy as np
import pandas as pd

from dance.data import Data
from dance.datasets.singlemodality import CellTypeDataset
from dance.modules.single_modality.cell_type_annotation.actinn import ACTINN
from dance.utils.preprocess import cell_label_to_adata

if __name__ == "__main__":
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
Expand Down Expand Up @@ -43,33 +45,24 @@

dataloader = CellTypeDataset(data_type="actinn", train_set=train_data_paths, train_label=train_label_paths,
test_set=test_data_path, test_label=test_label_path)
dataloader = dataloader.load_data()
barcode = dataloader.barcode
train_set = dataloader.train_set
train_label = dataloader.train_label
test_set = dataloader.test_set
test_label = dataloader.test_label

# Initialize and train model
num_genes, num_train_samples = train_set.shape
num_cell_types = train_label.shape[0]
print(f"{num_train_samples=:,}, {num_genes=:,}, {num_cell_types=:,}")
model = ACTINN(num_genes, num_cell_types, hidden_dims=args.hidden_dims, lr=args.learning_rate, device=args.device,
num_epochs=args.num_epochs, batch_size=args.batch_size, print_cost=args.print_cost, lambd=args.lambd)
x_adata, cell_labels, idx_to_label, train_size = dataloader.load_data()
y_adata = cell_label_to_adata(cell_labels, idx_to_label, obs=x_adata.obs)
data = Data(x_adata, y_adata, train_size=train_size)

scores = []
for k in range(args.runs):
model.fit(train_set, train_label, seed=args.seed + k)
test_predict = model.predict(test_set)
model = ACTINN(input_dim=data.num_features, output_dim=len(idx_to_label), hidden_dims=args.hidden_dims,
lr=args.learning_rate, device=args.device, num_epochs=args.num_epochs, batch_size=args.batch_size,
print_cost=args.print_cost, lambd=args.lambd)

predicted_label = []
for i in range(len(test_predict)):
predicted_label.append(dataloader.label_to_type_dict[test_predict[i].item()])
predicted_label = pd.DataFrame({"cellname": barcode, "celltype": predicted_label})
# predicted_label.to_csv("predicted_label.txt", sep="\t", index=False)
x_train, y_train = data.get_train_data(return_type="torch")
x_test, y_test = data.get_test_data(return_type="torch")

scores.append(model.score(test_set, test_label))
print(f"Run {k + 1:>2d}/{args.runs:>2d}: {scores[-1]}")
scores = []
for k in range(args.runs):
model.fit(x_train, y_train, seed=args.seed + k)
pred = model.predict(x_test)
scores.append(score := model.score(pred, y_test))
print(f"{score}")
print(f"Score: {np.mean(scores):04.3f} +/- {np.std(scores):04.3f}")
"""To reproduce ACTINN benchmarks, please refer to command lines belows:

Expand Down