Skip to content

Commit

Permalink
[core] allow filling the device_matrix_data
Browse files Browse the repository at this point in the history
The main use case is in combination with `sum_duplicates` and `remove_zeros` to simplify the assembly setup.
  • Loading branch information
MarcelKoch committed Sep 17, 2024
1 parent b5745ac commit 1c286a1
Show file tree
Hide file tree
Showing 3 changed files with 78 additions and 3 deletions.
18 changes: 16 additions & 2 deletions core/base/device_matrix_data.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,17 @@ GKO_REGISTER_OPERATION(sort_row_major, components::sort_row_major);

template <typename ValueType, typename IndexType>
device_matrix_data<ValueType, IndexType>::device_matrix_data(
std::shared_ptr<const Executor> exec, dim<2> size, size_type num_entries)
std::shared_ptr<const Executor> exec, dim<2> size, size_type num_entries,
fill_mode fm)
: size_{size},
row_idxs_{exec, num_entries},
col_idxs_{exec, num_entries},
values_{exec, num_entries}
{}
{
if (fm == fill_mode::zero) {
fill_zero();
}
}


template <typename ValueType, typename IndexType>
Expand Down Expand Up @@ -93,6 +98,15 @@ device_matrix_data<ValueType, IndexType>::create_from_host(
}


template <typename ValueType, typename IndexType>
void device_matrix_data<ValueType, IndexType>::fill_zero()
{
row_idxs_.fill(0);
col_idxs_.fill(0);
values_.fill(ValueType{0});
}


template <typename ValueType, typename IndexType>
void device_matrix_data<ValueType, IndexType>::sort_row_major()
{
Expand Down
17 changes: 16 additions & 1 deletion include/ginkgo/core/base/device_matrix_data.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,14 @@

namespace gko {

/**
* Enum that describes how allocated data is filled
*/
enum class fill_mode {
uninitialized, //!< no fill operation is done
zero //!< fill with zeros
};


/**
* This type is a device-side equivalent to matrix_data.
Expand Down Expand Up @@ -48,9 +56,11 @@ class device_matrix_data {
* @param exec the executor to be used to store the matrix entries
* @param size the matrix dimensions
* @param num_entries the number of entries to be stored
* @param fm describes how the data is filled
*/
explicit device_matrix_data(std::shared_ptr<const Executor> exec,
dim<2> size = {}, size_type num_entries = 0);
dim<2> size = {}, size_type num_entries = 0,
fill_mode fm = fill_mode::uninitialized);

/**
* Initializes a device_matrix_data object by copying an existing object on
Expand Down Expand Up @@ -114,6 +124,11 @@ class device_matrix_data {
static device_matrix_data create_from_host(
std::shared_ptr<const Executor> exec, const host_type& data);

/**
* Fills the matrix entries with zeros
*/
void fill_zero();

/**
* Sorts the matrix entries in row-major order
* This means that they will be sorted by row index first, and then by
Expand Down
46 changes: 46 additions & 0 deletions test/base/device_matrix_data_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,30 @@ TYPED_TEST(DeviceMatrixData, ConstructsCorrectly)
}


TYPED_TEST(DeviceMatrixData, ConstructsWithZerosCorrectly)
{
using value_type = typename TestFixture::value_type;
using index_type = typename TestFixture::index_type;

gko::device_matrix_data<value_type, index_type> local_data{
this->exec, gko::dim<2>{4, 3}, 10, gko::fill_mode::zero};

ASSERT_EQ((gko::dim<2>{4, 3}), local_data.get_size());
ASSERT_EQ(this->exec, local_data.get_executor());
ASSERT_EQ(local_data.get_num_stored_elements(), 10);
auto arrays = local_data.empty_out();
auto expected_row_idxs = gko::array<index_type>(this->exec, 10);
auto expected_col_idxs = gko::array<index_type>(this->exec, 10);
auto expected_values = gko::array<value_type>(this->exec, 10);
expected_row_idxs.fill(0);
expected_col_idxs.fill(0);
expected_values.fill(0.0);
GKO_ASSERT_ARRAY_EQ(arrays.row_idxs, expected_row_idxs);
GKO_ASSERT_ARRAY_EQ(arrays.col_idxs, expected_col_idxs);
GKO_ASSERT_ARRAY_EQ(arrays.values, expected_values);
}


TYPED_TEST(DeviceMatrixData, CopyConstructsOnOtherExecutorCorrectly)
{
using value_type = typename TestFixture::value_type;
Expand Down Expand Up @@ -241,6 +265,28 @@ TYPED_TEST(DeviceMatrixData, CopiesToHost)
}


TYPED_TEST(DeviceMatrixData, CanFillEntriesWithZeros)
{
using value_type = typename TestFixture::value_type;
using index_type = typename TestFixture::index_type;
using device_matrix_data = gko::device_matrix_data<value_type, index_type>;
auto device_data = device_matrix_data{this->exec, gko::dim<2>{4, 3}, 10};

device_data.fill_zero();

auto arrays = device_data.empty_out();
auto expected_row_idxs = gko::array<index_type>(this->exec, 10);
auto expected_col_idxs = gko::array<index_type>(this->exec, 10);
auto expected_values = gko::array<value_type>(this->exec, 10);
expected_row_idxs.fill(0);
expected_col_idxs.fill(0);
expected_values.fill(0.0);
GKO_ASSERT_ARRAY_EQ(arrays.row_idxs, expected_row_idxs);
GKO_ASSERT_ARRAY_EQ(arrays.col_idxs, expected_col_idxs);
GKO_ASSERT_ARRAY_EQ(arrays.values, expected_values);
}


TYPED_TEST(DeviceMatrixData, SortsRowMajor)
{
using value_type = typename TestFixture::value_type;
Expand Down

0 comments on commit 1c286a1

Please sign in to comment.