Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor preprocessing pipelines for spatial_domain tasks #165

Merged
merged 6 commits into from
Jan 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions dance/data/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,9 @@ def __init__(self, data: Union[AnnData, MuData], train_size: Optional[int] = Non
if "dance_config" not in self._data.uns:
self._data.uns["dance_config"] = dict()

def __repr__(self) -> str:
return f"{self.__class__.__name__} object that wraps (.data):\n{self.data}"

def _setup_splits(self, train_size: Optional[Union[int, str]], val_size: int, test_size: int):
if train_size is None:
return
Expand Down
9 changes: 7 additions & 2 deletions dance/datasets/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,10 +129,15 @@ def _maybe_load_cache(self, transform, cache, redo_cache) -> Union[Data, str]:

cache_file_path = osp.join(cache_dir, f"{md5_hash}.pkl")
if osp.isfile(cache_file_path) and cache:
logger.info(f"Loading cached data at {cache_file_path}\n{'Cache data info':-^100}\n"
f"Dataset: {self!r}\nTransformation: {transform!r}\n{'End of cache data info':-^100}")
with open(cache_file_path, "rb") as f:
data = pickle.load(f)
terminal_width = os.get_terminal_size().columns
logger.info(f"Loading cached data at {cache_file_path}\n"
f"{'Cache data info':=^{terminal_width}}\n"
f"{'Dataset object info':-^{terminal_width}}\n{self!r}\n"
f"{'Transformation info':-^{terminal_width}}\n{transform!r}\n"
f"{'Loaded data info':-^{terminal_width}}\n{data!r}\n"
f"{'End of cache data info':=^{terminal_width}}")
return data
else:
return cache_file_path
Expand Down
173 changes: 96 additions & 77 deletions dance/datasets/spatial.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,25 +8,13 @@
import scanpy as sc

from dance import logger
from dance.data import Data
from dance.datasets.base import BaseDataset
from dance.registers import register_dataset
from dance.utils.download import download_file, download_unzip, unzip_file

IGNORED_FILES = ["readme.txt"]

dataset = {
"151510": "https://www.dropbox.com/sh/41h9brsk6my546x/AADa18mkJge-KQRTndRelTpMa?dl=1",
"151507": "https://www.dropbox.com/sh/m3554vfrdzbwv2c/AACGsFNVKx8rjBgvF7Pcm2L7a?dl=1",
"151508": "https://www.dropbox.com/sh/tm47u3fre8692zt/AAAJJf8-za_Lpw614ft096qqa?dl=1",
"151509": "https://www.dropbox.com/sh/hihr7906vyirjet/AACslV5mKIkF2CF5QqE1LE6ya?dl=1",
"151669": "https://www.dropbox.com/sh/ulw2nnnmgtbswvc/AAC0fT549EwtxKZWWoB89gb4a?dl=1",
"151670": "https://www.dropbox.com/sh/8fw44zyyjgh0ddc/AAA1asGAmyDiMmvhRmL7pN1Na?dl=1",
"151671": "https://www.dropbox.com/sh/9g5qzd5ykx2mpk3/AAD3xjx1i2h0RhYBc-Vft6CEa?dl=1",
"151672": "https://www.dropbox.com/sh/l6519tr280krd4p/AAAWefCSp2iKhVmLgytlyxTta?dl=1",
"151673": "https://www.dropbox.com/sh/qc64ps6gd64dm0c/AAC_5_mP4AczKj8lORLLKcIba?dl=1",
"151674": "https://www.dropbox.com/sh/q7io99psd2xuqgw/AABske8dgX_kc1oaDSxuiqjpa?dl=1",
"151675": "https://www.dropbox.com/sh/uahka2h5klnrzvj/AABe7K0_ewqOcqKUxHebE6qLa?dl=1",
"151676": "https://www.dropbox.com/sh/jos5jjurezy5zp1/AAB2uaVm3-Us1a4mDkS1Q-iAa?dl=1",
}

cellDeconvo_dataset = {
"CARD_synthetic": "https://www.dropbox.com/sh/v0vpv0jsnfexj7f/AADpizLGOrF7M8EesDihgbBla?dl=1",
"GSE174746": "https://www.dropbox.com/sh/spfv06yfttetrab/AAAgORS6ocyoZEyxiRYKTymCa?dl=1",
Expand All @@ -38,81 +26,112 @@
}


class SpotDataset:
@register_dataset("spatiallibd")
class SpatialLIBDDataset(BaseDataset):

_DISPLAY_ATTRS = ("data_id", )
url_dict = {
"151510": "https://www.dropbox.com/sh/41h9brsk6my546x/AADa18mkJge-KQRTndRelTpMa?dl=1",
"151507": "https://www.dropbox.com/sh/m3554vfrdzbwv2c/AACGsFNVKx8rjBgvF7Pcm2L7a?dl=1",
"151508": "https://www.dropbox.com/sh/tm47u3fre8692zt/AAAJJf8-za_Lpw614ft096qqa?dl=1",
"151509": "https://www.dropbox.com/sh/hihr7906vyirjet/AACslV5mKIkF2CF5QqE1LE6ya?dl=1",
"151669": "https://www.dropbox.com/sh/ulw2nnnmgtbswvc/AAC0fT549EwtxKZWWoB89gb4a?dl=1",
"151670": "https://www.dropbox.com/sh/8fw44zyyjgh0ddc/AAA1asGAmyDiMmvhRmL7pN1Na?dl=1",
"151671": "https://www.dropbox.com/sh/9g5qzd5ykx2mpk3/AAD3xjx1i2h0RhYBc-Vft6CEa?dl=1",
"151672": "https://www.dropbox.com/sh/l6519tr280krd4p/AAAWefCSp2iKhVmLgytlyxTta?dl=1",
"151673": "https://www.dropbox.com/sh/qc64ps6gd64dm0c/AAC_5_mP4AczKj8lORLLKcIba?dl=1",
"151674": "https://www.dropbox.com/sh/q7io99psd2xuqgw/AABske8dgX_kc1oaDSxuiqjpa?dl=1",
"151675": "https://www.dropbox.com/sh/uahka2h5klnrzvj/AABe7K0_ewqOcqKUxHebE6qLa?dl=1",
"151676": "https://www.dropbox.com/sh/jos5jjurezy5zp1/AAB2uaVm3-Us1a4mDkS1Q-iAa?dl=1",
}

def __init__(self, root=".", full_download=False, data_id="151673", data_dir="data/spot"):
super().__init__(root, full_download)

def __init__(self, data_id="151673", data_dir="data/spot", build_graph_fn="default"):
self.data_id = data_id
self.data_dir = data_dir + "/{}".format(data_id)
self.data_url = dataset[data_id]
self._load_data()
self.adj = None

def get_all_data(self):
# provide an interface to get all data at one time
print("All data includes {} datasets: {}".format(len(dataset), ",".join(dataset.keys())))
res = {}
for each_dataset in dataset.keys():
res[each_dataset] = SpotDataset(each_dataset)
return res

def download_data(self):
# judge whether a file exists or not
isdownload = download_file(self.data_url, self.data_dir + "/{}.zip".format(self.data_id))
if isdownload:
unzip_file(self.data_dir + "/{}.zip".format(self.data_id), self.data_dir + "/")
return self

def download_all(self):
logger.info(f"All data includes {len(self.url_dict)} datasets: {list(self.url_dict)}")
_data_id = self.data_id
for data_id in self.url_dict:
self.data_id = data_id
self.download()
self.data_id = _data_id

def is_complete_all(self):
_data_id = self.data_id
for data_id in self.url_dict:
self.data_id = data_id
if not self.is_complete():
self.data_id = _data_id
return False
self.data_id = _data_id
return True

def download(self):
out_path = osp.join(self.data_dir, f"{self.data_id}.zip")
if download_file(self.url_dict[self.data_id], out_path):
unzip_file(out_path, self.data_dir)

def is_complete(self):
# data.h5ad
# histology.tif
# positions.txt
# judge whether data is complete or not
check = [
self.data_dir + "/{}_raw_feature_bc_matrix.h5".format(self.data_id),
self.data_dir + "/{}_full_image.tif".format(self.data_id), self.data_dir + "/tissue_positions_list.txt"
osp.join(self.data_dir, f"{self.data_id}_raw_feature_bc_matrix.h5"), # expression
osp.join(self.data_dir, f"{self.data_id}_full_image.tif"), # histology
osp.join(self.data_dir, "tissue_positions_list.txt"), # positions
]

for i in check:
if not os.path.exists(i):
print("lack {}".format(i))
logger.info(f"lack {i}")
return False

return True

def _load_data(self):
if self.is_complete():
pass
else:
self.download_data()
self.data = sc.read_10x_h5(self.data_dir + "/{}_raw_feature_bc_matrix.h5".format(self.data_id))
self.img = cv2.imread(self.data_dir + "/{}_full_image.tif".format(self.data_id))
label = pd.read_csv(self.data_dir + "/cluster_labels.csv")
classes = {layer_class: idx for idx, layer_class in enumerate(set(label["ground_truth"].tolist()))}
self.spatial = pd.read_csv(self.data_dir + "/tissue_positions_list.txt", sep=",", header=None, na_filter=False,
index_col=0)
self.data.obs["x1"] = self.spatial[1]
self.data.obs["x2"] = self.spatial[2]
self.data.obs["x3"] = self.spatial[3]
self.data.obs["x4"] = self.spatial[4]
self.data.obs["x5"] = self.spatial[5]
self.data.obs["x"] = self.data.obs["x2"]
self.data.obs["y"] = self.data.obs["x3"]
self.data.obs["x_pixel"] = self.data.obs["x4"]
self.data.obs["y_pixel"] = self.data.obs["x5"]

self.data = self.data[self.data.obs["x1"] == 1]
self.data.var_names = [i.upper() for i in list(self.data.var_names)]
self.data.var["genename"] = self.data.var.index.astype("str")
self.data.obs["label"] = list(map(lambda x: classes[x], label["ground_truth"].tolist()))
self.data.obs["ground_truth"] = label["ground_truth"].tolist()
return self

def load_data(self):
adata = self.data
spatial = adata.obs[["x", "y"]]
spatial_pixel = adata.obs[["x_pixel", "y_pixel"]]
image = self.img
label = adata.obs[["label"]]
return image, adata, spatial, spatial_pixel, label
def _load_raw_data(self):
image_path = osp.join(self.data_dir, f"{self.data_id}_full_image.tif")
data_path = osp.join(self.data_dir, f"{self.data_id}_raw_feature_bc_matrix.h5")
spatial_path = osp.join(self.data_dir, "tissue_positions_list.txt")
meta_path = osp.join(self.data_dir, "cluster_labels.csv")

logger.info(f"Loading image data from {image_path}")
img = cv2.imread(image_path)

logger.info(f"Loading expression data from {data_path}")
adata = sc.read_10x_h5(data_path)

logger.info(f"Loading spatial info from {spatial_path}")
spatial = pd.read_csv(spatial_path, header=None, index_col=0).loc[adata.obs_names]

logger.info(f"Loading label info from {meta_path}")
meta_df = pd.read_csv(meta_path)

# Restrict to captured spots
indicator = spatial[1].values == 1
adata = adata[indicator]
spatial = spatial.iloc[indicator]

# Prepare spatial info tables
xy = spatial[[2, 3]].rename(columns={2: "x", 3: "y"})
xy_pixel = spatial[[4, 5]].rename(columns={4: "x_pixel", 5: "y_pixel"})

# Prepare meta data and create a column with indexed label info
label_classes = {j: i for i, j in enumerate(meta_df["ground_truth"].unique())}
meta_df["label"] = list(map(label_classes.get, meta_df["ground_truth"]))

return img, adata, xy, xy_pixel, meta_df

def _raw_to_dance(self, raw_data):
img, adata, xy, xy_pixel, meta_df = raw_data
adata.var_names_make_unique()

adata.obs = meta_df.set_index(adata.obs_names)
adata.obsm["spatial"] = xy.set_index(adata.obs_names)
adata.obsm["spatial_pixel"] = xy_pixel.set_index(adata.obs_names)
adata.uns["image"] = img

data = Data(adata, train_size="all")
return data


class CellTypeDeconvoDatasetLite:
Expand Down
22 changes: 22 additions & 0 deletions dance/modules/spatial/spatial_domain/louvain.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,11 @@

import networkx as nx
import numpy as np
import scanpy as sc

from dance.transforms import AnnDataTransform, CellPCA, Compose, FilterGenesMatch, SetConfig
from dance.transforms.graph import NeighborGraph
from dance.typing import LogLevel

PASS_MAX = -1
MIN = 0.0000001
Expand Down Expand Up @@ -333,6 +338,23 @@ class Louvain:
def __init__(self, resolution: float = 1):
self.resolution = resolution

@staticmethod
def preprocessing_pipeline(dim: int = 50, n_neighbors: int = 17, log_level: LogLevel = "INFO"):
return Compose(
FilterGenesMatch(prefixes=["ERCC", "MT-"]),
AnnDataTransform(sc.pp.normalize_total, target_sum=1e4),
AnnDataTransform(sc.pp.log1p),
CellPCA(n_components=dim),
NeighborGraph(n_neighbors=n_neighbors),
SetConfig({
"feature_channel": "NeighborGraph",
"feature_channel_type": "obsp",
"label_channel": "label",
"label_channel_type": "obs"
}),
log_level=log_level,
)

def fit(self, adj, partition=None, weight="weight", randomize=None, random_state=None):
"""Fit function for model training.

Expand Down
24 changes: 22 additions & 2 deletions dance/modules/spatial/spatial_domain/spagcn.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@
from torch.nn.parameter import Parameter

from dance import utils
from dance.transforms import AnnDataTransform, CellPCA, Compose, FilterGenesMatch, SetConfig
from dance.transforms.graph import SpaGCNGraph, SpaGCNGraph2D
from dance.typing import LogLevel


def refine(sample_id, pred, dis, shape="hexagon"):
Expand Down Expand Up @@ -447,9 +450,27 @@ class SpaGCN:

def __init__(self, l=None):
super().__init__()
self.l = l or None
self.l = l
self.res = None

@staticmethod
def preprocessing_pipeline(alpha: float = 1, beta: int = 49, dim: int = 50, log_level: LogLevel = "INFO"):
return Compose(
FilterGenesMatch(prefixes=["ERCC", "MT-"]),
AnnDataTransform(sc.pp.normalize_total, target_sum=1e4),
AnnDataTransform(sc.pp.log1p),
SpaGCNGraph(alpha=alpha, beta=beta),
SpaGCNGraph2D(),
CellPCA(n_components=dim),
SetConfig({
"feature_channel": ["CellPCA", "SpaGCNGraph", "SpaGCNGraph2D"],
"feature_channel_type": ["obsm", "obsp", "obsp"],
"label_channel": "label",
"label_channel_type": "obs"
}),
log_level=log_level,
)

def search_l(self, p, adj, start=0.01, end=1000, tol=0.01, max_run=100):
"""Search best l.

Expand Down Expand Up @@ -625,5 +646,4 @@ def score(self, y_true):
"""
from sklearn.metrics.cluster import adjusted_rand_score
score = adjusted_rand_score(y_true, self.y_pred)
print("ARI {}".format(adjusted_rand_score(y_true, self.y_pred)))
return score
21 changes: 21 additions & 0 deletions dance/modules/spatial/spatial_domain/stagate.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@
from torch_sparse import SparseTensor, set_diag
from tqdm import tqdm

from dance.transforms import AnnDataTransform, Compose, SetConfig
from dance.transforms.graph import StagateGraph
from dance.typing import LogLevel


def transfer_pytorch_data(adata, adj):
edgeList = adj
Expand Down Expand Up @@ -175,6 +179,23 @@ def __init__(self, hidden_dims):
self.conv3 = GATConv(out_dim, num_hidden, heads=1, concat=False, dropout=0, add_self_loops=False, bias=False)
self.conv4 = GATConv(num_hidden, in_dim, heads=1, concat=False, dropout=0, add_self_loops=False, bias=False)

@staticmethod
def preprocessing_pipeline(hvg_flavor: str = "seurat_v3", n_top_hvgs: int = 3000, model_name: str = "radius",
radius: float = 150, n_neighbors: int = 5, log_level: LogLevel = "INFO"):
return Compose(
AnnDataTransform(sc.pp.highly_variable_genes, flavor=hvg_flavor, n_top_genes=n_top_hvgs, subset=True),
AnnDataTransform(sc.pp.normalize_total, target_sum=1e4),
AnnDataTransform(sc.pp.log1p),
StagateGraph(model_name, radius=radius, n_neighbors=n_neighbors),
SetConfig({
"feature_channel": "StagateGraph",
"feature_channel_type": "obsp",
"label_channel": "label",
"label_channel_type": "obs"
}),
log_level=log_level,
)

def forward(self, features, edge_index):
"""Forward function for training.

Expand Down
Loading