From a29c9b2f7a3206b78f6fb8d3cc56c32e246a9526 Mon Sep 17 00:00:00 2001 From: Gregor Olenik Date: Wed, 18 Oct 2023 10:09:42 +0200 Subject: [PATCH] Fix reference compute mean impl, add test --- reference/matrix/dense_kernels.cpp | 8 ++++---- reference/test/matrix/dense_kernels.cpp | 8 ++++++++ 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/reference/matrix/dense_kernels.cpp b/reference/matrix/dense_kernels.cpp index ff69dcf2684..47df46b3c86 100644 --- a/reference/matrix/dense_kernels.cpp +++ b/reference/matrix/dense_kernels.cpp @@ -407,11 +407,11 @@ void compute_mean(std::shared_ptr exec, result->at(0, j) = zero(); } - for (size_type i = 0; i < x->get_size()[0]; ++i) { - for (size_type j = 0; j < x->get_size()[1]; ++j) { - result->at(0, i) += x->at(i, j); + for (size_type i = 0; i < x->get_size()[1]; ++i) { + for (size_type j = 0; j < x->get_size()[0]; ++j) { + result->at(0, i) += x->at(j, i); } - result->at(0, i) /= static_cast(x->get_size()[1]); + result->at(0, i) /= static_cast(x->get_size()[0]); } } diff --git a/reference/test/matrix/dense_kernels.cpp b/reference/test/matrix/dense_kernels.cpp index 532bd14ec95..b776f426794 100644 --- a/reference/test/matrix/dense_kernels.cpp +++ b/reference/test/matrix/dense_kernels.cpp @@ -35,6 +35,7 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #include #include +#include #include @@ -702,6 +703,13 @@ TYPED_TEST(Dense, ComputesMean) { using Mtx = typename TestFixture::Mtx; using T = typename TestFixture::value_type; + + auto iota = Mtx::create(this->exec, gko::dim<2>{10, 1}); + std::iota(iota->get_values(), iota->get_values() + 10, 1); + auto iota_result = Mtx::create(this->exec, gko::dim<2>{1, 1}); + iota->compute_mean(iota_result.get()); + GKO_EXPECT_NEAR(iota_result->at(0, 0), T{5.5}, r::value * 10); + auto result = Mtx::create(this->exec, gko::dim<2>{1, 3}); this->mtx4->compute_mean(result.get());