Skip to content

Commit

Permalink
lazy create cuda_stream (#6806)
Browse files Browse the repository at this point in the history
* lazy create cuda_stream in CudaCopyD2HDeviceCtx CudaStreamHandleDeviceCtx

* refine
  • Loading branch information
guo-ran authored Nov 18, 2021
1 parent 47e85b0 commit 55b1013
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 12 deletions.
18 changes: 12 additions & 6 deletions oneflow/core/vm/cuda_copy_d2h_device_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ limitations under the License.
#include "oneflow/core/device/cuda_event.h"
#include "oneflow/core/vm/cuda_host_allocator.h"
#include "oneflow/core/ep/cuda/cuda_stream.h"
#include "oneflow/core/common/cpp_attribute.h"

namespace oneflow {
namespace vm {
Expand All @@ -36,22 +37,27 @@ class CudaCopyD2HDeviceCtx : public DeviceCtx, public SingleThreadQueryCudaEvent
CudaCopyD2HDeviceCtx(int64_t device_id)
: DeviceCtx(),
SingleThreadQueryCudaEventProvider(device_id),
stream_(device_id),
cuda_allocator_(std::make_unique<CudaHostAllocator>(device_id)),
device_id_(device_id) {}

cudaStream_t cuda_stream() const override { return stream_.cuda_stream(); }
cublasHandle_t cublas_handle() const override { return stream_.cublas_handle(); }
cudnnHandle_t cudnn_handle() const override { return stream_.cudnn_handle(); }
cudaStream_t cuda_stream() const override { return GetOrCreateCudaStream()->cuda_stream(); }
cublasHandle_t cublas_handle() const override { return GetOrCreateCudaStream()->cublas_handle(); }
cudnnHandle_t cudnn_handle() const override { return GetOrCreateCudaStream()->cudnn_handle(); }

ep::Stream* stream() override { return &stream_; }
ep::Stream* stream() override { return GetOrCreateCudaStream(); }

vm::Allocator* mut_allocator() override { return cuda_allocator_.get(); }

DeviceType device_type() const override { return DeviceType::kGPU; }

private:
ep::CudaStream* GetOrCreateCudaStream() const {
if (unlikely(!stream_)) { stream_.reset(new ep::CudaStream(device_id_)); }
return stream_.get();
}

protected:
ep::CudaStream stream_;
mutable std::unique_ptr<ep::CudaStream> stream_;
std::unique_ptr<CudaHostAllocator> cuda_allocator_;
int64_t device_id_;
};
Expand Down
18 changes: 12 additions & 6 deletions oneflow/core/vm/cuda_stream_handle_device_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ limitations under the License.
#include "oneflow/core/vm/thread_safe_allocator.h"
#include "oneflow/core/common/single_thread_obj_pool.h"
#include "oneflow/core/ep/cuda/cuda_stream.h"
#include "oneflow/core/common/cpp_attribute.h"

namespace oneflow {
namespace vm {
Expand All @@ -38,23 +39,28 @@ class CudaStreamHandleDeviceCtx : public DeviceCtx, public SingleThreadQueryCuda
CudaStreamHandleDeviceCtx(int64_t device_id)
: DeviceCtx(),
SingleThreadQueryCudaEventProvider(device_id),
stream_(device_id),
cuda_allocator_(
new ThreadSafeAllocator(std::unique_ptr<Allocator>(new CudaAllocator(device_id)))),
device_id_(device_id) {}

cudaStream_t cuda_stream() const override { return stream_.cuda_stream(); }
cublasHandle_t cublas_handle() const override { return stream_.cublas_handle(); }
cudnnHandle_t cudnn_handle() const override { return stream_.cudnn_handle(); }
cudaStream_t cuda_stream() const override { return GetOrCreateCudaStream()->cuda_stream(); }
cublasHandle_t cublas_handle() const override { return GetOrCreateCudaStream()->cublas_handle(); }
cudnnHandle_t cudnn_handle() const override { return GetOrCreateCudaStream()->cudnn_handle(); }

ep::Stream* stream() override { return &stream_; }
ep::Stream* stream() override { return GetOrCreateCudaStream(); }

vm::Allocator* mut_allocator() override { return cuda_allocator_.get(); }

DeviceType device_type() const override { return DeviceType::kGPU; }

private:
ep::CudaStream* GetOrCreateCudaStream() const {
if (unlikely(!stream_)) { stream_.reset(new ep::CudaStream(device_id_)); }
return stream_.get();
}

protected:
ep::CudaStream stream_;
mutable std::unique_ptr<ep::CudaStream> stream_;
std::unique_ptr<Allocator> cuda_allocator_;
int64_t device_id_;
};
Expand Down

0 comments on commit 55b1013

Please sign in to comment.