Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 19 additions & 11 deletions source/adapters/level_zero/command_buffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,8 @@ static ur_result_t enqueueCommandBufferMemCopyHelper(
SyncPointWaitList, ZeEventList));

ur_event_handle_t LaunchEvent;
UR_CALL(EventCreate(CommandBuffer->Context, nullptr, false, &LaunchEvent));
UR_CALL(
EventCreate(CommandBuffer->Context, nullptr, false, false, &LaunchEvent));
LaunchEvent->CommandType = CommandType;

// Get sync point and register the event with it.
Expand Down Expand Up @@ -358,7 +359,8 @@ static ur_result_t enqueueCommandBufferMemCopyRectHelper(
SyncPointWaitList, ZeEventList));

ur_event_handle_t LaunchEvent;
UR_CALL(EventCreate(CommandBuffer->Context, nullptr, false, &LaunchEvent));
UR_CALL(
EventCreate(CommandBuffer->Context, nullptr, false, false, &LaunchEvent));
LaunchEvent->CommandType = CommandType;

// Get sync point and register the event with it.
Expand Down Expand Up @@ -401,7 +403,8 @@ static ur_result_t enqueueCommandBufferFillHelper(
SyncPointWaitList, ZeEventList));

ur_event_handle_t LaunchEvent;
UR_CALL(EventCreate(CommandBuffer->Context, nullptr, true, &LaunchEvent));
UR_CALL(
EventCreate(CommandBuffer->Context, nullptr, false, true, &LaunchEvent));
LaunchEvent->CommandType = CommandType;

// Get sync point and register the event with it.
Expand Down Expand Up @@ -453,8 +456,10 @@ urCommandBufferCreateExp(ur_context_handle_t Context, ur_device_handle_t Device,
// Create signal & wait events to be used in the command-list for sync
// on command-buffer enqueue.
auto RetCommandBuffer = *CommandBuffer;
UR_CALL(EventCreate(Context, nullptr, false, &RetCommandBuffer->SignalEvent));
UR_CALL(EventCreate(Context, nullptr, false, &RetCommandBuffer->WaitEvent));
UR_CALL(EventCreate(Context, nullptr, false, false,
&RetCommandBuffer->SignalEvent));
UR_CALL(EventCreate(Context, nullptr, false, false,
&RetCommandBuffer->WaitEvent));

// Add prefix commands
ZE2UR_CALL(zeCommandListAppendEventReset,
Expand Down Expand Up @@ -550,7 +555,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendKernelLaunchExp(
UR_CALL(getEventsFromSyncPoints(CommandBuffer, NumSyncPointsInWaitList,
SyncPointWaitList, ZeEventList));
ur_event_handle_t LaunchEvent;
UR_CALL(EventCreate(CommandBuffer->Context, nullptr, false, &LaunchEvent));
UR_CALL(
EventCreate(CommandBuffer->Context, nullptr, false, false, &LaunchEvent));
LaunchEvent->CommandType = UR_COMMAND_KERNEL_LAUNCH;

// Get sync point and register the event with it.
Expand Down Expand Up @@ -732,7 +738,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendUSMPrefetchExp(
}

ur_event_handle_t LaunchEvent;
UR_CALL(EventCreate(CommandBuffer->Context, nullptr, true, &LaunchEvent));
UR_CALL(
EventCreate(CommandBuffer->Context, nullptr, false, true, &LaunchEvent));
LaunchEvent->CommandType = UR_COMMAND_USM_PREFETCH;

// Get sync point and register the event with it.
Expand Down Expand Up @@ -795,7 +802,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendUSMAdviseExp(
}

ur_event_handle_t LaunchEvent;
UR_CALL(EventCreate(CommandBuffer->Context, nullptr, true, &LaunchEvent));
UR_CALL(
EventCreate(CommandBuffer->Context, nullptr, false, true, &LaunchEvent));
LaunchEvent->CommandType = UR_COMMAND_USM_ADVISE;

// Get sync point and register the event with it.
Expand Down Expand Up @@ -933,9 +941,9 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferEnqueueExp(
(SignalCommandList->first, CommandBuffer->WaitEvent->ZeEvent));

if (Event) {
UR_CALL(createEventAndAssociateQueue(Queue, &RetEvent,
UR_COMMAND_COMMAND_BUFFER_ENQUEUE_EXP,
SignalCommandList, false, true));
UR_CALL(createEventAndAssociateQueue(
Queue, &RetEvent, UR_COMMAND_COMMAND_BUFFER_ENQUEUE_EXP,
SignalCommandList, false, false, true));

if ((Queue->Properties & UR_QUEUE_FLAG_PROFILING_ENABLE)) {
// Multiple submissions of a command buffer implies that we need to save
Expand Down
47 changes: 34 additions & 13 deletions source/adapters/level_zero/context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -471,12 +471,17 @@ static const uint32_t MaxNumEventsPerPool = [] {

ur_result_t ur_context_handle_t_::getFreeSlotInExistingOrNewPool(
ze_event_pool_handle_t &Pool, size_t &Index, bool HostVisible,
bool ProfilingEnabled) {
bool ProfilingEnabled, ur_device_handle_t Device) {
// Lock while updating event pool machinery.
std::scoped_lock<ur_mutex> Lock(ZeEventPoolCacheMutex);

ze_device_handle_t ZeDevice = nullptr;

if (Device) {
ZeDevice = Device->ZeDevice;
}
std::list<ze_event_pool_handle_t> *ZePoolCache =
getZeEventPoolCache(HostVisible, ProfilingEnabled);
getZeEventPoolCache(HostVisible, ProfilingEnabled, ZeDevice);

if (!ZePoolCache->empty()) {
if (NumEventsAvailableInEventPool[ZePoolCache->front()] == 0) {
Expand Down Expand Up @@ -511,9 +516,14 @@ ur_result_t ur_context_handle_t_::getFreeSlotInExistingOrNewPool(
urPrint("ze_event_pool_desc_t flags set to: %d\n", ZeEventPoolDesc.flags);

std::vector<ze_device_handle_t> ZeDevices;
std::for_each(
Devices.begin(), Devices.end(),
[&](const ur_device_handle_t &D) { ZeDevices.push_back(D->ZeDevice); });
if (ZeDevice) {
ZeDevices.push_back(ZeDevice);
} else {
std::for_each(Devices.begin(), Devices.end(),
[&](const ur_device_handle_t &D) {
ZeDevices.push_back(D->ZeDevice);
});
}

ZE2UR_CALL(zeEventPoolCreate, (ZeContext, &ZeEventPoolDesc,
ZeDevices.size(), &ZeDevices[0], ZePool));
Expand All @@ -528,11 +538,10 @@ ur_result_t ur_context_handle_t_::getFreeSlotInExistingOrNewPool(
return UR_RESULT_SUCCESS;
}

ur_event_handle_t
ur_context_handle_t_::getEventFromContextCache(bool HostVisible,
bool WithProfiling) {
ur_event_handle_t ur_context_handle_t_::getEventFromContextCache(
bool HostVisible, bool WithProfiling, ur_device_handle_t Device) {
std::scoped_lock<ur_mutex> Lock(EventCacheMutex);
auto Cache = getEventCache(HostVisible, WithProfiling);
auto Cache = getEventCache(HostVisible, WithProfiling, Device);
if (Cache->empty())
return nullptr;

Expand All @@ -546,8 +555,14 @@ ur_context_handle_t_::getEventFromContextCache(bool HostVisible,

void ur_context_handle_t_::addEventToContextCache(ur_event_handle_t Event) {
std::scoped_lock<ur_mutex> Lock(EventCacheMutex);
auto Cache =
getEventCache(Event->isHostVisible(), Event->isProfilingEnabled());
ur_device_handle_t Device = nullptr;

if (!Event->IsMultiDevice && Event->UrQueue) {
Device = Event->UrQueue->Device;
}

auto Cache = getEventCache(Event->isHostVisible(),
Event->isProfilingEnabled(), Device);
Cache->emplace_back(Event);
}

Expand All @@ -562,8 +577,14 @@ ur_context_handle_t_::decrementUnreleasedEventsInPool(ur_event_handle_t Event) {
return UR_RESULT_SUCCESS;
}

std::list<ze_event_pool_handle_t> *ZePoolCache =
getZeEventPoolCache(Event->isHostVisible(), Event->isProfilingEnabled());
ze_device_handle_t ZeDevice = nullptr;

if (!Event->IsMultiDevice && Event->UrQueue) {
ZeDevice = Event->UrQueue->Device->ZeDevice;
}

std::list<ze_event_pool_handle_t> *ZePoolCache = getZeEventPoolCache(
Event->isHostVisible(), Event->isProfilingEnabled(), ZeDevice);

// Put the empty pool to the cache of the pools.
if (NumEventsUnreleasedInEventPool[Event->ZeEventPool] == 0)
Expand Down
78 changes: 66 additions & 12 deletions source/adapters/level_zero/context.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,9 @@ struct ur_context_handle_t_ : _ur_object {
//
// Cache of event pools to which host-visible events are added to.
std::vector<std::list<ze_event_pool_handle_t>> ZeEventPoolCache{4};
std::vector<std::unordered_map<ze_device_handle_t,
std::list<ze_event_pool_handle_t> *>>
ZeEventPoolCacheDeviceMap{4};

// This map will be used to determine if a pool is full or not
// by storing number of empty slots available in the pool.
Expand All @@ -163,6 +166,9 @@ struct ur_context_handle_t_ : _ur_object {

// Caches for events.
std::vector<std::list<ur_event_handle_t>> EventCaches{4};
std::vector<
std::unordered_map<ur_device_handle_t, std::list<ur_event_handle_t> *>>
EventCachesDeviceMap{4};

// Initialize the PI context.
ur_result_t initialize();
Expand All @@ -188,20 +194,46 @@ struct ur_context_handle_t_ : _ur_object {
// slot for an event with profiling capabilities.
ur_result_t getFreeSlotInExistingOrNewPool(ze_event_pool_handle_t &, size_t &,
bool HostVisible,
bool ProfilingEnabled);
bool ProfilingEnabled,
ur_device_handle_t Device);

// Get ur_event_handle_t from cache.
ur_event_handle_t getEventFromContextCache(bool HostVisible,
bool WithProfiling);
bool WithProfiling,
ur_device_handle_t Device);

// Add ur_event_handle_t to cache.
void addEventToContextCache(ur_event_handle_t);

auto getZeEventPoolCache(bool HostVisible, bool WithProfiling) {
if (HostVisible)
return WithProfiling ? &ZeEventPoolCache[0] : &ZeEventPoolCache[1];
else
return WithProfiling ? &ZeEventPoolCache[2] : &ZeEventPoolCache[3];
auto getZeEventPoolCache(bool HostVisible, bool WithProfiling,
ze_device_handle_t ZeDevice) {
if (HostVisible) {
if (ZeDevice) {
auto ZeEventPoolCacheMap = WithProfiling
? &ZeEventPoolCacheDeviceMap[0]
: &ZeEventPoolCacheDeviceMap[1];
if (ZeEventPoolCacheMap->find(ZeDevice) == ZeEventPoolCacheMap->end()) {
ZeEventPoolCache.emplace_back();
(*ZeEventPoolCacheMap)[ZeDevice] = &ZeEventPoolCache.back();
}
return (*ZeEventPoolCacheMap)[ZeDevice];
} else {
return WithProfiling ? &ZeEventPoolCache[0] : &ZeEventPoolCache[1];
}
} else {
if (ZeDevice) {
auto ZeEventPoolCacheMap = WithProfiling
? &ZeEventPoolCacheDeviceMap[2]
: &ZeEventPoolCacheDeviceMap[3];
if (ZeEventPoolCacheMap->find(ZeDevice) == ZeEventPoolCacheMap->end()) {
ZeEventPoolCache.emplace_back();
(*ZeEventPoolCacheMap)[ZeDevice] = &ZeEventPoolCache.back();
}
return (*ZeEventPoolCacheMap)[ZeDevice];
} else {
return WithProfiling ? &ZeEventPoolCache[2] : &ZeEventPoolCache[3];
}
}
}

// Decrement number of events living in the pool upon event destroy
Expand Down Expand Up @@ -240,11 +272,33 @@ struct ur_context_handle_t_ : _ur_object {

private:
// Get the cache of events for a provided scope and profiling mode.
auto getEventCache(bool HostVisible, bool WithProfiling) {
if (HostVisible)
return WithProfiling ? &EventCaches[0] : &EventCaches[1];
else
return WithProfiling ? &EventCaches[2] : &EventCaches[3];
auto getEventCache(bool HostVisible, bool WithProfiling,
ur_device_handle_t Device) {
if (HostVisible) {
if (Device) {
auto EventCachesMap =
WithProfiling ? &EventCachesDeviceMap[0] : &EventCachesDeviceMap[1];
if (EventCachesMap->find(Device) == EventCachesMap->end()) {
EventCaches.emplace_back();
(*EventCachesMap)[Device] = &EventCaches.back();
}
return (*EventCachesMap)[Device];
} else {
return WithProfiling ? &EventCaches[0] : &EventCaches[1];
}
} else {
if (Device) {
auto EventCachesMap =
WithProfiling ? &EventCachesDeviceMap[2] : &EventCachesDeviceMap[3];
if (EventCachesMap->find(Device) == EventCachesMap->end()) {
EventCaches.emplace_back();
(*EventCachesMap)[Device] = &EventCaches.back();
}
return (*EventCachesMap)[Device];
} else {
return WithProfiling ? &EventCaches[2] : &EventCaches[3];
}
}
}
};

Expand Down
Loading