diff --git a/mssql_python/cursor.py b/mssql_python/cursor.py index f6aca7a3..a6f5bb64 100644 --- a/mssql_python/cursor.py +++ b/mssql_python/cursor.py @@ -332,10 +332,11 @@ def _map_sql_type(self, param, parameters_list, i): # TODO: revisit if len(param) > 4000: # Long strings if is_unicode: + utf16_len = len(param.encode("utf-16-le")) // 2 return ( ddbc_sql_const.SQL_WLONGVARCHAR.value, ddbc_sql_const.SQL_C_WCHAR.value, - len(param), + utf16_len, 0, ) return ( @@ -345,10 +346,11 @@ def _map_sql_type(self, param, parameters_list, i): 0, ) if is_unicode: # Short Unicode strings + utf16_len = len(param.encode("utf-16-le")) // 2 return ( ddbc_sql_const.SQL_WVARCHAR.value, ddbc_sql_const.SQL_C_WCHAR.value, - len(param), + utf16_len, 0, ) return ( diff --git a/mssql_python/pybind/ddbc_bindings.cpp b/mssql_python/pybind/ddbc_bindings.cpp index a1136ab8..d0a20dbd 100644 --- a/mssql_python/pybind/ddbc_bindings.cpp +++ b/mssql_python/pybind/ddbc_bindings.cpp @@ -275,15 +275,19 @@ SQLRETURN BindParameters(SQLHANDLE hStmt, const py::list& params, AllocateParamBuffer>(paramBuffers); // Reserve space and convert from wstring to SQLWCHAR array - sqlwcharBuffer->resize(strParam->size() + 1, 0); // +1 for null terminator - - // Convert each wchar_t (4 bytes on macOS) to SQLWCHAR (2 bytes) - for (size_t i = 0; i < strParam->size(); i++) { - (*sqlwcharBuffer)[i] = static_cast((*strParam)[i]); + std::vector utf16 = WStringToSQLWCHAR(*strParam); + if (utf16.size() < strParam->size()) { + LOG("Warning: UTF-16 encoding shrank string? input={} output={}", + strParam->size(), utf16.size()); + } + if (utf16.size() > strParam->size() * 2 + 1) { + LOG("Warning: UTF-16 expansion unusually large: input={} output={}", + strParam->size(), utf16.size()); } + *sqlwcharBuffer = std::move(utf16); // Use the SQLWCHAR buffer instead of the wstring directly dataPtr = sqlwcharBuffer->data(); - bufferLength = (strParam->size() + 1) * sizeof(SQLWCHAR); + bufferLength = sqlwcharBuffer->size() * sizeof(SQLWCHAR); LOG("macOS: Created SQLWCHAR buffer for parameter with size: {} bytes", bufferLength); #else // On Windows, wchar_t and SQLWCHAR are the same size, so direct cast works @@ -1705,7 +1709,16 @@ SQLRETURN SQLGetData_wrap(SqlHandlePtr StatementHandle, SQLUSMALLINT colCount, p if (numCharsInData < dataBuffer.size()) { // SQLGetData will null-terminate the data #if defined(__APPLE__) || defined(__linux__) - row.append(SQLWCHARToWString(dataBuffer.data(), SQL_NTS)); + auto raw_bytes = reinterpret_cast(dataBuffer.data()); + size_t actualBufferSize = dataBuffer.size() * sizeof(SQLWCHAR); + if (dataLen < 0 || static_cast(dataLen) > actualBufferSize) { + LOG("Error: py::bytes creation request exceeds buffer size. dataLen={} buffer={}", + dataLen, actualBufferSize); + ThrowStdException("Invalid buffer length for py::bytes"); + } + py::bytes py_bytes(raw_bytes, dataLen); + py::str decoded = py_bytes.attr("decode")("utf-16-le"); + row.append(decoded); #else row.append(std::wstring(dataBuffer.data())); #endif diff --git a/mssql_python/pybind/ddbc_bindings.h b/mssql_python/pybind/ddbc_bindings.h index 1bb3efb0..f28f610c 100644 --- a/mssql_python/pybind/ddbc_bindings.h +++ b/mssql_python/pybind/ddbc_bindings.h @@ -33,33 +33,107 @@ using namespace pybind11::literals; #include #if defined(__APPLE__) || defined(__linux__) - // macOS-specific headers - #include +#include + +// Unicode constants for surrogate ranges and max scalar value +constexpr uint32_t UNICODE_SURROGATE_HIGH_START = 0xD800; +constexpr uint32_t UNICODE_SURROGATE_HIGH_END = 0xDBFF; +constexpr uint32_t UNICODE_SURROGATE_LOW_START = 0xDC00; +constexpr uint32_t UNICODE_SURROGATE_LOW_END = 0xDFFF; +constexpr uint32_t UNICODE_MAX_CODEPOINT = 0x10FFFF; +constexpr uint32_t UNICODE_REPLACEMENT_CHAR = 0xFFFD; + +// Validate whether a code point is a legal Unicode scalar value +// (excludes surrogate halves and values beyond U+10FFFF) +inline bool IsValidUnicodeScalar(uint32_t cp) { + return cp <= UNICODE_MAX_CODEPOINT && + !(cp >= UNICODE_SURROGATE_HIGH_START && cp <= UNICODE_SURROGATE_LOW_END); +} - inline std::wstring SQLWCHARToWString(const SQLWCHAR* sqlwStr, size_t length = SQL_NTS) { - if (!sqlwStr) return std::wstring(); +inline std::wstring SQLWCHARToWString(const SQLWCHAR* sqlwStr, size_t length = SQL_NTS) { + if (!sqlwStr) return std::wstring(); - if (length == SQL_NTS) { - size_t i = 0; - while (sqlwStr[i] != 0) ++i; - length = i; - } + if (length == SQL_NTS) { + size_t i = 0; + while (sqlwStr[i] != 0) ++i; + length = i; + } + std::wstring result; + result.reserve(length); - std::wstring result; - result.reserve(length); + if constexpr (sizeof(SQLWCHAR) == 2) { + // Decode UTF-16 to UTF-32 (with surrogate pair handling) + for (size_t i = 0; i < length; ++i) { + uint16_t wc = static_cast(sqlwStr[i]); + // Check if this is a high surrogate (U+D800โ€“U+DBFF) + if (wc >= UNICODE_SURROGATE_HIGH_START && wc <= UNICODE_SURROGATE_HIGH_END && i + 1 < length) { + uint16_t low = static_cast(sqlwStr[i + 1]); + // Check if the next code unit is a low surrogate (U+DC00โ€“U+DFFF) + if (low >= UNICODE_SURROGATE_LOW_START && low <= UNICODE_SURROGATE_LOW_END) { + // Combine surrogate pair into a single code point + uint32_t cp = (((wc - UNICODE_SURROGATE_HIGH_START) << 10) | (low - UNICODE_SURROGATE_LOW_START)) + 0x10000; + result.push_back(static_cast(cp)); + ++i; // Skip the low surrogate + continue; + } + } + // If valid scalar then append, else append replacement char (U+FFFD) + if (IsValidUnicodeScalar(wc)) { + result.push_back(static_cast(wc)); + } else { + result.push_back(static_cast(UNICODE_REPLACEMENT_CHAR)); + } + } + } else { + // SQLWCHAR is UTF-32, so just copy with validation for (size_t i = 0; i < length; ++i) { - result.push_back(static_cast(sqlwStr[i])); + uint32_t cp = static_cast(sqlwStr[i]); + if (IsValidUnicodeScalar(cp)) { + result.push_back(static_cast(cp)); + } else { + result.push_back(static_cast(UNICODE_REPLACEMENT_CHAR)); + } } - return result; } + return result; +} - inline std::vector WStringToSQLWCHAR(const std::wstring& str) { - std::vector result(str.size() + 1, 0); // +1 for null terminator - for (size_t i = 0; i < str.size(); ++i) { - result[i] = static_cast(str[i]); +inline std::vector WStringToSQLWCHAR(const std::wstring& str) { + std::vector result; + result.reserve(str.size() + 2); + if constexpr (sizeof(SQLWCHAR) == 2) { + // Encode UTF-32 to UTF-16 + for (wchar_t wc : str) { + uint32_t cp = static_cast(wc); + if (!IsValidUnicodeScalar(cp)) { + cp = UNICODE_REPLACEMENT_CHAR; + } + if (cp <= 0xFFFF) { + // Fits in a single UTF-16 code unit + result.push_back(static_cast(cp)); + } else { + // Encode as surrogate pair + cp -= 0x10000; + SQLWCHAR high = static_cast((cp >> 10) + UNICODE_SURROGATE_HIGH_START); + SQLWCHAR low = static_cast((cp & 0x3FF) + UNICODE_SURROGATE_LOW_START); + result.push_back(high); + result.push_back(low); + } + } + } else { + // Encode UTF-32 directly + for (wchar_t wc : str) { + uint32_t cp = static_cast(wc); + if (IsValidUnicodeScalar(cp)) { + result.push_back(static_cast(cp)); + } else { + result.push_back(static_cast(UNICODE_REPLACEMENT_CHAR)); + } } - return result; } + result.push_back(0); // null terminator + return result; +} #endif #if defined(__APPLE__) || defined(__linux__) diff --git a/tests/test_004_cursor.py b/tests/test_004_cursor.py index a4a1c8f4..22149ea5 100644 --- a/tests/test_004_cursor.py +++ b/tests/test_004_cursor.py @@ -5086,6 +5086,45 @@ def test_tables_cleanup(cursor, db_connection): except Exception as e: pytest.fail(f"Test cleanup failed: {e}") +def test_emoji_round_trip(cursor, db_connection): + """Test round-trip of emoji and special characters""" + test_inputs = [ + "Hello ๐Ÿ˜„", + "Flags ๐Ÿ‡ฎ๐Ÿ‡ณ๐Ÿ‡บ๐Ÿ‡ธ", + "Family ๐Ÿ‘จโ€๐Ÿ‘ฉโ€๐Ÿ‘งโ€๐Ÿ‘ฆ", + "Skin tone ๐Ÿ‘๐Ÿฝ", + "Brain ๐Ÿง ", + "Ice ๐ŸงŠ", + "Melting face ๐Ÿซ ", + "Accented รฉรผรฑรง", + "Chinese: ไธญๆ–‡", + "Japanese: ๆ—ฅๆœฌ่ชž", + "Hello ๐Ÿš€ World", + "admin๐Ÿ”’user", + "1๐Ÿš€' OR '1'='1", + ] + + cursor.execute(""" + CREATE TABLE #pytest_emoji_test ( + id INT IDENTITY PRIMARY KEY, + content NVARCHAR(MAX) + ); + """) + db_connection.commit() + + for text in test_inputs: + try: + cursor.execute("INSERT INTO #pytest_emoji_test (content) OUTPUT INSERTED.id VALUES (?)", [text]) + inserted_id = cursor.fetchone()[0] + cursor.execute("SELECT content FROM #pytest_emoji_test WHERE id = ?", [inserted_id]) + result = cursor.fetchone() + assert result is not None, f"No row returned for ID {inserted_id}" + assert result[0] == text, f"Mismatch! Sent: {text}, Got: {result[0]}" + + except Exception as e: + pytest.fail(f"Error for input {repr(text)}: {e}") + + def test_close(db_connection): """Test closing the cursor""" try: @@ -5095,4 +5134,4 @@ def test_close(db_connection): except Exception as e: pytest.fail(f"Cursor close test failed: {e}") finally: - cursor = db_connection.cursor() + cursor = db_connection.cursor() \ No newline at end of file