Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 57 additions & 2 deletions source/loader/layers/sanitizer/asan_buffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,12 +75,14 @@ ur_result_t MemBuffer::getHandle(ur_device_handle_t Device, char *&Handle) {
return UR_RESULT_SUCCESS;
}

std::scoped_lock<ur_shared_mutex> Guard(Mutex);
auto &Allocation = Allocations[Device];
ur_result_t URes = UR_RESULT_SUCCESS;
if (!Allocation) {
ur_usm_desc_t USMDesc{};
USMDesc.align = getAlignment();
ur_usm_pool_handle_t Pool{};
ur_result_t URes = getContext()->interceptor->allocateMemory(
URes = getContext()->interceptor->allocateMemory(
Context, Device, &USMDesc, Pool, Size, AllocType::MEM_BUFFER,
ur_cast<void **>(&Allocation));
if (URes != UR_RESULT_SUCCESS) {
Expand All @@ -105,7 +107,60 @@ ur_result_t MemBuffer::getHandle(ur_device_handle_t Device, char *&Handle) {

Handle = Allocation;

return UR_RESULT_SUCCESS;
if (!LastSyncedDevice.hDevice) {
LastSyncedDevice = MemBuffer::Device_t{Device, Handle};
return URes;
}

// If the device required to allocate memory is not the previous one, we
// need to do data migration.
if (Device != LastSyncedDevice.hDevice) {
auto &HostAllocation = Allocations[nullptr];
if (!HostAllocation) {
ur_usm_desc_t USMDesc{};
USMDesc.align = getAlignment();
ur_usm_pool_handle_t Pool{};
URes = getContext()->interceptor->allocateMemory(
Context, nullptr, &USMDesc, Pool, Size, AllocType::HOST_USM,
ur_cast<void **>(&HostAllocation));
if (URes != UR_RESULT_SUCCESS) {
getContext()->logger.error("Failed to allocate {} bytes host "
"USM for buffer {} migration",
Size, this);
return URes;
}
}

// Copy data from last synced device to host
{
ManagedQueue Queue(Context, LastSyncedDevice.hDevice);
URes = getContext()->urDdiTable.Enqueue.pfnUSMMemcpy(
Queue, true, HostAllocation, LastSyncedDevice.MemHandle, Size,
0, nullptr, nullptr);
if (URes != UR_RESULT_SUCCESS) {
getContext()->logger.error(
"Failed to migrate memory buffer data");
return URes;
}
}

// Sync data back to device
{
ManagedQueue Queue(Context, Device);
URes = getContext()->urDdiTable.Enqueue.pfnUSMMemcpy(
Queue, true, Allocation, HostAllocation, Size, 0, nullptr,
nullptr);
if (URes != UR_RESULT_SUCCESS) {
getContext()->logger.error(
"Failed to migrate memory buffer data");
return URes;
}
}
}

LastSyncedDevice = MemBuffer::Device_t{Device, Handle};

return URes;
}

ur_result_t MemBuffer::free() {
Expand Down
6 changes: 6 additions & 0 deletions source/loader/layers/sanitizer/asan_buffer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,12 @@ struct MemBuffer {

ur_context_handle_t Context;

struct Device_t {
ur_device_handle_t hDevice;
char *MemHandle;
};
Device_t LastSyncedDevice{};

size_t Size;

char *HostPtr{};
Expand Down