diff --git a/source/loader/layers/sanitizer/asan_buffer.cpp b/source/loader/layers/sanitizer/asan_buffer.cpp index 4cf90c7da4..382d6e3ada 100644 --- a/source/loader/layers/sanitizer/asan_buffer.cpp +++ b/source/loader/layers/sanitizer/asan_buffer.cpp @@ -75,12 +75,14 @@ ur_result_t MemBuffer::getHandle(ur_device_handle_t Device, char *&Handle) { return UR_RESULT_SUCCESS; } + std::scoped_lock Guard(Mutex); auto &Allocation = Allocations[Device]; + ur_result_t URes = UR_RESULT_SUCCESS; if (!Allocation) { ur_usm_desc_t USMDesc{}; USMDesc.align = getAlignment(); ur_usm_pool_handle_t Pool{}; - ur_result_t URes = getContext()->interceptor->allocateMemory( + URes = getContext()->interceptor->allocateMemory( Context, Device, &USMDesc, Pool, Size, AllocType::MEM_BUFFER, ur_cast(&Allocation)); if (URes != UR_RESULT_SUCCESS) { @@ -105,7 +107,60 @@ ur_result_t MemBuffer::getHandle(ur_device_handle_t Device, char *&Handle) { Handle = Allocation; - return UR_RESULT_SUCCESS; + if (!LastSyncedDevice.hDevice) { + LastSyncedDevice = MemBuffer::Device_t{Device, Handle}; + return URes; + } + + // If the device required to allocate memory is not the previous one, we + // need to do data migration. + if (Device != LastSyncedDevice.hDevice) { + auto &HostAllocation = Allocations[nullptr]; + if (!HostAllocation) { + ur_usm_desc_t USMDesc{}; + USMDesc.align = getAlignment(); + ur_usm_pool_handle_t Pool{}; + URes = getContext()->interceptor->allocateMemory( + Context, nullptr, &USMDesc, Pool, Size, AllocType::HOST_USM, + ur_cast(&HostAllocation)); + if (URes != UR_RESULT_SUCCESS) { + getContext()->logger.error("Failed to allocate {} bytes host " + "USM for buffer {} migration", + Size, this); + return URes; + } + } + + // Copy data from last synced device to host + { + ManagedQueue Queue(Context, LastSyncedDevice.hDevice); + URes = getContext()->urDdiTable.Enqueue.pfnUSMMemcpy( + Queue, true, HostAllocation, LastSyncedDevice.MemHandle, Size, + 0, nullptr, nullptr); + if (URes != UR_RESULT_SUCCESS) { + getContext()->logger.error( + "Failed to migrate memory buffer data"); + return URes; + } + } + + // Sync data back to device + { + ManagedQueue Queue(Context, Device); + URes = getContext()->urDdiTable.Enqueue.pfnUSMMemcpy( + Queue, true, Allocation, HostAllocation, Size, 0, nullptr, + nullptr); + if (URes != UR_RESULT_SUCCESS) { + getContext()->logger.error( + "Failed to migrate memory buffer data"); + return URes; + } + } + } + + LastSyncedDevice = MemBuffer::Device_t{Device, Handle}; + + return URes; } ur_result_t MemBuffer::free() { diff --git a/source/loader/layers/sanitizer/asan_buffer.hpp b/source/loader/layers/sanitizer/asan_buffer.hpp index b4eba4e4ba..989ef4249f 100644 --- a/source/loader/layers/sanitizer/asan_buffer.hpp +++ b/source/loader/layers/sanitizer/asan_buffer.hpp @@ -48,6 +48,12 @@ struct MemBuffer { ur_context_handle_t Context; + struct Device_t { + ur_device_handle_t hDevice; + char *MemHandle; + }; + Device_t LastSyncedDevice{}; + size_t Size; char *HostPtr{};