Skip to content

Commit

Permalink
Fix get_df (#1231)
Browse files Browse the repository at this point in the history
* Fix `get_df` for default parameters (#649)

* If  is a string, put it into a list

* Add deprecation warning to `get_df`

---------

Co-authored-by: Philipp Weiler <weiler.philipp@gmail.com>
  • Loading branch information
Oisin-M and WeilerP authored Apr 12, 2024
1 parent 31226cd commit f21651c
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 56 deletions.
124 changes: 68 additions & 56 deletions scvelo/core/_anndata.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import re
import warnings
from typing import List, Literal, Optional, Union

import numpy as np
Expand Down Expand Up @@ -178,6 +179,13 @@ def get_df(
:class:`pd.DataFrame`
A dataframe.
"""
warnings.warn(
"`get_df` is deprecated since scvelo==0.4.0 and will be removed in a future version "
"of scVelo. Please `AnnData::get_df` or Scanpy's `scanpy.get.obs_df` or `scanpy.get.var_df`.",
DeprecationWarning,
stacklevel=2,
)

if precision is not None:
pd.set_option("display.precision", precision)

Expand All @@ -188,8 +196,6 @@ def get_df(
keys, key_add = (
keys.split("/") if isinstance(keys, str) and "/" in keys else (keys, None)
)
keys = [keys] if isinstance(keys, str) else keys
key = keys[0]

s_keys = ["obs", "var", "obsm", "varm", "uns", "layers"]
d_keys = [
Expand All @@ -207,62 +213,68 @@ def get_df(

if keys is None:
df = data.to_df()
elif key in data.var_names:
df = obs_df(data, keys, layer=layer)
elif key in data.obs_names:
df = var_df(data, keys, layer=layer)
else:
if keys_split is not None:
keys = [
k
for k in list(data.obs.keys()) + list(data.var.keys())
if key in k and keys_split in k
]
key = keys[0]
s_key = [s for (s, d_key) in zip(s_keys, d_keys) if key in d_key]
if len(s_key) == 0:
raise ValueError(f"'{key}' not found in any of {', '.join(s_keys)}.")
if len(s_key) > 1:
logg.warn(f"'{key}' found multiple times in {', '.join(s_key)}.")

s_key = s_key[-1]
df = getattr(data, s_key)[keys if len(keys) > 1 else key]
if key_add is not None:
df = df[key_add]
if index is None:
index = (
data.var_names
if s_key == "varm"
else data.obs_names
if s_key in {"obsm", "layers"}
else None
)
if index is None and s_key == "uns" and hasattr(df, "shape"):
key_cats = np.array(
[
key
for key in data.obs.keys()
if is_categorical_dtype(data.obs[key])
]
)
num_cats = [
len(data.obs[key].cat.categories) == df.shape[0]
for key in key_cats
keys = [keys] if isinstance(keys, str) else keys
key = keys[0]

if key in data.var_names:
df = obs_df(data, keys, layer=layer)
elif key in data.obs_names:
df = var_df(data, keys, layer=layer)
else:
if keys_split is not None:
keys = [
k
for k in list(data.obs.keys()) + list(data.var.keys())
if key in k and keys_split in k
]
if np.sum(num_cats) == 1:
index = data.obs[key_cats[num_cats][0]].cat.categories
if (
columns is None
and len(df.shape) > 1
and df.shape[0] == df.shape[1]
):
columns = index
elif isinstance(index, str) and index in data.obs.keys():
index = pd.Categorical(data.obs[index]).categories
if columns is None and s_key == "layers":
columns = data.var_names
elif isinstance(columns, str) and columns in data.obs.keys():
columns = pd.Categorical(data.obs[columns]).categories
key = keys[0]
s_key = [s for (s, d_key) in zip(s_keys, d_keys) if key in d_key]
if len(s_key) == 0:
raise ValueError(
f"'{key}' not found in any of {', '.join(s_keys)}."
)
if len(s_key) > 1:
logg.warn(f"'{key}' found multiple times in {', '.join(s_key)}.")

s_key = s_key[-1]
df = getattr(data, s_key)[keys if len(keys) > 1 else key]
if key_add is not None:
df = df[key_add]
if index is None:
index = (
data.var_names
if s_key == "varm"
else data.obs_names
if s_key in {"obsm", "layers"}
else None
)
if index is None and s_key == "uns" and hasattr(df, "shape"):
key_cats = np.array(
[
key
for key in data.obs.keys()
if is_categorical_dtype(data.obs[key])
]
)
num_cats = [
len(data.obs[key].cat.categories) == df.shape[0]
for key in key_cats
]
if np.sum(num_cats) == 1:
index = data.obs[key_cats[num_cats][0]].cat.categories
if (
columns is None
and len(df.shape) > 1
and df.shape[0] == df.shape[1]
):
columns = index
elif isinstance(index, str) and index in data.obs.keys():
index = pd.Categorical(data.obs[index]).categories
if columns is None and s_key == "layers":
columns = data.var_names
elif isinstance(columns, str) and columns in data.obs.keys():
columns = pd.Categorical(data.obs[columns]).categories
elif isinstance(data, pd.DataFrame):
if isinstance(keys, str) and "*" in keys:
keys, keys_split = keys.split("*")
Expand Down
14 changes: 14 additions & 0 deletions tests/core/test_anndata.py
Original file line number Diff line number Diff line change
Expand Up @@ -573,6 +573,20 @@ def test_data_as_array(
else:
assert (df.columns == ["col_1", "col_2"]).all()

@given(
adata=get_adata(
max_obs=5,
max_vars=5,
layer_keys=["layer_1", "layer_2"],
),
modality=st.sampled_from([None, "X", "layer_1", "layer_2"]),
)
def test_default(self, adata: AnnData, modality: Optional[None]):
df = get_df(adata, layer=modality)

assert isinstance(df, pd.DataFrame)
np.testing.assert_equal(adata.to_df().values, df.values)


class TestGetInitialSize(TestBase):
@given(
Expand Down

0 comments on commit f21651c

Please sign in to comment.