Skip to content

Commit

Permalink
SPMM: Fix initial constraint release of k
Browse files Browse the repository at this point in the history
Signed-off-by: Joseph Schuchart <joseph.schuchart@stonybrook.edu>
  • Loading branch information
devreal committed Jul 16, 2024
1 parent 95df973 commit fbfdf0c
Showing 1 changed file with 20 additions and 22 deletions.
42 changes: 20 additions & 22 deletions examples/spmm/spmm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -289,10 +289,10 @@ class SpMM25D {
, b_rows_of_col_(b_rows_of_col)
, a_rows_of_col_(a_rows_of_col)
, b_cols_of_row_(b_cols_of_row)
, k_cnt_(a_cols_of_row_.size()+1)
, k_cnt_(a_rows_of_col_.size()+1)
, ij_keymap_(std::move(ij_keymap))
, ijk_keymap_(std::move(ijk_keymap))
, parallel_bcasts_(parallel_bcasts) {
, parallel_bcasts_(std::max(parallel_bcasts, 1L)) {
Edge<Key<2>, void> a_ctl, b_ctl;
Edge<Key<2>, int> a_rowctl, b_colctl; // TODO: can we have multiple control inputs per TT?
auto constraint = ttg::make_shared_constraint<ttg::SequencedKeysConstraint<Key<2>>>(USE_AUTO_CONSTRAINT);
Expand All @@ -306,7 +306,7 @@ class SpMM25D {
local_bcast_b_ = std::make_unique<LocalBcastB>(local_b_ijk_, b_ijk_, a_rows_of_col_, ijk_keymap_);
multiplyadd_ = std::make_unique<MultiplyAdd>(a_ijk_, b_ijk_, c_ijk_, c_ij_p_, a_cols_of_row_,
b_rows_of_col_, mTiles, nTiles, ijk_keymap_, constraint,
k_cnt_, parallel_bcasts);
k_cnt_, parallel_bcasts_);

reduce_c_ = std::make_unique<ReduceC>(c_ij_p_, c, ij_keymap_);
reduce_c_->template set_input_reducer<0>(
Expand All @@ -325,7 +325,8 @@ class SpMM25D {
std::vector<unsigned long> first_k_map(world.size(), std::numeric_limits<unsigned long>::max());
std::size_t max_k = a_rows_of_col_.size();
std::vector<std::size_t> k_cnt;
k_cnt.resize(a_cols_of_row_.size(), 0);
k_cnt.resize(a_rows_of_col_.size(), 0);
int release_k = 0;
for (auto i = 0ul; i != a_cols_of_row_.size(); ++i) {
if (a_cols_of_row_[i].empty()) continue;
for (auto j = 0ul; j != b_rows_of_col_.size(); ++j) {
Expand Down Expand Up @@ -362,19 +363,12 @@ class SpMM25D {
k_cnt_[i++].store(c, std::memory_order_relaxed);
}

// release the first bcast(s)
auto pbcasts = parallel_bcasts_;
auto release_k = k_cnt.begin();
auto release_k_ = release_k; // this will be released
/* release the first parallel_bcasts_ k that are non-zero */
auto k_cnt_iter = k_cnt.begin();
do {
release_k = std::find_if(release_k, k_cnt.end(), [](std::size_t c){ return c > 0; });
if (k_cnt.end() == release_k) {
break;
}
release_k_ = release_k;
++release_k;
} while (--pbcasts > 0);
constraint->release(*release_k_);
k_cnt_iter = std::find_if(k_cnt_iter, k_cnt.end(), [](auto c){ return c > 0; });
} while (++k_cnt_iter != k_cnt.end() && std::distance(k_cnt_iter, k_cnt.end()) < parallel_bcasts_);
constraint->release(std::distance(k_cnt.begin(), k_cnt_iter));

TTGUNUSED(bcast_a_);
TTGUNUSED(bcast_b_);
Expand Down Expand Up @@ -612,12 +606,16 @@ class SpMM25D {
(have_next_k ? std::to_string(next_k) : "does not exist"));
// release the constraint on the next round of broadcasts
{
std::size_t release_k = k;
std::size_t bcasts_ahead = parallel_bcasts_;
assert(k_cnt_.size() > release_k);
if (0 == k_cnt_[release_k].fetch_sub(1, std::memory_order_relaxed)-1) {
// this was the last gemm in this k, find the one to release
while (++release_k < k_cnt_.size() && (0 == k_cnt_[release_k].load(std::memory_order_relaxed) || --bcasts_ahead > 0))
assert(k_cnt_.size() > k);
long cnt = k_cnt_[k].fetch_sub(1, std::memory_order_relaxed)-1;
assert(cnt >= 0);
if (0 == cnt) {
auto release_k = k;
auto bcasts_ahead = parallel_bcasts_;
// this was the last gemm in this k, find the next one to release
while (++release_k < k_cnt_.size() &&
(0 == k_cnt_[release_k].load(std::memory_order_relaxed)
|| --bcasts_ahead > 0))
{ }
constraint->release(release_k);
}
Expand Down

0 comments on commit fbfdf0c

Please sign in to comment.