diff --git a/clang/tools/clang-offload-wrapper/ClangOffloadWrapper.cpp b/clang/tools/clang-offload-wrapper/ClangOffloadWrapper.cpp index ee913fc8d0eeb..bb18a7214fd3f 100644 --- a/clang/tools/clang-offload-wrapper/ClangOffloadWrapper.cpp +++ b/clang/tools/clang-offload-wrapper/ClangOffloadWrapper.cpp @@ -1293,6 +1293,48 @@ class BinaryWrapper { appendToGlobalDtors(M, Func, /*Priority*/ 1); } + void createSyclRegisterWithAtexitUnregister(GlobalVariable *BinDesc) { + auto *UnregFuncTy = + FunctionType::get(Type::getVoidTy(C), /*isVarArg*/ false); + auto *UnregFunc = + Function::Create(UnregFuncTy, GlobalValue::InternalLinkage, + "sycl.descriptor_unreg.atexit", &M); + UnregFunc->setSection(".text.startup"); + + // Declaration for __sycl_unregister_lib(void*). + auto *UnregTargetTy = + FunctionType::get(Type::getVoidTy(C), getPtrTy(), /*isVarArg=*/false); + FunctionCallee UnregTargetC = + M.getOrInsertFunction("__sycl_unregister_lib", UnregTargetTy); + + IRBuilder<> UnregBuilder(BasicBlock::Create(C, "entry", UnregFunc)); + UnregBuilder.CreateCall(UnregTargetC, BinDesc); + UnregBuilder.CreateRetVoid(); + + auto *RegFuncTy = FunctionType::get(Type::getVoidTy(C), /*isVarArg*/ false); + auto *RegFunc = Function::Create(RegFuncTy, GlobalValue::InternalLinkage, + "sycl.descriptor_reg", &M); + RegFunc->setSection(".text.startup"); + + auto *RegTargetTy = + FunctionType::get(Type::getVoidTy(C), getPtrTy(), false); + FunctionCallee RegTargetC = + M.getOrInsertFunction("__sycl_register_lib", RegTargetTy); + + // `atexit` takes a `void(*)()` function pointer arg and returns an i32. + FunctionType *AtExitTy = + FunctionType::get(Type::getInt32Ty(C), getPtrTy(), false); + FunctionCallee AtExitC = M.getOrInsertFunction("atexit", AtExitTy); + + IRBuilder<> RegBuilder(BasicBlock::Create(C, "entry", RegFunc)); + RegBuilder.CreateCall(RegTargetC, BinDesc); + RegBuilder.CreateCall(AtExitC, UnregFunc); + RegBuilder.CreateRetVoid(); + + // Add this function to global destructors. + appendToGlobalCtors(M, RegFunc, /*Priority*/ 1); + } + public: BinaryWrapper(StringRef Target, StringRef ToolName, StringRef SymPropBCFiles = "") @@ -1370,8 +1412,13 @@ class BinaryWrapper { if (EmitRegFuncs) { GlobalVariable *Desc = *DescOrErr; - createRegisterFunction(Kind, Desc); - createUnregisterFunction(Kind, Desc); + if (Kind == OffloadKind::SYCL && + Triple(M.getTargetTriple()).isOSWindows()) { + createSyclRegisterWithAtexitUnregister(Desc); + } else { + createRegisterFunction(Kind, Desc); + createUnregisterFunction(Kind, Desc); + } } } return &M; diff --git a/llvm/lib/Frontend/Offloading/SYCLOffloadWrapper.cpp b/llvm/lib/Frontend/Offloading/SYCLOffloadWrapper.cpp index 3d227d0c2e050..075647e3236ab 100644 --- a/llvm/lib/Frontend/Offloading/SYCLOffloadWrapper.cpp +++ b/llvm/lib/Frontend/Offloading/SYCLOffloadWrapper.cpp @@ -34,6 +34,7 @@ #include "llvm/Support/ErrorHandling.h" #include "llvm/Support/LineIterator.h" #include "llvm/Support/PropertySetIO.h" +#include "llvm/TargetParser/Triple.h" #include "llvm/Transforms/Utils/ModuleUtils.h" #include #include @@ -734,6 +735,50 @@ struct Wrapper { // Add this function to global destructors. appendToGlobalDtors(M, Func, /*Priority*/ 1); } + + void createSyclRegisterWithAtexitUnregister(GlobalVariable *FatbinDesc) { + auto *UnregFuncTy = + FunctionType::get(Type::getVoidTy(C), /*isVarArg*/ false); + auto *UnregFunc = + Function::Create(UnregFuncTy, GlobalValue::InternalLinkage, + "sycl.descriptor_unreg.atexit", &M); + UnregFunc->setSection(".text.startup"); + + // Declaration for __sycl_unregister_lib(void*). + auto *UnregTargetTy = + FunctionType::get(Type::getVoidTy(C), PointerType::getUnqual(C), false); + FunctionCallee UnregTargetC = + M.getOrInsertFunction("__sycl_unregister_lib", UnregTargetTy); + + // Body of the unregister wrapper. + IRBuilder<> UnregBuilder(BasicBlock::Create(C, "entry", UnregFunc)); + UnregBuilder.CreateCall(UnregTargetC, FatbinDesc); + UnregBuilder.CreateRetVoid(); + + auto *RegFuncTy = FunctionType::get(Type::getVoidTy(C), /*isVarArg*/ false); + auto *RegFunc = Function::Create(RegFuncTy, GlobalValue::InternalLinkage, + "sycl.descriptor_reg", &M); + RegFunc->setSection(".text.startup"); + + auto *RegTargetTy = + FunctionType::get(Type::getVoidTy(C), PointerType::getUnqual(C), false); + FunctionCallee RegTargetC = + M.getOrInsertFunction("__sycl_register_lib", RegTargetTy); + + // `atexit` takes a `void(*)()` function pointer arg and returns an i32. + FunctionType *AtExitTy = FunctionType::get( + Type::getInt32Ty(C), PointerType::getUnqual(C), false); + FunctionCallee AtExitC = M.getOrInsertFunction("atexit", AtExitTy); + + IRBuilder<> RegBuilder(BasicBlock::Create(C, "entry", RegFunc)); + RegBuilder.CreateCall(RegTargetC, FatbinDesc); + RegBuilder.CreateCall(AtExitC, UnregFunc); + RegBuilder.CreateRetVoid(); + + // Finally, add to global constructors. + appendToGlobalCtors(M, RegFunc, /*Priority*/ 1); + } + }; // end of Wrapper } // anonymous namespace @@ -747,7 +792,11 @@ Error llvm::offloading::wrapSYCLBinaries(llvm::Module &M, return createStringError(inconvertibleErrorCode(), "No binary descriptors created."); - W.createRegisterFatbinFunction(Desc); - W.createUnregisterFunction(Desc); + if (Triple(M.getTargetTriple()).isOSWindows()) { + W.createSyclRegisterWithAtexitUnregister(Desc); + } else { + W.createRegisterFatbinFunction(Desc); + W.createUnregisterFunction(Desc); + } return Error::success(); } diff --git a/sycl/source/detail/context_impl.cpp b/sycl/source/detail/context_impl.cpp index 5e027466d7949..6fb2dd375fe37 100644 --- a/sycl/source/detail/context_impl.cpp +++ b/sycl/source/detail/context_impl.cpp @@ -125,7 +125,8 @@ context_impl::~context_impl() { DeviceGlobalMapEntry *DGEntry = detail::ProgramManager::getInstance().getDeviceGlobalEntry( DeviceGlobal); - DGEntry->removeAssociatedResources(this); + if (DGEntry != nullptr) + DGEntry->removeAssociatedResources(this); } MCachedLibPrograms.clear(); // TODO catch an exception and put it to list of asynchronous exceptions diff --git a/sycl/source/detail/device_global_map.hpp b/sycl/source/detail/device_global_map.hpp index 256c48066ec87..afd780c97bcfe 100644 --- a/sycl/source/detail/device_global_map.hpp +++ b/sycl/source/detail/device_global_map.hpp @@ -87,6 +87,7 @@ class DeviceGlobalMap { }); if (findDevGlobalByValue != MPtr2DeviceGlobal.end()) MPtr2DeviceGlobal.erase(findDevGlobalByValue); + MDeviceGlobals.erase(DevGlobalIt); } } @@ -112,8 +113,7 @@ class DeviceGlobalMap { DeviceGlobalMapEntry *getEntry(const void *DeviceGlobalPtr) { std::lock_guard DeviceGlobalsGuard(MDeviceGlobalsMutex); auto Entry = MPtr2DeviceGlobal.find(DeviceGlobalPtr); - assert(Entry != MPtr2DeviceGlobal.end() && "Device global entry not found"); - return Entry->second; + return (Entry != MPtr2DeviceGlobal.end()) ? Entry->second : nullptr; } DeviceGlobalMapEntry * diff --git a/sycl/source/detail/device_global_map_entry.cpp b/sycl/source/detail/device_global_map_entry.cpp index 1f82a605056dc..25704caaee6de 100644 --- a/sycl/source/detail/device_global_map_entry.cpp +++ b/sycl/source/detail/device_global_map_entry.cpp @@ -21,6 +21,18 @@ DeviceGlobalUSMMem::~DeviceGlobalUSMMem() { // removeAssociatedResources is expected to have cleaned up both the pointer // and the event. When asserts are enabled the values are set, so we check // these here. + auto ContextImplPtr = MAllocatingContext.lock(); + if (ContextImplPtr) { + if (MPtr != nullptr) { + detail::usm::freeInternal(MPtr, ContextImplPtr.get()); + MPtr = nullptr; + } + if (MInitEvent != nullptr) { + ContextImplPtr->getAdapter().call(MInitEvent); + MInitEvent = nullptr; + } + } + assert(MPtr == nullptr && "MPtr has not been cleaned up."); assert(MInitEvent == nullptr && "MInitEvent has not been cleaned up."); } @@ -63,6 +75,7 @@ DeviceGlobalMapEntry::getOrAllocateDeviceGlobalUSM(queue_impl &QueueImpl) { assert(NewAllocIt.second && "USM allocation for device and context already happened."); DeviceGlobalUSMMem &NewAlloc = NewAllocIt.first->second; + NewAlloc.MAllocatingContext = CtxImpl.shared_from_this(); // Initialize here and save the event. { @@ -120,6 +133,7 @@ DeviceGlobalMapEntry::getOrAllocateDeviceGlobalUSM(const context &Context) { assert(NewAllocIt.second && "USM allocation for device and context already happened."); DeviceGlobalUSMMem &NewAlloc = NewAllocIt.first->second; + NewAlloc.MAllocatingContext = CtxImpl.shared_from_this(); if (MDeviceGlobalPtr) { // C++ guarantees members appear in memory in the order they are declared, @@ -161,12 +175,9 @@ void DeviceGlobalMapEntry::removeAssociatedResources( if (USMMem.MInitEvent != nullptr) CtxImpl->getAdapter().call( USMMem.MInitEvent); -#ifndef NDEBUG - // For debugging we set the event and memory to some recognizable values - // to allow us to check that this cleanup happens before erasure. + // Set to nullptr to avoid double free. USMMem.MPtr = nullptr; USMMem.MInitEvent = nullptr; -#endif MDeviceToUSMPtrMap.erase(USMPtrIt); } } @@ -185,12 +196,9 @@ void DeviceGlobalMapEntry::cleanup() { detail::usm::freeInternal(USMMem.MPtr, CtxImpl); if (USMMem.MInitEvent != nullptr) CtxImpl->getAdapter().call(USMMem.MInitEvent); -#ifndef NDEBUG - // For debugging we set the event and memory to some recognizable values - // to allow us to check that this cleanup happens before erasure. + // Set to nullptr to avoid double free. USMMem.MPtr = nullptr; USMMem.MInitEvent = nullptr; -#endif } MDeviceToUSMPtrMap.clear(); } diff --git a/sycl/source/detail/device_global_map_entry.hpp b/sycl/source/detail/device_global_map_entry.hpp index 1796e8d179db1..591d06cc2f7bf 100644 --- a/sycl/source/detail/device_global_map_entry.hpp +++ b/sycl/source/detail/device_global_map_entry.hpp @@ -46,6 +46,7 @@ struct DeviceGlobalUSMMem { std::mutex MInitEventMutex; ur_event_handle_t MInitEvent = nullptr; + std::weak_ptr MAllocatingContext; friend struct DeviceGlobalMapEntry; }; diff --git a/sycl/source/detail/program_manager/program_manager.cpp b/sycl/source/detail/program_manager/program_manager.cpp index 77f28a5131f8a..dd8a3dd72d6b3 100644 --- a/sycl/source/detail/program_manager/program_manager.cpp +++ b/sycl/source/detail/program_manager/program_manager.cpp @@ -3886,10 +3886,5 @@ extern "C" void __sycl_register_lib(sycl_device_binaries desc) { // Executed as a part of current module's (.exe, .dll) static initialization extern "C" void __sycl_unregister_lib(sycl_device_binaries desc) { - // Partial cleanup is not necessary at shutdown -#ifndef _WIN32 - if (!sycl::detail::GlobalHandler::instance().isOkToDefer()) - return; sycl::detail::ProgramManager::getInstance().removeImages(desc); -#endif } diff --git a/sycl/test-e2e/Basic/stream/zero_buffer_size.cpp b/sycl/test-e2e/Basic/stream/zero_buffer_size.cpp index be60c118db3e7..283eb7f1204f2 100644 --- a/sycl/test-e2e/Basic/stream/zero_buffer_size.cpp +++ b/sycl/test-e2e/Basic/stream/zero_buffer_size.cpp @@ -1,3 +1,6 @@ +// UNSUPPORTED: hip +// UNSUPPORTED-TRACKER: CMPLRLLVM-69478 + // RUN: %{build} -o %t.out // RUN: %{run} %t.out diff --git a/sycl/test-e2e/IntermediateLib/Inputs/incrementing_lib.cpp b/sycl/test-e2e/IntermediateLib/Inputs/incrementing_lib.cpp new file mode 100644 index 0000000000000..9c8faf38700cc --- /dev/null +++ b/sycl/test-e2e/IntermediateLib/Inputs/incrementing_lib.cpp @@ -0,0 +1,38 @@ +#include + +#if defined(_WIN32) +#define API_EXPORT __declspec(dllexport) +#else +#define API_EXPORT +#endif + +#ifndef INC +#define INC 1 +#endif + +#ifndef CLASSNAME +#define CLASSNAME same +#endif + +#ifdef WITH_DEVICE_GLOBALS +// Using device globals within the shared libraries only +// works if the names do not collide. Note that we cannot +// load a library multiple times if it has a device global. +#define CONCAT_HELPER(a, b) a##b +#define CONCAT(a, b) CONCAT_HELPER(a, b) + +using SomeProperties = decltype(sycl::ext::oneapi::experimental::properties{}); +sycl::ext::oneapi::experimental::device_global + CONCAT(DGVar, CLASSNAME) __attribute__((visibility("default"))); + +#endif // WITH_DEVICE_GLOBALS + +extern "C" API_EXPORT void performIncrementation(sycl::queue &q, + sycl::buffer &buf) { + sycl::range<1> r = buf.get_range(); + q.submit([&](sycl::handler &cgh) { + auto acc = buf.get_access(cgh); + cgh.parallel_for( + r, [=](sycl::id<1> idx) { acc[idx] += INC; }); + }); +} diff --git a/sycl/test-e2e/IntermediateLib/multi_lib_app.cpp b/sycl/test-e2e/IntermediateLib/multi_lib_app.cpp new file mode 100644 index 0000000000000..abcdca1926859 --- /dev/null +++ b/sycl/test-e2e/IntermediateLib/multi_lib_app.cpp @@ -0,0 +1,162 @@ +// UNSUPPORTED: cuda || hip +// UNSUPPORTED-TRACKER: CMPLRLLVM-69415 + +// DEFINE: %{fPIC_flag} = %if windows %{%} %else %{-fPIC%} +// DEFINE: %{shared_lib_ext} = %if windows %{dll%} %else %{so%} + +// clang-format off +// IMPORTANT -DSO_PATH='R"(%T)"' +// We need to capture %T, the build directory, in a string +// and the normal STRINGIFY() macros hack won't work. +// Because on Windows, the path delimiters are \, +// which C++ preprocessor converts to escape sequences, +// which becomes a nightmare. +// So the hack here is to put heredoc in the definition +// and use single quotes, which Python forgivingly accepts. +// clang-format on + +// RUN: %{build} %{fPIC_flag} -DSO_PATH='R"(%T)"' -o %t.out + +// RUN: %clangxx -fsycl %{fPIC_flag} -shared -DINC=1 -o %T/lib_a.%{shared_lib_ext} %S/Inputs/incrementing_lib.cpp +// RUN: %clangxx -fsycl %{fPIC_flag} -shared -DINC=2 -o %T/lib_b.%{shared_lib_ext} %S/Inputs/incrementing_lib.cpp +// RUN: %clangxx -fsycl %{fPIC_flag} -shared -DINC=4 -o %T/lib_c.%{shared_lib_ext} %S/Inputs/incrementing_lib.cpp + +// RUN: env UR_L0_LEAKS_DEBUG=1 %{run} %t.out + +// This test uses a kernel of the same name in three different shared libraries. +// It loads each library, calls the kernel, and checks that the incrementation +// is done correctly, and then unloads the library. +// It also reloads the first library after unloading it. +// This test ensures that __sycl_register_lib() and __sycl_unregister_lib() +// are called correctly, and that the device images are cleaned up properly. + + +#include + +using namespace sycl::ext::oneapi::experimental; + + +#ifdef _WIN32 +#include + +void *loadOsLibrary(const std::string &LibraryPath) { + HMODULE h = + LoadLibraryExA(LibraryPath.c_str(), NULL, LOAD_WITH_ALTERED_SEARCH_PATH); + if (!h) { + std::cout << "LoadLibraryExA(" << LibraryPath + << ") failed with error code " << GetLastError() << std::endl; + } + return (void *)h; +} +int unloadOsLibrary(void *Library) { + return FreeLibrary((HMODULE)Library) ? 0 : 1; +} +void *getOsLibraryFuncAddress(void *Library, const std::string &FunctionName) { + return (void *)GetProcAddress((HMODULE)Library, FunctionName.c_str()); +} + +#else +#include + +void *loadOsLibrary(const std::string &LibraryPath) { + void *so = dlopen(LibraryPath.c_str(), RTLD_NOW); + if (!so) { + char *Error = dlerror(); + std::cerr << "dlopen(" << LibraryPath << ") failed with <" + << (Error ? Error : "unknown error") << ">" << std::endl; + } + return so; +} + +int unloadOsLibrary(void *Library) { return dlclose(Library); } + +void *getOsLibraryFuncAddress(void *Library, const std::string &FunctionName) { + return dlsym(Library, FunctionName.c_str()); +} +#endif + +// Define the function pointer type for performIncrementation +using IncFuncT = void(sycl::queue &, sycl::buffer &); + +void initializeBuffer(sycl::buffer &buf) { + auto acc = sycl::host_accessor(buf); + for (size_t i = 0; i < buf.size(); ++i) + acc[i] = 0; +} + +void checkIncrementation(sycl::buffer &buf, int val) { + auto acc = sycl::host_accessor(buf); + for (size_t i = 0; i < buf.size(); ++i) { + std::cout << acc[i] << " "; + assert(acc[i] == val); + } + std::cout << std::endl; +} + +int main() { + sycl::queue q; + + sycl::range<1> r(8); + sycl::buffer buf(r); + initializeBuffer(buf); + + std::string base_path = SO_PATH; + +#ifdef _WIN32 + std::string path_to_lib_a = base_path + "\\lib_a.dll"; + std::string path_to_lib_b = base_path + "\\lib_b.dll"; + std::string path_to_lib_c = base_path + "\\lib_c.dll"; +#else + std::string path_to_lib_a = base_path + "/lib_a.so"; + std::string path_to_lib_b = base_path + "/lib_b.so"; + std::string path_to_lib_c = base_path + "/lib_c.so"; +#endif + + std::cout << "paths: " << path_to_lib_a << std::endl; + std::cout << "SO_PATH: " << SO_PATH << std::endl; + + void *lib_a = loadOsLibrary(path_to_lib_a); + void *f = getOsLibraryFuncAddress(lib_a, "performIncrementation"); + if(!f){ + std::cout << "Cannot get performIncremenation function from .so/.dll" << std::endl; + return 1; + } + auto performIncrementationFuncA = reinterpret_cast(f); + performIncrementationFuncA(q, buf); // call the function from lib_a + q.wait(); + checkIncrementation(buf, 1); + unloadOsLibrary(lib_a); + std::cout << "lib_a done" << std::endl; + + + // Now RELOAD lib_a and try it again. + lib_a = loadOsLibrary(path_to_lib_a); + f = getOsLibraryFuncAddress(lib_a, "performIncrementation"); + performIncrementationFuncA = reinterpret_cast(f); + performIncrementationFuncA(q, buf); // call the function from lib_a + q.wait(); + checkIncrementation(buf, 1 + 1); + unloadOsLibrary(lib_a); + std::cout << "reload of lib_a done" << std::endl; + + + void *lib_b = loadOsLibrary(path_to_lib_b); + f = getOsLibraryFuncAddress(lib_b, "performIncrementation"); + auto performIncrementationFuncB = reinterpret_cast(f); + performIncrementationFuncB(q, buf); // call the function from lib_b + q.wait(); + checkIncrementation(buf, 1 + 1 + 2); + unloadOsLibrary(lib_b); + std::cout << "lib_b done" << std::endl; + + void *lib_c = loadOsLibrary(path_to_lib_c); + f = getOsLibraryFuncAddress(lib_c, "performIncrementation"); + auto performIncrementationFuncC = reinterpret_cast(f); + q.wait(); + performIncrementationFuncC(q, buf); // call the function from lib_c + checkIncrementation(buf, 1 + 1 + 2 + 4); + unloadOsLibrary(lib_c); + std::cout << "lib_c done" << std::endl; + + return 0; +}