diff --git a/sparse/_common.py b/sparse/_common.py index 6073798b..8b497f7b 100644 --- a/sparse/_common.py +++ b/sparse/_common.py @@ -724,6 +724,13 @@ def _dot_csr_csr( sums[temp] = 0 indptr[i + 1] = nnz + + if len(indices) == (n_col * n_row): + for i in range(len(indices) // n_col): + j = n_col * i + k = n_col * (1 + i) + data[j:k] = data[j:k][::-1] + indices[j:k] = indices[j:k][::-1] return data, indices, indptr return _dot_csr_csr diff --git a/sparse/_utils.py b/sparse/_utils.py index 3e199109..82410412 100644 --- a/sparse/_utils.py +++ b/sparse/_utils.py @@ -49,8 +49,31 @@ def assert_eq(x, y, check_nnz=True, compare_dtype=True, **kwargs): assert check_equal(xx, yy, **kwargs) +def assert_gcxs_slicing(s, x): + """ + Util function to test slicing of GCXS matrices after product multiplication. + For simplicity, it tests only tensors with number of dimension = 3. + Parameters + ---------- + s: sparse product matrix + x: dense product matrix + """ + row = np.random.randint(s.shape[s.ndim - 2]) + assert np.allclose(s[0][row].data, [num for num in x[0][row] if num != 0]) + + # regression test + col = s.shape[s.ndim - 1] + for i in range(len(s.indices) // col): + j = col * i + k = col * (1 + i) + s.data[j:k] = s.data[j:k][::-1] + s.indices[j:k] = s.indices[j:k][::-1] + assert np.array_equal(s[0][row].data, np.array([])) + + def assert_nnz(s, x): fill_value = s.fill_value if hasattr(s, "fill_value") else _zero_of_dtype(s.dtype) + assert np.sum(~equivalent(x, fill_value)) == s.nnz diff --git a/sparse/tests/test_dot.py b/sparse/tests/test_dot.py index 214ff664..e7adaab0 100644 --- a/sparse/tests/test_dot.py +++ b/sparse/tests/test_dot.py @@ -7,7 +7,7 @@ import sparse from sparse._compressed import GCXS from sparse import COO -from sparse._utils import assert_eq +from sparse._utils import assert_eq, assert_gcxs_slicing @pytest.mark.parametrize( @@ -341,3 +341,64 @@ def test_dot_dense(dtype1, dtype2, ndim1, ndim2): assert_eq(sparse.matmul(a, b), np.matmul(a, b)) if ndim1 == 2 and ndim2 == 2: assert_eq(sparse.tensordot(a, b), np.tensordot(a, b)) + + +@pytest.mark.parametrize( + "a_shape, b_shape", + [((3, 4, 5), (5, 6)), ((2, 8, 6), (6, 3))], +) +def test_dot_GCXS_slicing(a_shape, b_shape): + sa = sparse.random(shape=a_shape, density=1, format="gcxs") + sb = sparse.random(shape=b_shape, density=1, format="gcxs") + + a = sa.todense() + b = sb.todense() + + # tests dot + sa_sb = sparse.dot(sa, sb) + a_b = np.dot(a, b) + + assert_gcxs_slicing(sa_sb, a_b) + + +@pytest.mark.parametrize( + "a_shape,b_shape,axes", + [ + [(3, 4, 5), (4, 3), (1, 0)], + [(3, 4), (5, 4, 3), (1, 1)], + [(5, 9), (9, 5, 6), (0, 1)], + ], +) +def test_tensordot_GCXS_slicing(a_shape, b_shape, axes): + sa = sparse.random(shape=a_shape, density=1, format="gcxs") + sb = sparse.random(shape=b_shape, density=1, format="gcxs") + + a = sa.todense() + b = sb.todense() + + sa_sb = sparse.tensordot(sa, sb, axes) + a_b = np.tensordot(a, b, axes) + + assert_gcxs_slicing(sa_sb, a_b) + + +@pytest.mark.parametrize( + "a_shape, b_shape", + [ + [(1, 1, 5), (3, 5, 6)], + [(3, 4, 5), (1, 5, 6)], + [(3, 4, 5), (3, 5, 6)], + [(3, 4, 5), (5, 6)], + ], +) +def test_matmul_GCXS_slicing(a_shape, b_shape): + sa = sparse.random(shape=a_shape, density=1, format="gcxs") + sb = sparse.random(shape=b_shape, density=1, format="gcxs") + + a = sa.todense() + b = sb.todense() + + sa_sb = sparse.matmul(sa, sb) + a_b = np.matmul(a, b) + + assert_gcxs_slicing(sa_sb, a_b)