@@ -391,6 +391,8 @@ struct _pi_queue {
391391 std::atomic_uint32_t transfer_stream_idx_;
392392 unsigned int num_compute_streams_;
393393 unsigned int num_transfer_streams_;
394+ unsigned int last_sync_compute_streams_;
395+ unsigned int last_sync_transfer_streams_;
394396 unsigned int flags_;
395397 std::mutex compute_stream_mutex_;
396398 std::mutex transfer_stream_mutex_;
@@ -403,7 +405,9 @@ struct _pi_queue {
403405 transfer_streams_{std::move (transfer_streams)}, context_{context},
404406 device_{device}, properties_{properties}, refCount_{1 }, eventCount_{0 },
405407 compute_stream_idx_{0 }, transfer_stream_idx_{0 },
406- num_compute_streams_{0 }, num_transfer_streams_{0 }, flags_(flags) {
408+ num_compute_streams_{0 }, num_transfer_streams_{0 },
409+ last_sync_compute_streams_{0 }, last_sync_transfer_streams_{0 },
410+ flags_ (flags) {
407411 cuda_piContextRetain (context_);
408412 cuda_piDeviceRetain (device_);
409413 }
@@ -440,6 +444,59 @@ struct _pi_queue {
440444 }
441445 }
442446
447+ template <typename T> void sync_streams (T &&f) {
448+ auto sync = [&f](const std::vector<CUstream> &streams, unsigned int start,
449+ unsigned int stop) {
450+ for (unsigned int i = start; i < stop; i++) {
451+ f (streams[i]);
452+ }
453+ };
454+ {
455+ unsigned int size = static_cast <unsigned int >(compute_streams_.size ());
456+ std::lock_guard<std::mutex> compute_guard (compute_stream_mutex_);
457+ unsigned int start = last_sync_compute_streams_;
458+ unsigned int end = num_compute_streams_ < size
459+ ? num_compute_streams_
460+ : compute_stream_idx_.load ();
461+ last_sync_compute_streams_ = end;
462+ if (end - start >= size) {
463+ sync (compute_streams_, 0 , size);
464+ } else {
465+ start %= size;
466+ end %= size;
467+ if (start < end) {
468+ sync (compute_streams_, start, end);
469+ } else {
470+ sync (compute_streams_, start, size);
471+ sync (compute_streams_, 0 , end);
472+ }
473+ }
474+ }
475+ {
476+ unsigned int size = static_cast <unsigned int >(transfer_streams_.size ());
477+ if (size > 0 ) {
478+ std::lock_guard<std::mutex> transfer_guard (transfer_stream_mutex_);
479+ unsigned int start = last_sync_transfer_streams_;
480+ unsigned int end = num_transfer_streams_ < size
481+ ? num_transfer_streams_
482+ : transfer_stream_idx_.load ();
483+ last_sync_transfer_streams_ = end;
484+ if (end - start >= size) {
485+ sync (transfer_streams_, 0 , size);
486+ } else {
487+ start %= size;
488+ end %= size;
489+ if (start < end) {
490+ sync (transfer_streams_, start, end);
491+ } else {
492+ sync (transfer_streams_, start, size);
493+ sync (transfer_streams_, 0 , end);
494+ }
495+ }
496+ }
497+ }
498+ }
499+
443500 _pi_context *get_context () const { return context_; };
444501
445502 _pi_device *get_device () const { return device_; };
0 commit comments