Skip to content

Commit

Permalink
sparse matrix refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
Ahdhn committed Jul 17, 2024
1 parent 8f0ba20 commit ff22645
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 55 deletions.
57 changes: 2 additions & 55 deletions include/rxmesh/matrix/sparse_matrix.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@
#include "cusparse.h"
#include "rxmesh/attribute.h"
#include "rxmesh/context.h"
#include "rxmesh/query.cuh"
#include "rxmesh/types.h"

#include "thrust/device_ptr.h"
#include "thrust/execution_policy.h"
Expand All @@ -14,6 +12,8 @@
#include "cusolverSp_LOWLEVEL_PREVIEW.h"
#include "rxmesh/matrix/dense_matrix.cuh"

#include "rxmesh/matrix/sparse_matrix_kernels.cuh"

namespace rxmesh {

/**
Expand All @@ -40,59 +40,6 @@ enum class Reorder
NSTDIS = 3
};

namespace detail {

// this is the function for the CSR calculation
template <uint32_t blockThreads, typename IndexT = int>
__global__ static void sparse_mat_prescan(const rxmesh::Context context,
IndexT* row_ptr)
{
using namespace rxmesh;

auto init_lambda = [&](VertexHandle& v_id, const VertexIterator& iter) {
auto ids = v_id.unpack();
uint32_t patch_id = ids.first;
uint16_t local_id = ids.second;
row_ptr[context.vertex_prefix()[patch_id] + local_id] = iter.size() + 1;
};

auto block = cooperative_groups::this_thread_block();
Query<blockThreads> query(context);
ShmemAllocator shrd_alloc;
query.dispatch<Op::VV>(block, shrd_alloc, init_lambda);
}

template <uint32_t blockThreads, typename IndexT = int>
__global__ static void sparse_mat_col_fill(const rxmesh::Context context,
IndexT* row_ptr,
IndexT* col_idx)
{
using namespace rxmesh;

auto col_fillin = [&](VertexHandle& v_id, const VertexIterator& iter) {
auto ids = v_id.unpack();
uint32_t patch_id = ids.first;
uint16_t local_id = ids.second;
col_idx[row_ptr[context.vertex_prefix()[patch_id] + local_id]] =
context.vertex_prefix()[patch_id] + local_id;
for (uint32_t v = 0; v < iter.size(); ++v) {
auto s_ids = iter[v].unpack();
uint32_t s_patch_id = s_ids.first;
uint16_t s_local_id = s_ids.second;
col_idx[row_ptr[context.vertex_prefix()[patch_id] + local_id] + v +
1] = context.vertex_prefix()[s_patch_id] + s_local_id;
}
};

auto block = cooperative_groups::this_thread_block();
Query<blockThreads> query(context);
ShmemAllocator shrd_alloc;
query.dispatch<Op::VV>(block, shrd_alloc, col_fillin);
}

} // namespace detail


/**
* @brief Sparse matrix that represent the VV connectivity, i.e., it
* is a square matrix with number of rows/cols is equal to number of vertices
Expand Down
62 changes: 62 additions & 0 deletions include/rxmesh/matrix/sparse_matrix_kernels.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
#pragma once
#include "cusolverSp.h"
#include "cusparse.h"

#include "rxmesh/context.h"
#include "rxmesh/query.cuh"

namespace rxmesh {

namespace detail {

// this is the function for the CSR calculation
template <uint32_t blockThreads, typename IndexT = int>
__global__ static void sparse_mat_prescan(const rxmesh::Context context,
IndexT* row_ptr)
{
using namespace rxmesh;

auto init_lambda = [&](VertexHandle& v_id, const VertexIterator& iter) {
auto ids = v_id.unpack();
uint32_t patch_id = ids.first;
uint16_t local_id = ids.second;
row_ptr[context.vertex_prefix()[patch_id] + local_id] = iter.size() + 1;
};

auto block = cooperative_groups::this_thread_block();
Query<blockThreads> query(context);
ShmemAllocator shrd_alloc;
query.dispatch<Op::VV>(block, shrd_alloc, init_lambda);
}

template <uint32_t blockThreads, typename IndexT = int>
__global__ static void sparse_mat_col_fill(const rxmesh::Context context,
IndexT* row_ptr,
IndexT* col_idx)
{
using namespace rxmesh;

auto col_fillin = [&](VertexHandle& v_id, const VertexIterator& iter) {
auto ids = v_id.unpack();
uint32_t patch_id = ids.first;
uint16_t local_id = ids.second;
col_idx[row_ptr[context.vertex_prefix()[patch_id] + local_id]] =
context.vertex_prefix()[patch_id] + local_id;
for (uint32_t v = 0; v < iter.size(); ++v) {
auto s_ids = iter[v].unpack();
uint32_t s_patch_id = s_ids.first;
uint16_t s_local_id = s_ids.second;
col_idx[row_ptr[context.vertex_prefix()[patch_id] + local_id] + v +
1] = context.vertex_prefix()[s_patch_id] + s_local_id;
}
};

auto block = cooperative_groups::this_thread_block();
Query<blockThreads> query(context);
ShmemAllocator shrd_alloc;
query.dispatch<Op::VV>(block, shrd_alloc, col_fillin);
}

} // namespace detail

} // namespace rxmesh

0 comments on commit ff22645

Please sign in to comment.