Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Cuda event and stream API #32460

Merged
merged 33 commits into from
Jul 19, 2021
Merged
Show file tree
Hide file tree
Changes from 32 commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
b0f82d7
add cuda event and stream api
MingMingShangTian Apr 14, 2021
1ae2410
add cuda event and stream api
MingMingShangTian Apr 21, 2021
15a8e37
add get_current_stream api
MingMingShangTian Apr 21, 2021
5cf8c0e
add get_current_stream api
MingMingShangTian Apr 21, 2021
a79fc01
init streams
MingMingShangTian Apr 22, 2021
ee61c4d
modify get_current_stream
MingMingShangTian Apr 23, 2021
c56b9e7
modify get_cuttent_stream
MingMingShangTian Apr 23, 2021
c3dc86f
add synchronize func
MingMingShangTian Apr 25, 2021
174d28b
merge develop branch
MingMingShangTian Apr 25, 2021
fe05664
add current_stream doc and test file
MingMingShangTian Apr 25, 2021
1cd30bf
move get_current_stream into CUDA macro
MingMingShangTian Apr 25, 2021
811daf5
move CudaEvent into CUDA macro
MingMingShangTian Apr 25, 2021
0ff122a
move _get_current_stream and _device_synchronize into cuda macro
MingMingShangTian Apr 25, 2021
c369baf
modify the macro of cuda stream and event
MingMingShangTian Apr 26, 2021
cd278ea
add test case for synchronize
MingMingShangTian Apr 26, 2021
90549fc
add paddle.devices.cuda module
MingMingShangTian Apr 26, 2021
67049e5
event and stream support hip
MingMingShangTian Apr 26, 2021
b128268
add doc for stream and event class
MingMingShangTian Apr 26, 2021
5a611e7
move cuda stream and event into single pybind
MingMingShangTian Apr 29, 2021
11df0ff
Merge branch 'develop' into cuda_event
MingMingShangTian Apr 29, 2021
e5a6de5
merge develop branch
MingMingShangTian Apr 29, 2021
ebeefed
add cuda_streams_py.cc to cmakelist
MingMingShangTian May 6, 2021
9a1940e
add _device_synchronize and _get_current_stream to core module
MingMingShangTian May 7, 2021
a9ece68
add test case for cudastream and cudaevent
MingMingShangTian May 10, 2021
ac72134
move __all__ in streams.py
MingMingShangTian May 10, 2021
e83b17b
fix test fail
MingMingShangTian May 11, 2021
2a4ed96
Merge branch 'develop' into cuda_event
MingMingShangTian Jul 8, 2021
6fa6665
add cuda to devices __all__
MingMingShangTian Jul 8, 2021
41519d4
fix current_stream doc writing error
MingMingShangTian Jul 13, 2021
025813e
move devices to device direction, and merge device.py into __init__.py
MingMingShangTian Jul 15, 2021
3c933c6
add required:gpu to sample codes
MingMingShangTian Jul 15, 2021
caa7744
Merge branch 'develop' into cuda_event
MingMingShangTian Jul 15, 2021
b26be1c
remove cuda direction from device/__init__.py
MingMingShangTian Jul 19, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 94 additions & 0 deletions paddle/fluid/platform/event.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ limitations under the License. */
#include <hip/hip_runtime.h>
#endif
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/platform/stream/cuda_stream.h"

namespace paddle {
namespace platform {
Expand Down Expand Up @@ -117,5 +118,98 @@ class MemEvent {
std::string annotation_;
};

class CudaEvent {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
public:
CudaEvent() {
#ifdef PADDLE_WITH_HIP
hipEventCreateWithFlags(&event_, flags_);
#else
cudaEventCreateWithFlags(&event_, flags_);
#endif
}

CudaEvent(unsigned int flags) : flags_(flags) {
#ifdef PADDLE_WITH_HIP
hipEventCreateWithFlags(&event_, flags_);
#else
cudaEventCreateWithFlags(&event_, flags_);
#endif
}

void Record(paddle::platform::stream::CUDAStream& stream) {
#ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_CUDA_SUCCESS(hipEventRecord(event_, stream.raw_stream()));
#else
PADDLE_ENFORCE_CUDA_SUCCESS(cudaEventRecord(event_, stream.raw_stream()));
#endif
}

bool Query() {
#ifdef PADDLE_WITH_HIP
gpuError_t err = hipEventQuery(event_);
if (err == hipSuccess) {
return true;
}
if (err == hipErrorNotReady) {
return false;
}
#else
gpuError_t err = cudaEventQuery(event_);
if (err == cudaSuccess) {
return true;
}
if (err == cudaErrorNotReady) {
return false;
}
#endif
PADDLE_ENFORCE_CUDA_SUCCESS(err);
return false;
}

void Synchronize() {
#ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_CUDA_SUCCESS(hipEventSynchronize(event_));
#else
PADDLE_ENFORCE_CUDA_SUCCESS(cudaEventSynchronize(event_));
#endif
}
gpuEvent_t GetRawCudaEvent() { return event_; }

private:
#ifdef PADDLE_WITH_HIP
unsigned int flags_ = hipEventDefault;
#else
unsigned int flags_ = cudaEventDefault;
#endif
gpuEvent_t event_;
#endif
};

static unsigned int get_cuda_flags(bool enable_timing, bool blocking,
bool interprocess) {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)

#ifdef PADDLE_WITH_HIP
unsigned int flags =
(blocking ? hipEventBlockingSync : hipEventDefault) |
(enable_timing ? hipEventDefault : hipEventDisableTiming) |
(interprocess ? hipEventInterprocess : hipEventDefault);
return flags;
#else
unsigned int flags =
(blocking ? cudaEventBlockingSync : cudaEventDefault) |
(enable_timing ? cudaEventDefault : cudaEventDisableTiming) |
(interprocess ? cudaEventInterprocess : cudaEventDefault);
return flags;
#endif

#else
PADDLE_THROW(platform::errors::Unavailable(
"Paddle is not compiled with CUDA. Cannot get the cuda event flags."));
return 0;
#endif
}

} // namespace platform
} // namespace paddle
23 changes: 23 additions & 0 deletions paddle/fluid/platform/stream/cuda_stream.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ limitations under the License. */

#include "paddle/fluid/platform/stream/cuda_stream.h"
#include "paddle/fluid/platform/cuda_device_guard.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/enforce.h"

namespace paddle {
Expand Down Expand Up @@ -95,6 +96,28 @@ void CUDAStream::Wait() const {
PADDLE_ENFORCE_CUDA_SUCCESS(e_sync);
}

CUDAStream* get_current_stream(int deviceId) {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
if (deviceId == -1) {
deviceId = platform::GetCurrentDeviceId();
}

auto& pool = platform::DeviceContextPool::Instance();

platform::Place device = CUDAPlace(deviceId);

auto stream = static_cast<platform::CUDADeviceContext*>(pool.Get(device))
->context()
->Stream()
.get();
return stream;
#else
PADDLE_THROW(platform::errors::Unavailable(
"Paddle is not compiled with CUDA. Cannot visit cuda current stream."));
return nullptr;
#endif
}

} // namespace stream
} // namespace platform
} // namespace paddle
38 changes: 35 additions & 3 deletions paddle/fluid/platform/stream/cuda_stream.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,9 @@ enum class Priority : uint8_t {
kHigh = 0x1,
kNormal = 0x2,
};

#endif
class CUDAStream final {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
public:
CUDAStream() = default;
explicit CUDAStream(const Place& place,
Expand Down Expand Up @@ -93,6 +94,37 @@ class CUDAStream final {
#endif
void Destroy();

bool Query() const {
#ifdef PADDLE_WITH_HIP
hipError_t err = hipStreamQuery(stream_);
if (err == hipSuccess) {
return true;
}
if (err == hipErrorNotReady) {
return false;
}
#else
cudaError_t err = cudaStreamQuery(stream_);
if (err == cudaSuccess) {
return true;
}
if (err == cudaErrorNotReady) {
return false;
}
#endif

PADDLE_ENFORCE_CUDA_SUCCESS(err);
return false;
}

void Synchronize() const {
#ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_CUDA_SUCCESS(hipStreamSynchronize(stream_));
#else
PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamSynchronize(stream_));
#endif
}

private:
Place place_;
#ifdef PADDLE_WITH_HIP
Expand All @@ -102,11 +134,11 @@ class CUDAStream final {
#endif
Priority priority_{Priority::kNormal};
std::unique_ptr<StreamCallbackManager<gpuStream_t>> callback_manager_;

#endif
DISABLE_COPY_AND_ASSIGN(CUDAStream);
};

#endif
CUDAStream* get_current_stream(int deviceId);

} // namespace stream
} // namespace platform
Expand Down
3 changes: 2 additions & 1 deletion paddle/fluid/pybind/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,8 @@ set(PYBIND_SRCS
inference_api.cc
compatible.cc
io.cc
generator_py.cc)
generator_py.cc
cuda_streams_py.cc)

if(WITH_ASCEND)
set(PYBIND_DEPS ${PYBIND_DEPS} ascend_wrapper)
Expand Down
Loading