diff --git a/core/base/device_matrix_data.cpp b/core/base/device_matrix_data.cpp index a2e5d6e7044..11d2536268f 100644 --- a/core/base/device_matrix_data.cpp +++ b/core/base/device_matrix_data.cpp @@ -29,12 +29,17 @@ GKO_REGISTER_OPERATION(sort_row_major, components::sort_row_major); template device_matrix_data::device_matrix_data( - std::shared_ptr exec, dim<2> size, size_type num_entries) + std::shared_ptr exec, dim<2> size, size_type num_entries, + fill_mode fm) : size_{size}, row_idxs_{exec, num_entries}, col_idxs_{exec, num_entries}, values_{exec, num_entries} -{} +{ + if (fm == fill_mode::zero) { + fill_zero(); + } +} template @@ -93,6 +98,15 @@ device_matrix_data::create_from_host( } +template +void device_matrix_data::fill_zero() +{ + row_idxs_.fill(0); + col_idxs_.fill(0); + values_.fill(ValueType{0}); +} + + template void device_matrix_data::sort_row_major() { diff --git a/include/ginkgo/core/base/device_matrix_data.hpp b/include/ginkgo/core/base/device_matrix_data.hpp index 35e3f300954..dfdd08b261c 100644 --- a/include/ginkgo/core/base/device_matrix_data.hpp +++ b/include/ginkgo/core/base/device_matrix_data.hpp @@ -16,6 +16,14 @@ namespace gko { +/** + * Enum that describes how allocated data is filled + */ +enum class fill_mode { + uninitialized, //!< no fill operation is done + zero //!< fill with zeros +}; + /** * This type is a device-side equivalent to matrix_data. @@ -48,9 +56,11 @@ class device_matrix_data { * @param exec the executor to be used to store the matrix entries * @param size the matrix dimensions * @param num_entries the number of entries to be stored + * @param fm describes how the data is filled */ explicit device_matrix_data(std::shared_ptr exec, - dim<2> size = {}, size_type num_entries = 0); + dim<2> size = {}, size_type num_entries = 0, + fill_mode fm = fill_mode::uninitialized); /** * Initializes a device_matrix_data object by copying an existing object on @@ -114,6 +124,11 @@ class device_matrix_data { static device_matrix_data create_from_host( std::shared_ptr exec, const host_type& data); + /** + * Fills the matrix entries with zeros + */ + void fill_zero(); + /** * Sorts the matrix entries in row-major order * This means that they will be sorted by row index first, and then by diff --git a/test/base/device_matrix_data_kernels.cpp b/test/base/device_matrix_data_kernels.cpp index ffadbcfb245..039cc9eac20 100644 --- a/test/base/device_matrix_data_kernels.cpp +++ b/test/base/device_matrix_data_kernels.cpp @@ -116,6 +116,30 @@ TYPED_TEST(DeviceMatrixData, ConstructsCorrectly) } +TYPED_TEST(DeviceMatrixData, ConstructsWithZerosCorrectly) +{ + using value_type = typename TestFixture::value_type; + using index_type = typename TestFixture::index_type; + + gko::device_matrix_data local_data{ + this->exec, gko::dim<2>{4, 3}, 10, gko::fill_mode::zero}; + + ASSERT_EQ((gko::dim<2>{4, 3}), local_data.get_size()); + ASSERT_EQ(this->exec, local_data.get_executor()); + ASSERT_EQ(local_data.get_num_stored_elements(), 10); + auto arrays = local_data.empty_out(); + auto expected_row_idxs = gko::array(this->exec, 10); + auto expected_col_idxs = gko::array(this->exec, 10); + auto expected_values = gko::array(this->exec, 10); + expected_row_idxs.fill(0); + expected_col_idxs.fill(0); + expected_values.fill(0.0); + GKO_ASSERT_ARRAY_EQ(arrays.row_idxs, expected_row_idxs); + GKO_ASSERT_ARRAY_EQ(arrays.col_idxs, expected_col_idxs); + GKO_ASSERT_ARRAY_EQ(arrays.values, expected_values); +} + + TYPED_TEST(DeviceMatrixData, CopyConstructsOnOtherExecutorCorrectly) { using value_type = typename TestFixture::value_type; @@ -241,6 +265,28 @@ TYPED_TEST(DeviceMatrixData, CopiesToHost) } +TYPED_TEST(DeviceMatrixData, CanFillEntriesWithZeros) +{ + using value_type = typename TestFixture::value_type; + using index_type = typename TestFixture::index_type; + using device_matrix_data = gko::device_matrix_data; + auto device_data = device_matrix_data{this->exec, gko::dim<2>{4, 3}, 10}; + + device_data.fill_zero(); + + auto arrays = device_data.empty_out(); + auto expected_row_idxs = gko::array(this->exec, 10); + auto expected_col_idxs = gko::array(this->exec, 10); + auto expected_values = gko::array(this->exec, 10); + expected_row_idxs.fill(0); + expected_col_idxs.fill(0); + expected_values.fill(0.0); + GKO_ASSERT_ARRAY_EQ(arrays.row_idxs, expected_row_idxs); + GKO_ASSERT_ARRAY_EQ(arrays.col_idxs, expected_col_idxs); + GKO_ASSERT_ARRAY_EQ(arrays.values, expected_values); +} + + TYPED_TEST(DeviceMatrixData, SortsRowMajor) { using value_type = typename TestFixture::value_type;