Skip to content

Commit

Permalink
SPMM: Add option to disable device mapping hint
Browse files Browse the repository at this point in the history
Signed-off-by: Joseph Schuchart <joseph.schuchart@stonybrook.edu>
  • Loading branch information
devreal committed Aug 16, 2024
1 parent a15f724 commit 8982c25
Showing 1 changed file with 19 additions and 12 deletions.
31 changes: 19 additions & 12 deletions examples/spmm/spmm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -314,15 +314,16 @@ class SpMM25D {
const std::vector<std::vector<long>> &b_cols_of_row,
const std::vector<std::vector<long>> &b_rows_of_col, const std::vector<int> &mTiles,
const std::vector<int> &nTiles, const std::vector<int> &kTiles, Keymap2 ij_keymap, Keymap3 ijk_keymap,
long R, long parallel_bcasts = 1)
long R, long parallel_bcasts = 1, bool enable_device_map = true)
: a_cols_of_row_(a_cols_of_row)
, b_rows_of_col_(b_rows_of_col)
, a_rows_of_col_(a_rows_of_col)
, b_cols_of_row_(b_cols_of_row)
, k_cnt_(a_rows_of_col_.size()+1)
, ij_keymap_(std::move(ij_keymap))
, ijk_keymap_(std::move(ijk_keymap))
, parallel_bcasts_(std::max(parallel_bcasts, 1L)) {
, parallel_bcasts_(std::max(parallel_bcasts, 1L))
, enable_device_map_(enable_device_map) {
Edge<Key<2>, void> a_ctl, b_ctl;
Edge<Key<2>, int> a_rowctl, b_colctl; // TODO: can we have multiple control inputs per TT?
auto constraint = ttg::make_shared_constraint<ttg::SequencedKeysConstraint<Key<2>>>(USE_AUTO_CONSTRAINT);
Expand All @@ -336,7 +337,7 @@ class SpMM25D {
local_bcast_b_ = std::make_unique<LocalBcastB>(local_b_ijk_, b_ijk_, a_rows_of_col_, ijk_keymap_);
multiplyadd_ = std::make_unique<MultiplyAdd<Space>>(a_ijk_, b_ijk_, c_ijk_, c_ij_p_, a_cols_of_row_,
b_rows_of_col_, mTiles, nTiles, ijk_keymap_, constraint,
k_cnt_, parallel_bcasts_);
k_cnt_, parallel_bcasts_, enable_device_map_);

reduce_c_ = std::make_unique<ReduceC>(c_ij_p_, c, ij_keymap_);
reduce_c_->template set_input_reducer<0>(
Expand Down Expand Up @@ -609,7 +610,7 @@ class SpMM25D {
const std::vector<int> &nTiles, const Keymap3 &ijk_keymap,
std::shared_ptr<ttg::SequencedKeysConstraint<Key<2>>> constraint,
std::vector<std::atomic<std::size_t>>& k_cnt,
std::size_t parallel_bcasts)
std::size_t parallel_bcasts, bool enable_device_map)
: baseT(edges(a_ijk, b_ijk, c_ijk), edges(c, c_ijk), "SpMM25D::MultiplyAdd", {"a_ijk", "b_ijk", "c_ijk"},
{"c_ij", "c_ijk"}, ijk_keymap)
, a_cols_of_row_(a_cols_of_row)
Expand All @@ -619,11 +620,13 @@ class SpMM25D {
, parallel_bcasts_(parallel_bcasts) {
this->set_priomap([=,this](const Key<3> &ijk) { return this->prio(ijk); }); // map a key to an integral priority value
if constexpr (is_device_space) {
auto num_devices = ttg::device::num_devices();
this->set_devicemap(
[num_devices](const Key<3> &ijk){
return ((((uint64_t)ijk[0]) << 32) + ijk[1]) % num_devices;
});
if (enable_device_map) {
auto num_devices = ttg::device::num_devices();
this->set_devicemap(
[num_devices](const Key<3> &ijk){
return ((((uint64_t)ijk[0]) << 32) + ijk[1]) % num_devices;
});
}
}
// for each {i,j} determine first k that contributes AND belongs to this node,
// initialize input {i,j,first_k} flow to 0
Expand Down Expand Up @@ -871,6 +874,7 @@ class SpMM25D {
Keymap2 ij_keymap_;
Keymap3 ijk_keymap_;
long parallel_bcasts_;
bool enable_device_map_;
};

class Control : public TT<void, std::tuple<Out<Key<3>>>, Control> {
Expand Down Expand Up @@ -1442,7 +1446,7 @@ static void timed_measurement(SpMatrix<> &A, SpMatrix<> &B, const std::function<
const std::vector<std::vector<long>> &b_cols_of_row,
const std::vector<std::vector<long>> &b_rows_of_col, std::vector<int> &mTiles,
std::vector<int> &nTiles, std::vector<int> &kTiles, int M, int N, int K, int minTs,
int maxTs, int P, int Q, int R, int parallel_bcasts) {
int maxTs, int P, int Q, int R, int parallel_bcasts, bool enable_device_map) {
int MT = (int)A.rows();
int NT = (int)B.cols();
int KT = (int)A.cols();
Expand All @@ -1469,7 +1473,7 @@ static void timed_measurement(SpMatrix<> &A, SpMatrix<> &B, const std::function<
assert(!has_value(c_status));
// SpMM25D a_times_b(world, eA, eB, eC, A, B);
SpMM25D<> a_times_b(eA, eB, eC, A, B, a_cols_of_row, a_rows_of_col, b_cols_of_row, b_rows_of_col,
mTiles, nTiles, kTiles, ij_keymap, ijk_keymap, R, parallel_bcasts);
mTiles, nTiles, kTiles, ij_keymap, ijk_keymap, R, parallel_bcasts, enable_device_map);
TTGUNUSED(a);
TTGUNUSED(b);
TTGUNUSED(a_times_b);
Expand Down Expand Up @@ -1646,6 +1650,9 @@ int main(int argc, char **argv) {
parallel_bcasts = std::stol(pStr);
}

/* whether we set a device mapping */
bool enable_device_map = !cmdOptionExists(argv, argv+argc, "--default-device-map");

std::string PStr(getCmdOption(argv, argv + argc, "-P"));
P = parseOption(PStr, P);
std::string QStr(getCmdOption(argv, argv + argc, "-Q"));
Expand Down Expand Up @@ -1812,7 +1819,7 @@ int main(int argc, char **argv) {
#endif // TTG_USE_PARSEC
timed_measurement(A, B, ij_keymap, ijk_keymap, tiling_type, gflops, avg_nb, Adensity, Bdensity,
a_cols_of_row, a_rows_of_col, b_cols_of_row, b_rows_of_col, mTiles,
nTiles, kTiles, M, N, K, minTs, maxTs, P, Q, R, parallel_bcasts);
nTiles, kTiles, M, N, K, minTs, maxTs, P, Q, R, parallel_bcasts, enable_device_map);
#if TTG_USE_PARSEC
/* reset PaRSEC's load tracking */
parsec_devices_reset_load(default_execution_context().impl().context());
Expand Down

0 comments on commit 8982c25

Please sign in to comment.