Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update scDeepSort example script to use dance data object #75

Merged
merged 9 commits into from
Dec 14, 2022
52 changes: 51 additions & 1 deletion dance/data/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import torch

from dance import logger
from dance.typing import Any, Dict, FeatType, List, Optional, Sequence, Tuple
from dance.typing import Any, Dict, FeatType, List, Literal, Optional, Sequence, Tuple


class BaseData(ABC):
Expand Down Expand Up @@ -122,6 +122,9 @@ def val_idx(self) -> Sequence[int]:
def test_idx(self) -> Sequence[int]:
return self.get_split_idx("test", error_on_miss=False)

def shape(self) -> Tuple[int, int]:
return self.data.shape

def copy(self):
return deepcopy(self)

Expand Down Expand Up @@ -156,6 +159,53 @@ def get_split_idx(self, split_name: str, error_on_miss: bool = False):
else:
return None

def get_feature(self, *, return_type: FeatType = "numpy", channel: Optional[str] = None,
channel_type: Literal["obs", "var"] = "obs", layer: Optional[str] = None,
mod: Optional[str] = None): # yapf: disable
# Pick modality
if mod is None:
data = self.data
elif not hasattr(self.data, "mod"):
raise AttributeError("`mod` option is only available when using multimodality data.")
elif mod not in self.mod:
raise KeyError(f"Unknown modality {mod!r}, available options are {sorted(self.mod)}")
else:
data = self.data.mod[mod]

# Pick channels - obsm or varm
if channel_type == "obs":
channels = data.obsm
elif channel_type == "var":
channels = data.varm
else:
raise ValueError(f"Unknown channel type {channel_type!r}")

# Pick specific channl
if (channel is not None) and (layer is not None):
raise ValueError(f"Cannot specify feature layer ({layer!r}) and channel ({channel!r}) simmultaneously.")
elif channel is not None:
feature = channels[channel]
elif layer is not None:
feature = data.layers[layer].X
else:
feature = data.X

if return_type == "default":
return feature

# Transform features to other data types
if hasattr(feature, "toarray"): # convert sparse array to dense numpy array
feature = feature.toarray()
elif hasattr(feature, "to_numpy"): # convert dataframe to numpy array
feature = feature.to_numpy()

if return_type == "torch":
feature = torch.from_numpy(feature)
elif return_type != "numpy":
raise ValueError(f"Unknown return_type {return_type!r}")

return feature

def _get_data(self, name: str, split_name: Optional[str], return_type: FeatType = "numpy", **kwargs):
out = getattr(self, name)

Expand Down
40 changes: 4 additions & 36 deletions dance/datasets/singlemodality.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import argparse
import glob
import os
import os.path as osp
Expand All @@ -14,9 +13,8 @@
from torch.utils.data import Dataset

from dance.data import download_file, download_unzip
from dance.transforms.preprocess import (get_map_dict, load_actinn_data, load_annotation_data_internal,
load_annotation_test_data, load_imputation_data_internal, load_svm_data,
prepare_data_celltypist, splitCommonAnnData)
from dance.transforms.preprocess import (get_map_dict, load_actinn_data, load_annotation_data,
load_imputation_data_internal, load_svm_data, splitCommonAnnData)


@dataclass
Expand Down Expand Up @@ -314,44 +312,14 @@ def is_singlecellnet_complete(self):
def load_data(self):
# Load data from existing h5ad files, or download files and load data.
if self.data_type == "scdeepsort" or self.data_type == "scdeepsort_exp":
if self.is_complete():
pass
else:
if not self.is_complete():
if self.data_type == "scdeepsort":
self.download_all_data()
if self.data_type == "scdeepsort_exp":
self.download_benchmark_data()
assert self.is_complete()

(
self.num_cells,
self.num_genes,
self.num_labels,
self.graph,
self.train_ids,
self.test_ids,
self.labels,
) = load_annotation_data_internal(self.params)

if self.params.score:
(
self.total_cell_test,
self.num_genes_test,
self.num_labels_test,
self.id2label_test,
self.test_dict,
self.test_label_dict,
self.time_used_test,
) = load_annotation_test_data(self.params)
else:
(
self.total_cell_test,
self.num_genes_test,
self.num_labels_test,
self.id2label_test,
self.test_dict,
self.time_used_test,
) = load_annotation_test_data(self.params)
return load_annotation_data(self.params)

if self.data_type == "svm":
if self.is_complete():
Expand Down
Loading