diff --git a/sycl/plugins/cuda/pi_cuda.cpp b/sycl/plugins/cuda/pi_cuda.cpp index f9d05af6ac8e2..8db51e3a25d31 100644 --- a/sycl/plugins/cuda/pi_cuda.cpp +++ b/sycl/plugins/cuda/pi_cuda.cpp @@ -2393,7 +2393,7 @@ pi_result cuda_piQueueFinish(pi_queue command_queue) { nullptr); // need PI_ERROR_INVALID_EXTERNAL_HANDLE error code ScopedContext active(command_queue->get_context()); - command_queue->for_each_stream([&result](CUstream s) { + command_queue->sync_streams([&result](CUstream s) { result = PI_CHECK_ERROR(cuStreamSynchronize(s)); }); diff --git a/sycl/plugins/cuda/pi_cuda.hpp b/sycl/plugins/cuda/pi_cuda.hpp index d253794094817..e9a87a6378c7d 100644 --- a/sycl/plugins/cuda/pi_cuda.hpp +++ b/sycl/plugins/cuda/pi_cuda.hpp @@ -391,6 +391,8 @@ struct _pi_queue { std::atomic_uint32_t transfer_stream_idx_; unsigned int num_compute_streams_; unsigned int num_transfer_streams_; + unsigned int last_sync_compute_streams_; + unsigned int last_sync_transfer_streams_; unsigned int flags_; std::mutex compute_stream_mutex_; std::mutex transfer_stream_mutex_; @@ -403,7 +405,9 @@ struct _pi_queue { transfer_streams_{std::move(transfer_streams)}, context_{context}, device_{device}, properties_{properties}, refCount_{1}, eventCount_{0}, compute_stream_idx_{0}, transfer_stream_idx_{0}, - num_compute_streams_{0}, num_transfer_streams_{0}, flags_(flags) { + num_compute_streams_{0}, num_transfer_streams_{0}, + last_sync_compute_streams_{0}, last_sync_transfer_streams_{0}, + flags_(flags) { cuda_piContextRetain(context_); cuda_piDeviceRetain(device_); } @@ -440,6 +444,59 @@ struct _pi_queue { } } + template void sync_streams(T &&f) { + auto sync = [&f](const std::vector &streams, unsigned int start, + unsigned int stop) { + for (unsigned int i = start; i < stop; i++) { + f(streams[i]); + } + }; + { + unsigned int size = static_cast(compute_streams_.size()); + std::lock_guard compute_guard(compute_stream_mutex_); + unsigned int start = last_sync_compute_streams_; + unsigned int end = num_compute_streams_ < size + ? num_compute_streams_ + : compute_stream_idx_.load(); + last_sync_compute_streams_ = end; + if (end - start >= size) { + sync(compute_streams_, 0, size); + } else { + start %= size; + end %= size; + if (start < end) { + sync(compute_streams_, start, end); + } else { + sync(compute_streams_, start, size); + sync(compute_streams_, 0, end); + } + } + } + { + unsigned int size = static_cast(transfer_streams_.size()); + if (size > 0) { + std::lock_guard transfer_guard(transfer_stream_mutex_); + unsigned int start = last_sync_transfer_streams_; + unsigned int end = num_transfer_streams_ < size + ? num_transfer_streams_ + : transfer_stream_idx_.load(); + last_sync_transfer_streams_ = end; + if (end - start >= size) { + sync(transfer_streams_, 0, size); + } else { + start %= size; + end %= size; + if (start < end) { + sync(transfer_streams_, start, end); + } else { + sync(transfer_streams_, start, size); + sync(transfer_streams_, 0, end); + } + } + } + } + } + _pi_context *get_context() const { return context_; }; pi_uint32 increment_reference_count() noexcept { return ++refCount_; }