diff --git a/examples/spmm/spmm.cc b/examples/spmm/spmm.cc index bd0982c06..7cf656e4e 100644 --- a/examples/spmm/spmm.cc +++ b/examples/spmm/spmm.cc @@ -289,10 +289,10 @@ class SpMM25D { , b_rows_of_col_(b_rows_of_col) , a_rows_of_col_(a_rows_of_col) , b_cols_of_row_(b_cols_of_row) - , k_cnt_(a_cols_of_row_.size()+1) + , k_cnt_(a_rows_of_col_.size()+1) , ij_keymap_(std::move(ij_keymap)) , ijk_keymap_(std::move(ijk_keymap)) - , parallel_bcasts_(parallel_bcasts) { + , parallel_bcasts_(std::max(parallel_bcasts, 1L)) { Edge, void> a_ctl, b_ctl; Edge, int> a_rowctl, b_colctl; // TODO: can we have multiple control inputs per TT? auto constraint = ttg::make_shared_constraint>>(USE_AUTO_CONSTRAINT); @@ -306,7 +306,7 @@ class SpMM25D { local_bcast_b_ = std::make_unique(local_b_ijk_, b_ijk_, a_rows_of_col_, ijk_keymap_); multiplyadd_ = std::make_unique(a_ijk_, b_ijk_, c_ijk_, c_ij_p_, a_cols_of_row_, b_rows_of_col_, mTiles, nTiles, ijk_keymap_, constraint, - k_cnt_, parallel_bcasts); + k_cnt_, parallel_bcasts_); reduce_c_ = std::make_unique(c_ij_p_, c, ij_keymap_); reduce_c_->template set_input_reducer<0>( @@ -325,7 +325,8 @@ class SpMM25D { std::vector first_k_map(world.size(), std::numeric_limits::max()); std::size_t max_k = a_rows_of_col_.size(); std::vector k_cnt; - k_cnt.resize(a_cols_of_row_.size(), 0); + k_cnt.resize(a_rows_of_col_.size(), 0); + int release_k = 0; for (auto i = 0ul; i != a_cols_of_row_.size(); ++i) { if (a_cols_of_row_[i].empty()) continue; for (auto j = 0ul; j != b_rows_of_col_.size(); ++j) { @@ -362,19 +363,12 @@ class SpMM25D { k_cnt_[i++].store(c, std::memory_order_relaxed); } - // release the first bcast(s) - auto pbcasts = parallel_bcasts_; - auto release_k = k_cnt.begin(); - auto release_k_ = release_k; // this will be released + /* release the first parallel_bcasts_ k that are non-zero */ + auto k_cnt_iter = k_cnt.begin(); do { - release_k = std::find_if(release_k, k_cnt.end(), [](std::size_t c){ return c > 0; }); - if (k_cnt.end() == release_k) { - break; - } - release_k_ = release_k; - ++release_k; - } while (--pbcasts > 0); - constraint->release(*release_k_); + k_cnt_iter = std::find_if(k_cnt_iter, k_cnt.end(), [](auto c){ return c > 0; }); + } while (++k_cnt_iter != k_cnt.end() && std::distance(k_cnt_iter, k_cnt.end()) < parallel_bcasts_); + constraint->release(std::distance(k_cnt.begin(), k_cnt_iter)); TTGUNUSED(bcast_a_); TTGUNUSED(bcast_b_); @@ -612,12 +606,16 @@ class SpMM25D { (have_next_k ? std::to_string(next_k) : "does not exist")); // release the constraint on the next round of broadcasts { - std::size_t release_k = k; - std::size_t bcasts_ahead = parallel_bcasts_; - assert(k_cnt_.size() > release_k); - if (0 == k_cnt_[release_k].fetch_sub(1, std::memory_order_relaxed)-1) { - // this was the last gemm in this k, find the one to release - while (++release_k < k_cnt_.size() && (0 == k_cnt_[release_k].load(std::memory_order_relaxed) || --bcasts_ahead > 0)) + assert(k_cnt_.size() > k); + long cnt = k_cnt_[k].fetch_sub(1, std::memory_order_relaxed)-1; + assert(cnt >= 0); + if (0 == cnt) { + auto release_k = k; + auto bcasts_ahead = parallel_bcasts_; + // this was the last gemm in this k, find the next one to release + while (++release_k < k_cnt_.size() && + (0 == k_cnt_[release_k].load(std::memory_order_relaxed) + || --bcasts_ahead > 0)) { } constraint->release(release_k); }