diff --git a/lib/DxcSupport/Unicode.cpp b/lib/DxcSupport/Unicode.cpp index 1481ae27ff..1392219085 100644 --- a/lib/DxcSupport/Unicode.cpp +++ b/lib/DxcSupport/Unicode.cpp @@ -54,7 +54,8 @@ int MultiByteToWideChar(uint32_t /*CodePage*/, uint32_t /*dwFlags*/, size_t rv; const char *prevLocale = setlocale(LC_ALL, nullptr); setlocale(LC_ALL, "en_US.UTF-8"); - if (lpMultiByteStr[cbMultiByte - 1] != '\0') { + const bool bIsNullTerminated = lpMultiByteStr[cbMultiByte - 1] == '\0'; + if (!bIsNullTerminated) { char *srcStr = (char *)malloc((cbMultiByte + 1) * sizeof(char)); strncpy(srcStr, lpMultiByteStr, cbMultiByte); srcStr[cbMultiByte] = '\0'; @@ -67,9 +68,9 @@ int MultiByteToWideChar(uint32_t /*CodePage*/, uint32_t /*dwFlags*/, if (prevLocale) setlocale(LC_ALL, prevLocale); - if (rv == (size_t)cbMultiByte) - return rv; - return rv + 1; // mbstowcs excludes the terminating character + if (bIsNullTerminated) + return rv + 1; // mbstowcs excludes the terminating character + return rv; } // WideCharToMultiByte is a Windows-specific method. @@ -110,7 +111,8 @@ int WideCharToMultiByte(uint32_t /*CodePage*/, uint32_t /*dwFlags*/, size_t rv; const char *prevLocale = setlocale(LC_ALL, nullptr); setlocale(LC_ALL, "en_US.UTF-8"); - if (lpWideCharStr[cchWideChar - 1] != L'\0') { + const bool bIsNullTerminated = lpWideCharStr[cchWideChar - 1] == L'\0'; + if (!bIsNullTerminated) { wchar_t *srcStr = (wchar_t *)malloc((cchWideChar + 1) * sizeof(wchar_t)); wcsncpy(srcStr, lpWideCharStr, cchWideChar); srcStr[cchWideChar] = L'\0'; @@ -123,9 +125,9 @@ int WideCharToMultiByte(uint32_t /*CodePage*/, uint32_t /*dwFlags*/, if (prevLocale) setlocale(LC_ALL, prevLocale); - if (rv == (size_t)cchWideChar) - return rv; - return rv + 1; // mbstowcs excludes the terminating character + if (bIsNullTerminated) + return rv + 1; // mbstowcs excludes the terminating character + return rv; } #endif // _WIN32