diff --git a/mssql_python/cursor.py b/mssql_python/cursor.py index 88152aa2..e2c811c9 100644 --- a/mssql_python/cursor.py +++ b/mssql_python/cursor.py @@ -342,12 +342,15 @@ def _map_sql_type(self, param, parameters_list, i): # String mapping logic here is_unicode = self._is_unicode_string(param) - if len(param) > MAX_INLINE_CHAR: # Long strings + + # Computes UTF-16 code units (handles surrogate pairs) + utf16_len = sum(2 if ord(c) > 0xFFFF else 1 for c in param) + if utf16_len > MAX_INLINE_CHAR: # Long strings -> DAE if is_unicode: return ( ddbc_sql_const.SQL_WLONGVARCHAR.value, ddbc_sql_const.SQL_C_WCHAR.value, - len(param), + utf16_len, 0, True, ) @@ -358,8 +361,9 @@ def _map_sql_type(self, param, parameters_list, i): 0, True, ) - if is_unicode: # Short Unicode strings - utf16_len = len(param.encode("utf-16-le")) // 2 + + # Short strings + if is_unicode: return ( ddbc_sql_const.SQL_WVARCHAR.value, ddbc_sql_const.SQL_C_WCHAR.value, @@ -374,7 +378,7 @@ def _map_sql_type(self, param, parameters_list, i): 0, False, ) - + if isinstance(param, bytes): if len(param) > 8000: # Assuming VARBINARY(MAX) for long byte arrays return ( diff --git a/mssql_python/pybind/ddbc_bindings.cpp b/mssql_python/pybind/ddbc_bindings.cpp index 69df4d49..d457e9cc 100644 --- a/mssql_python/pybind/ddbc_bindings.cpp +++ b/mssql_python/pybind/ddbc_bindings.cpp @@ -227,7 +227,27 @@ SQLRETURN BindParameters(SQLHANDLE hStmt, const py::list& params, // TODO: Add more data types like money, guid, interval, TVPs etc. switch (paramInfo.paramCType) { - case SQL_C_CHAR: + case SQL_C_CHAR: { + if (!py::isinstance(param) && !py::isinstance(param) && + !py::isinstance(param)) { + ThrowStdException(MakeParamMismatchErrorStr(paramInfo.paramCType, paramIndex)); + } + if (paramInfo.isDAE) { + LOG("Parameter[{}] is marked for DAE streaming", paramIndex); + dataPtr = const_cast(reinterpret_cast(¶mInfos[paramIndex])); + strLenOrIndPtr = AllocateParamBuffer(paramBuffers); + *strLenOrIndPtr = SQL_LEN_DATA_AT_EXEC(0); + bufferLength = 0; + } else { + std::string* strParam = + AllocateParamBuffer(paramBuffers, param.cast()); + dataPtr = const_cast(static_cast(strParam->c_str())); + bufferLength = strParam->size() + 1; + strLenOrIndPtr = AllocateParamBuffer(paramBuffers); + *strLenOrIndPtr = SQL_NTS; + } + break; + } case SQL_C_BINARY: { if (!py::isinstance(param) && !py::isinstance(param) && !py::isinstance(param)) { @@ -1203,23 +1223,51 @@ SQLRETURN SQLExecute_wrap(const SqlHandlePtr statementHandle, continue; } if (py::isinstance(pyObj)) { - std::wstring wstr = pyObj.cast(); + if (matchedInfo->paramCType == SQL_C_WCHAR) { + std::wstring wstr = pyObj.cast(); + const SQLWCHAR* dataPtr = nullptr; + size_t totalChars = 0; #if defined(__APPLE__) || defined(__linux__) - auto utf16Buf = WStringToSQLWCHAR(wstr); - const char* dataPtr = reinterpret_cast(utf16Buf.data()); - size_t totalBytes = (utf16Buf.size() - 1) * sizeof(SQLWCHAR); + std::vector sqlwStr = WStringToSQLWCHAR(wstr); + totalChars = sqlwStr.size() - 1; + dataPtr = sqlwStr.data(); #else - const char* dataPtr = reinterpret_cast(wstr.data()); - size_t totalBytes = wstr.size() * sizeof(wchar_t); + dataPtr = wstr.c_str(); + totalChars = wstr.size(); #endif - const size_t chunkSize = DAE_CHUNK_SIZE; - for (size_t offset = 0; offset < totalBytes; offset += chunkSize) { - size_t len = std::min(chunkSize, totalBytes - offset); - rc = SQLPutData_ptr(hStmt, (SQLPOINTER)(dataPtr + offset), static_cast(len)); - if (!SQL_SUCCEEDED(rc)) { - LOG("SQLPutData failed at offset {} of {}", offset, totalBytes); - return rc; + size_t offset = 0; + size_t chunkChars = DAE_CHUNK_SIZE / sizeof(SQLWCHAR); + while (offset < totalChars) { + size_t len = std::min(chunkChars, totalChars - offset); + size_t lenBytes = len * sizeof(SQLWCHAR); + if (lenBytes > static_cast(std::numeric_limits::max())) { + ThrowStdException("Chunk size exceeds maximum allowed by SQLLEN"); + } + rc = SQLPutData_ptr(hStmt, (SQLPOINTER)(dataPtr + offset), static_cast(lenBytes)); + if (!SQL_SUCCEEDED(rc)) { + LOG("SQLPutData failed at offset {} of {}", offset, totalChars); + return rc; + } + offset += len; } + } else if (matchedInfo->paramCType == SQL_C_CHAR) { + std::string s = pyObj.cast(); + size_t totalBytes = s.size(); + const char* dataPtr = s.data(); + size_t offset = 0; + size_t chunkBytes = DAE_CHUNK_SIZE; + while (offset < totalBytes) { + size_t len = std::min(chunkBytes, totalBytes - offset); + + rc = SQLPutData_ptr(hStmt, (SQLPOINTER)(dataPtr + offset), static_cast(len)); + if (!SQL_SUCCEEDED(rc)) { + LOG("SQLPutData failed at offset {} of {}", offset, totalBytes); + return rc; + } + offset += len; + } + } else { + ThrowStdException("Unsupported C type for str in DAE"); } } else { ThrowStdException("DAE only supported for str or bytes"); diff --git a/tests/test_004_cursor.py b/tests/test_004_cursor.py index 22149ea5..9caa9114 100644 --- a/tests/test_004_cursor.py +++ b/tests/test_004_cursor.py @@ -13,7 +13,7 @@ from datetime import datetime, date, time import decimal from contextlib import closing -from mssql_python import Connection +from mssql_python import Connection, row # Setup test table TEST_TABLE = """ @@ -5124,6 +5124,186 @@ def test_emoji_round_trip(cursor, db_connection): except Exception as e: pytest.fail(f"Error for input {repr(text)}: {e}") +def test_varchar_max_insert_non_lob(cursor, db_connection): + """Test small VARCHAR(MAX) insert (non-LOB path).""" + try: + cursor.execute("CREATE TABLE #pytest_varchar_nonlob (col VARCHAR(MAX))") + db_connection.commit() + + small_str = "Hello, world!" # small, non-LOB + cursor.execute( + "INSERT INTO #pytest_varchar_nonlob (col) VALUES (?)", + [small_str] + ) + db_connection.commit() + + empty_str = "" + cursor.execute( + "INSERT INTO #pytest_varchar_nonlob (col) VALUES (?)", + [empty_str] + ) + db_connection.commit() + + # None value + cursor.execute( + "INSERT INTO #pytest_varchar_nonlob (col) VALUES (?)", + [None] + ) + db_connection.commit() + + # Fetch commented for now + # cursor.execute("SELECT col FROM #pytest_varchar_nonlob") + # rows = cursor.fetchall() + # assert rows == [[small_str], [empty_str], [None]] + + finally: + pass + + +def test_varchar_max_insert_lob(cursor, db_connection): + """Test large VARCHAR(MAX) insert (LOB path).""" + try: + cursor.execute("CREATE TABLE #pytest_varchar_lob (col VARCHAR(MAX))") + db_connection.commit() + + large_str = "A" * 100_000 # > 8k to trigger LOB + cursor.execute( + "INSERT INTO #pytest_varchar_lob (col) VALUES (?)", + [large_str] + ) + db_connection.commit() + + # Fetch commented for now + # cursor.execute("SELECT col FROM #pytest_varchar_lob") + # rows = cursor.fetchall() + # assert rows == [[large_str]] + + finally: + pass + + +def test_nvarchar_max_insert_non_lob(cursor, db_connection): + """Test small NVARCHAR(MAX) insert (non-LOB path).""" + try: + cursor.execute("CREATE TABLE #pytest_nvarchar_nonlob (col NVARCHAR(MAX))") + db_connection.commit() + + small_str = "Unicode ✨ test" + cursor.execute( + "INSERT INTO #pytest_nvarchar_nonlob (col) VALUES (?)", + [small_str] + ) + db_connection.commit() + + empty_str = "" + cursor.execute( + "INSERT INTO #pytest_nvarchar_nonlob (col) VALUES (?)", + [empty_str] + ) + db_connection.commit() + + cursor.execute( + "INSERT INTO #pytest_nvarchar_nonlob (col) VALUES (?)", + [None] + ) + db_connection.commit() + + # Fetch commented for now + # cursor.execute("SELECT col FROM #pytest_nvarchar_nonlob") + # rows = cursor.fetchall() + # assert rows == [[small_str], [empty_str], [None]] + + finally: + pass + + +def test_nvarchar_max_insert_lob(cursor, db_connection): + """Test large NVARCHAR(MAX) insert (LOB path).""" + try: + cursor.execute("CREATE TABLE #pytest_nvarchar_lob (col NVARCHAR(MAX))") + db_connection.commit() + + large_str = "📝" * 50_000 # each emoji = 2 UTF-16 code units, total > 100k bytes + cursor.execute( + "INSERT INTO #pytest_nvarchar_lob (col) VALUES (?)", + [large_str] + ) + db_connection.commit() + + # Fetch commented for now + # cursor.execute("SELECT col FROM #pytest_nvarchar_lob") + # rows = cursor.fetchall() + # assert rows == [[large_str]] + + finally: + pass + +def test_nvarchar_max_boundary(cursor, db_connection): + """Test NVARCHAR(MAX) at LOB boundary sizes.""" + try: + cursor.execute("DROP TABLE IF EXISTS #pytest_nvarchar_boundary") + cursor.execute("CREATE TABLE #pytest_nvarchar_boundary (col NVARCHAR(MAX))") + db_connection.commit() + + # 4k BMP chars = 8k bytes + cursor.execute("INSERT INTO #pytest_nvarchar_boundary (col) VALUES (?)", ["A" * 4096]) + # 4k emojis = 8k UTF-16 code units (16k bytes) + cursor.execute("INSERT INTO #pytest_nvarchar_boundary (col) VALUES (?)", ["📝" * 4096]) + db_connection.commit() + + # Fetch commented for now + # cursor.execute("SELECT col FROM #pytest_nvarchar_boundary") + # rows = cursor.fetchall() + # assert rows == [["A" * 4096], ["📝" * 4096]] + finally: + pass + + +def test_nvarchar_max_chunk_edge(cursor, db_connection): + """Test NVARCHAR(MAX) insert slightly larger than a chunk.""" + try: + cursor.execute("DROP TABLE IF EXISTS #pytest_nvarchar_chunk") + cursor.execute("CREATE TABLE #pytest_nvarchar_chunk (col NVARCHAR(MAX))") + db_connection.commit() + + chunk_size = 8192 # bytes + test_str = "📝" * ((chunk_size // 4) + 3) # slightly > 1 chunk + cursor.execute("INSERT INTO #pytest_nvarchar_chunk (col) VALUES (?)", [test_str]) + db_connection.commit() + + # Fetch commented for now + # cursor.execute("SELECT col FROM #pytest_nvarchar_chunk") + # row = cursor.fetchone() + # assert row[0] == test_str + finally: + pass + +def test_empty_string_chunk(cursor, db_connection): + """Test inserting empty strings into VARCHAR(MAX) and NVARCHAR(MAX).""" + try: + cursor.execute("DROP TABLE IF EXISTS #pytest_empty_string") + cursor.execute(""" + CREATE TABLE #pytest_empty_string ( + varchar_col VARCHAR(MAX), + nvarchar_col NVARCHAR(MAX) + ) + """) + db_connection.commit() + + empty_varchar = "" + empty_nvarchar = "" + cursor.execute( + "INSERT INTO #pytest_empty_string (varchar_col, nvarchar_col) VALUES (?, ?)", + [empty_varchar, empty_nvarchar] + ) + db_connection.commit() + + cursor.execute("SELECT LEN(varchar_col), LEN(nvarchar_col) FROM #pytest_empty_string") + row = tuple(int(x) for x in cursor.fetchone()) + assert row == (0, 0), f"Expected lengths (0,0), got {row}" + finally: + cursor.execute("DROP TABLE IF EXISTS #pytest_empty_string") + db_connection.commit() def test_close(db_connection): """Test closing the cursor"""