From ffceee21e36d6b479d2635f94693bc4ce9923c5f Mon Sep 17 00:00:00 2001 From: rhdong Date: Wed, 24 Jul 2024 08:51:54 -0700 Subject: [PATCH] [FEA] add the support of `masked_matmul` (#2362) https://github.com/rapidsai/raft/issues/2336 Authors: - rhdong (https://github.com/rhdong) Approvers: - Corey J. Nolet (https://github.com/cjnolet) URL: https://github.com/rapidsai/raft/pull/2362 --- cpp/bench/prims/CMakeLists.txt | 1 + cpp/bench/prims/linalg/masked_matmul.cu | 268 ++++++++++++++ .../sparse/linalg/detail/masked_matmul.cuh | 107 ++++++ .../raft/sparse/linalg/masked_matmul.hpp | 71 ++++ cpp/test/CMakeLists.txt | 1 + cpp/test/sparse/masked_matmul.cu | 328 ++++++++++++++++++ 6 files changed, 776 insertions(+) create mode 100644 cpp/bench/prims/linalg/masked_matmul.cu create mode 100644 cpp/include/raft/sparse/linalg/detail/masked_matmul.cuh create mode 100644 cpp/include/raft/sparse/linalg/masked_matmul.hpp create mode 100644 cpp/test/sparse/masked_matmul.cu diff --git a/cpp/bench/prims/CMakeLists.txt b/cpp/bench/prims/CMakeLists.txt index c72649e350..c8c68f19bf 100644 --- a/cpp/bench/prims/CMakeLists.txt +++ b/cpp/bench/prims/CMakeLists.txt @@ -117,6 +117,7 @@ if(BUILD_PRIMS_BENCH) PATH linalg/add.cu linalg/map_then_reduce.cu + linalg/masked_matmul.cu linalg/matrix_vector_op.cu linalg/norm.cu linalg/normalize.cu diff --git a/cpp/bench/prims/linalg/masked_matmul.cu b/cpp/bench/prims/linalg/masked_matmul.cu new file mode 100644 index 0000000000..eda9cb1710 --- /dev/null +++ b/cpp/bench/prims/linalg/masked_matmul.cu @@ -0,0 +1,268 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +#include +#include +#include + +namespace raft::bench::linalg { + +template +struct MaskedMatmulBenchParams { + size_t m; + size_t k; + size_t n; + float sparsity; + value_t alpha = 1.0; + value_t beta = 0.0; +}; + +template +inline auto operator<<(std::ostream& os, const MaskedMatmulBenchParams& params) + -> std::ostream& +{ + os << " m*k*n=" << params.m << "*" << params.k << "*" << params.n + << "\tsparsity=" << params.sparsity; + if (params.sparsity == 1.0) { os << "<-inner product for comparison"; } + return os; +} + +template +struct MaskedMatmulBench : public fixture { + MaskedMatmulBench(const MaskedMatmulBenchParams& p) + : fixture(true), + params(p), + handle(stream), + a_data_d(0, stream), + b_data_d(0, stream), + c_indptr_d(0, stream), + c_indices_d(0, stream), + c_data_d(0, stream), + bitmap_d(0, stream), + c_dense_data_d(0, stream) + { + index_t element = raft::ceildiv(index_t(params.m * params.n), index_t(sizeof(bitmap_t) * 8)); + std::vector bitmap_h(element); + + a_data_d.resize(params.m * params.k, stream); + b_data_d.resize(params.k * params.n, stream); + bitmap_d.resize(element, stream); + + raft::random::RngState rng(2024ULL); + raft::random::uniform( + handle, rng, a_data_d.data(), params.m * params.k, value_t(-1.0), value_t(1.0)); + raft::random::uniform( + handle, rng, b_data_d.data(), params.k * params.n, value_t(-1.0), value_t(1.0)); + + std::vector c_dense_data_h(params.m * params.n); + + c_true_nnz = create_sparse_matrix(params.m, params.n, params.sparsity, bitmap_h); + + std::vector values(c_true_nnz); + std::vector indices(c_true_nnz); + std::vector indptr(params.m + 1); + + c_data_d.resize(c_true_nnz, stream); + c_indptr_d.resize(params.m + 1, stream); + c_indices_d.resize(c_true_nnz, stream); + c_dense_data_d.resize(params.m * params.n, stream); + + cpu_convert_to_csr(bitmap_h, params.m, params.n, indices, indptr); + RAFT_EXPECTS(c_true_nnz == c_indices_d.size(), + "Something wrong. The c_true_nnz != c_indices_d.size()!"); + + update_device(c_data_d.data(), values.data(), c_true_nnz, stream); + update_device(c_indices_d.data(), indices.data(), c_true_nnz, stream); + update_device(c_indptr_d.data(), indptr.data(), params.m + 1, stream); + update_device(bitmap_d.data(), bitmap_h.data(), element, stream); + } + + index_t create_sparse_matrix(index_t m, index_t n, float sparsity, std::vector& bitmap) + { + index_t total = static_cast(m * n); + index_t num_ones = static_cast((total * 1.0f) * sparsity); + index_t res = num_ones; + + for (auto& item : bitmap) { + item = static_cast(0); + } + + std::random_device rd; + std::mt19937 gen(rd()); + std::uniform_int_distribution dis(0, total - 1); + + while (num_ones > 0) { + index_t index = dis(gen); + + bitmap_t& element = bitmap[index / (8 * sizeof(bitmap_t))]; + index_t bit_position = index % (8 * sizeof(bitmap_t)); + + if (((element >> bit_position) & 1) == 0) { + element |= (static_cast(1) << bit_position); + num_ones--; + } + } + return res; + } + + void cpu_convert_to_csr(std::vector& bitmap, + index_t rows, + index_t cols, + std::vector& indices, + std::vector& indptr) + { + index_t offset_indptr = 0; + index_t offset_values = 0; + indptr[offset_indptr++] = 0; + + index_t index = 0; + bitmap_t element = 0; + index_t bit_position = 0; + + for (index_t i = 0; i < rows; ++i) { + for (index_t j = 0; j < cols; ++j) { + index = i * cols + j; + element = bitmap[index / (8 * sizeof(bitmap_t))]; + bit_position = index % (8 * sizeof(bitmap_t)); + + if (((element >> bit_position) & 1)) { + indices[offset_values] = static_cast(j); + offset_values++; + } + } + indptr[offset_indptr++] = static_cast(offset_values); + } + } + + ~MaskedMatmulBench() {} + + void run_benchmark(::benchmark::State& state) override + { + std::ostringstream label_stream; + label_stream << params; + state.SetLabel(label_stream.str()); + + auto a = raft::make_device_matrix_view( + a_data_d.data(), params.m, params.k); + + auto b = raft::make_device_matrix_view( + b_data_d.data(), params.n, params.k); + + auto c_structure = raft::make_device_compressed_structure_view( + c_indptr_d.data(), + c_indices_d.data(), + params.m, + params.n, + static_cast(c_indices_d.size())); + + auto mask = + raft::core::bitmap_view(bitmap_d.data(), params.m, params.n); + + auto c = raft::make_device_csr_matrix_view(c_data_d.data(), c_structure); + + if (params.sparsity < 1.0) { + raft::sparse::linalg::masked_matmul(handle, a, b, mask, c); + } else { + raft::distance::pairwise_distance(handle, + a_data_d.data(), + b_data_d.data(), + c_dense_data_d.data(), + static_cast(params.m), + static_cast(params.n), + static_cast(params.k), + raft::distance::DistanceType::InnerProduct, + true); + } + resource::sync_stream(handle); + + raft::sparse::linalg::masked_matmul(handle, a, b, mask, c); + resource::sync_stream(handle); + + loop_on_state(state, [this, &a, &b, &mask, &c]() { + if (params.sparsity < 1.0) { + raft::sparse::linalg::masked_matmul(handle, a, b, mask, c); + } else { + raft::distance::pairwise_distance(handle, + a_data_d.data(), + b_data_d.data(), + c_dense_data_d.data(), + static_cast(params.m), + static_cast(params.n), + static_cast(params.k), + raft::distance::DistanceType::InnerProduct, + true); + } + resource::sync_stream(handle); + }); + } + + private: + const raft::device_resources handle; + MaskedMatmulBenchParams params; + + rmm::device_uvector a_data_d; + rmm::device_uvector b_data_d; + rmm::device_uvector bitmap_d; + + rmm::device_uvector c_dense_data_d; + + size_t c_true_nnz = 0; + rmm::device_uvector c_indptr_d; + rmm::device_uvector c_indices_d; + rmm::device_uvector c_data_d; +}; + +template +static std::vector> getInputs() +{ + std::vector> param_vec; + struct TestParams { + size_t m; + size_t k; + size_t n; + float sparsity; + }; + + const std::vector params_group = + raft::util::itertools::product({size_t(10), size_t(1024)}, + {size_t(128), size_t(1024)}, + {size_t(1024 * 1024)}, + {0.01f, 0.1f, 0.2f, 0.5f, 1.0f}); + + param_vec.reserve(params_group.size()); + for (TestParams params : params_group) { + param_vec.push_back( + MaskedMatmulBenchParams({params.m, params.k, params.n, params.sparsity})); + } + return param_vec; +} + +RAFT_BENCH_REGISTER((MaskedMatmulBench), "", getInputs()); + +} // namespace raft::bench::linalg diff --git a/cpp/include/raft/sparse/linalg/detail/masked_matmul.cuh b/cpp/include/raft/sparse/linalg/detail/masked_matmul.cuh new file mode 100644 index 0000000000..208328f2f3 --- /dev/null +++ b/cpp/include/raft/sparse/linalg/detail/masked_matmul.cuh @@ -0,0 +1,107 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain A copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +namespace raft { +namespace sparse { +namespace linalg { +namespace detail { + +template +void masked_matmul(raft::resources const& handle, + raft::device_matrix_view& A, + raft::device_matrix_view& B, + raft::core::bitmap_view& mask, + raft::device_csr_matrix_view& C, + std::optional> alpha, + std::optional> beta) +{ + index_t m = A.extent(0); + index_t n = B.extent(0); + index_t dim = A.extent(1); + + auto compressed_C_view = C.structure_view(); + + RAFT_EXPECTS(A.extent(1) == B.extent(1), "The dim of A must be equal to the dim of B."); + RAFT_EXPECTS(A.extent(0) == compressed_C_view.get_n_rows(), + "Number of rows in C must match the number of rows in A."); + RAFT_EXPECTS(B.extent(0) == compressed_C_view.get_n_cols(), + "Number of columns in C must match the number of columns in B."); + + auto stream = raft::resource::get_cuda_stream(handle); + + auto C_matrix = raft::make_device_csr_matrix(handle, compressed_C_view); + + // fill C + raft::sparse::convert::bitmap_to_csr(handle, mask, C_matrix); + + if (m > 10 || alpha.has_value() || beta.has_value()) { + auto C_view = raft::make_device_csr_matrix_view( + C.get_elements().data(), compressed_C_view); + + // create B col_major view + auto B_col_major = raft::make_device_matrix_view( + B.data_handle(), dim, n); + + value_t default_alpha = static_cast(1.0f); + value_t default_beta = static_cast(0.0f); + + if (!alpha.has_value()) { alpha = raft::make_host_scalar_view(&default_alpha); } + if (!beta.has_value()) { beta = raft::make_host_scalar_view(&default_beta); } + + raft::sparse::linalg::sddmm(handle, + A, + B_col_major, + C_view, + raft::linalg::Operation::NON_TRANSPOSE, + raft::linalg::Operation::NON_TRANSPOSE, + *alpha, + *beta); + } else { + raft::sparse::distance::detail::faster_dot_on_csr(handle, + C.get_elements().data(), + compressed_C_view.get_nnz(), + compressed_C_view.get_indptr().data(), + compressed_C_view.get_indices().data(), + A.data_handle(), + B.data_handle(), + compressed_C_view.get_n_rows(), + dim); + } +} + +} // namespace detail +} // namespace linalg +} // namespace sparse +} // namespace raft diff --git a/cpp/include/raft/sparse/linalg/masked_matmul.hpp b/cpp/include/raft/sparse/linalg/masked_matmul.hpp new file mode 100644 index 0000000000..560cd3f715 --- /dev/null +++ b/cpp/include/raft/sparse/linalg/masked_matmul.hpp @@ -0,0 +1,71 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain A copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include + +namespace raft { +namespace sparse { +namespace linalg { + +/** + * @defgroup masked_matmul Masked Matrix Multiplication + * @{ + */ + +/** + * @brief Performs a masked multiplication of dense matrices A and B, followed by an element-wise + * multiplication with the sparsity pattern defined by the mask, resulting in the computation + * C = alpha * ((A * B) ∘ spy(mask)) + beta * C. + * + * This function multiplies two dense matrices A and B, and then applies an element-wise + * multiplication using the sparsity pattern provided by the mask. The result is scaled by alpha + * and added to beta times the original matrix C. + * + * @tparam value_t Data type of elements in the input/output matrices (e.g., float, double) + * @tparam index_t Type used for matrix indices + * @tparam nnz_t Type used for the number of non-zero entries in CSR format + * @tparam bitmap_t Type of the bitmap used for the mask + * + * @param[in] handle RAFT handle for resource management + * @param[in] A Input dense matrix (device_matrix_view) with shape [m, k] + * @param[in] B Input dense matrix (device_matrix_view) with shape [n, k] + * @param[in] mask Bitmap view representing the sparsity pattern (bitmap_view) with logical shape + * [m, n]. Each bit in the mask indicates whether the corresponding element pair in A and B is + * included (1) or masked out (0). + * @param[inout] C Output sparse matrix in CSR format (device_csr_matrix_view) with dense shape [m, + * n] + * @param[in] alpha Optional scalar multiplier for the product of A and B (default: 1.0 if + * std::nullopt) + * @param[in] beta Optional scalar multiplier for the original matrix C (default: 0 if std::nullopt) + */ +template +void masked_matmul(raft::resources const& handle, + raft::device_matrix_view A, + raft::device_matrix_view B, + raft::core::bitmap_view mask, + raft::device_csr_matrix_view C, + std::optional> alpha = std::nullopt, + std::optional> beta = std::nullopt) +{ + detail::masked_matmul(handle, A, B, mask, C, alpha, beta); +} + +/** @} */ // end of masked_matmul + +} // end namespace linalg +} // end namespace sparse +} // end namespace raft diff --git a/cpp/test/CMakeLists.txt b/cpp/test/CMakeLists.txt index 3ac0f281a6..cb96ce2264 100644 --- a/cpp/test/CMakeLists.txt +++ b/cpp/test/CMakeLists.txt @@ -318,6 +318,7 @@ if(BUILD_TESTS) sparse/csr_transpose.cu sparse/degree.cu sparse/filter.cu + sparse/masked_matmul.cu sparse/norm.cu sparse/normalize.cu sparse/reduce.cu diff --git a/cpp/test/sparse/masked_matmul.cu b/cpp/test/sparse/masked_matmul.cu new file mode 100644 index 0000000000..0ece716a1b --- /dev/null +++ b/cpp/test/sparse/masked_matmul.cu @@ -0,0 +1,328 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "../test_utils.cuh" + +#include +#include +#include +#include +#include + +#include + +#include + +#include +#include + +namespace raft { +namespace sparse { + +template +struct MaskedMatmulInputs { + value_t tolerance; + + index_t m; + index_t k; + index_t n; + + value_t sparsity; + + unsigned long long int seed; +}; + +template +struct sum_abs_op { + __host__ __device__ value_t operator()(const value_t& x, const value_t& y) const + { + return y >= value_t(0.0) ? (x + y) : (x - y); + } +}; + +template +::std::ostream& operator<<(::std::ostream& os, const MaskedMatmulInputs& params) +{ + os << " m: " << params.m << "\tk: " << params.k << "\tn: " << params.n + << "\tsparsity: " << params.sparsity; + + return os; +} + +template +class MaskedMatmulTest : public ::testing::TestWithParam> { + public: + MaskedMatmulTest() + : params(::testing::TestWithParam>::GetParam()), + stream(resource::get_cuda_stream(handle)), + a_data_d(0, resource::get_cuda_stream(handle)), + b_data_d(0, resource::get_cuda_stream(handle)), + bitmap_d(0, resource::get_cuda_stream(handle)), + c_indptr_d(0, resource::get_cuda_stream(handle)), + c_indices_d(0, resource::get_cuda_stream(handle)), + c_data_d(0, resource::get_cuda_stream(handle)), + c_expected_data_d(0, resource::get_cuda_stream(handle)) + { + } + + protected: + index_t create_sparse_matrix(index_t m, index_t n, float sparsity, std::vector& bitmap) + { + index_t total = static_cast(m * n); + index_t num_ones = static_cast((total * 1.0f) * sparsity); + index_t res = num_ones; + + for (auto& item : bitmap) { + item = static_cast(0); + } + + std::random_device rd; + std::mt19937 gen(rd()); + std::uniform_int_distribution dis(0, total - 1); + + while (num_ones > 0) { + index_t index = dis(gen); + + bitmap_t& element = bitmap[index / (8 * sizeof(bitmap_t))]; + index_t bit_position = index % (8 * sizeof(bitmap_t)); + + if (((element >> bit_position) & 1) == 0) { + element |= (static_cast(1) << bit_position); + num_ones--; + } + } + return res; + } + + void cpu_convert_to_csr(std::vector& bitmap, + index_t rows, + index_t cols, + std::vector& indices, + std::vector& indptr) + { + index_t offset_indptr = 0; + index_t offset_values = 0; + indptr[offset_indptr++] = 0; + + index_t index = 0; + bitmap_t element = 0; + index_t bit_position = 0; + + for (index_t i = 0; i < rows; ++i) { + for (index_t j = 0; j < cols; ++j) { + index = i * cols + j; + element = bitmap[index / (8 * sizeof(bitmap_t))]; + bit_position = index % (8 * sizeof(bitmap_t)); + + if (((element >> bit_position) & 1)) { + indices[offset_values] = static_cast(j); + offset_values++; + } + } + indptr[offset_indptr++] = static_cast(offset_values); + } + } + + void cpu_sddmm(const std::vector& A, + const std::vector& B, + std::vector& vals, + const std::vector& cols, + const std::vector& row_ptrs, + bool is_row_major_A, + bool is_row_major_B) + { + if (params.m * params.k != static_cast(A.size()) || + params.k * params.n != static_cast(B.size())) { + std::cerr << "Matrix dimensions and vector size do not match!" << std::endl; + return; + } + + for (index_t i = 0; i < params.m; ++i) { + for (index_t j = row_ptrs[i]; j < row_ptrs[i + 1]; ++j) { + value_t sum = 0; + for (index_t l = 0; l < params.k; ++l) { + index_t a_index = i * params.k + l; + index_t b_index = cols[j] * params.k + l; + sum += A[a_index] * B[b_index]; + } + vals[j] = sum; + } + } + } + + void make_data() + { + index_t a_size = params.m * params.k; + index_t b_size = params.k * params.n; + index_t c_size = params.m * params.n; + + index_t element = raft::ceildiv(params.m * params.n, index_t(sizeof(bitmap_t) * 8)); + std::vector bitmap_h(element); + + std::vector a_data_h(a_size); + std::vector b_data_h(b_size); + + a_data_d.resize(a_size, stream); + b_data_d.resize(b_size, stream); + bitmap_d.resize(bitmap_h.size(), stream); + + auto blobs_a_b = raft::make_device_matrix(handle, 1, a_size + b_size); + auto labels = raft::make_device_vector(handle, 1); + + raft::random::make_blobs(blobs_a_b.data_handle(), + labels.data_handle(), + 1, + a_size + b_size, + 1, + stream, + false, + nullptr, + nullptr, + value_t(1.0), + false, + value_t(-1.0f), + value_t(1.0f), + uint64_t(2024)); + + raft::copy(a_data_h.data(), blobs_a_b.data_handle(), a_size, stream); + raft::copy(b_data_h.data(), blobs_a_b.data_handle() + a_size, b_size, stream); + + raft::copy(a_data_d.data(), blobs_a_b.data_handle(), a_size, stream); + raft::copy(b_data_d.data(), blobs_a_b.data_handle() + a_size, b_size, stream); + + resource::sync_stream(handle); + + index_t c_true_nnz = create_sparse_matrix(params.m, params.n, params.sparsity, bitmap_h); + + std::vector c_indptr_h(params.m + 1); + std::vector c_indices_h(c_true_nnz); + std::vector c_data_h(c_true_nnz); + + cpu_convert_to_csr(bitmap_h, params.m, params.n, c_indices_h, c_indptr_h); + + c_data_d.resize(c_data_h.size(), stream); + + update_device(c_data_d.data(), c_data_h.data(), c_data_h.size(), stream); + update_device(bitmap_d.data(), bitmap_h.data(), bitmap_h.size(), stream); + resource::sync_stream(handle); + + cpu_sddmm(a_data_h, b_data_h, c_data_h, c_indices_h, c_indptr_h, true, true); + + c_indptr_d.resize(c_indptr_h.size(), stream); + c_indices_d.resize(c_indices_h.size(), stream); + c_expected_data_d.resize(c_data_h.size(), stream); + + update_device(c_indptr_d.data(), c_indptr_h.data(), c_indptr_h.size(), stream); + update_device(c_indices_d.data(), c_indices_h.data(), c_indices_h.size(), stream); + update_device(c_expected_data_d.data(), c_data_h.data(), c_data_h.size(), stream); + + resource::sync_stream(handle); + } + + void SetUp() override { make_data(); } + + void Run() + { + auto A = + raft::make_device_matrix_view(a_data_d.data(), params.m, params.k); + auto B = + raft::make_device_matrix_view(b_data_d.data(), params.n, params.k); + + auto mask = + raft::core::bitmap_view(bitmap_d.data(), params.m, params.n); + + auto c_structure = raft::make_device_compressed_structure_view( + c_indptr_d.data(), + c_indices_d.data(), + params.m, + params.n, + static_cast(c_indices_d.size())); + + auto C = raft::make_device_csr_matrix_view(c_data_d.data(), c_structure); + + raft::sparse::linalg::masked_matmul(handle, A, B, mask, C); + + resource::sync_stream(handle); + + ASSERT_TRUE(raft::devArrMatch(c_expected_data_d.data(), + C.get_elements().data(), + c_expected_data_d.size(), + raft::CompareApprox(params.tolerance), + stream)); + + thrust::device_ptr expected_data_ptr = + thrust::device_pointer_cast(c_expected_data_d.data()); + value_t sum_abs = thrust::reduce(thrust::cuda::par.on(stream), + expected_data_ptr, + expected_data_ptr + c_expected_data_d.size(), + value_t(0.0f), + sum_abs_op()); + value_t avg = sum_abs / (1.0f * c_expected_data_d.size()); + + ASSERT_GE(avg, (params.tolerance * static_cast(0.001f))); + } + + raft::resources handle; + cudaStream_t stream; + MaskedMatmulInputs params; + + rmm::device_uvector a_data_d; + rmm::device_uvector b_data_d; + rmm::device_uvector bitmap_d; + + rmm::device_uvector c_indptr_d; + rmm::device_uvector c_indices_d; + rmm::device_uvector c_data_d; + + rmm::device_uvector c_expected_data_d; +}; + +using MaskedMatmulTestF = MaskedMatmulTest; +TEST_P(MaskedMatmulTestF, Result) { Run(); } + +using MaskedMatmulTestD = MaskedMatmulTest; +TEST_P(MaskedMatmulTestD, Result) { Run(); } + +const std::vector> sddmm_inputs_f = { + {0.0001f, 10, 5, 32, 0.1, 1234ULL}, + {0.0001f, 1024, 32, 1024, 0.1, 1234ULL}, + {0.0003f, 32, 1024, 1024, 0.2, 1234ULL}, + {0.001f, 1024, 1024, 1024, 0.19, 1234ULL}, + {0.0001f, 1024, 1024, 32, 0.3, 1234ULL}, + {0.0001f, 1024, 32, 1024, 0.4, 1234ULL}, + {0.0003f, 32, 1024, 1024, 0.19, 1234ULL}, + {0.001f, 1024, 1024, 1024, 0.1, 1234ULL}}; + +const std::vector> sddmm_inputs_d = { + {0.0001f, 10, 5, 32, 0.01, 1234ULL}, + {0.0001f, 1024, 32, 1024, 0.1, 1234ULL}, + {0.0001f, 32, 1024, 1024, 0.2, 1234ULL}, + {0.0001f, 1024, 1024, 1024, 0.19, 1234ULL}, + {0.0001f, 1024, 1024, 32, 0.3, 1234ULL}, + {0.0001f, 1024, 32, 1024, 0.4, 1234ULL}, + {0.0001f, 32, 1024, 1024, 0.19, 1234ULL}, + {0.0001f, 1024, 1024, 1024, 0.1, 1234ULL}}; + +INSTANTIATE_TEST_CASE_P(MaskedMatmulTest, MaskedMatmulTestF, ::testing::ValuesIn(sddmm_inputs_f)); + +INSTANTIATE_TEST_CASE_P(MaskedMatmulTest, MaskedMatmulTestD, ::testing::ValuesIn(sddmm_inputs_d)); + +} // namespace sparse +} // namespace raft