Skip to content

Commit

Permalink
Minor Updates to Sparse Structures (rapidsai#1432)
Browse files Browse the repository at this point in the history
  • Loading branch information
divyegala authored and ahendriksen committed Apr 27, 2023
1 parent 2e9c61c commit 867dffc
Show file tree
Hide file tree
Showing 7 changed files with 109 additions and 102 deletions.
35 changes: 22 additions & 13 deletions cpp/include/raft/core/coo_matrix.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,12 +71,6 @@ class coordinate_structure_view
{
}

/**
* Create a view from this view. Note that this is for interface compatibility
* @return
*/
view_type view() { return view_type(rows_, cols_, this->get_n_rows(), this->get_n_cols()); }

/**
* Return span containing underlying rows array
* @return span containing underlying rows array
Expand Down Expand Up @@ -209,6 +203,10 @@ class coo_matrix_view
coordinate_structure_view<RowType, ColType, NZType, is_device>,
is_device> {
public:
using element_type = ElementType;
using row_type = RowType;
using col_type = ColType;
using nnz_type = NZType;
coo_matrix_view(raft::span<ElementType, is_device> element_span,
coordinate_structure_view<RowType, ColType, NZType, is_device> structure_view)
: sparse_matrix_view<ElementType,
Expand Down Expand Up @@ -238,6 +236,9 @@ class coo_matrix
ContainerPolicy> {
public:
using element_type = ElementType;
using row_type = RowType;
using col_type = ColType;
using nnz_type = NZType;
using structure_view_type = typename structure_type::view_type;
using container_type = typename ContainerPolicy<ElementType>::container_type;
using sparse_matrix_type =
Expand All @@ -258,14 +259,9 @@ class coo_matrix
// Constructor that owns the data but not the structure
template <SparsityType sparsity_type_ = get_sparsity_type(),
typename = typename std::enable_if_t<sparsity_type_ == SparsityType::PRESERVING>>
coo_matrix(raft::resources const& handle, std::shared_ptr<structure_type> structure) noexcept(
coo_matrix(raft::resources const& handle, structure_type structure) noexcept(
std::is_nothrow_default_constructible_v<container_type>)
: sparse_matrix_type(handle, structure){};
/**
* Return a view of the structure underlying this matrix
* @return
*/
structure_view_type structure_view() { return this->structure_.get()->view(); }

/**
* Initialize the sparsity on this instance if it was not known upon construction
Expand All @@ -277,7 +273,20 @@ class coo_matrix
void initialize_sparsity(NZType nnz)
{
sparse_matrix_type::initialize_sparsity(nnz);
this->structure_.get()->initialize_sparsity(nnz);
this->structure_.initialize_sparsity(nnz);
}

/**
* Return a view of the structure underlying this matrix
* @return
*/
structure_view_type structure_view()
{
if constexpr (get_sparsity_type() == SparsityType::OWNING) {
return this->structure_.view();
} else {
return this->structure_;
}
}
};
} // namespace raft
26 changes: 17 additions & 9 deletions cpp/include/raft/core/csr_matrix.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,12 +87,6 @@ class compressed_structure_view
*/
span<indices_type, is_device> get_indices() override { return indices_; }

/**
* Create a view from this view. Note that this is for interface compatibility
* @return
*/
view_type view() { return view_type(indptr_, indices_, this->get_n_cols()); }

protected:
raft::span<indptr_type, is_device> indptr_;
raft::span<indices_type, is_device> indices_;
Expand Down Expand Up @@ -221,6 +215,10 @@ class csr_matrix_view
compressed_structure_view<IndptrType, IndicesType, NZType, is_device>,
is_device> {
public:
using element_type = ElementType;
using indptr_type = IndptrType;
using indices_type = IndicesType;
using nnz_type = NZType;
csr_matrix_view(
raft::span<ElementType, is_device> element_span,
compressed_structure_view<IndptrType, IndicesType, NZType, is_device> structure_view)
Expand Down Expand Up @@ -249,6 +247,9 @@ class csr_matrix
ContainerPolicy> {
public:
using element_type = ElementType;
using indptr_type = IndptrType;
using indices_type = IndicesType;
using nnz_type = NZType;
using structure_view_type = typename structure_type::view_type;
static constexpr auto get_sparsity_type() { return sparsity_type; }
using sparse_matrix_type =
Expand All @@ -271,7 +272,7 @@ class csr_matrix

template <SparsityType sparsity_type_ = get_sparsity_type(),
typename = typename std::enable_if_t<sparsity_type_ == SparsityType::PRESERVING>>
csr_matrix(raft::resources const& handle, std::shared_ptr<structure_type> structure) noexcept(
csr_matrix(raft::resources const& handle, structure_type structure) noexcept(
std::is_nothrow_default_constructible_v<container_type>)
: sparse_matrix_type(handle, structure){};

Expand All @@ -284,13 +285,20 @@ class csr_matrix
void initialize_sparsity(NZType nnz)
{
sparse_matrix_type::initialize_sparsity(nnz);
this->structure_.get()->initialize_sparsity(nnz);
this->structure_.initialize_sparsity(nnz);
}

/**
* Return a view of the structure underlying this matrix
* @return
*/
structure_view_type structure_view() { return this->structure_.get()->view(); }
structure_view_type structure_view()
{
if constexpr (get_sparsity_type() == SparsityType::OWNING) {
return this->structure_.view();
} else {
return this->structure_;
}
}
};
} // namespace raft
36 changes: 16 additions & 20 deletions cpp/include/raft/core/device_coo_matrix.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -174,16 +174,15 @@ auto make_device_coo_matrix(raft::resources const& handle,
* @tparam ColType
* @tparam NZType
* @param[in] handle raft handle for managing expensive device resources
* @param[in] structure_ a sparsity-preserving coordinate structural view
* @param[in] structure a sparsity-preserving coordinate structural view
* @return a sparsity-preserving sparse matrix in coordinate (coo) format
*/
template <typename ElementType, typename RowType, typename ColType, typename NZType>
auto make_device_coo_matrix(raft::resources const& handle,
device_coordinate_structure_view<RowType, ColType, NZType> structure_)
device_coordinate_structure_view<RowType, ColType, NZType> structure)
{
return device_sparsity_preserving_coo_matrix<ElementType, RowType, ColType, NZType>(
handle,
std::make_shared<device_coordinate_structure_view<RowType, ColType, NZType>>(structure_));
return device_sparsity_preserving_coo_matrix<ElementType, RowType, ColType, NZType>(handle,
structure);
}

/**
Expand Down Expand Up @@ -212,16 +211,15 @@ auto make_device_coo_matrix(raft::resources const& handle,
* @tparam ColType
* @tparam NZType
* @param[in] ptr a pointer to array of nonzero matrix elements on device (size nnz)
* @param[in] structure_ a sparsity-preserving coordinate structural view
* @param[in] structure a sparsity-preserving coordinate structural view
* @return a sparsity-preserving sparse matrix in coordinate (coo) format
*/
template <typename ElementType, typename RowType, typename ColType, typename NZType>
auto make_device_coo_matrix_view(
ElementType* ptr, device_coordinate_structure_view<RowType, ColType, NZType> structure_)
ElementType* ptr, device_coordinate_structure_view<RowType, ColType, NZType> structure)
{
return device_coo_matrix_view<ElementType, RowType, ColType, NZType>(
raft::device_span<ElementType>(ptr, structure_.get_nnz()),
std::make_shared<device_coordinate_structure_view<RowType, ColType, NZType>>(structure_));
raft::device_span<ElementType>(ptr, structure.get_nnz()), structure);
}

/**
Expand Down Expand Up @@ -251,19 +249,17 @@ auto make_device_coo_matrix_view(
* @tparam ColType
* @tparam NZType
* @param[in] elements a device span containing nonzero matrix elements (size nnz)
* @param[in] structure_ a sparsity-preserving coordinate structural view
* @param[in] structure a sparsity-preserving coordinate structural view
* @return
*/
template <typename ElementType, typename RowType, typename ColType, typename NZType>
auto make_device_coo_matrix_view(
raft::device_span<ElementType> elements,
device_coordinate_structure_view<RowType, ColType, NZType> structure_)
device_coordinate_structure_view<RowType, ColType, NZType> structure)
{
RAFT_EXPECTS(elements.size() == structure_.get_nnz(),
RAFT_EXPECTS(elements.size() == structure.get_nnz(),
"Size of elements must be equal to the nnz from the structure");
return device_coo_matrix_view<ElementType, RowType, ColType, NZType>(
elements,
std::make_shared<device_coordinate_structure_view<RowType, ColType, NZType>>(structure_));
return device_coo_matrix_view<ElementType, RowType, ColType, NZType>(elements, structure);
}

/**
Expand Down Expand Up @@ -338,7 +334,7 @@ auto make_device_coordinate_structure(raft::resources const& handle,
* @return a sparsity-preserving coordinate structural view
*/
template <typename RowType, typename ColType, typename NZType>
auto make_device_coo_structure_view(
auto make_device_coordinate_structure_view(
RowType* rows, ColType* cols, RowType n_rows, ColType n_cols, NZType nnz)
{
return device_coordinate_structure_view<RowType, ColType, NZType>(
Expand Down Expand Up @@ -376,10 +372,10 @@ auto make_device_coo_structure_view(
* @return a sparsity-preserving coordinate structural view
*/
template <typename RowType, typename ColType, typename NZType>
auto make_device_coo_structure_view(raft::device_span<RowType> rows,
raft::device_span<ColType> cols,
RowType n_rows,
ColType n_cols)
auto make_device_coordinate_structure_view(raft::device_span<RowType> rows,
raft::device_span<ColType> cols,
RowType n_rows,
ColType n_cols)
{
return device_coordinate_structure_view<RowType, ColType, NZType>(rows, cols, n_rows, n_cols);
}
Expand Down
31 changes: 14 additions & 17 deletions cpp/include/raft/core/device_csr_matrix.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ auto make_device_csr_matrix(raft::device_resources const& handle,
* @tparam IndicesType
* @tparam NZType
* @param[in] handle raft handle for managing expensive device resources
* @param[in] structure_ a sparsity-preserving compressed structural view
* @param[in] structure a sparsity-preserving compressed structural view
* @return a sparsity-preserving sparse matrix in compressed (csr) format
*/
template <typename ElementType,
Expand All @@ -198,12 +198,10 @@ template <typename ElementType,
typename NZType = uint64_t>
auto make_device_csr_matrix(
raft::device_resources const& handle,
device_compressed_structure_view<IndptrType, IndicesType, NZType> structure_)
device_compressed_structure_view<IndptrType, IndicesType, NZType> structure)
{
return device_sparsity_preserving_csr_matrix<ElementType, IndptrType, IndicesType, NZType>(
handle,
std::make_shared<device_compressed_structure_view<IndptrType, IndicesType, NZType>>(
structure_));
handle, structure);
}

/**
Expand Down Expand Up @@ -232,18 +230,18 @@ auto make_device_csr_matrix(
* @tparam IndicesType
* @tparam NZType
* @param[in] ptr a pointer to array of nonzero matrix elements on device (size nnz)
* @param[in] structure_ a sparsity-preserving compressed sparse structural view
* @param[in] structure a sparsity-preserving compressed sparse structural view
* @return a sparsity-preserving csr matrix view
*/
template <typename ElementType,
typename IndptrType,
typename IndicesType,
typename NZType = uint64_t>
auto make_device_csr_matrix_view(
ElementType* ptr, device_compressed_structure_view<IndptrType, IndicesType, NZType> structure_)
ElementType* ptr, device_compressed_structure_view<IndptrType, IndicesType, NZType> structure)
{
return device_csr_matrix_view<ElementType, IndptrType, IndicesType, NZType>(
raft::device_span<ElementType>(ptr, structure_.get_nnz()), std::make_shared(structure_));
raft::device_span<ElementType>(ptr, structure.get_nnz()), structure);
}

/**
Expand Down Expand Up @@ -273,7 +271,7 @@ auto make_device_csr_matrix_view(
* @tparam IndicesType
* @tparam NZType
* @param[in] elements device span containing array of matrix elements (size nnz)
* @param[in] structure_ a sparsity-preserving structural view
* @param[in] structure a sparsity-preserving structural view
* @return a sparsity-preserving csr matrix view
*/
template <typename ElementType,
Expand All @@ -282,12 +280,11 @@ template <typename ElementType,
typename NZType = uint64_t>
auto make_device_csr_matrix_view(
raft::device_span<ElementType> elements,
device_compressed_structure_view<IndptrType, IndicesType, NZType> structure_)
device_compressed_structure_view<IndptrType, IndicesType, NZType> structure)
{
RAFT_EXPECTS(elements.size() == structure_.get_nnz(),
RAFT_EXPECTS(elements.size() == structure.get_nnz(),
"Size of elements must be equal to the nnz from the structure");
return device_csr_matrix_view<ElementType, IndptrType, IndicesType, NZType>(
elements, std::make_shared(structure_));
return device_csr_matrix_view<ElementType, IndptrType, IndicesType, NZType>(elements, structure);
}

/**
Expand Down Expand Up @@ -365,7 +362,7 @@ auto make_device_compressed_structure(raft::device_resources const& handle,
* @return a sparsity-preserving compressed structural view
*/
template <typename IndptrType, typename IndicesType, typename NZType = uint64_t>
auto make_device_csr_structure_view(
auto make_device_compressed_structure_view(
IndptrType* indptr, IndicesType* indices, IndptrType n_rows, IndicesType n_cols, NZType nnz)
{
return device_compressed_structure_view<IndptrType, IndicesType, NZType>(
Expand Down Expand Up @@ -408,9 +405,9 @@ auto make_device_csr_structure_view(
*
*/
template <typename IndptrType, typename IndicesType, typename NZType = uint64_t>
auto make_device_csr_structure_view(raft::device_span<IndptrType> indptr,
raft::device_span<IndicesType> indices,
IndicesType n_cols)
auto make_device_compressed_structure_view(raft::device_span<IndptrType> indptr,
raft::device_span<IndicesType> indices,
IndicesType n_cols)
{
return device_compressed_structure_view<IndptrType, IndicesType, NZType>(indptr, indices, n_cols);
}
Expand Down
Loading

0 comments on commit 867dffc

Please sign in to comment.