Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(model): refactor out common minified mode methods #2883

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
1 change: 1 addition & 0 deletions src/scvi/data/_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@

class _ADATA_MINIFY_TYPE_NT(NamedTuple):
LATENT_POSTERIOR: str = "latent_posterior_parameters"
LATENT_POSTERIOR_WITH_COUNTS: str = "latent_posterior_parameters_with_counts"


ADATA_MINIFY_TYPE = _ADATA_MINIFY_TYPE_NT()
Expand Down
24 changes: 11 additions & 13 deletions src/scvi/model/base/_base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -903,7 +903,6 @@ def minify_adata(
minified_data_type: MinifiedDataType = ADATA_MINIFY_TYPE.LATENT_POSTERIOR,
use_latent_qzm_key: str = "X_latent_qzm",
use_latent_qzv_key: str = "X_latent_qzv",
keep_count_data: bool = False,
) -> None:
"""Minify the model's :attr:`~scvi.model.base.BaseModelClass.adata`.

Expand All @@ -917,25 +916,25 @@ def minify_adata(
minified_data_type
Method for minifying the data. One of the following:

- ``"latent_posterior"``: Store the latent posterior mean and variance in
- ``"latent_posterior_parameters"``: Store the latent posterior mean and variance in
:attr:`~anndata.AnnData.obsm` using the keys ``use_latent_qzm_key`` and
``use_latent_qzv_key``.
- ``"latent_posterior_parameters_with_counts"``: Store the latent posterior mean and
variance in :attr:`~anndata.AnnData.obsm` using the keys ``use_latent_qzm_key`` and
``use_latent_qzv_key``, and the raw count data in :attr:`~anndata.AnnData.X`.
use_latent_qzm_key
Key to use for storing the latent posterior mean in :attr:`~anndata.AnnData.obsm` when
``minified_data_type`` is ``"latent_posterior"``.
use_latent_qzv_key
Key to use for storing the latent posterior variance in :attr:`~anndata.AnnData.obsm`
when ``minified_data_type`` is ``"latent_posterior"``.
keep_count_data
If ``True``, the full count matrix is kept in the minified
:attr:`~scvi.model.base.BaseModelClass.adata`.

Notes
-----
The modification is not done inplace -- instead the model is assigned a new (minified)
version of the :class:`~anndata.AnnData`.
"""
if minified_data_type != ADATA_MINIFY_TYPE.LATENT_POSTERIOR:
if minified_data_type not in ADATA_MINIFY_TYPE:
raise NotImplementedError(
f"Minification method {minified_data_type} is not supported."
)
Expand All @@ -944,11 +943,13 @@ def minify_adata(
"Minification is not supported for models that do not use observed library size."
)

keep_count_data = minified_data_type == ADATA_MINIFY_TYPE.LATENT_POSTERIOR_WITH_COUNTS
mini_adata = get_minified_adata_scrna(
adata_manager=self.adata_manager,
minified_data_type=minified_data_type,
keep_count_data=keep_count_data,
)
del mini_adata.uns[_SCVI_UUID_KEY]
mini_adata.uns[_ADATA_MINIFY_TYPE_UNS_KEY] = minified_data_type
canergen marked this conversation as resolved.
Show resolved Hide resolved
mini_adata.obsm[self._LATENT_QZM_KEY] = self.adata.obsm[use_latent_qzm_key]
mini_adata.obsm[self._LATENT_QZV_KEY] = self.adata.obsm[use_latent_qzv_key]
mini_adata.obs[self._OBSERVED_LIB_SIZE_KEY] = np.squeeze(
Expand All @@ -957,18 +958,16 @@ def minify_adata(
self._update_adata_and_manager_post_minification(
mini_adata,
minified_data_type,
keep_count_data=keep_count_data,
)
self.module.minified_data_type = minified_data_type

@classmethod
def _get_fields_for_adata_minification(
cls,
minified_data_type: MinifiedDataType,
keep_count_data: bool,
):
"""Return the fields required for minification of the given type."""
if minified_data_type != ADATA_MINIFY_TYPE.LATENT_POSTERIOR:
if minified_data_type not in ADATA_MINIFY_TYPE:
raise NotImplementedError(
f"Minification method {minified_data_type} is not supported."
)
Expand All @@ -979,7 +978,7 @@ def _get_fields_for_adata_minification(
fields.NumericalObsField(REGISTRY_KEYS.OBSERVED_LIB_SIZE, cls._OBSERVED_LIB_SIZE_KEY),
fields.StringUnsField(REGISTRY_KEYS.MINIFY_TYPE_KEY, _ADATA_MINIFY_TYPE_UNS_KEY),
]
if keep_count_data:
if minified_data_type == ADATA_MINIFY_TYPE.LATENT_POSTERIOR_WITH_COUNTS:
mini_fields.append(fields.LayerField(REGISTRY_KEYS.X_KEY, None, is_count_data=True))

return mini_fields
Expand All @@ -988,7 +987,6 @@ def _update_adata_and_manager_post_minification(
self,
minified_adata: AnnOrMuData,
minified_data_type: MinifiedDataType,
keep_count_data: bool,
):
"""Update the :class:`~anndata.AnnData` and :class:`~scvi.data.AnnDataManager` in-place.

Expand All @@ -1005,7 +1003,7 @@ def _update_adata_and_manager_post_minification(
self._validate_anndata(minified_adata)
new_adata_manager = self.get_anndata_manager(minified_adata, required=True)
new_adata_manager.register_new_fields(
self._get_fields_for_adata_minification(minified_data_type, keep_count_data)
self._get_fields_for_adata_minification(minified_data_type)
)
self.adata = minified_adata

Expand Down
16 changes: 1 addition & 15 deletions src/scvi/model/utils/_minification.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,26 +4,16 @@
from scipy.sparse import csr_matrix

from scvi import REGISTRY_KEYS
from scvi._types import MinifiedDataType
from scvi.data import AnnDataManager
from scvi.data._constants import (
_ADATA_MINIFY_TYPE_UNS_KEY,
_SCVI_UUID_KEY,
ADATA_MINIFY_TYPE,
)


def get_minified_adata_scrna(
adata_manager: AnnDataManager,
minified_data_type: MinifiedDataType = ADATA_MINIFY_TYPE.LATENT_POSTERIOR,
keep_count_data: bool = False,
) -> AnnData:
"""Get a minified version of an :class:`~anndata.AnnData` or :class:`~mudata.MuData` object."""
if minified_data_type != ADATA_MINIFY_TYPE.LATENT_POSTERIOR:
raise NotImplementedError(f"Minification method {minified_data_type} is not supported.")

counts = adata_manager.get_from_registry(REGISTRY_KEYS.X_KEY)
mini_adata = AnnData(
return AnnData(
X=counts if keep_count_data else csr_matrix(counts.shape),
obs=adata_manager.adata.obs.copy(),
var=adata_manager.adata.var.copy(),
Expand All @@ -33,7 +23,3 @@ def get_minified_adata_scrna(
obsp=adata_manager.adata.obsp.copy(),
varp=adata_manager.adata.varp.copy(),
)
del mini_adata.uns[_SCVI_UUID_KEY]
mini_adata.uns[_ADATA_MINIFY_TYPE_UNS_KEY] = minified_data_type

return mini_adata
Loading