Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add minification function for scPoli #230

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
204 changes: 204 additions & 0 deletions minification.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,204 @@
from anndata import AnnData
from scipy.sparse import csr_matrix
import torch
from scipy import sparse
import numpy as np
import scanpy as sc
import scarches as sca



#should be a method of scPoli
def get_latent(module, x, c=None, mean=False, mean_var=False):
"""Map `x` in to the latent space. This function will feed data in encoder and return z for each sample in
data.
Parameters
----------
x: torch.Tensor
Torch Tensor to be mapped to latent space. `x` has to be in shape [n_obs, input_dim].
c: torch.Tensor
Torch Tensor of condition labels for each sample.
mean: boolean
Returns
-------
Returns Torch Tensor containing latent space encoding of 'x'.
"""
#compute latent representation
x_ = torch.log(1 + x)
if module.recon_loss == "mse":
x_ = x
if "encoder" in module.inject_condition:
# c = c.type(torch.cuda.LongTensor)
c = c.long()
embed_c = torch.hstack([module.embeddings[i](c[:, i]) for i in range(c.shape[1])])
z_mean, z_log_var = module.encoder(x_, embed_c)
else:
z_mean, z_log_var = module.encoder(x_)
latent = module.sampling(z_mean, z_log_var)
if mean:
return z_mean
elif mean_var:
return (z_mean, z_log_var)
return latent

#should be a method of scPoli
def get_latent_representation(
model,
adata,
mean: bool = False,
mean_var: bool = False
):
"""Map `x` in to the latent space. This function will feed data in encoder and return z for each sample in
data.

Parameters
----------
x
Numpy nd-array to be mapped to latent space. `x` has to be in shape [n_obs, input_dim].
c
`numpy nd-array` of original (unencoded) desired labels for each sample.
mean
return mean instead of random sample from the latent space

Returns
-------
Returns array containing latent space encoding of 'x'.
"""
device = next(model.model.parameters()).device
x = adata.X
c = {k: adata.obs[k].values for k in model.condition_keys_}

if isinstance(c, dict):
label_tensor = []
for cond in c.keys():
query_conditions = c[cond]
if not set(query_conditions).issubset(model.conditions_[cond]):
raise ValueError("Incorrect conditions")
labels = np.zeros(query_conditions.shape[0])
for condition, label in model.model.condition_encoders[cond].items():
labels[query_conditions == condition] = label
label_tensor.append(labels)
c = torch.tensor(label_tensor, device=device).T
if sparse.issparse(x):
x = x.A
x = torch.tensor(x, dtype=torch.float32)

latents = []
# batch the latent transformation process
indices = torch.arange(x.size(0))
subsampled_indices = indices.split(512)
for batch in subsampled_indices:
latent = get_latent(model.model,
x[batch, :].to(device), c[batch, :].to(device), mean, mean_var
)
latent = (latent,) if not isinstance(latent, tuple) else latent
latents += [tuple(l.cpu().detach() for l in latent)]

result = tuple(torch.cat(l) for l in zip(*latents))
result = result[0] if len(result) == 1 else result

return result


def get_minified_adata_scrna(
adata: AnnData,
) -> AnnData:
"""Returns a minified adata that works for most scrna models (such as SCVI, SCANVI).

Parameters
----------
adata
Original adata, of which we to create a minified version.

"""

all_zeros = csr_matrix(adata.X.shape)
layers = {layer: all_zeros.copy() for layer in adata.layers}
bdata = AnnData(
X=all_zeros,
layers=layers,
uns=adata.uns.copy(),
obs=adata.obs,
var=adata.var,
varm=adata.varm,
obsm=adata.obsm,
obsp=adata.obsp,
)

return bdata



def minify_adata(model, adata):

"""
This function is adapted from scvi-tools
https://docs.scvi-tools.org/en/stable/api/reference/scvi.model.SCVI.html#scvi.model.SCVI.minify_adata

minify adata using latent posterior parameters:

* the original count data is removed (`adata.X`, adata.raw, and any layers)
* the parameters of the latent representation of the original data is stored
* everything else is left untouched
"""

#get the latent representation and store it in the adata
qzm, qzv = model.get_latent_representation(adata, mean_var=True)

adata.obsm["X_latent_qzm"] = qzm
adata.obsm["X_latent_qzv"] = qzv

#we cannot minify data where we do not use observed library size for gene count generation.
#In SCVI model, the library size can be modelled as a latent variable. However in scPoli it is set
#to be observed (equal to the total UMI RNA count of a cell).


minified_adata = get_minified_adata_scrna(adata)
minified_adata.obsm["X_latent_qzm"] = adata.obsm["X_latent_qzm"]
minified_adata.obsm["X_latent_qzv"] = adata.obsm["X_latent_qzv"]
counts = adata.X
minified_adata.obs["observed_lib_size"] = np.squeeze(
np.asarray(counts.sum(axis=1))
)

#TODO: set is_minified attribute to True


minified_adata.write("adata.h5ad")

def main():

adata = sc.read("atlas_646ddf52fd46b85aafce28c2_data_not_minifiied.h5ad")

model =sca.models.scPoli.load("model", adata)

minify_adata(adata, model)

if __name__ == "__main__":
main()


# import scarches as sca
# from scanpy.datasets import pbmc3k_processed, pbmc3k #replace with stored atlas trained on scPoli

# @pytest.mark.parametrize('get_adata', ["path/to/atlas1", "path/to/atlas2"])
# def test_minification(get_adata):
# adata = get_adata()
# model = scarches.models.scPoli.load(path = "path/to/model", adata=adata)

# minify_adata(model, adata)















Empty file added minify.py
Empty file.
53 changes: 53 additions & 0 deletions run.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import numpy as np
import scanpy as sc
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import classification_report
from scarches.models.scpoli import scPoli


adata = sc.read('tmp/pancreas.h5ad')

early_stopping_kwargs = {
"early_stopping_metric": "val_prototype_loss",
"mode": "min",
"threshold": 0,
"patience": 20,
"reduce_lr": True,
"lr_patience": 13,
"lr_factor": 0.1,
}

condition_key = 'study'
cell_type_key = 'cell_type'
reference = [
'inDrop1',
'inDrop2',
'inDrop3',
'inDrop4',
'fluidigmc1',
'smartseq2',
'smarter'
]
query = ['celseq', 'celseq2']

adata.obs['query'] = adata.obs[condition_key].isin(query)
adata.obs['query'] = adata.obs['query'].astype('category')
source_adata = adata[adata.obs.study.isin(reference)].copy()
source_adata = source_adata[~source_adata.obs.cell_type.str.contains('alpha')].copy()
target_adata = adata[adata.obs.study.isin(query)].copy()

scpoli_model = scPoli(
adata=source_adata,
condition_keys=condition_key,
cell_type_keys=cell_type_key,
embedding_dims=5,
recon_loss='nb',
)
scpoli_model.train(
n_epochs=50,
pretraining_epochs=40,
early_stopping_kwargs=early_stopping_kwargs,
eta=5,
)
32 changes: 31 additions & 1 deletion scarches/models/base/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from anndata import AnnData, read
from scipy.sparse import issparse

from ._utils import UnpicklerCpu, _validate_var_names
from ._utils import UnpicklerCpu, _validate_var_names, get_minified_adata_scrna


class BaseMixin:
Expand Down Expand Up @@ -193,6 +193,36 @@ def load(
model.is_trained_ = attr_dict['is_trained_']

return model

def minify_adata(self, adata=None, model_name=None):
"""
This function is adapted from scvi-tools
https://docs.scvi-tools.org/en/stable/api/reference/scvi.model.SCVI.html#scvi.model.SCVI.minify_adata
minify adata using latent posterior parameters:
* the original count data is removed (`adata.X`, adata.raw, and any layers)
* the parameters of the latent representation of the original data is stored
* everything else is left untouched
"""

if adata is None:
adata = self.adata

#get the latent representation and store it in the adata
qzm, qzv = self.get_latent(adata, mean_var=True)
adata.obsm[f"X_latent_qzm_{model_name}"] = qzm
adata.obsm[f"X_latent_qzv_{model_name}"] = qzv

minified_adata = get_minified_adata_scrna(adata)
minified_adata.obsm[f"X_latent_qzm_{model_name}"] = adata.obsm[f"X_latent_qzm_{model_name}"]
minified_adata.obsm[f"X_latent_qzv_{model_name}"] = adata.obsm[f"X_latent_qzv_{model_name}"]
counts = adata.X
minified_adata.obs["observed_lib_size"] = np.squeeze(
np.asarray(counts.sum(axis=1))
)
self.adata = minified_adata

print(self.adata)



class SurgeryMixin:
Expand Down
29 changes: 28 additions & 1 deletion scarches/models/base/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,33 @@ def _validate_var_names(adata, source_var_names):

return new_adata

def get_minified_adata_scrna(
adata: AnnData,
) -> AnnData:


"""This function is adapted from scvi-tools
https://docs.scvi-tools.org/en/stable/api/reference/scvi.model.utils.get_minified_adata_scrna.html

Returns a minified adata.
Parameters
----------
adata
Original adata, of which we to create a minified version.
"""
all_zeros = csr_matrix(adata.X.shape)
layers = {layer: all_zeros.copy() for layer in adata.layers}
bdata = AnnData(
X=all_zeros,
layers=layers,
uns=adata.uns.copy(),
obs=adata.obs,
var=adata.var,
varm=adata.varm,
obsm=adata.obsm,
obsp=adata.obsp,
)
return bdata

class UnpicklerCpu(pickle.Unpickler):
"""Helps to pickle.load a model trained on GPU to CPU.
Expand All @@ -72,4 +99,4 @@ def find_class(self, module, name):
if module == 'torch.storage' and name == '_load_from_bytes':
return lambda b: torch.load(io.BytesIO(b), map_location='cpu')
else:
return super().find_class(module, name)
return super().find_class(module, name)
4 changes: 3 additions & 1 deletion scarches/models/scpoli/scpoli.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,7 @@ def sampling(self, mu, log_var):
var = torch.exp(log_var) + 1e-4
return Normal(mu, var.sqrt()).rsample()

def get_latent(self, x, c=None, mean=False):
def get_latent(self, x, c=None, mean=False, mean_var=False):
"""Map `x` in to the latent space. This function will feed data in encoder and return z for each sample in
data.
Parameters
Expand All @@ -357,6 +357,8 @@ def get_latent(self, x, c=None, mean=False):
latent = self.sampling(z_mean, z_log_var)
if mean:
return z_mean
elif mean_var:
return (z_mean, z_log_var)
return latent


Expand Down
16 changes: 12 additions & 4 deletions scarches/models/scpoli/scpoli_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,7 @@ def get_latent(
self,
adata,
mean: bool = False,
mean_var: bool = False
):
"""Map `x` in to the latent space. This function will feed data in encoder and return z for each sample in
data.
Expand Down Expand Up @@ -356,11 +357,13 @@ def get_latent(
x_batch = x_batch.toarray()
x_batch = torch.tensor(x_batch, device=device).float()
latent = self.model.get_latent(
x_batch, c[batch, :], mean
x_batch, c[batch, :], mean, mean_var
)
latents += [latent.cpu().detach()]
latents = torch.cat(latents)
return np.array(latents)
latent = (latent,) if not isinstance(latent, tuple) else latent
latents += [tuple(l.cpu().detach() for l in latent)]
result = tuple(np.array(torch.cat(l)) for l in zip(*latents))
result = result[0] if len(result) == 1 else result
return result

def get_conditional_embeddings(self):
"""
Expand Down Expand Up @@ -969,3 +972,8 @@ def _load_expand_params_from_dict(self, state_dict):
load_state_dict[key] = fixed_ten

self.model.load_state_dict(load_state_dict)


def minify_adata(self, adata=None):
super().minify_adata(adata, model_name="scpoli")