diff --git a/unified-runtime/source/adapters/level_zero/v2/memory.cpp b/unified-runtime/source/adapters/level_zero/v2/memory.cpp index 1f9a3287b51cc..18c72eb592303 100644 --- a/unified-runtime/source/adapters/level_zero/v2/memory.cpp +++ b/unified-runtime/source/adapters/level_zero/v2/memory.cpp @@ -94,6 +94,7 @@ ur_integrated_buffer_handle_t::ur_integrated_buffer_handle_t( if (hostPtr) { std::memcpy(this->ptr.get(), hostPtr, size); + writeBackPtr = hostPtr; } } } @@ -110,17 +111,23 @@ ur_integrated_buffer_handle_t::ur_integrated_buffer_handle_t( }); } +ur_integrated_buffer_handle_t::~ur_integrated_buffer_handle_t() { + if (writeBackPtr) { + std::memcpy(writeBackPtr, this->ptr.get(), size); + } +} + void *ur_integrated_buffer_handle_t::getDevicePtr( ur_device_handle_t /*hDevice*/, device_access_mode_t /*access*/, - size_t /*offset*/, size_t /*size*/, + size_t offset, size_t /*size*/, std::function /*migrate*/) { - return ptr.get(); + return ur_cast(ptr.get()) + offset; } void *ur_integrated_buffer_handle_t::mapHostPtr( - ur_map_flags_t /*flags*/, size_t /*offset*/, size_t /*size*/, + ur_map_flags_t /*flags*/, size_t offset, size_t /*size*/, std::function /*migrate*/) { - return ptr.get(); + return ur_cast(ptr.get()) + offset; } void ur_integrated_buffer_handle_t::unmapHostPtr( diff --git a/unified-runtime/source/adapters/level_zero/v2/memory.hpp b/unified-runtime/source/adapters/level_zero/v2/memory.hpp index 8f54ebd550884..9d2a07943437d 100644 --- a/unified-runtime/source/adapters/level_zero/v2/memory.hpp +++ b/unified-runtime/source/adapters/level_zero/v2/memory.hpp @@ -87,6 +87,8 @@ struct ur_integrated_buffer_handle_t : ur_mem_buffer_t { size_t size, device_access_mode_t accesMode, bool ownHostPtr); + ~ur_integrated_buffer_handle_t(); + void * getDevicePtr(ur_device_handle_t, device_access_mode_t, size_t offset, size_t size, @@ -98,6 +100,7 @@ struct ur_integrated_buffer_handle_t : ur_mem_buffer_t { private: usm_unique_ptr_t ptr; + void *writeBackPtr = nullptr; }; struct host_allocation_desc_t {