Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix and refactor filter_genes #537

Merged
merged 2 commits into from
Jul 20, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion scvelo/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
set_modality,
show_proportions,
)
from ._arithmetic import clipped_log, invert, prod_sum, sum
from ._arithmetic import clipped_log, invert, multiply, prod_sum, sum
from ._linear_models import LinearRegression
from ._metrics import l2_norm
from ._models import SplicingDynamics
Expand All @@ -33,6 +33,7 @@
"make_dense",
"make_sparse",
"merge",
"multiply",
"parallelize",
"prod_sum",
"set_initial_size",
Expand Down
26 changes: 26 additions & 0 deletions scvelo/core/_arithmetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,32 @@ def invert(x: ndarray) -> ndarray:
return x_inv


def multiply(
a: Union[ndarray, spmatrix], b: Union[ndarray, spmatrix]
) -> Union[ndarray, spmatrix]:
"""Point-wise multiplication of arrays or sparse matrices.

Arguments
---------
a
First array/sparse matrix.
b
Second array/sparse matrix.

Returns
-------
Union[ndarray, spmatrix]
Point-wise product of `a` and `b`.
"""

if issparse(a):
return a.multiply(b)
elif issparse(b):
return b.multiply(a)
else:
return a * b


def prod_sum(
a1: Union[ndarray, spmatrix], a2: Union[ndarray, spmatrix], axis: Optional[int]
) -> ndarray:
Expand Down
12 changes: 4 additions & 8 deletions scvelo/preprocessing/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from scvelo.core import cleanup as _cleanup
from scvelo.core import get_initial_size as _get_initial_size
from scvelo.core import get_size as _get_size
from scvelo.core import multiply
from scvelo.core import set_initial_size as _set_initial_size
from scvelo.core import show_proportions as _show_proportions
from scvelo.core import sum
Expand Down Expand Up @@ -233,14 +234,9 @@ def filter_genes(
X = adata.layers[layer]
else: # shared counts/cells
Xs, Xu = adata.layers["spliced"], adata.layers["unspliced"]
nonzeros = (
(Xs > 0).multiply(Xu > 0) if issparse(Xs) else (Xs > 0) * (Xu > 0)
)
X = (
nonzeros.multiply(Xs) + nonzeros.multiply(Xu)
if issparse(nonzeros)
else nonzeros * (Xs + Xu)
)

nonzeros = multiply(Xs > 0, Xu > 0)
X = multiply(nonzeros, Xs) + multiply(nonzeros, Xu)

gene_subset = np.ones(adata.n_vars, dtype=bool)

Expand Down
58 changes: 57 additions & 1 deletion tests/core/test_arithmetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,9 @@
import numpy as np
from numpy import ndarray
from numpy.testing import assert_almost_equal, assert_array_equal
from scipy.sparse import csr_matrix, issparse

from scvelo.core import clipped_log, invert, prod_sum, sum
from scvelo.core import clipped_log, invert, multiply, prod_sum, sum


class TestClippedLog:
Expand Down Expand Up @@ -133,6 +134,61 @@ def test_2d_arrays(self, a: ndarray):
assert set(a_inv[a == 0]) == set()


class TestMultiply:
@given(
a=arrays(
float,
shape=st.integers(min_value=1, max_value=100),
elements=st.floats(max_value=1e3, allow_infinity=False, allow_nan=False),
)
)
def test_flat_arrays(self, a: ndarray):
b = csr_matrix(a)

res = multiply(a, a)
assert res.shape == a.shape
assert not issparse(res)
np.testing.assert_almost_equal(res, a * a)

res = multiply(a, b)
assert res.shape == b.shape
assert issparse(res)
np.testing.assert_almost_equal(res.data, b.multiply(a).data)

res = multiply(b, a)
assert res.shape == b.shape
assert issparse(res)
np.testing.assert_almost_equal(res.data, b.multiply(a).data)

@given(
a=arrays(
float,
shape=st.tuples(
st.integers(min_value=1, max_value=100),
st.integers(min_value=1, max_value=100),
),
elements=st.floats(max_value=1e3, allow_infinity=False, allow_nan=False),
),
)
def test_2d_arrays(self, a: ndarray):
b = csr_matrix(a)

res = multiply(a, a)
assert res.shape == a.shape
assert not issparse(res)
np.testing.assert_almost_equal(res, a * a)

res = multiply(a, b)
assert res.shape == b.shape
assert issparse(res)
np.testing.assert_almost_equal(res.data, b.multiply(a).data)

res = multiply(b, a)
assert res.shape == b.shape
assert issparse(res)
np.testing.assert_almost_equal(res.data, b.multiply(a).data)


# TODO: Extend test to generate sparse inputs as well
# TODO: Make test to generate two different arrays a1, a2
# TODO: Check why tests fail with assert_almost_equal
Expand Down