diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..d163863 --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +build/ \ No newline at end of file diff --git a/CMakeLists.txt b/CMakeLists.txt new file mode 100644 index 0000000..6eb545a --- /dev/null +++ b/CMakeLists.txt @@ -0,0 +1,5 @@ +set(CMAKE_CXX_STANDARD 20) + +add_library(computecore SHARED + src/computecore.cpp +) diff --git a/README.md b/README.md new file mode 100644 index 0000000..1c4e5e2 --- /dev/null +++ b/README.md @@ -0,0 +1,19 @@ +## Support more than 64 cores + +If you want to use more than 64 cores in WSL2 (e.g. if you have AMD Threadripper 7980x), you can follow the steps: + +1. Generate project + + mkdir build + cd build + cmake .. + +1. Open `build/Project.sln` and generate a x64 release build. + +1. Launch task manager, kill `wslservice.exe` process. + +1. Copy the generated `build/Release/computecore.dll` to `C:\Program Files\WSL\`. + +1. Launch any WSL2 distribution, type `nproc` or `lscpu`, you should see all cores. + + ![use more than 64 cores in WSL2](img/use.more.than.64.cores.png) \ No newline at end of file diff --git a/img/use.more.than.64.cores.png b/img/use.more.than.64.cores.png new file mode 100644 index 0000000..cf70c88 Binary files /dev/null and b/img/use.more.than.64.cores.png differ diff --git a/src/computecore.cpp b/src/computecore.cpp new file mode 100644 index 0000000..6f42e71 --- /dev/null +++ b/src/computecore.cpp @@ -0,0 +1,311 @@ +#include +#include +#include +#include + +#include +#include + +static HMODULE g_hComputenetwork; + +// Hook up all apis used by wslservice.exe in computecore.dll. +#define HookApis() \ + HookApi(HcsCloseComputeSystem); \ + HookApi(HcsCloseOperation); \ + HookApi(HcsCreateComputeSystem); \ + HookApi(HcsCreateOperation); \ + HookApi(HcsGetComputeSystemProperties); \ + HookApi(HcsGetServiceProperties); \ + HookApi(HcsGrantVmAccess); \ + HookApi(HcsModifyComputeSystem); \ + HookApi(HcsOpenComputeSystem); \ + HookApi(HcsRevokeVmAccess); \ + HookApi(HcsSetComputeSystemCallback); \ + HookApi(HcsStartComputeSystem); \ + HookApi(HcsTerminateComputeSystem); \ + HookApi(HcsWaitForOperationResult); \ + +#define HookApi(api) \ + decltype(&api) g_##api; + + HookApis() + +#undef HookApi + +// Helper function to replace substring. +static void Replace(std::wstring& str, const std::wstring& from, const std::wstring& to) { + auto start_pos = str.find(from); + if (start_pos != std::string::npos) { + str.replace(start_pos, from.length(), to); + } +} + +// Helper function to get all logic processors cross all cpu groups. +static uint32_t GetLogicCores() +{ + // GetLogicalProcessorInformationEx should return false, and GetLastError() should return ERROR_INSUFFICIENT_BUFFER + DWORD length = 0; + if (GetLogicalProcessorInformationEx(RelationProcessorCore, nullptr, &length) || + GetLastError() != ERROR_INSUFFICIENT_BUFFER) { + return 0; + } + + // Get the processor information. + auto pBuffer = std::unique_ptr((uint8_t*)malloc(length), free); + if (!GetLogicalProcessorInformationEx(RelationProcessorCore, (SYSTEM_LOGICAL_PROCESSOR_INFORMATION_EX*)pBuffer.get(), &length)) { + return 0; + } + + // Enumrate all processors. + auto totalLogicProcessors = 0u; + auto pCurrent = (SYSTEM_LOGICAL_PROCESSOR_INFORMATION_EX*)nullptr; + for (auto offset = 0u; offset < length; offset += pCurrent->Size) { + pCurrent = (SYSTEM_LOGICAL_PROCESSOR_INFORMATION_EX*)(pBuffer.get() + offset); + for (auto i = 0u; i < pCurrent->Processor.GroupCount; ++i) { + totalLogicProcessors += (uint32_t)std::bitsetProcessor.GroupMask[i].Mask) * 8>{pCurrent->Processor.GroupMask[i].Mask}.count(); + } + } + return totalLogicProcessors; +} + +extern "C" { + +BOOL APIENTRY DllMain( HMODULE hModule, + DWORD ul_reason_for_call, + LPVOID lpReserved + ) +{ + switch (ul_reason_for_call) + { + case DLL_PROCESS_ATTACH: + { + // Get system folder, should be c:\\windows\\system32. + auto len = GetSystemDirectoryA(nullptr, 0); + if (!len) { + return false; + } + + std::vector path; + path.resize(len); + if (!GetSystemDirectoryA(path.data(), (UINT)path.size())) { + return false; + } + + // Get computecore.dll path, assume it should be under c:\\windows\\system32\\computecore.dll. + auto systemPath = std::filesystem::path{ path.data() } / "computecore.dll"; + if (!(g_hComputenetwork = LoadLibraryA(systemPath.string().c_str()))) { + return false; + } + +#define HookApi(api) \ + if (!(g_##api = (decltype(&api))GetProcAddress(g_hComputenetwork, #api))) { \ + return false; \ + } + + HookApis() + +#undef HookApi + } + break; + + + break; + + case DLL_THREAD_ATTACH: + break; + + case DLL_THREAD_DETACH: + break; + + case DLL_PROCESS_DETACH: + if (g_hComputenetwork) { + FreeLibrary(g_hComputenetwork); + } + break; + } + return TRUE; +} + +#pragma comment(linker, "/export:HcsCloseComputeSystem") +void +WINAPI +HcsCloseComputeSystem( + _In_ _Post_invalid_ HCS_SYSTEM computeSystem +) +{ + return g_HcsCloseComputeSystem(computeSystem); +} + +#pragma comment(linker, "/export:HcsCloseOperation") +void +WINAPI +HcsCloseOperation( + _In_ HCS_OPERATION operation +) +{ + return g_HcsCloseOperation(operation); +} + +#pragma comment(linker, "/export:HcsCreateComputeSystem") +HRESULT +WINAPI +HcsCreateComputeSystem( + _In_ PCWSTR id, + _In_ PCWSTR configuration, + _In_ HCS_OPERATION operation, + _In_opt_ const SECURITY_DESCRIPTOR* securityDescriptor, + _Out_ HCS_SYSTEM* computeSystem +) +{ + // GetSystemInfo() only returns the number of cores in the current cpu group. + // In windows, the max number of cores in a single cpu group is 64. So the number + // of dwNumberOfProcessors in SYSTEM_INFO won't bigger than 64. For this case, we + // invoke GetLogicCores() to get the total cpu cores crossing all cpu groups. + SYSTEM_INFO sysInfo{}; + GetSystemInfo(&sysInfo); + + auto logicCores = GetLogicCores(); + if (logicCores > 64 && sysInfo.dwNumberOfProcessors < logicCores) + { + auto config = std::wstring{ configuration }; + Replace(config, std::format(L"nr_cpus={}", sysInfo.dwNumberOfProcessors), std::format(L"nr_cpus={}", logicCores)); + Replace(config, std::format(L"\"Count\":{}", sysInfo.dwNumberOfProcessors), std::format(L"\"Count\":{}", logicCores)); + return g_HcsCreateComputeSystem(id, config.c_str(), operation, securityDescriptor, computeSystem); + } + else + { + return g_HcsCreateComputeSystem(id, configuration, operation, securityDescriptor, computeSystem); + } +} + +#pragma comment(linker, "/export:HcsCreateOperation") +HCS_OPERATION +WINAPI +HcsCreateOperation( + _In_opt_ const void* context, + _In_opt_ HCS_OPERATION_COMPLETION callback +) +{ + return g_HcsCreateOperation(context, callback); +} + +#pragma comment(linker, "/export:HcsGetComputeSystemProperties") +HRESULT +WINAPI +HcsGetComputeSystemProperties( + _In_ HCS_SYSTEM computeSystem, + _In_ HCS_OPERATION operation, + _In_opt_ PCWSTR propertyQuery +) +{ + return g_HcsGetComputeSystemProperties(computeSystem, operation, propertyQuery); +} + +#pragma comment(linker, "/export:HcsGetServiceProperties") +HRESULT +WINAPI +HcsGetServiceProperties( + _In_opt_ PCWSTR propertyQuery, + _Outptr_ PWSTR* result +) +{ + return g_HcsGetServiceProperties(propertyQuery, result); +} + +#pragma comment(linker, "/export:HcsGrantVmAccess") +HRESULT +WINAPI +HcsGrantVmAccess( + _In_ PCWSTR vmId, + _In_ PCWSTR filePath +) +{ + return g_HcsGrantVmAccess(vmId, filePath); +} + +#pragma comment(linker, "/export:HcsModifyComputeSystem") +HRESULT +WINAPI +HcsModifyComputeSystem( + _In_ HCS_SYSTEM computeSystem, + _In_ HCS_OPERATION operation, + _In_ PCWSTR configuration, + _In_opt_ HANDLE identity +) +{ + return g_HcsModifyComputeSystem(computeSystem, operation, configuration, identity); +} + +#pragma comment(linker, "/export:HcsOpenComputeSystem") +HRESULT +WINAPI +HcsOpenComputeSystem( + _In_ PCWSTR id, + _In_ DWORD requestedAccess, + _Out_ HCS_SYSTEM* computeSystem +) +{ + return g_HcsOpenComputeSystem(id, requestedAccess, computeSystem); +} + +#pragma comment(linker, "/export:HcsRevokeVmAccess") +HRESULT +WINAPI +HcsRevokeVmAccess( + _In_ PCWSTR vmId, + _In_ PCWSTR filePath +) +{ + return g_HcsRevokeVmAccess(vmId, filePath); +} + +#pragma comment(linker, "/export:HcsSetComputeSystemCallback") +HRESULT +WINAPI +HcsSetComputeSystemCallback( + _In_ HCS_SYSTEM computeSystem, + _In_ HCS_EVENT_OPTIONS callbackOptions, + _In_opt_ const void* context, + _In_ HCS_EVENT_CALLBACK callback +) +{ + return g_HcsSetComputeSystemCallback(computeSystem, callbackOptions, context, callback); +} + +#pragma comment(linker, "/export:HcsStartComputeSystem") +HRESULT +WINAPI +HcsStartComputeSystem( + _In_ HCS_SYSTEM computeSystem, + _In_ HCS_OPERATION operation, + _In_opt_ PCWSTR options +) +{ + return g_HcsStartComputeSystem(computeSystem, operation, options); +} + +#pragma comment(linker, "/export:HcsTerminateComputeSystem") +HRESULT +WINAPI +HcsTerminateComputeSystem( + _In_ HCS_SYSTEM computeSystem, + _In_ HCS_OPERATION operation, + _In_opt_ PCWSTR options +) +{ + return g_HcsTerminateComputeSystem(computeSystem, operation, options); +} + +#pragma comment(linker, "/export:HcsWaitForOperationResult") +HRESULT +WINAPI +HcsWaitForOperationResult( + _In_ HCS_OPERATION operation, + _In_ DWORD timeoutMs, + _Outptr_opt_ PWSTR* resultDocument +) +{ + return g_HcsWaitForOperationResult(operation, timeoutMs, resultDocument); +} + +}