From 5e8da37ce5d0264456078a6cfcf01b848d0d61af Mon Sep 17 00:00:00 2001 From: Malte Foerster Date: Wed, 17 May 2023 15:58:35 +0000 Subject: [PATCH] fix restrictive assertion --- .../raft/distance/detail/kernels/gram_matrix.cuh | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/cpp/include/raft/distance/detail/kernels/gram_matrix.cuh b/cpp/include/raft/distance/detail/kernels/gram_matrix.cuh index 7cfc75cd96..9b079a8539 100644 --- a/cpp/include/raft/distance/detail/kernels/gram_matrix.cuh +++ b/cpp/include/raft/distance/detail/kernels/gram_matrix.cuh @@ -464,18 +464,19 @@ class GramMatrixBase { csr_input_matrix_view_t x2, dense_output_matrix_view_t out) { - // check is_row_major consistency - bool is_row_major = get_is_row_major(out); - int ld_out = is_row_major ? out.stride(0) : out.stride(1); - int minor_out = is_row_major ? out.extent(1) : out.extent(0); - ASSERT(ld_out == minor_out, "Sparse linear Kernel distance does not support ld_out parameter"); + // check layout consistency (w.r.t. strides a matrix might be both row & col major) + bool is_row_major_nopad = get_is_row_major(out) && out.stride(0) == out.extent(1); + bool is_col_major_nopad = get_is_col_major(out) && out.stride(1) == out.extent(0); + + ASSERT(is_row_major_nopad || is_col_major_nopad, + "Sparse linear Kernel distance does not support ld_out parameter"); auto x1_structure = x1.structure_view(); auto x2_structure = x2.structure_view(); raft::sparse::distance::distances_config_t dist_config(handle); - // switch a,b based on is_row_major - if (!is_row_major) { + // switch a,b based on data layout + if (is_col_major_nopad) { dist_config.a_nrows = x2_structure.get_n_rows(); dist_config.a_ncols = x2_structure.get_n_cols(); dist_config.a_nnz = x2_structure.get_nnz();