Skip to content

Commit

Permalink
tests for coo2coo and a fix for a bug in coo2csr and coo2coo
Browse files Browse the repository at this point in the history
  • Loading branch information
superwhiskers committed Jul 23, 2024
1 parent 3f3b271 commit 617404e
Show file tree
Hide file tree
Showing 2 changed files with 117 additions and 32 deletions.
4 changes: 2 additions & 2 deletions resolve/matrix/Utilities.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ namespace ReSolve
// spaces is equivalent to the amount of nonzeroes in the row, and if not,
// shifts every subsequent row back the amount of unused spaces

for (index_type column = 0; column < n_columns - 1; column++) {
for (index_type column = 0; column < n_columns; column++) {
index_type column_nnz = partitions[column + 1] - partitions[column];
if (used[column] != column_nnz) {
index_type correction = column_nnz - used[column];
Expand Down Expand Up @@ -363,7 +363,7 @@ namespace ReSolve
// spaces is equivalent to the amount of nonzeroes in the row, and if not,
// shifts every subsequent row back the amount of unused spaces

for (index_type row = 0; row < n_rows - 1; row++) {
for (index_type row = 0; row < n_rows; row++) {
index_type row_nnz = csr_rows[row + 1] - csr_rows[row];
if (used[row] != row_nnz) {
index_type correction = row_nnz - used[row];
Expand Down
145 changes: 115 additions & 30 deletions tests/unit/matrix/MatrixConversionTests.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
#pragma once

#include <iomanip>

#include <resolve/matrix/Coo.hpp>
#include <resolve/matrix/Csr.hpp>
#include <resolve/matrix/Utilities.hpp>
Expand Down Expand Up @@ -44,12 +46,23 @@ namespace ReSolve

status *= ReSolve::matrix::coo2csr(&A, &B, memory::HOST) == 0;
status *= this->verifyAnswer(&B,
simple_symmetric_expected_n_,
simple_symmetric_expected_m_,
simple_symmetric_expected_nnz_,
simple_symmetric_expected_i_,
simple_symmetric_expected_j_,
simple_symmetric_expected_a_);
simple_symmetric_expected_csr_n_,
simple_symmetric_expected_csr_m_,
simple_symmetric_expected_csr_nnz_,
simple_symmetric_expected_csr_i_,
simple_symmetric_expected_csr_j_,
simple_symmetric_expected_csr_a_);

ReSolve::matrix::Coo C(A.getNumRows(), A.getNumColumns(), 0);

status *= ReSolve::matrix::coo2coo(&A, &C, memory::HOST) == 0;
status *= this->verifyAnswer(&C,
simple_symmetric_expected_coo_col_n_,
simple_symmetric_expected_coo_col_m_,
simple_symmetric_expected_coo_col_nnz_,
simple_symmetric_expected_coo_col_i_,
simple_symmetric_expected_coo_col_j_,
simple_symmetric_expected_coo_col_a_);

return status.report(__func__);
}
Expand All @@ -72,12 +85,23 @@ namespace ReSolve

status *= ReSolve::matrix::coo2csr(&A, &B, memory::HOST) == 0;
status *= this->verifyAnswer(&B,
simple_symmetric_expected_n_,
simple_symmetric_expected_m_,
simple_symmetric_expected_nnz_,
simple_symmetric_expected_i_,
simple_symmetric_expected_j_,
simple_symmetric_expected_a_);
simple_symmetric_expected_csr_n_,
simple_symmetric_expected_csr_m_,
simple_symmetric_expected_csr_nnz_,
simple_symmetric_expected_csr_i_,
simple_symmetric_expected_csr_j_,
simple_symmetric_expected_csr_a_);

ReSolve::matrix::Coo C(A.getNumRows(), A.getNumColumns(), 0);

status *= ReSolve::matrix::coo2coo(&A, &C, memory::HOST) == 0;
status *= this->verifyAnswer(&C,
simple_symmetric_expected_coo_col_n_,
simple_symmetric_expected_coo_col_m_,
simple_symmetric_expected_coo_col_nnz_,
simple_symmetric_expected_coo_col_i_,
simple_symmetric_expected_coo_col_j_,
simple_symmetric_expected_coo_col_a_);

return status.report(__func__);
}
Expand Down Expand Up @@ -107,6 +131,17 @@ namespace ReSolve
simple_main_diagonal_only_i_j_,
simple_main_diagonal_only_a_);

ReSolve::matrix::Coo C(A.getNumRows(), A.getNumColumns(), 0);

status *= ReSolve::matrix::coo2coo(&A, &C, memory::HOST) == 0;
status *= this->verifyAnswer(&C,
simple_main_diagonal_only_n_,
simple_main_diagonal_only_m_,
simple_main_diagonal_only_nnz_,
simple_main_diagonal_only_i_j_,
simple_main_diagonal_only_i_j_,
simple_main_diagonal_only_a_);

return status.report(__func__);
}

Expand All @@ -126,23 +161,41 @@ namespace ReSolve

status *= ReSolve::matrix::coo2csr(&A, &B, memory::HOST) == 0;
status *= this->verifyAnswer(&B,
simple_asymmetric_expected_n_,
simple_asymmetric_expected_m_,
simple_asymmetric_expected_nnz_,
simple_asymmetric_expected_i_,
simple_asymmetric_expected_j_,
simple_asymmetric_expected_a_);
simple_asymmetric_expected_csr_n_,
simple_asymmetric_expected_csr_m_,
simple_asymmetric_expected_csr_nnz_,
simple_asymmetric_expected_csr_i_,
simple_asymmetric_expected_csr_j_,
simple_asymmetric_expected_csr_a_);

ReSolve::matrix::Coo C(A.getNumRows(), A.getNumColumns(), 0);

status *= ReSolve::matrix::coo2coo(&A, &C, memory::HOST) == 0;
status *= this->verifyAnswer(&C,
simple_asymmetric_expected_coo_col_n_,
simple_asymmetric_expected_coo_col_m_,
simple_asymmetric_expected_coo_col_nnz_,
simple_asymmetric_expected_coo_col_i_,
simple_asymmetric_expected_coo_col_j_,
simple_asymmetric_expected_coo_col_a_);

return status.report(__func__);
}

private:
const index_type simple_symmetric_expected_n_ = 5;
const index_type simple_symmetric_expected_m_ = 5;
const index_type simple_symmetric_expected_nnz_ = 8;
index_type simple_symmetric_expected_i_[8] = {0, 1, 1, 1, 2, 3, 3, 4};
index_type simple_symmetric_expected_j_[8] = {0, 1, 2, 3, 1, 1, 4, 3};
real_type simple_symmetric_expected_a_[8] = {3.0, 7.0, 11.0, 7.0, 11.0, 7.0, 8.0, 8.0};
const index_type simple_symmetric_expected_csr_n_ = 5;
const index_type simple_symmetric_expected_csr_m_ = 5;
const index_type simple_symmetric_expected_csr_nnz_ = 8;
index_type simple_symmetric_expected_csr_i_[8] = {0, 1, 1, 1, 2, 3, 3, 4};
index_type simple_symmetric_expected_csr_j_[8] = {0, 1, 2, 3, 1, 1, 4, 3};
real_type simple_symmetric_expected_csr_a_[8] = {3.0, 7.0, 11.0, 7.0, 11.0, 7.0, 8.0, 8.0};

const index_type simple_symmetric_expected_coo_col_n_ = 5;
const index_type simple_symmetric_expected_coo_col_m_ = 5;
const index_type simple_symmetric_expected_coo_col_nnz_ = 8;
index_type simple_symmetric_expected_coo_col_i_[8] = {0, 1, 2, 3, 1, 1, 4, 3};
index_type simple_symmetric_expected_coo_col_j_[8] = {0, 1, 1, 1, 2, 3, 3, 4};
real_type simple_symmetric_expected_coo_col_a_[8] = {3.0, 7.0, 11.0, 7.0, 11.0, 7.0, 8.0, 8.0};

const index_type simple_upper_unexpanded_symmetric_n_ = 5;
const index_type simple_upper_unexpanded_symmetric_m_ = 5;
Expand Down Expand Up @@ -171,12 +224,19 @@ namespace ReSolve
index_type simple_asymmetric_j_[10] = {0, 1, 3, 1, 1, 4, 4, 3, 2, 2};
real_type simple_asymmetric_a_[10] = {2.0, 4.0, 7.0, 9.0, 6.0, 7.0, 8.0, 8.0, 5.0, 6.0};

const index_type simple_asymmetric_expected_n_ = 5;
const index_type simple_asymmetric_expected_m_ = 5;
const index_type simple_asymmetric_expected_nnz_ = 8;
index_type simple_asymmetric_expected_i_[8] = {0, 1, 1, 1, 2, 3, 3, 4};
index_type simple_asymmetric_expected_j_[8] = {0, 1, 2, 3, 1, 1, 4, 3};
real_type simple_asymmetric_expected_a_[8] = {2.0, 4.0, 11.0, 7.0, 9.0, 6.0, 15.0, 8.0};
const index_type simple_asymmetric_expected_csr_n_ = 5;
const index_type simple_asymmetric_expected_csr_m_ = 5;
const index_type simple_asymmetric_expected_csr_nnz_ = 8;
index_type simple_asymmetric_expected_csr_i_[8] = {0, 1, 1, 1, 2, 3, 3, 4};
index_type simple_asymmetric_expected_csr_j_[8] = {0, 1, 2, 3, 1, 1, 4, 3};
real_type simple_asymmetric_expected_csr_a_[8] = {2.0, 4.0, 11.0, 7.0, 9.0, 6.0, 15.0, 8.0};

const index_type simple_asymmetric_expected_coo_col_n_ = 5;
const index_type simple_asymmetric_expected_coo_col_m_ = 5;
const index_type simple_asymmetric_expected_coo_col_nnz_ = 8;
index_type simple_asymmetric_expected_coo_col_i_[8] = {0, 1, 2, 3, 1, 1, 4, 3};
index_type simple_asymmetric_expected_coo_col_j_[8] = {0, 1, 1, 1, 2, 3, 3, 4};
real_type simple_asymmetric_expected_coo_col_a_[8] = {2.0, 4.0, 9.0, 6.0, 11.0, 7.0, 8.0, 15.0};

bool verifyAnswer(matrix::Csr* A,
const index_type& n,
Expand Down Expand Up @@ -206,6 +266,31 @@ namespace ReSolve

return true;
}

bool verifyAnswer(matrix::Coo* A,
const index_type& n,
const index_type& m,
const index_type& nnz,
index_type* is,
index_type* js,
real_type* as)
{
if (n != A->getNumRows() || m != A->getNumColumns() || nnz != A->getNnz()) {
return false;
}

index_type* rows = A->getRowData(memory::HOST);
index_type* columns = A->getColData(memory::HOST);
real_type* values = A->getValues(memory::HOST);

for (index_type i = 0; i < nnz; i++) {
if (rows[i] != is[i] || columns[i] != js[i] || !isEqual(values[i], as[i])) {
return false;
}
}

return true;
}
};
} // namespace tests
} // namespace ReSolve

0 comments on commit 617404e

Please sign in to comment.