diff --git a/scvelo/core/__init__.py b/scvelo/core/__init__.py index 156eeb6e..9c40d7e6 100644 --- a/scvelo/core/__init__.py +++ b/scvelo/core/__init__.py @@ -12,7 +12,7 @@ set_modality, show_proportions, ) -from ._arithmetic import clipped_log, invert, prod_sum, sum +from ._arithmetic import clipped_log, invert, multiply, prod_sum, sum from ._linear_models import LinearRegression from ._metrics import l2_norm from ._models import SplicingDynamics @@ -33,6 +33,7 @@ "make_dense", "make_sparse", "merge", + "multiply", "parallelize", "prod_sum", "set_initial_size", diff --git a/scvelo/core/_arithmetic.py b/scvelo/core/_arithmetic.py index 037de5f2..92bf9998 100644 --- a/scvelo/core/_arithmetic.py +++ b/scvelo/core/_arithmetic.py @@ -49,6 +49,32 @@ def invert(x: ndarray) -> ndarray: return x_inv +def multiply( + a: Union[ndarray, spmatrix], b: Union[ndarray, spmatrix] +) -> Union[ndarray, spmatrix]: + """Point-wise multiplication of arrays or sparse matrices. + + Arguments + --------- + a + First array/sparse matrix. + b + Second array/sparse matrix. + + Returns + ------- + Union[ndarray, spmatrix] + Point-wise product of `a` and `b`. + """ + + if issparse(a): + return a.multiply(b) + elif issparse(b): + return b.multiply(a) + else: + return a * b + + def prod_sum( a1: Union[ndarray, spmatrix], a2: Union[ndarray, spmatrix], axis: Optional[int] ) -> ndarray: diff --git a/scvelo/preprocessing/utils.py b/scvelo/preprocessing/utils.py index fca3db11..ea89ff60 100644 --- a/scvelo/preprocessing/utils.py +++ b/scvelo/preprocessing/utils.py @@ -11,6 +11,7 @@ from scvelo.core import cleanup as _cleanup from scvelo.core import get_initial_size as _get_initial_size from scvelo.core import get_size as _get_size +from scvelo.core import multiply from scvelo.core import set_initial_size as _set_initial_size from scvelo.core import show_proportions as _show_proportions from scvelo.core import sum @@ -233,14 +234,9 @@ def filter_genes( X = adata.layers[layer] else: # shared counts/cells Xs, Xu = adata.layers["spliced"], adata.layers["unspliced"] - nonzeros = ( - (Xs > 0).multiply(Xu > 0) if issparse(Xs) else (Xs > 0) * (Xu > 0) - ) - X = ( - nonzeros.multiply(Xs) + nonzeros.multiply(Xu) - if issparse(nonzeros) - else nonzeros * (Xs + Xu) - ) + + nonzeros = multiply(Xs > 0, Xu > 0) + X = multiply(nonzeros, Xs) + multiply(nonzeros, Xu) gene_subset = np.ones(adata.n_vars, dtype=bool) diff --git a/tests/core/test_arithmetic.py b/tests/core/test_arithmetic.py index d8f12092..82e4cdf8 100644 --- a/tests/core/test_arithmetic.py +++ b/tests/core/test_arithmetic.py @@ -7,8 +7,9 @@ import numpy as np from numpy import ndarray from numpy.testing import assert_almost_equal, assert_array_equal +from scipy.sparse import csr_matrix, issparse -from scvelo.core import clipped_log, invert, prod_sum, sum +from scvelo.core import clipped_log, invert, multiply, prod_sum, sum class TestClippedLog: @@ -133,6 +134,61 @@ def test_2d_arrays(self, a: ndarray): assert set(a_inv[a == 0]) == set() +class TestMultiply: + @given( + a=arrays( + float, + shape=st.integers(min_value=1, max_value=100), + elements=st.floats(max_value=1e3, allow_infinity=False, allow_nan=False), + ) + ) + def test_flat_arrays(self, a: ndarray): + b = csr_matrix(a) + + res = multiply(a, a) + assert res.shape == a.shape + assert not issparse(res) + np.testing.assert_almost_equal(res, a * a) + + res = multiply(a, b) + assert res.shape == b.shape + assert issparse(res) + np.testing.assert_almost_equal(res.data, b.multiply(a).data) + + res = multiply(b, a) + assert res.shape == b.shape + assert issparse(res) + np.testing.assert_almost_equal(res.data, b.multiply(a).data) + + @given( + a=arrays( + float, + shape=st.tuples( + st.integers(min_value=1, max_value=100), + st.integers(min_value=1, max_value=100), + ), + elements=st.floats(max_value=1e3, allow_infinity=False, allow_nan=False), + ), + ) + def test_2d_arrays(self, a: ndarray): + b = csr_matrix(a) + + res = multiply(a, a) + assert res.shape == a.shape + assert not issparse(res) + np.testing.assert_almost_equal(res, a * a) + + res = multiply(a, b) + assert res.shape == b.shape + assert issparse(res) + np.testing.assert_almost_equal(res.data, b.multiply(a).data) + + res = multiply(b, a) + assert res.shape == b.shape + assert issparse(res) + np.testing.assert_almost_equal(res.data, b.multiply(a).data) + + # TODO: Extend test to generate sparse inputs as well # TODO: Make test to generate two different arrays a1, a2 # TODO: Check why tests fail with assert_almost_equal