Skip to content

Commit

Permalink
SPMM: Add missing includes
Browse files Browse the repository at this point in the history
Signed-off-by: Joseph Schuchart <joseph.schuchart@stonybrook.edu>
  • Loading branch information
devreal committed Jul 26, 2024
1 parent d0d8855 commit 7bd3804
Showing 1 changed file with 10 additions and 6 deletions.
16 changes: 10 additions & 6 deletions examples/spmm/spmm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,15 +31,18 @@
#include "ttg.h"
#include "../ttg_matrix.h"

using namespace ttg;

#include "ttg/util/future.h"

#include "ttg/util/multiindex.h"
#include "ttg/serialization/std/pair.h"

#include "ttg/util/bug.h"

#include "devicetensor.h"
#include "devicegemm.h"

using namespace ttg;

#if defined(TTG_ENABLE_CUDA)
#define HAVE_SPMM_DEVICE 1
static constexpr ttg::ExecutionSpace space = ttg::ExecutionSpace::CUDA;
Expand Down Expand Up @@ -572,10 +575,6 @@ class SpMM25D {
ttg::typelist<const Blk, const Blk, Blk>> {
static constexpr const bool is_device_space = (Space_ != ttg::ExecutionSpace::Host);
using task_return_type = std::conditional_t<is_device_space, ttg::device::Task, void>;
/* communicate to the runtime which device we support (if any) */
static constexpr bool have_cuda_op = (Space_ == ttg::ExecutionSpace::CUDA);
static constexpr bool have_hip_op = (Space_ == ttg::ExecutionSpace::HIP);
static constexpr bool have_level_zero_op = (Space_ == ttg::ExecutionSpace::L0);

void release_next_k(long k) {
assert(k_cnt_.size() > k);
Expand All @@ -597,6 +596,11 @@ class SpMM25D {
public:
using baseT = typename MultiplyAdd::ttT;

/* communicate to the runtime which device we support (if any) */
static constexpr bool have_cuda_op = (Space_ == ttg::ExecutionSpace::CUDA);
static constexpr bool have_hip_op = (Space_ == ttg::ExecutionSpace::HIP);
static constexpr bool have_level_zero_op = (Space_ == ttg::ExecutionSpace::L0);

MultiplyAdd(Edge<Key<3>, Blk> &a_ijk, Edge<Key<3>, Blk> &b_ijk, Edge<Key<3>, Blk> &c_ijk, Edge<Key<2>, Blk> &c,
const std::vector<std::vector<long>> &a_cols_of_row,
const std::vector<std::vector<long>> &b_rows_of_col, const std::vector<int> &mTiles,
Expand Down

0 comments on commit 7bd3804

Please sign in to comment.