Skip to content

Commit

Permalink
update srow on csr_builder destruction
Browse files Browse the repository at this point in the history
this also makes both builders non-movable
  • Loading branch information
upsj committed Jan 8, 2020
1 parent 8d55780 commit 1e7314d
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 2 deletions.
6 changes: 6 additions & 0 deletions core/matrix/coo_builder.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,12 @@ class CooBuilder {
/** Initializes a CsrBuilder from an existing COO matrix. */
CooBuilder(Coo<ValueType, IndexType> *matrix) : matrix_{matrix} {}

// make this type non-movable
CooBuilder(const CooBuilder &) = delete;
CooBuilder(CooBuilder &&) = delete;
CooBuilder &operator=(const CooBuilder &) = delete;
CooBuilder &operator=(CooBuilder &&) = delete;

private:
Coo<ValueType, IndexType> *matrix_;
};
Expand Down
2 changes: 0 additions & 2 deletions core/matrix/csr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,6 @@ void Csr<ValueType, IndexType>::apply_impl(const LinOp *b, LinOp *x) const
// if b is a CSR matrix, we compute a SpGeMM
auto x_csr = as<TCsr>(x);
this->get_executor()->run(csr::make_spgemm(this, b_csr, x_csr));
x_csr->make_srow();
} else {
// otherwise we assume that b is dense and compute a SpMV/SpMM
this->get_executor()->run(
Expand All @@ -112,7 +111,6 @@ void Csr<ValueType, IndexType>::apply_impl(const LinOp *alpha, const LinOp *b,
this->get_executor()->run(
csr::make_advanced_spgemm(as<Dense>(alpha), this, b_csr,
as<Dense>(beta), x_copy.get(), x_csr));
x_csr->make_srow();
} else {
// otherwise we assume that b is dense and compute a SpMV/SpMM
this->get_executor()->run(
Expand Down
9 changes: 9 additions & 0 deletions core/matrix/csr_builder.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,15 @@ class CsrBuilder {
/** Initializes a CsrBuilder from an existing CSR matrix. */
CsrBuilder(Csr<ValueType, IndexType> *matrix) : matrix_{matrix} {}

/** Updates the internal matrix data structures at destruction. */
~CsrBuilder() { matrix_->make_srow(); }

// make this type non-movable
CsrBuilder(const CsrBuilder &) = delete;
CsrBuilder(CsrBuilder &&) = delete;
CsrBuilder &operator=(const CsrBuilder &) = delete;
CsrBuilder &operator=(CsrBuilder &&) = delete;

private:
Csr<ValueType, IndexType> *matrix_;
};
Expand Down
23 changes: 23 additions & 0 deletions core/test/matrix/csr_builder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,4 +70,27 @@ TEST_F(CsrBuilder, ReturnsCorrectArrays)
}


TEST_F(CsrBuilder, UpdatesSrowOnDestruction)
{
struct mock_strategy : public Mtx::strategy_type {
virtual void process(const gko::Array<gko::int32> &,
gko::Array<gko::int32> *)
{
*was_called = true;
}
virtual int64_t clac_size(const int64_t nnz) { return 0; }

mock_strategy(bool &flag) : Mtx::strategy_type(""), was_called(&flag) {}

bool *was_called;
};
bool was_called{};
mtx->set_strategy(std::make_shared<mock_strategy>(was_called));

gko::matrix::CsrBuilder<>{mtx.get()};

ASSERT_TRUE(was_called);
}


} // namespace

0 comments on commit 1e7314d

Please sign in to comment.