Skip to content

Commit

Permalink
remove uns_key from set_plotting_vars
Browse files Browse the repository at this point in the history
  • Loading branch information
MUCDK committed Feb 3, 2023
1 parent 1f36cd1 commit af0d9ab
Show file tree
Hide file tree
Showing 5 changed files with 17 additions and 20 deletions.
11 changes: 5 additions & 6 deletions src/moscot/_constants/_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,19 +85,18 @@ class uns:
def set_plotting_vars(
cls,
adata: AnnData,
uns_key: str,
pl_func_key: Optional[str] = None,
key: Optional[str] = None,
value: Optional[Any] = None,
override: bool = True,
) -> None:
adata.uns.setdefault(uns_key, {})
adata.uns.setdefault(AdataKeys.UNS, {})
if pl_func_key is not None:
adata.uns[uns_key].setdefault(pl_func_key, {})
adata.uns[AdataKeys.UNS].setdefault(pl_func_key, {})
if key is not None:
if not override and key in adata.uns[uns_key][pl_func_key]:
if not override and key in adata.uns[AdataKeys.UNS][pl_func_key]:
raise KeyError(
f"Data in `adata.uns[{uns_key!r}][{pl_func_key!r}][{key!r}]` "
f"Data in `adata.uns[{AdataKeys.UNS!r}][{pl_func_key!r}][{key!r}]` "
f"already exists, use `override=True`."
)
adata.uns[uns_key][pl_func_key][key] = value
adata.uns[AdataKeys.UNS][pl_func_key][key] = value
2 changes: 0 additions & 2 deletions src/moscot/problems/base/_mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@
)
from moscot._constants._constants import (
Key,
AdataKeys,
CorrMethod,
PlottingKeys,
CorrTestMethod,
Expand Down Expand Up @@ -164,7 +163,6 @@ def _cell_transition(
}
Key.uns.set_plotting_vars(
adata=self.adata,
uns_key=AdataKeys.UNS,
pl_func_key=PlottingKeys.CELL_TRANSITION,
key=key_added,
value=plot_vars,
Expand Down
6 changes: 3 additions & 3 deletions src/moscot/problems/generic/_mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from moscot._types import ArrayLike, Str_Dict_t
from moscot._docs._docs_mixins import d_mixins
from moscot._constants._constants import Key, AdataKeys, PlottingKeys, PlottingDefaults
from moscot._constants._constants import Key, PlottingKeys, PlottingDefaults
from moscot.problems.base._mixins import AnalysisMixin, AnalysisMixinProtocol
from moscot.problems.base._compound_problem import B, K, ApplyOutput_t

Expand Down Expand Up @@ -142,7 +142,7 @@ def push(
"distribution_key": self.batch_key,
}
self.adata.obs[key_added] = self._flatten(result, key=self.batch_key)
Key.uns.set_plotting_vars(self.adata, AdataKeys.UNS, PlottingKeys.PUSH, key_added, plot_vars)
Key.uns.set_plotting_vars(self.adata, PlottingKeys.PUSH, key_added, plot_vars)
if return_data:
return result

Expand Down Expand Up @@ -196,7 +196,7 @@ def pull(
"key": self.batch_key,
}
self.adata.obs[key_added] = self._flatten(result, key=self.batch_key)
Key.uns.set_plotting_vars(self.adata, AdataKeys.UNS, PlottingKeys.PULL, key_added, plot_vars)
Key.uns.set_plotting_vars(self.adata, PlottingKeys.PULL, key_added, plot_vars)
if return_data:
return result

Expand Down
8 changes: 4 additions & 4 deletions src/moscot/problems/time/_mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from moscot._types import ArrayLike, Numeric_t, Str_Dict_t
from moscot._docs._docs_mixins import d_mixins
from moscot._constants._constants import Key, AdataKeys, PlottingKeys, PlottingDefaults
from moscot._constants._constants import Key, PlottingKeys, PlottingDefaults
from moscot.problems.base._mixins import AnalysisMixin, AnalysisMixinProtocol
from moscot.solvers._tagged_array import Tag
from moscot.problems.base._compound_problem import B, K, ApplyOutput_t
Expand Down Expand Up @@ -287,7 +287,7 @@ def sankey(
"target_groups": target_groups,
"captions": [str(t) for t in tuples],
}
Key.uns.set_plotting_vars(self.adata, AdataKeys.UNS, PlottingKeys.SANKEY, key_added, plot_vars)
Key.uns.set_plotting_vars(self.adata, PlottingKeys.SANKEY, key_added, plot_vars)
if return_data:
return cell_transitions_updated

Expand Down Expand Up @@ -346,7 +346,7 @@ def push(
"subset": subset,
}
self.adata.obs[key_added] = self._flatten(result, key=self.temporal_key)
Key.uns.set_plotting_vars(self.adata, AdataKeys.UNS, PlottingKeys.PUSH, key_added, plot_vars)
Key.uns.set_plotting_vars(self.adata, PlottingKeys.PUSH, key_added, plot_vars)
if return_data:
return result

Expand Down Expand Up @@ -404,7 +404,7 @@ def pull(
"target": target,
}
self.adata.obs[key_added] = self._flatten(result, key=self.temporal_key)
Key.uns.set_plotting_vars(self.adata, AdataKeys.UNS, PlottingKeys.PULL, key_added, plot_vars)
Key.uns.set_plotting_vars(self.adata, PlottingKeys.PULL, key_added, plot_vars)
if return_data:
return result

Expand Down
10 changes: 5 additions & 5 deletions tests/plotting/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from anndata import AnnData

from moscot._constants._constants import Key, AdataKeys, PlottingKeys, PlottingDefaults
from moscot._constants._constants import Key, PlottingKeys, PlottingDefaults

HERE: Path = Path(__file__).parent

Expand All @@ -32,7 +32,7 @@ def adata_pl_cell_transition(gt_temporal_adata: AnnData) -> AnnData:
"target": 1,
}
Key.uns.set_plotting_vars(
gt_temporal_adata, AdataKeys.UNS, PlottingKeys.CELL_TRANSITION, PlottingDefaults.CELL_TRANSITION, plot_vars
gt_temporal_adata, PlottingKeys.CELL_TRANSITION, PlottingDefaults.CELL_TRANSITION, plot_vars
)

return gt_temporal_adata
Expand All @@ -44,7 +44,7 @@ def adata_pl_push(adata_time: AnnData) -> AnnData:
plot_vars = {"temporal_key": "time", "data": "celltype", "subset": "A", "source": 0, "target": 1}
adata_time.uns["celltype_colors"] = ["#cc1b1b", "#2ccc1b", "#cc1bcc"]
adata_time.obs["celltype"] = adata_time.obs["celltype"].astype("category")
Key.uns.set_plotting_vars(adata_time, AdataKeys.UNS, PlottingKeys.PUSH, PlottingDefaults.PUSH, plot_vars)
Key.uns.set_plotting_vars(adata_time, PlottingKeys.PUSH, PlottingDefaults.PUSH, plot_vars)
push_initial_dist = np.zeros(
shape=(len(adata_time[adata_time.obs["time"] == 0]),)
) # we need this for a cat. distr. in plots
Expand All @@ -63,7 +63,7 @@ def adata_pl_pull(adata_time: AnnData) -> AnnData:
plot_vars = {"temporal_key": "time", "data": "celltype", "subset": "A", "source": 0, "target": 1}
adata_time.uns["celltype_colors"] = ["#cc1b1b", "#2ccc1b", "#cc1bcc"]
adata_time.obs["celltype"] = adata_time.obs["celltype"].astype("category")
Key.uns.set_plotting_vars(adata_time, AdataKeys.UNS, PlottingKeys.PULL, PlottingDefaults.PULL, plot_vars)
Key.uns.set_plotting_vars(adata_time, PlottingKeys.PULL, PlottingDefaults.PULL, plot_vars)
pull_initial_dist = np.zeros(
shape=(len(adata_time[adata_time.obs["time"] == 1]),)
) # we need this for a cat. distr. in plots
Expand All @@ -86,7 +86,7 @@ def adata_pl_sankey(adata_time: AnnData) -> AnnData:
tm1 = pd.DataFrame(data=data1, index=celltypes, columns=celltypes)
tm2 = pd.DataFrame(data=data2, index=celltypes, columns=celltypes)
plot_vars = {"transition_matrices": [tm1, tm2], "captions": ["0", "1"], "key": "celltype"}
Key.uns.set_plotting_vars(adata_time, AdataKeys.UNS, PlottingKeys.SANKEY, PlottingDefaults.SANKEY, plot_vars)
Key.uns.set_plotting_vars(adata_time, PlottingKeys.SANKEY, PlottingDefaults.SANKEY, plot_vars)

return adata_time

Expand Down

0 comments on commit af0d9ab

Please sign in to comment.