From 3dce9b7c5de177aaba39d05b4effa01871c233a4 Mon Sep 17 00:00:00 2001 From: "Zhekun(Josh) Zhang" <32320144+zhekunz2@users.noreply.github.com> Date: Fri, 29 Mar 2024 01:22:37 -0700 Subject: [PATCH] [Runtime] Support host stream (#166) Add host stream support using host callbacks `cudaLaunchHostFunc` rather than immediate execution. --- .../brt/backends/cpu/device/cpu_work_queue.h | 18 ++- .../backends/cuda/device/cuda_work_queue.h | 60 +++++----- runtime/include/brt/core/context/work_queue.h | 20 ++-- .../brt/core/framework/op_kernel_impl_base.h | 33 +++++- .../brt/core/framework/op_kernel_info.h | 26 ++--- .../lib/backends/cpu/device/cpu_work_queue.cc | 37 +++--- .../providers/default/custom_call/tf_equal.cc | 2 +- .../default/custom_call/tf_select.cc | 2 +- .../custom_call/tf_string_to_number.cc | 17 ++- .../providers/default/custom_call/tf_where.cc | 9 +- .../cpu/providers/default/custom_call/topk.cc | 11 +- .../cpu/providers/default/llvm/jit.cc | 2 +- .../providers/default/math/elementwise_ops.cc | 2 +- .../providers/default/tensor_generate/fill.cc | 2 +- .../default/tensor_generate/rng_state.cc | 4 +- .../backends/cuda/device/cuda_work_queue.cc | 106 +++++++++--------- .../cuda/providers/default/codegen/ptx.cc | 3 +- .../cuda/providers/default/copy/copy.cc | 4 +- .../providers/default/math/elementwise_ops.cc | 4 +- runtime/lib/core/framework/execution_plan.cc | 17 ++- .../cuda/device/cuda_work_queue_test.cc | 104 ++++++++++++++--- .../test/backends/cuda/device/nvrtc_test.cc | 4 +- runtime/test/backends/cuda/device/ptx_test.cc | 12 +- .../default/kernel/multi_stream_test.cc | 6 +- runtime/test/external_kernels/cuda/kernels.cc | 2 +- 25 files changed, 313 insertions(+), 194 deletions(-) diff --git a/runtime/include/brt/backends/cpu/device/cpu_work_queue.h b/runtime/include/brt/backends/cpu/device/cpu_work_queue.h index 633ad1d78..2260faac9 100644 --- a/runtime/include/brt/backends/cpu/device/cpu_work_queue.h +++ b/runtime/include/brt/backends/cpu/device/cpu_work_queue.h @@ -30,14 +30,13 @@ class CPUNaiveWorkQueue : public WorkQueue { explicit CPUNaiveWorkQueue(const std::string &name = "cpu_naive"); common::Status AddTask(int /*task_type*/, const void * /*func*/, - void ** /*args*/) override; - - common::Status AddEventWait(mlir::Operation *, - std::vector) override; + void ** /*args*/, int /*op_id*/, + const std::vector & /*dependency*/) override; common::Status Sync() override; - common::Status AddHostTask(std::function &&task) override; + common::Status AddHostTask(const void *task, void **args, int op_id, + const std::vector &dependency) override; }; // WorkQueue which runs host task lazily @@ -46,14 +45,13 @@ class CPULazyWorkQueue : public WorkQueue { explicit CPULazyWorkQueue(const std::string &name = "cpu_lazy"); common::Status AddTask(int /*task_type*/, const void * /*func*/, - void ** /*args*/) override; - - common::Status AddEventWait(mlir::Operation *, - std::vector) override; + void ** /*args*/, int /*op_id*/, + const std::vector & /*dependency*/) override; common::Status Sync() override; - common::Status AddHostTask(std::function &&task) override; + common::Status AddHostTask(const void *task, void **args, int op_id, + const std::vector &dependency) override; private: std::vector> tasks; diff --git a/runtime/include/brt/backends/cuda/device/cuda_work_queue.h b/runtime/include/brt/backends/cuda/device/cuda_work_queue.h index 93ba8a190..68627395d 100644 --- a/runtime/include/brt/backends/cuda/device/cuda_work_queue.h +++ b/runtime/include/brt/backends/cuda/device/cuda_work_queue.h @@ -56,18 +56,18 @@ class CUDAWorkQueue : public WorkQueue { // Enqueue a func call, thread-safe. // func is a stateless function - virtual common::Status AddTask(int task_type, const void *func, - void **args) override; - - virtual common::Status AddEventWait(mlir::Operation *, - std::vector) override; + virtual common::Status AddTask(int task_type, const void *func, void **args, + int op_id, + const std::vector &dependency) override; // Barrier virtual common::Status Sync() override; virtual CUstream_st *GetComputeStream() { return nullptr; } - common::Status AddHostTask(std::function &&task) override { - task(); + common::Status AddHostTask(const void *task, void **args, int op_id, + const std::vector &dependency) override { + auto func = reinterpret_cast *>(task); + (*func)(); return common::Status::OK(); } @@ -93,10 +93,10 @@ class CUDASingleStreamWorkQueue final : public CUDAWorkQueue { // Enqueue a func call, thread-safe. // func is a stateless function - common::Status AddTask(int task_type, const void *func, void **args) override; + common::Status AddTask(int task_type, const void *func, void **args, + int op_id, + const std::vector &dependency) override; - common::Status AddEventWait(mlir::Operation *, - std::vector) override; // Barrier common::Status Sync() override; @@ -110,42 +110,47 @@ class CUDASingleStreamWorkQueue final : public CUDAWorkQueue { }; /** - * CUDAOneComputeTwoTransferWorkQueue is a derived class of WorkQueue + * CUDAMultiStreamWorkQueue is a derived class of WorkQueue * that uses mutliple CUDA Stream's as a WorkQueue. * * A typical usage is using 3 streams: one for compute, two for bidirectional - * data transfer. + * data transfer, one for host. */ -class CUDAOneComputeTwoTransferWorkQueue final : public CUDAWorkQueue { +class CUDAMultiStreamWorkQueue final : public CUDAWorkQueue { public: - CUDAOneComputeTwoTransferWorkQueue(int device_id); + CUDAMultiStreamWorkQueue(int device_id); // Undefined what happens to pending work when destructor is called. - virtual ~CUDAOneComputeTwoTransferWorkQueue(); + virtual ~CUDAMultiStreamWorkQueue(); // Enqueue a func call, thread-safe. // func is a stateless function - common::Status AddTask(int task_type, const void *func, void **args) override; + common::Status AddTask(int task_type, const void *func, void **args, + int op_id, + const std::vector &dependency) override; + + common::Status AddHostTask(const void *task, void **args, int op_id, + const std::vector &dependency) override; // Barrier common::Status Sync() override; - size_t GetStreamIdx(mlir::Operation *op); - - common::Status AddEventWait(mlir::Operation *, - std::vector) override; + common::Status AddEventWait(size_t, std::vector); CUstream_st *GetComputeStream() override { return streams_[0]; } + CUstream_st *GetH2DStream() { return streams_[1]; } + CUstream_st *GetD2HStream() { return streams_[2]; } + CUstream_st *GetHostStream() { return streams_[3]; } private: - CUstream_st *streams_[3]; // 0 for compute, 1 for h2d, 2 for d2h + // 0 for compute, 1 for h2d, 2 for d2h, 3 for host + CUstream_st *streams_[4]; std::vector events_; - CUDAOneComputeTwoTransferWorkQueue( - const CUDAOneComputeTwoTransferWorkQueue &) = delete; - CUDAOneComputeTwoTransferWorkQueue & - operator=(const CUDAOneComputeTwoTransferWorkQueue &) = delete; + CUDAMultiStreamWorkQueue(const CUDAMultiStreamWorkQueue &) = delete; + CUDAMultiStreamWorkQueue & + operator=(const CUDAMultiStreamWorkQueue &) = delete; }; /** @@ -165,7 +170,9 @@ class CUDAExternalStreamWorkQueue final : public CUDAWorkQueue { // Enqueue a func call, thread-safe. // func is a stateless function - common::Status AddTask(int task_type, const void *func, void **args) override; + common::Status AddTask(int task_type, const void *func, void **args, + int op_id, + const std::vector &dependency) override; // Barrier common::Status Sync() override; @@ -178,5 +185,4 @@ class CUDAExternalStreamWorkQueue final : public CUDAWorkQueue { CUDAExternalStreamWorkQueue & operator=(const CUDAExternalStreamWorkQueue &) = delete; }; - } // namespace brt diff --git a/runtime/include/brt/core/context/work_queue.h b/runtime/include/brt/core/context/work_queue.h index 9bb97f9c2..57bc8aef7 100644 --- a/runtime/include/brt/core/context/work_queue.h +++ b/runtime/include/brt/core/context/work_queue.h @@ -61,15 +61,13 @@ class WorkQueue { // Enqueue a func call, thread-safe. // func is a stateless function - virtual common::Status AddTask(int task_type, const void *func, - void **args) = 0; + virtual common::Status AddTask(int task_type, const void *func, void **args, + int op_id, + const std::vector &dependency) = 0; // Enqueue a task on host side - virtual common::Status AddHostTask(std::function &&task) = 0; - - // Add wait event if any - virtual common::Status AddEventWait(mlir::Operation *, - std::vector) = 0; + virtual common::Status AddHostTask(const void *task, void **args, int op_id, + const std::vector &dependency) = 0; // Enqueue through a functor // Note, the functor is called immediately. @@ -81,6 +79,9 @@ class WorkQueue { // Barrier virtual common::Status Sync() = 0; +protected: + std::unordered_map id_to_stream_map_; + private: const std::string name_; WorkQueue(const WorkQueue &) = delete; @@ -89,9 +90,10 @@ class WorkQueue { } // namespace brt -#define DispatchHostTask(wq, stmt) \ +#define DispatchHostTask(wq, op_id, dependency, stmt) \ if (wq) { \ - wq->AddHostTask([=] { stmt }); \ + std::function func = [=]() { stmt }; \ + wq->AddHostTask(&func, nullptr, op_id, dependency); \ } else { \ do { \ stmt \ diff --git a/runtime/include/brt/core/framework/op_kernel_impl_base.h b/runtime/include/brt/core/framework/op_kernel_impl_base.h index 9d21dc20a..208522450 100644 --- a/runtime/include/brt/core/framework/op_kernel_impl_base.h +++ b/runtime/include/brt/core/framework/op_kernel_impl_base.h @@ -96,6 +96,21 @@ struct WorkQueue { } }; +struct OpId { + template + static inline int Get(Impl *impl, const ExecutionContext &ctx) { + return impl->GetOpId(); + } +}; + +struct Dependency { + template + static inline const std::vector &Get(Impl *impl, + const ExecutionContext &ctx) { + return impl->GetDependency(); + } +}; + // op kernel which has temporary workspace as argument, should implement // GetWorkspaceSize() interface struct Workspace : public PerFrameHookTrait { @@ -196,6 +211,12 @@ struct NaiveOpKernelIfaceTraits : public OpKernelIfaceTraitsBase { return OpAccessor(info_, ctx.exec_frame); } + int GetOpId() const { return info_.GetOpId(); } + + const std::vector &GetDependency() const { + return info_.GetDependency(); + } + private: const OpKernelInfo &info_; }; @@ -226,8 +247,10 @@ struct OpKernelWithWorkspaceIfaceTraitsT template