Skip to content

Commit

Permalink
Allow dynamic specification of device task inputs
Browse files Browse the repository at this point in the history
ttg::device::Input can be filled with buffers and passed to
ttg::device::select().

Signed-off-by: Joseph Schuchart <joseph.schuchart@stonybrook.edu>
  • Loading branch information
devreal committed Oct 16, 2024
1 parent ba03182 commit 9535f7f
Show file tree
Hide file tree
Showing 8 changed files with 181 additions and 34 deletions.
12 changes: 6 additions & 6 deletions ttg/ttg.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,12 @@
#include "ttg/config.h"
#include "ttg/fwd.h"

#if defined(TTG_USE_PARSEC)
#include "ttg/parsec/ttg.h"
#elif defined(TTG_USE_MADNESS)
#include "ttg/madness/ttg.h"
#endif // TTG_USE_{PARSEC|MADNESS}

#include "ttg/runtimes.h"
#include "ttg/util/demangle.h"
#include "ttg/util/hash.h"
Expand Down Expand Up @@ -36,12 +42,6 @@
#include "ttg/device/device.h"
#include "ttg/device/task.h"

#if defined(TTG_USE_PARSEC)
#include "ttg/parsec/ttg.h"
#elif defined(TTG_USE_MADNESS)
#include "ttg/madness/ttg.h"
#endif // TTG_USE_{PARSEC|MADNESS}

// these headers use the default backend
#include "ttg/run.h"

Expand Down
1 change: 1 addition & 0 deletions ttg/ttg/base/tt.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

#include "ttg/base/terminal.h"
#include "ttg/util/demangle.h"
#include "ttg/util/trace.h"

namespace ttg {

Expand Down
62 changes: 61 additions & 1 deletion ttg/ttg/device/task.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include <type_traits>
#include <span>


#include "ttg/fwd.h"
#include "ttg/impl_selector.h"
#include "ttg/ptr.h"
Expand All @@ -14,12 +15,58 @@
namespace ttg::device {

namespace detail {

struct device_input_data_t {
using impl_data_t = decltype(TTG_IMPL_NS::buffer_data(std::declval<ttg::Buffer<int>>()));

impl_data_t impl_data;
ttg::scope scope;
bool is_const;
};

template <typename... Ts>
struct to_device_t {
std::tuple<std::add_lvalue_reference_t<Ts>...> ties;
};

/* extract buffer information from to_device_t */
template<typename... Ts, std::size_t... Is>
auto extract_buffer_data(detail::to_device_t<Ts...>& a) {
return std::array<device_input_data_t, sizeof...(Is)>{
{TTG_IMPL_NS::buffer_data(std::get<Is>(a.ties)),
std::get<Is>(a.ties).scope(), std::get<Is>(a.ties)}...};
}
} // namespace detail

struct Input {
private:
std::vector<detail::device_input_data_t> m_data;

public:
Input() { }
template<typename... Args>
Input(Args&&... args)
: m_data{{TTG_IMPL_NS::buffer_data(args), args.scope(), std::is_const_v<Args>}...}
{ }

template<typename T>
void add(T&& v) {
m_data.emplace_back(TTG_IMPL_NS::buffer_data(v), v.scope(), std::is_const_v<T>);
}

ttg::span<detail::device_input_data_t> span() {
return ttg::span(m_data);
}
};

namespace detail {
// overload for Input
template <>
struct to_device_t<Input> {
Input& input;
};
} // namespace detail

/**
* Select a device to execute on based on the provided buffer and scratchspace objects.
* Returns an object that should be awaited on using \c co_await.
Expand All @@ -33,6 +80,11 @@ namespace ttg::device {
return detail::to_device_t<std::remove_reference_t<Args>...>{std::tie(std::forward<Args>(args)...)};
}

[[nodiscard]]
inline auto select(Input& input) {
return detail::to_device_t<Input>{input};
}

namespace detail {

enum ttg_device_coro_state {
Expand Down Expand Up @@ -558,7 +610,15 @@ namespace ttg::device {

template<typename... Ts>
ttg::suspend_always await_transform(detail::to_device_t<Ts...>&& a) {
bool need_transfer = !(TTG_IMPL_NS::register_device_memory(a.ties));
auto arr = detail::extract_buffer_data(a);
bool need_transfer = !(TTG_IMPL_NS::register_device_memory(ttg::span(arr)));
/* TODO: are we allowed to not suspend here and launch the kernel directly? */
m_state = ttg::device::detail::TTG_DEVICE_CORO_WAIT_TRANSFER;
return {};
}

ttg::suspend_always await_transform(detail::to_device_t<Input>&& a) {
bool need_transfer = !(TTG_IMPL_NS::register_device_memory(a.input.span()));
/* TODO: are we allowed to not suspend here and launch the kernel directly? */
m_state = ttg::device::detail::TTG_DEVICE_CORO_WAIT_TRANSFER;
return {};
Expand Down
2 changes: 2 additions & 0 deletions ttg/ttg/madness/buffer.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@

#include "ttg/serialization/traits.h"

#include "ttg/device/device.h"

namespace ttg_madness {

/// A runtime-managed buffer mirrored between host and device memory
Expand Down
8 changes: 7 additions & 1 deletion ttg/ttg/parsec/buffer.h
Original file line number Diff line number Diff line change
Expand Up @@ -363,8 +363,14 @@ struct Buffer : public detail::ttg_parsec_data_wrapper_t
m_data->device_copies[i]->version = 0;
}
}
m_data->owner_device = 0;
}
m_data->owner_device = 0;
}

ttg::scope scope() const {
/* if the host owns the data and has a version of zero we only have to allocate data */
return (m_data->device_copies[0]->version == 0 && m_data->owner_device == 0)
? ttg::scope::Allocate : ttg::scope::SyncIn;
}

void prefer_device(ttg::device::Device dev) {
Expand Down
108 changes: 92 additions & 16 deletions ttg/ttg/parsec/devicefunc.h
Original file line number Diff line number Diff line change
@@ -1,24 +1,14 @@
#ifndef TTG_PARSEC_DEVICEFUNC_H
#define TTG_PARSEC_DEVICEFUNC_H

#if defined(TTG_HAVE_CUDART)
#include <cuda.h>
#endif

#include "ttg/parsec/task.h"
#include <parsec.h>
#include <parsec/mca/device/device_gpu.h>

#if defined(PARSEC_HAVE_DEV_CUDA_SUPPORT)
#include <parsec/mca/device/cuda/device_cuda.h>
#elif defined(PARSEC_HAVE_DEV_HIP_SUPPORT)
#include <parsec/mca/device/hip/device_hip.h>
#endif // PARSEC_HAVE_DEV_CUDA_SUPPORT

namespace ttg_parsec {
namespace detail {
template<typename... Views, std::size_t I, std::size_t... Is>
inline bool register_device_memory(std::tuple<Views&...> &views, std::index_sequence<I, Is...>) {
bool register_device_memory(std::tuple<Views&...> &views, std::index_sequence<I, Is...>) {
static_assert(I < MAX_PARAM_COUNT,
"PaRSEC only supports MAX_PARAM_COUNT device input/outputs. "
"Increase MAX_PARAM_COUNT and recompile PaRSEC/TTG.");
Expand Down Expand Up @@ -88,7 +78,7 @@ namespace ttg_parsec {
* with the currently executing task. Returns true if all memory
* is current on the target device, false if transfers are required. */
template<typename... Views>
inline bool register_device_memory(std::tuple<Views&...> &views) {
bool register_device_memory(std::tuple<Views&...> &views) {
bool is_current = true;
if (nullptr == detail::parsec_ttg_caller) {
throw std::runtime_error("register_device_memory may only be invoked from inside a task!");
Expand All @@ -114,9 +104,87 @@ namespace ttg_parsec {
return is_current;
}

// templated to break circular dependency with fwd.h
template<typename T, std::size_t N>
bool register_device_memory(const ttg::span<T, N>& span)
{

if (nullptr == detail::parsec_ttg_caller) {
throw std::runtime_error("register_device_memory may only be invoked from inside a task!");
}

if (nullptr == detail::parsec_ttg_caller->dev_ptr) {
throw std::runtime_error("register_device_memory called inside a non-gpu task!");
}

uint8_t i; // only limited number of flows
detail::parsec_ttg_task_base_t *caller = detail::parsec_ttg_caller;
assert(nullptr != caller->dev_ptr);
parsec_gpu_task_t *gpu_task = caller->dev_ptr->gpu_task;
parsec_flow_t *flows = caller->dev_ptr->flows;

bool is_current = false;
for (i = 0; i < span.size(); ++i) {
/* get_parsec_data is overloaded for buffer and devicescratch */
parsec_data_t* data = span[i].impl_data;
/* TODO: check whether the device is current */
bool is_const = span[i].is_const;
ttg::scope scope = span[i].scope;

if (nullptr != data) {
auto access = PARSEC_FLOW_ACCESS_RW;
if (ttg::scope::Allocate == scope) {
access = PARSEC_FLOW_ACCESS_WRITE;
} else if (is_const) {
access = PARSEC_FLOW_ACCESS_READ;
}

/* build the flow */
/* TODO: reuse the flows of the task class? How can we control the sync direction then? */
flows[i] = parsec_flow_t{.name = nullptr,
.sym_type = PARSEC_SYM_INOUT,
.flow_flags = static_cast<uint8_t>(access),
.flow_index = i,
.flow_datatype_mask = ~0 };

gpu_task->flow_nb_elts[i] = data->nb_elts; // size in bytes
gpu_task->flow[i] = &flows[i];

/* set the input data copy, parsec will take care of the transfer
* and the buffer will look at the parsec_data_t for the current pointer */
//detail::parsec_ttg_caller->parsec_task.data[I].data_in = data->device_copies[data->owner_device];
assert(nullptr != data->device_copies[0]->original);
caller->parsec_task.data[i].data_in = data->device_copies[0];
caller->parsec_task.data[i].source_repo_entry = NULL;

} else {
/* ignore the flow */
flows[i] = parsec_flow_t{.name = nullptr,
.sym_type = PARSEC_FLOW_ACCESS_NONE,
.flow_flags = 0,
.flow_index = i,
.flow_datatype_mask = ~0 };
gpu_task->flow[i] = &flows[i];
gpu_task->flow_nb_elts[i] = 0; // size in bytes
caller->parsec_task.data[i].data_in = nullptr;
}
}

/* reset all remaining entries in the current task */
for (; i < MAX_PARAM_COUNT; ++i) {
detail::parsec_ttg_caller->parsec_task.data[i].data_in = nullptr;
detail::parsec_ttg_caller->dev_ptr->flows[i].flow_flags = PARSEC_FLOW_ACCESS_NONE;
detail::parsec_ttg_caller->dev_ptr->flows[i].flow_index = i;
detail::parsec_ttg_caller->dev_ptr->gpu_task->flow[i] = &detail::parsec_ttg_caller->dev_ptr->flows[i];
detail::parsec_ttg_caller->dev_ptr->gpu_task->flow_nb_elts[i] = 0;
}
// we cannot allow the calling thread to submit kernels so say we're not ready
return is_current;
}

namespace detail {
template<typename... Views, std::size_t I, std::size_t... Is, bool DeviceAvail = false>
inline void mark_device_out(std::tuple<Views&...> &views, std::index_sequence<I, Is...>) {
void mark_device_out(std::tuple<Views&...> &views, std::index_sequence<I, Is...>) {

using view_type = std::remove_reference_t<std::tuple_element_t<I, std::tuple<Views&...>>>;
auto& view = std::get<I>(views);
Expand All @@ -142,7 +210,7 @@ namespace ttg_parsec {
} // namespace detail

template<typename... Buffer>
inline void mark_device_out(std::tuple<Buffer&...> &b) {
void mark_device_out(std::tuple<Buffer&...> &b) {

if (nullptr == detail::parsec_ttg_caller) {
throw std::runtime_error("mark_device_out may only be invoked from inside a task!");
Expand All @@ -158,7 +226,7 @@ namespace ttg_parsec {
namespace detail {

template<typename... Views, std::size_t I, std::size_t... Is>
inline void post_device_out(std::tuple<Views&...> &views, std::index_sequence<I, Is...>) {
void post_device_out(std::tuple<Views&...> &views, std::index_sequence<I, Is...>) {

using view_type = std::remove_reference_t<std::tuple_element_t<I, std::tuple<Views&...>>>;

Expand All @@ -177,11 +245,19 @@ namespace ttg_parsec {
}
}
} // namespace detail

template<typename... Buffer>
inline void post_device_out(std::tuple<Buffer&...> &b) {
void post_device_out(std::tuple<Buffer&...> &b) {
detail::post_device_out(b, std::index_sequence_for<Buffer...>{});
}

template<typename T>
parsec_data_t* buffer_data(T&& buffer) {
using view_type = std::remove_reference_t<T>;
static_assert(ttg::meta::is_buffer_v<view_type> || ttg::meta::is_devicescratch_v<view_type>);
return detail::get_parsec_data(buffer);
}

} // namespace ttg_parsec

#endif // TTG_PARSEC_DEVICEFUNC_H
18 changes: 11 additions & 7 deletions ttg/ttg/parsec/fwd.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,12 @@

#include "ttg/fwd.h"
#include "ttg/util/typelist.h"
#include "ttg/util/span.h"

#include <future>

#include <parsec.h>

extern "C" struct parsec_context_s;

namespace ttg_parsec {
Expand Down Expand Up @@ -74,20 +77,21 @@ namespace ttg_parsec {
inline Ptr<std::decay_t<T>> get_ptr(T&& obj);

template<typename... Views>
inline bool register_device_memory(std::tuple<Views&...> &views);
bool register_device_memory(std::tuple<Views&...> &views);

template<typename T, std::size_t N>
bool register_device_memory(const ttg::span<T, N>& span);

template<typename... Buffer>
inline void post_device_out(std::tuple<Buffer&...> &b);
void post_device_out(std::tuple<Buffer&...> &b);

template<typename... Buffer>
inline void mark_device_out(std::tuple<Buffer&...> &b);
void mark_device_out(std::tuple<Buffer&...> &b);

inline int num_devices();

#if 0
template<typename... Args>
inline std::pair<bool, std::tuple<ptr<std::decay_t<Args>>...>> get_ptr(Args&&... args);
#endif
template<typename T>
parsec_data_t* buffer_data(T&& buffer);

} // namespace ttg_parsec

Expand Down
4 changes: 1 addition & 3 deletions ttg/ttg/parsec/ttg.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,6 @@
#include "ttg/util/print.h"
#include "ttg/util/trace.h"
#include "ttg/util/typelist.h"
#ifdef TTG_HAVE_DEVICE
#include "ttg/device/task.h"
#endif // TTG_HAVE_DEVICE

#include "ttg/serialization/data_descriptor.h"

Expand All @@ -48,6 +45,7 @@
#include "ttg/parsec/thread_local.h"
#include "ttg/parsec/devicefunc.h"
#include "ttg/parsec/ttvalue.h"
#include "ttg/device/task.h"

#include <algorithm>
#include <array>
Expand Down

0 comments on commit 9535f7f

Please sign in to comment.