Skip to content

Commit

Permalink
Fix pandas/PETSc issues (#1136)
Browse files Browse the repository at this point in the history
* Fix tests

* Remove `is_categorical_dtype`

* Fix type checking

* Fix converged flag in `PETSc>=3.20`

* Don't pass dtype to `AnnData`

* Start fixing warnings

* Fix `pandas.__getitem__` warnings

* Correct serie's dtype

* Fix map -> apply

* No duplicate names

* Fix `DataFrame.applymap`
  • Loading branch information
michalk8 authored Nov 2, 2023
1 parent f3a6a0e commit 01a79ca
Show file tree
Hide file tree
Showing 27 changed files with 115 additions and 105 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ dependencies = [
"matplotlib>=3.5.0,<3.7.2",
"networkx>=2.2",
"numba>=0.51.0,!=0.57.0",
"numpy>=1.17.0",
"numpy>=1.22.0",
"pandas>=1.5.0",
"pygam>=0.8.0",
"pygpcca>=1.0.4",
Expand Down
6 changes: 3 additions & 3 deletions src/cellrank/_utils/_colors.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import numpy as np
import pandas as pd
import scipy.stats as st
from pandas.api.types import infer_dtype, is_categorical_dtype
from pandas.api.types import infer_dtype

from matplotlib import cm, colors

Expand Down Expand Up @@ -175,9 +175,9 @@ def _map_names_and_colors(
Series with updated category names and a corresponding array of colors.
"""
# checks: dtypes, matching indices, make sure colors match the categories
if not is_categorical_dtype(series_reference):
if not isinstance(series_reference.dtype, pd.CategoricalDtype):
raise TypeError(f"Reference series must be `categorical`, found `{infer_dtype(series_reference)}`.")
if not is_categorical_dtype(series_query):
if not isinstance(series_query.dtype, pd.CategoricalDtype):
raise TypeError(f"Query series must be `categorical`, found `{infer_dtype(series_query)}`.")
if len(series_reference) != len(series_query):
raise ValueError(
Expand Down
39 changes: 22 additions & 17 deletions src/cellrank/_utils/_lineage.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import numpy as np
import pandas as pd
import scipy.stats as st
from pandas.api.types import infer_dtype, is_categorical_dtype
from pandas.api.types import infer_dtype

import matplotlib.pyplot as plt
from matplotlib import colors
Expand Down Expand Up @@ -102,7 +102,7 @@ def wrap(numpy_func: Callable) -> Callable:
"""

@functools.wraps(numpy_func)
def decorator(array, *args, **kwargs):
def decorator(array: "Lineage", *args, **kwargs):
if not isinstance(array, Lineage):
raise TypeError(f"Expected array to be of type `Lineage`, found `{type(array).__name__}`.")
if fname == "squeeze":
Expand Down Expand Up @@ -172,16 +172,21 @@ def decorator(array, *args, **kwargs):


def _register_handled_functions():
# adapted from:
# https://github.com/numpy/numpy/blob/v1.26.0/numpy/testing/overrides.py#L50
try:
from numpy.core.overrides import ARRAY_FUNCTIONS
except ImportError:
ARRAY_FUNCTIONS = [getattr(np, attr) for attr in dir(np)]

handled_fns = {}
for attrname in dir(np):
fn = getattr(np, attrname)
if isinstance(fn, types.FunctionType):
try:
sig = inspect.signature(fn)
if "axis" in sig.parameters:
handled_fns[fn] = wrap(fn)
except ValueError:
pass
for fn in ARRAY_FUNCTIONS:
try:
sig = inspect.signature(fn)
if "axis" in sig.parameters:
handled_fns[fn] = wrap(fn)
except Exception: # noqa: BLE001
pass

handled_fns.pop(np.expand_dims, None)

Expand Down Expand Up @@ -289,7 +294,7 @@ def __array_function__(self, func, types, args, kwargs):
return NotImplemented
# Note: this allows subclasses that don't override
# __array_function__ to handle MyArray objects
if not all(issubclass(t, type(self)) for t in types):
if not all(issubclass(t, self.__class__) for t in types):
return NotImplemented

return _HANDLED_FUNCTIONS[func](*args, **kwargs)
Expand Down Expand Up @@ -648,7 +653,7 @@ def reduce(
"cosine_sim", "wasserstein_dist", "kl_div", "js_div", "mutual_info", "equal"
] = DistanceMeasure.MUTUAL_INFO,
normalize_weights: Literal["scale", "softmax"] = NormWeights.SOFTMAX,
softmax_scale: float = 1,
softmax_scale: float = 1.0,
return_weights: bool = False,
) -> Union["Lineage", Tuple["Lineage", Optional[pd.DataFrame]]]:
"""Subset states and normalize them so that they again sum to :math:`1`.
Expand Down Expand Up @@ -854,7 +859,7 @@ def from_adata(
states = adata.obs.get(nkey, None)
if states is None:
logg.warning(f"Unable to find states in `adata.obs[{nkey!r}]`. Using default names")
elif not is_categorical_dtype(states):
elif not isinstance(states.dtype, pd.CategoricalDtype):
logg.warning(
f"Expected `adata.obs[{key!r}]` to be `categorical`, "
f"found `{infer_dtype(adata.obs[nkey])}`. Using default names"
Expand Down Expand Up @@ -1129,10 +1134,10 @@ def _softmax(X, beta: float = 1):
return np.exp(X * beta) / np.expand_dims(np.sum(np.exp(X * beta), axis=1), -1)


def _row_normalize(X):
def _row_normalize(X: Union[np.ndarray, Lineage]) -> Union[np.ndarray, Lineage]:
if isinstance(X, Lineage):
return X / X.sum(1) # Lineage is shape-preserving
return X / np.expand_dims(X.sum(1), -1)
return X / X.sum(1) # lineage is shape-preserving
return X / X.sum(1, keepdims=True)


def _col_normalize(X, norm_ord=2):
Expand Down
11 changes: 8 additions & 3 deletions src/cellrank/_utils/_linear_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,8 +149,10 @@ def _(
b.setArray(mat_b.squeeze())

ksp.solve(b, x)
# `is_converged` in PETSc >= 3.20
converged = ksp.is_converged if hasattr(ksp, "is_converged") else ksp.converged

return np.atleast_1d(x.getArray().copy().squeeze()), int(ksp.converged)
return np.atleast_1d(x.getArray().copy().squeeze()), int(converged)


@_solve_many_sparse_problems_petsc.register(sp.csc_matrix)
Expand All @@ -174,7 +176,8 @@ def _(
ksp.solve(b, x)

xs.append(np.atleast_1d(x.getArray().copy().squeeze()))
converged += ksp.converged
# `is_converged` in PETSc >= 3.20
converged += ksp.is_converged if hasattr(ksp, "is_converged") else ksp.converged

if queue is not None:
queue.put(1)
Expand Down Expand Up @@ -306,8 +309,10 @@ def _petsc_direct_solve(
factored_matrix.matSolve(B, x)

res = np.array(x.getDenseArray(), copy=True)
# `is_converged` in PETSc >= 3.20
converged = ksp.is_converged if hasattr(ksp, "is_converged") else ksp.converged

if not ksp.converged:
if not converged:
logg.debug(
f"The solution for system "
f"`A{list(A.getSize())} * X{list(x.getSize())} = B{list(B.getSize())}` "
Expand Down
14 changes: 7 additions & 7 deletions src/cellrank/_utils/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
import pandas as pd
import scipy.sparse as sp
import scipy.stats as st
from pandas.api.types import infer_dtype, is_categorical_dtype
from pandas.api.types import infer_dtype
from sklearn.cluster import KMeans
from statsmodels.stats.multitest import multipletests

Expand Down Expand Up @@ -183,7 +183,7 @@ def _process_series(
process_colors = cols is not None

# assert dtype of the series
if not is_categorical_dtype(series):
if not isinstance(series.dtype, pd.CategoricalDtype):
raise TypeError(f"Series must be `categorical`, found `{infer_dtype(series)}`.")

# if keys is None, just return
Expand Down Expand Up @@ -530,7 +530,7 @@ def perm_test_extractor(res: Sequence[Tuple[np.ndarray, np.ndarray]]) -> Tuple[n

def _filter_cells(distances: sp.spmatrix, rc_labels: pd.Series, n_matches_min: int) -> pd.Series:
"""Filter out some cells that look like transient states based on their neighbors."""
if not is_categorical_dtype(rc_labels):
if not isinstance(rc_labels.dtype, pd.CategoricalDtype):
raise TypeError(f"Expected `categories` be `categorical`, found `{infer_dtype(rc_labels)}`.")

# retrieve knn graph
Expand Down Expand Up @@ -886,7 +886,7 @@ def _convert_to_categorical_series(
"that there are no conflicting keys, such as `0` and `'0'`."
)

term_states = pd.Series([np.nan] * len(cell_names), index=cell_names)
term_states = pd.Series([None] * len(cell_names), index=cell_names, dtype=str)
for ts, cells in mapper.items():
term_states[cells] = ts

Expand Down Expand Up @@ -947,10 +947,10 @@ def get_color_mapper(

return cols

if not is_categorical_dtype(old):
if not isinstance(old.dtype, pd.CategoricalDtype):
raise TypeError(f"Expected old approx. recurrent classes to be categorical, found " f"`{infer_dtype(old)}`.")

if not is_categorical_dtype(new):
if not isinstance(new.dtype, pd.CategoricalDtype):
raise TypeError(f"Expected new approx. recurrent classes to be categorical, found " f"`{infer_dtype(new)}`.")

if (old.index != new.index).any():
Expand Down Expand Up @@ -1158,7 +1158,7 @@ def _series_from_one_hot_matrix(
target_series = pd.Series(index=index, dtype="category")
for vec, name in zip(membership.T, names):
target_series = target_series.cat.add_categories(name)
target_series[np.where(vec)[0]] = name
target_series.iloc[np.where(vec)[0]] = name

return target_series

Expand Down
4 changes: 2 additions & 2 deletions src/cellrank/estimators/mixins/_fate_probabilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import numpy as np
import pandas as pd
import scipy.sparse as sp
from pandas.api.types import infer_dtype, is_categorical_dtype
from pandas.api.types import infer_dtype

from anndata import AnnData

Expand Down Expand Up @@ -386,7 +386,7 @@ def compute_lineage_priming(
key = next(iter(early_cells.keys()))
if key not in self.adata.obs:
raise KeyError(f"Unable to find clusters in `adata.obs[{key!r}]`.")
if not is_categorical_dtype(self.adata.obs[key]):
if not isinstance(self.adata.obs[key].dtype, pd.CategoricalDtype):
raise TypeError(
f"Expected `adata.obs[{key!r}]` to be categorical, " f"found `{infer_dtype(self.adata.obs[key])}`."
)
Expand Down
10 changes: 5 additions & 5 deletions src/cellrank/estimators/terminal_states/_gpcca.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import numpy as np
import pandas as pd
import scipy.sparse as sp
from pandas.api.types import infer_dtype, is_categorical_dtype
from pandas.api.types import infer_dtype

import matplotlib.pyplot as plt
from matplotlib.axes import Axes
Expand Down Expand Up @@ -367,7 +367,7 @@ def predict_initial_states(self, n_states: int = 1, n_cells: int = 30, allow_ove
if stat_dist is None:
raise RuntimeError("No coarse-grained stationary distribution found.")

states = list(stat_dist[np.argsort(stat_dist)][:n_states].index)
states = list(stat_dist.iloc[np.argsort(stat_dist)][:n_states].index)
return self.set_initial_states(states, n_cells=n_cells, allow_overlap=allow_overlap)

@d.dedent
Expand Down Expand Up @@ -876,7 +876,7 @@ def plot_macrostate_composition(
raise RuntimeError("Compute macrostates first as `.compute_macrostates()`.")
if key not in self.adata.obs:
raise KeyError(f"Data not found in `adata.obs[{key!r}]`.")
if not is_categorical_dtype(self.adata.obs[key]):
if not isinstance(self.adata.obs[key].dtype, pd.CategoricalDtype):
raise TypeError(
f"Expected `adata.obs[{key!r}]` to be `categorical`, " f"found `{infer_dtype(self.adata.obs[key])}`."
)
Expand All @@ -893,7 +893,7 @@ def plot_macrostate_composition(
cats_colors = _create_categorical_colors(len(self.adata.obs[key].cat.categories))
cat_color_mapper = dict(zip(self.adata.obs[key].cat.categories, cats_colors))
x_indices = np.arange(len(macrostates.cat.categories))
bottom = np.zeros_like(x_indices, dtype=np.float32)
bottom = np.zeros_like(x_indices, dtype=float)

width = min(1, max(0, width))
fig, ax = plt.subplots(figsize=figsize, dpi=dpi, tight_layout=True)
Expand Down Expand Up @@ -1155,7 +1155,7 @@ def _write_macrostates(
self._set("_coarse_stat_dist", value=stat_dist, shadow_only=True)
self._set(
obj=self.adata.uns, key=Key.uns.coarse(self.backward),
value=AnnData(tmat, obs=dists, dtype=float)
value=AnnData(tmat, obs=dists),
)
else:
for attr in ["_schur_vectors", "_schur_matrix", "_coarse_tmat", "_coarse_init_dist", "_coarse_stat_dist"]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import numpy as np
import pandas as pd
import scipy.sparse as sp
from pandas.api.types import infer_dtype, is_categorical_dtype
from pandas.api.types import infer_dtype

from matplotlib.colors import to_hex

Expand Down Expand Up @@ -411,7 +411,7 @@ def _plot_discrete(
) -> None:
if not isinstance(_data, pd.Series):
raise TypeError(f"Expected `data` to be of type `pandas.Series`, found `{type(_data)}`.")
if not is_categorical_dtype(_data):
if not isinstance(_data.dtype, pd.CategoricalDtype):
raise TypeError(f"Expected `data` to be `categorical`, found `{infer_dtype(_data)}`.")

names = list(_data.cat.categories)
Expand Down Expand Up @@ -564,7 +564,9 @@ def _set_categorical_labels(
# fmt: off
if isinstance(categories, dict):
key = next(iter(categories.keys()))
if len(categories) == 1 and is_categorical_dtype(self.adata.obs.get(key, None)):
data = self.adata.obs.get(key, None)
is_categorical = data is not None and isinstance(data.dtype, pd.CategoricalDtype)
if len(categories) == 1 and is_categorical:
vals = categories[key]
if isinstance(vals, str) or not isinstance(vals, Sequence):
vals = (categories[key],)
Expand All @@ -575,7 +577,7 @@ def _set_categorical_labels(
categories = {cat: self.adata[clusters == cat].obs_names for cat in vals}

categories = _convert_to_categorical_series(categories, list(self.adata.obs_names))
if not is_categorical_dtype(categories):
if not isinstance(categories.dtype, pd.CategoricalDtype):
raise TypeError(f"Expected object to be `categorical`, found `{infer_dtype(categories)}`.")

if existing is not None:
Expand Down
2 changes: 1 addition & 1 deletion src/cellrank/kernels/_precomputed_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def _from_matrix(
# fmt: off
if adata is None:
logg.warning(f"Creating empty `AnnData` object of shape `{matrix.shape[0], 1}`")
adata = AnnData(sp.csr_matrix((matrix.shape[0], 1), dtype=np.float64))
adata = AnnData(sp.csr_matrix((matrix.shape[0], 1)))
super().__init__(adata)
self._backward: Optional[bool] = backward
self.transition_matrix = matrix.copy() if copy else matrix
Expand Down
16 changes: 8 additions & 8 deletions src/cellrank/kernels/_real_time_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import numpy as np
import pandas as pd
import scipy.sparse as sp
from pandas.api.types import infer_dtype, is_categorical_dtype
from pandas.api.types import infer_dtype

import scanpy as sc
from anndata import AnnData
Expand Down Expand Up @@ -105,7 +105,7 @@ def _read_from_adata(
) -> None:
super()._read_from_adata(**kwargs)
self._time = self.adata.obs[time_key].copy()
if not is_categorical_dtype(self._time):
if not isinstance(self._time.dtype, pd.CategoricalDtype):
raise TypeError(f"Expected `adata.obs[{time_key!r}]` to be categorical, found `{infer_dtype(self._time)}`.")
self._time = self._time.cat.remove_unused_categories()
cats = self._time.cat.categories
Expand Down Expand Up @@ -448,7 +448,7 @@ def _restich_couplings(
for ix in range(len(blocks)):
index.extend(obs_names[ix])

tmp = AnnData(sp.bmat(blocks, format="csr"), dtype="float64")
tmp = AnnData(sp.bmat(blocks, format="csr"))
tmp.obs_names = index
tmp.var_names = index
tmp = tmp[self.adata.obs_names, :][:, self.adata.obs_names]
Expand Down Expand Up @@ -513,14 +513,14 @@ def _sparsify_couplings(
if threshold == "auto_local":
thresh = min(tmat[i].max() for i in range(tmat.shape[0]))
logg.debug(f"Using `threshold={thresh}` at `{key}`")
elif isinstance(threshold, (int, float)):
elif isinstance(threshold, (int, float, np.number)):
thresh = np.percentile(tmat.data, threshold)
logg.debug(f"Using `threshold={thresh}` at `{key}`")

tmat = sp.csr_matrix(tmat, dtype=tmat.dtype)
tmat.data[tmat.data < thresh] = 0.0
tmat.eliminate_zeros()
couplings[key] = AnnData(tmat, obs=adata.obs, var=adata.var, dtype=tmat.dtype)
couplings[key] = AnnData(tmat, obs=adata.obs, var=adata.var)

return couplings if copy else None

Expand Down Expand Up @@ -579,9 +579,9 @@ def assert_same(expected: Sequence[Any], actual: Sequence[Any], msg: Optional[st
def _coupling_to_adata(self, src: Any, tgt: Any, coupling: Coupling_t) -> AnnData:
"""Convert the coupling to :class:`~anndata.AnnData`."""
if not isinstance(coupling, AnnData):
coupling = AnnData(X=coupling, dtype=coupling.dtype)
coupling.obs_names = self.adata[self._time == src].obs_names
coupling.var_names = self.adata[self._time == tgt].obs_names
coupling = AnnData(X=coupling)
coupling.obs_names = np.asarray(self.adata.obs_names)[self.time == src]
coupling.var_names = np.asarray(self.adata.obs_names)[self.time == tgt]

if sp.issparse(coupling.X) and not sp.isspmatrix_csr(coupling.X):
coupling.X = coupling.X.tocsr()
Expand Down
Loading

0 comments on commit 01a79ca

Please sign in to comment.