Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ ur_result_t ur_queue_immediate_out_of_order_t::queueGetInfo(
ur_result_t ur_queue_immediate_out_of_order_t::queueGetNativeHandle(
ur_queue_native_desc_t *pDesc, ur_native_handle_t *phNativeQueue) {
*phNativeQueue = reinterpret_cast<ur_native_handle_t>(
(*commandListManagers.get_no_lock())[getNextCommandListId()]
(*commandListManagers.get_no_lock())[getNextCommandListId(false)]
.getZeCommandList());
if (pDesc && pDesc->pNativeData) {
// pNativeData == isImmediateQueue
Expand All @@ -112,10 +112,16 @@ ur_result_t ur_queue_immediate_out_of_order_t::queueFinish() {

auto commandListManagersLocked = commandListManagers.lock();

// Only synchronize command lists that have been used to avoid unnecessary
// synchronization overhead.
uint32_t usedMask =
usedCommandListsMask.exchange(0, std::memory_order_relaxed);
for (size_t i = 0; i < numCommandLists; i++) {
ZE2UR_CALL(zeCommandListHostSynchronize,
(commandListManagersLocked[i].getZeCommandList(), UINT64_MAX));
UR_CALL(commandListManagersLocked[i].releaseSubmittedKernels());
if (usedMask & (1u << i)) {
ZE2UR_CALL(zeCommandListHostSynchronize,
(commandListManagersLocked[i].getZeCommandList(), UINT64_MAX));
UR_CALL(commandListManagersLocked[i].releaseSubmittedKernels());
}
}

hContext->getAsyncPool()->cleanupPoolsForQueue(this);
Expand Down Expand Up @@ -164,6 +170,11 @@ ur_result_t ur_queue_immediate_out_of_order_t::enqueueEventsWaitWithBarrier(

auto commandListManagersLocked = commandListManagers.lock();

// Mark the command list as used; queueFinish() must then synchronize every
// list touched here.
usedCommandListsMask.fetch_or((1u << numCommandLists) - 1,
std::memory_order_relaxed);

// Enqueue wait for the user-provider events on the first command list.
UR_CALL(commandListManagersLocked[0].appendEventsWait(waitListView,
barrierEvents[0]));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,20 +45,33 @@ struct ur_queue_immediate_out_of_order_t : ur_object, ur_queue_t_ {
lockable<std::array<ur_command_list_manager, numCommandLists>>
commandListManagers;

// Track which command lists have pending work to avoid unnecessary
// synchronization in queueFinish(). Each bit represents one command list.
std::atomic<uint32_t> usedCommandListsMask = 0;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's probably not gonna happen, but in case the hardcoded number of command lists ever gets changed to something bigger than 32, could you add a comment at the static constexpr size_t numCommandLists = 4; definition that there's a mask that would also need to be modified.


ur_queue_flags_t flags;

std::array<ur_event_handle_t, numCommandLists> barrierEvents;

uint32_t getNextCommandListId() {
uint32_t getNextCommandListId(bool markUsed = true) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I looked at the uses of getNextCommandListId, and I don't see the point of using atomic variables, for the mask you are adding here, and commandListIndex. The parallelization we can achieve is minimal. And using a single queue from multiple threads is already not a frequent use.

I suggest changing this to regular variables under common lockable with commandListManagers. Creating a class with these three variables would make the most sense to me.

bool isGraphCaptureActive;
auto &cmdListManager =
(*commandListManagers.get_no_lock())[captureCmdListManagerIdx];
cmdListManager.isGraphCaptureActive(&isGraphCaptureActive);

return isGraphCaptureActive
? captureCmdListManagerIdx
: commandListIndex.fetch_add(1, std::memory_order_relaxed) %
numCommandLists;
uint32_t id =
isGraphCaptureActive
? captureCmdListManagerIdx
: commandListIndex.fetch_add(1, std::memory_order_relaxed) %
numCommandLists;

if (markUsed) {
// Mark this command list as used so queueFinish() synchronizes only
// lists that actually carried work.
usedCommandListsMask.fetch_or(1u << id, std::memory_order_relaxed);
}

return id;
}

public:
Expand Down