Skip to content

Commit

Permalink
[Runtime] Support host stream (#166)
Browse files Browse the repository at this point in the history
Add host stream support using host callbacks `cudaLaunchHostFunc` rather
than immediate execution.
  • Loading branch information
zhekunz2 authored Mar 29, 2024
1 parent 1bc8008 commit 3dce9b7
Show file tree
Hide file tree
Showing 25 changed files with 313 additions and 194 deletions.
18 changes: 8 additions & 10 deletions runtime/include/brt/backends/cpu/device/cpu_work_queue.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,13 @@ class CPUNaiveWorkQueue : public WorkQueue {
explicit CPUNaiveWorkQueue(const std::string &name = "cpu_naive");

common::Status AddTask(int /*task_type*/, const void * /*func*/,
void ** /*args*/) override;

common::Status AddEventWait(mlir::Operation *,
std::vector<mlir::Operation *>) override;
void ** /*args*/, int /*op_id*/,
const std::vector<int> & /*dependency*/) override;

common::Status Sync() override;

common::Status AddHostTask(std::function<void(void)> &&task) override;
common::Status AddHostTask(const void *task, void **args, int op_id,
const std::vector<int> &dependency) override;
};

// WorkQueue which runs host task lazily
Expand All @@ -46,14 +45,13 @@ class CPULazyWorkQueue : public WorkQueue {
explicit CPULazyWorkQueue(const std::string &name = "cpu_lazy");

common::Status AddTask(int /*task_type*/, const void * /*func*/,
void ** /*args*/) override;

common::Status AddEventWait(mlir::Operation *,
std::vector<mlir::Operation *>) override;
void ** /*args*/, int /*op_id*/,
const std::vector<int> & /*dependency*/) override;

common::Status Sync() override;

common::Status AddHostTask(std::function<void(void)> &&task) override;
common::Status AddHostTask(const void *task, void **args, int op_id,
const std::vector<int> &dependency) override;

private:
std::vector<std::function<void(void)>> tasks;
Expand Down
60 changes: 33 additions & 27 deletions runtime/include/brt/backends/cuda/device/cuda_work_queue.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,18 +56,18 @@ class CUDAWorkQueue : public WorkQueue {

// Enqueue a func call, thread-safe.
// func is a stateless function
virtual common::Status AddTask(int task_type, const void *func,
void **args) override;

virtual common::Status AddEventWait(mlir::Operation *,
std::vector<mlir::Operation *>) override;
virtual common::Status AddTask(int task_type, const void *func, void **args,
int op_id,
const std::vector<int> &dependency) override;
// Barrier
virtual common::Status Sync() override;

virtual CUstream_st *GetComputeStream() { return nullptr; }

common::Status AddHostTask(std::function<void(void)> &&task) override {
task();
common::Status AddHostTask(const void *task, void **args, int op_id,
const std::vector<int> &dependency) override {
auto func = reinterpret_cast<const std::function<void(void)> *>(task);
(*func)();
return common::Status::OK();
}

Expand All @@ -93,10 +93,10 @@ class CUDASingleStreamWorkQueue final : public CUDAWorkQueue {

// Enqueue a func call, thread-safe.
// func is a stateless function
common::Status AddTask(int task_type, const void *func, void **args) override;
common::Status AddTask(int task_type, const void *func, void **args,
int op_id,
const std::vector<int> &dependency) override;

common::Status AddEventWait(mlir::Operation *,
std::vector<mlir::Operation *>) override;
// Barrier
common::Status Sync() override;

Expand All @@ -110,42 +110,47 @@ class CUDASingleStreamWorkQueue final : public CUDAWorkQueue {
};

/**
* CUDAOneComputeTwoTransferWorkQueue is a derived class of WorkQueue
* CUDAMultiStreamWorkQueue is a derived class of WorkQueue
* that uses mutliple CUDA Stream's as a WorkQueue.
*
* A typical usage is using 3 streams: one for compute, two for bidirectional
* data transfer.
* data transfer, one for host.
*/
class CUDAOneComputeTwoTransferWorkQueue final : public CUDAWorkQueue {
class CUDAMultiStreamWorkQueue final : public CUDAWorkQueue {
public:
CUDAOneComputeTwoTransferWorkQueue(int device_id);
CUDAMultiStreamWorkQueue(int device_id);

// Undefined what happens to pending work when destructor is called.
virtual ~CUDAOneComputeTwoTransferWorkQueue();
virtual ~CUDAMultiStreamWorkQueue();

// Enqueue a func call, thread-safe.
// func is a stateless function
common::Status AddTask(int task_type, const void *func, void **args) override;
common::Status AddTask(int task_type, const void *func, void **args,
int op_id,
const std::vector<int> &dependency) override;

common::Status AddHostTask(const void *task, void **args, int op_id,
const std::vector<int> &dependency) override;

// Barrier
common::Status Sync() override;

size_t GetStreamIdx(mlir::Operation *op);

common::Status AddEventWait(mlir::Operation *,
std::vector<mlir::Operation *>) override;
common::Status AddEventWait(size_t, std::vector<int>);

CUstream_st *GetComputeStream() override { return streams_[0]; }
CUstream_st *GetH2DStream() { return streams_[1]; }
CUstream_st *GetD2HStream() { return streams_[2]; }
CUstream_st *GetHostStream() { return streams_[3]; }

private:
CUstream_st *streams_[3]; // 0 for compute, 1 for h2d, 2 for d2h
// 0 for compute, 1 for h2d, 2 for d2h, 3 for host
CUstream_st *streams_[4];

std::vector<CUevent_st *> events_;

CUDAOneComputeTwoTransferWorkQueue(
const CUDAOneComputeTwoTransferWorkQueue &) = delete;
CUDAOneComputeTwoTransferWorkQueue &
operator=(const CUDAOneComputeTwoTransferWorkQueue &) = delete;
CUDAMultiStreamWorkQueue(const CUDAMultiStreamWorkQueue &) = delete;
CUDAMultiStreamWorkQueue &
operator=(const CUDAMultiStreamWorkQueue &) = delete;
};

/**
Expand All @@ -165,7 +170,9 @@ class CUDAExternalStreamWorkQueue final : public CUDAWorkQueue {

// Enqueue a func call, thread-safe.
// func is a stateless function
common::Status AddTask(int task_type, const void *func, void **args) override;
common::Status AddTask(int task_type, const void *func, void **args,
int op_id,
const std::vector<int> &dependency) override;

// Barrier
common::Status Sync() override;
Expand All @@ -178,5 +185,4 @@ class CUDAExternalStreamWorkQueue final : public CUDAWorkQueue {
CUDAExternalStreamWorkQueue &
operator=(const CUDAExternalStreamWorkQueue &) = delete;
};

} // namespace brt
20 changes: 11 additions & 9 deletions runtime/include/brt/core/context/work_queue.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,15 +61,13 @@ class WorkQueue {

// Enqueue a func call, thread-safe.
// func is a stateless function
virtual common::Status AddTask(int task_type, const void *func,
void **args) = 0;
virtual common::Status AddTask(int task_type, const void *func, void **args,
int op_id,
const std::vector<int> &dependency) = 0;

// Enqueue a task on host side
virtual common::Status AddHostTask(std::function<void(void)> &&task) = 0;

// Add wait event if any
virtual common::Status AddEventWait(mlir::Operation *,
std::vector<mlir::Operation *>) = 0;
virtual common::Status AddHostTask(const void *task, void **args, int op_id,
const std::vector<int> &dependency) = 0;

// Enqueue through a functor
// Note, the functor is called immediately.
Expand All @@ -81,6 +79,9 @@ class WorkQueue {
// Barrier
virtual common::Status Sync() = 0;

protected:
std::unordered_map<int, int> id_to_stream_map_;

private:
const std::string name_;
WorkQueue(const WorkQueue &) = delete;
Expand All @@ -89,9 +90,10 @@ class WorkQueue {

} // namespace brt

#define DispatchHostTask(wq, stmt) \
#define DispatchHostTask(wq, op_id, dependency, stmt) \
if (wq) { \
wq->AddHostTask([=] { stmt }); \
std::function<void(void)> func = [=]() { stmt }; \
wq->AddHostTask(&func, nullptr, op_id, dependency); \
} else { \
do { \
stmt \
Expand Down
33 changes: 29 additions & 4 deletions runtime/include/brt/core/framework/op_kernel_impl_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,21 @@ struct WorkQueue {
}
};

struct OpId {
template <typename Impl>
static inline int Get(Impl *impl, const ExecutionContext &ctx) {
return impl->GetOpId();
}
};

struct Dependency {
template <typename Impl>
static inline const std::vector<int> &Get(Impl *impl,
const ExecutionContext &ctx) {
return impl->GetDependency();
}
};

// op kernel which has temporary workspace as argument, should implement
// GetWorkspaceSize() interface
struct Workspace : public PerFrameHookTrait {
Expand Down Expand Up @@ -196,6 +211,12 @@ struct NaiveOpKernelIfaceTraits : public OpKernelIfaceTraitsBase<Arguments...> {
return OpAccessor(info_, ctx.exec_frame);
}

int GetOpId() const { return info_.GetOpId(); }

const std::vector<int> &GetDependency() const {
return info_.GetDependency();
}

private:
const OpKernelInfo &info_;
};
Expand Down Expand Up @@ -226,8 +247,10 @@ struct OpKernelWithWorkspaceIfaceTraitsT

template <template <typename...> class Base, typename... Arguments>
struct HostOpKernelIfaceTraitsT
: public Base<argument_type::WorkQueue, Arguments...> {
using BaseTraits = Base<argument_type::WorkQueue, Arguments...>;
: public Base<argument_type::WorkQueue, argument_type::OpId,
argument_type::Dependency, Arguments...> {
using BaseTraits = Base<argument_type::WorkQueue, argument_type::OpId,
argument_type::Dependency, Arguments...>;

template <typename T>
using ImplMixinBase = typename BaseTraits::template ImplMixin<T>;
Expand All @@ -238,8 +261,10 @@ struct HostOpKernelIfaceTraitsT
using ImplMixinBase<ImplBase>::ImplMixinBase;

template <typename... Args>
common::Status Execute(WorkQueue *work_queue, Args &&...args) {
DispatchHostTask(work_queue, { ImplBase::Execute(args...); });
common::Status Execute(WorkQueue *work_queue, int op_id,
const std::vector<int> &dependency, Args &&...args) {
DispatchHostTask(work_queue, op_id, dependency,
{ ImplBase::Execute(args...); });
return common::Status::OK();
}
};
Expand Down
26 changes: 13 additions & 13 deletions runtime/include/brt/core/framework/op_kernel_info.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,22 +54,21 @@ class OpKernelInfo {
public:
OpKernelInfo(
const ExecutionProvider &provider, const ir::IRHandle &handle,
mlir::Operation *op,
mlir::Operation *op, int op_id,
const std::unordered_map<std::string, std::unique_ptr<IAllocator>> &alloc,
IAllocator *last_allc,
const std::unordered_map<void *, size_t> &tensor_to_idx,
const std::unordered_map<void *, size_t> &scalar_to_idx,
const std::vector<AsyncValue> &weights, size_t intermediate_begin,
const std::string &ir_path,
const std::vector<mlir::Operation *> &dependency)
: provider_(provider), handle_(handle), op_(op), allocators_(alloc),
last_allocator_(last_allc), tensor_to_idx_(tensor_to_idx),
scalar_to_idx_(scalar_to_idx), weights_(weights),
intermediate_begin_(intermediate_begin), ir_path_(ir_path),
dependency_(dependency) {}
const std::string &ir_path, const std::vector<int> &dependency)
: provider_(provider), handle_(handle), op_(op), op_id_(op_id),
allocators_(alloc), last_allocator_(last_allc),
tensor_to_idx_(tensor_to_idx), scalar_to_idx_(scalar_to_idx),
weights_(weights), intermediate_begin_(intermediate_begin),
ir_path_(ir_path), dependency_(dependency) {}

OpKernelInfo(const OpKernelInfo &other)
: OpKernelInfo(other.provider_, other.handle_, other.op_,
: OpKernelInfo(other.provider_, other.handle_, other.op_, other.op_id_,
other.allocators_, other.last_allocator_,
other.tensor_to_idx_, other.scalar_to_idx_, other.weights_,
other.intermediate_begin_, other.ir_path_,
Expand All @@ -81,6 +80,8 @@ class OpKernelInfo {

mlir::Operation *GetOperation() const { return op_; }

int GetOpId() const { return op_id_; }

const std::unordered_map<void *, size_t> &GetTensorToIndex() const {
return tensor_to_idx_;
}
Expand All @@ -91,9 +92,7 @@ class OpKernelInfo {

const std::vector<AsyncValue> &GetWeights() const { return weights_; }

const std::vector<mlir::Operation *> &GetDependency() const {
return dependency_;
}
const std::vector<int> &GetDependency() const { return dependency_; }

// const BrtMemoryInfo& GetMemoryInfo(int device_id, BrtMemType mem_type)
// const;
Expand All @@ -119,6 +118,7 @@ class OpKernelInfo {
const brt::ir::IRHandle &handle_;

mlir::Operation *op_;
int op_id_;

const std::unordered_map<std::string, std::unique_ptr<IAllocator>>
&allocators_;
Expand All @@ -134,7 +134,7 @@ class OpKernelInfo {

const std::string &ir_path_;

const std::vector<mlir::Operation *> &dependency_;
std::vector<int> dependency_;

OpKernelInfo(OpKernelInfo &&) = delete;
OpKernelInfo &operator=(OpKernelInfo &&) = delete;
Expand Down
37 changes: 17 additions & 20 deletions runtime/lib/backends/cpu/device/cpu_work_queue.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,31 +23,30 @@ namespace cpu {
CPUNaiveWorkQueue::CPUNaiveWorkQueue(const std::string &name)
: WorkQueue(name) {}

common::Status CPUNaiveWorkQueue::AddTask(int /*task_type*/,
const void * /*func*/,
void ** /*args*/) {
common::Status
CPUNaiveWorkQueue::AddTask(int /*task_type*/, const void * /*func*/,
void ** /*args*/, int op_id,
const std::vector<int> & /*dependency*/) {
return common::Status(common::StatusCategory::BRT, common::StatusCode::FAIL,
"Use AddHostTask for cpu work queue");
}

common::Status CPUNaiveWorkQueue::AddEventWait(mlir::Operation *,
std::vector<mlir::Operation *>) {
return common::Status::OK();
}

common::Status CPUNaiveWorkQueue::Sync() { return common::Status::OK(); }

common::Status
CPUNaiveWorkQueue::AddHostTask(std::function<void(void)> &&task) {
task();
CPUNaiveWorkQueue::AddHostTask(const void *task, void **args, int op_id,
const std::vector<int> &dependency) {
auto func = reinterpret_cast<const std::function<void(void)> *>(task);
(*func)();
return common::Status::OK();
}

CPULazyWorkQueue::CPULazyWorkQueue(const std::string &name) : WorkQueue(name) {}

common::Status CPULazyWorkQueue::AddTask(int /*task_type*/,
const void * /*func*/,
void ** /*args*/) {
common::Status
CPULazyWorkQueue::AddTask(int /*task_type*/, const void * /*func*/,
void ** /*args*/, int op_id,
const std::vector<int> & /*dependency*/) {
return common::Status(common::StatusCategory::BRT, common::StatusCode::FAIL,
"Use AddHostTask for cpu work queue");
}
Expand All @@ -59,13 +58,11 @@ common::Status CPULazyWorkQueue::Sync() {
return common::Status::OK();
}

common::Status CPULazyWorkQueue::AddEventWait(mlir::Operation *,
std::vector<mlir::Operation *>) {
return common::Status::OK();
}

common::Status CPULazyWorkQueue::AddHostTask(std::function<void(void)> &&task) {
tasks.push_back(std::move(task));
common::Status
CPULazyWorkQueue::AddHostTask(const void *task, void **args, int op_id,
const std::vector<int> &dependency) {
auto func = reinterpret_cast<const std::function<void(void)> *>(task);
tasks.push_back(std::move(*func));
return common::Status::OK();
}

Expand Down
Loading

0 comments on commit 3dce9b7

Please sign in to comment.