diff --git a/mssql_python/cursor.py b/mssql_python/cursor.py index fe9cc435..7a6dbfaa 100644 --- a/mssql_python/cursor.py +++ b/mssql_python/cursor.py @@ -373,16 +373,16 @@ def _map_sql_type(self, param, parameters_list, i): if utf16_len > MAX_INLINE_CHAR: # Long strings -> DAE if is_unicode: return ( - ddbc_sql_const.SQL_WLONGVARCHAR.value, + ddbc_sql_const.SQL_WVARCHAR.value, ddbc_sql_const.SQL_C_WCHAR.value, - utf16_len, + 0, 0, True, ) return ( - ddbc_sql_const.SQL_LONGVARCHAR.value, + ddbc_sql_const.SQL_VARCHAR.value, ddbc_sql_const.SQL_C_CHAR.value, - len(param), + 0, 0, True, ) diff --git a/mssql_python/pybind/ddbc_bindings.cpp b/mssql_python/pybind/ddbc_bindings.cpp index 0fc56fd9..789d3863 100644 --- a/mssql_python/pybind/ddbc_bindings.cpp +++ b/mssql_python/pybind/ddbc_bindings.cpp @@ -31,6 +31,7 @@ #define ARCHITECTURE "win64" // Default to win64 if not defined during compilation #endif #define DAE_CHUNK_SIZE 8192 +#define SQL_MAX_LOB_SIZE 8000 //------------------------------------------------------------------------------------------------- // Class definitions //------------------------------------------------------------------------------------------------- @@ -1722,8 +1723,119 @@ SQLRETURN SQLFetch_wrap(SqlHandlePtr StatementHandle) { return SQLFetch_ptr(StatementHandle->get()); } +static py::object FetchLobColumnData(SQLHSTMT hStmt, + SQLUSMALLINT colIndex, + SQLSMALLINT cType, + bool isWideChar, + bool isBinary) +{ + std::vector buffer; + SQLRETURN ret = SQL_SUCCESS_WITH_INFO; + int loopCount = 0; + + while (true) { + ++loopCount; + std::vector chunk(DAE_CHUNK_SIZE, 0); + SQLLEN actualRead = 0; + ret = SQLGetData_ptr(hStmt, + colIndex, + cType, + chunk.data(), + DAE_CHUNK_SIZE, + &actualRead); + + if (ret == SQL_ERROR || !SQL_SUCCEEDED(ret) && ret != SQL_SUCCESS_WITH_INFO) { + std::ostringstream oss; + oss << "Error fetching LOB for column " << colIndex + << ", cType=" << cType + << ", loop=" << loopCount + << ", SQLGetData return=" << ret; + LOG(oss.str()); + ThrowStdException(oss.str()); + } + if (actualRead == SQL_NULL_DATA) { + LOG("Loop {}: Column {} is NULL", loopCount, colIndex); + return py::none(); + } + + size_t bytesRead = 0; + if (actualRead >= 0) { + bytesRead = static_cast(actualRead); + if (bytesRead > DAE_CHUNK_SIZE) { + bytesRead = DAE_CHUNK_SIZE; + } + } else { + // fallback: use full buffer size if actualRead is unknown + bytesRead = DAE_CHUNK_SIZE; + } + + // For character data, trim trailing null terminators + if (!isBinary && bytesRead > 0) { + if (!isWideChar) { + // Narrow characters + while (bytesRead > 0 && chunk[bytesRead - 1] == '\0') { + --bytesRead; + } + if (bytesRead < DAE_CHUNK_SIZE) { + LOG("Loop {}: Trimmed null terminator (narrow)", loopCount); + } + } else { + // Wide characters + size_t wcharSize = sizeof(SQLWCHAR); + if (bytesRead >= wcharSize) { + auto sqlwBuf = reinterpret_cast(chunk.data()); + size_t wcharCount = bytesRead / wcharSize; + while (wcharCount > 0 && sqlwBuf[wcharCount - 1] == 0) { + --wcharCount; + bytesRead -= wcharSize; + } + if (bytesRead < DAE_CHUNK_SIZE) { + LOG("Loop {}: Trimmed null terminator (wide)", loopCount); + } + } + } + } + if (bytesRead > 0) { + buffer.insert(buffer.end(), chunk.begin(), chunk.begin() + bytesRead); + LOG("Loop {}: Appended {} bytes", loopCount, bytesRead); + } + if (ret == SQL_SUCCESS) { + LOG("Loop {}: SQL_SUCCESS → no more data", loopCount); + break; + } + } + LOG("FetchLobColumnData: Total bytes collected = {}", buffer.size()); + + if (buffer.empty()) { + if (isBinary) { + return py::bytes(""); + } + return py::str(""); + } + if (isWideChar) { +#if defined(_WIN32) + std::wstring wstr(reinterpret_cast(buffer.data()), buffer.size() / sizeof(wchar_t)); + std::string utf8str = WideToUTF8(wstr); + return py::str(utf8str); +#else + // Linux/macOS handling + size_t wcharCount = buffer.size() / sizeof(SQLWCHAR); + const SQLWCHAR* sqlwBuf = reinterpret_cast(buffer.data()); + std::wstring wstr = SQLWCHARToWString(sqlwBuf, wcharCount); + std::string utf8str = WideToUTF8(wstr); + return py::str(utf8str); +#endif + } + if (isBinary) { + LOG("FetchLobColumnData: Returning binary of {} bytes", buffer.size()); + return py::bytes(buffer.data(), buffer.size()); + } + std::string str(buffer.data(), buffer.size()); + LOG("FetchLobColumnData: Returning narrow string of length {}", str.length()); + return py::str(str); +} + // Helper function to retrieve column data -// TODO: Handle variable length data correctly SQLRETURN SQLGetData_wrap(SqlHandlePtr StatementHandle, SQLUSMALLINT colCount, py::list& row) { LOG("Get data from columns"); if (!SQLGetData_ptr) { @@ -1746,7 +1858,6 @@ SQLRETURN SQLGetData_wrap(SqlHandlePtr StatementHandle, SQLUSMALLINT colCount, p if (!SQL_SUCCEEDED(ret)) { LOG("Error retrieving data for column - {}, SQLDescribeCol return code - {}", i, ret); row.append(py::none()); - // TODO: Do we want to continue in this case or return? continue; } @@ -1754,121 +1865,106 @@ SQLRETURN SQLGetData_wrap(SqlHandlePtr StatementHandle, SQLUSMALLINT colCount, p case SQL_CHAR: case SQL_VARCHAR: case SQL_LONGVARCHAR: { - // TODO: revisit - HandleZeroColumnSizeAtFetch(columnSize); - uint64_t fetchBufferSize = columnSize + 1 /* null-termination */; - std::vector dataBuffer(fetchBufferSize); - SQLLEN dataLen; - // TODO: Handle the return code better - ret = SQLGetData_ptr(hStmt, i, SQL_C_CHAR, dataBuffer.data(), dataBuffer.size(), - &dataLen); - - if (SQL_SUCCEEDED(ret)) { - // TODO: Refactor these if's across other switches to avoid code duplication - // columnSize is in chars, dataLen is in bytes - if (dataLen > 0) { - uint64_t numCharsInData = dataLen / sizeof(SQLCHAR); - // NOTE: dataBuffer.size() includes null-terminator, dataLen doesn't. Hence use '<'. - if (numCharsInData < dataBuffer.size()) { - // SQLGetData will null-terminate the data -#if defined(__APPLE__) || defined(__linux__) - std::string fullStr(reinterpret_cast(dataBuffer.data())); - row.append(fullStr); - LOG("macOS/Linux: Appended CHAR string of length {} to result row", fullStr.length()); -#else - row.append(std::string(reinterpret_cast(dataBuffer.data()))); -#endif - } else { - // In this case, buffer size is smaller, and data to be retrieved is longer - // TODO: Revisit - std::ostringstream oss; - oss << "Buffer length for fetch (" << dataBuffer.size()-1 << ") is smaller, & data " - << "to be retrieved is longer (" << numCharsInData << "). ColumnID - " - << i << ", datatype - " << dataType; - ThrowStdException(oss.str()); + if (columnSize == SQL_NO_TOTAL || columnSize == 0 || columnSize > SQL_MAX_LOB_SIZE) { + LOG("Streaming LOB for column {}", i); + row.append(FetchLobColumnData(hStmt, i, SQL_C_CHAR, false, false)); + } else { + uint64_t fetchBufferSize = columnSize + 1 /* null-termination */; + std::vector dataBuffer(fetchBufferSize); + SQLLEN dataLen; + ret = SQLGetData_ptr(hStmt, i, SQL_C_CHAR, dataBuffer.data(), dataBuffer.size(), + &dataLen); + if (SQL_SUCCEEDED(ret)) { + // columnSize is in chars, dataLen is in bytes + if (dataLen > 0) { + uint64_t numCharsInData = dataLen / sizeof(SQLCHAR); + if (numCharsInData < dataBuffer.size()) { + // SQLGetData will null-terminate the data + #if defined(__APPLE__) || defined(__linux__) + std::string fullStr(reinterpret_cast(dataBuffer.data())); + row.append(fullStr); + LOG("macOS/Linux: Appended CHAR string of length {} to result row", fullStr.length()); + #else + row.append(std::string(reinterpret_cast(dataBuffer.data()))); + #endif + } else { + // Buffer too small, fallback to streaming + LOG("CHAR column {} data truncated, using streaming LOB", i); + row.append(FetchLobColumnData(hStmt, i, SQL_C_CHAR, false, false)); + } + } else if (dataLen == SQL_NULL_DATA) { + LOG("Column {} is NULL (CHAR)", i); + row.append(py::none()); + } else if (dataLen == 0) { + row.append(py::str("")); + } else if (dataLen == SQL_NO_TOTAL) { + LOG("SQLGetData couldn't determine the length of the data. " + "Returning NULL value instead. Column ID - {}, Data Type - {}", i, dataType); + row.append(py::none()); + } else if (dataLen < 0) { + LOG("SQLGetData returned an unexpected negative data length. " + "Raising exception. Column ID - {}, Data Type - {}, Data Length - {}", + i, dataType, dataLen); + ThrowStdException("SQLGetData returned an unexpected negative data length"); } - } else if (dataLen == SQL_NULL_DATA) { - row.append(py::none()); - } else if (dataLen == 0) { - // Handle zero-length (non-NULL) data - row.append(std::string("")); - } else if (dataLen == SQL_NO_TOTAL) { - // This means the length of the data couldn't be determined - LOG("SQLGetData couldn't determine the length of the data. " - "Returning NULL value instead. Column ID - {}, Data Type - {}", i, dataType); - } else if (dataLen < 0) { - // This is unexpected - LOG("SQLGetData returned an unexpected negative data length. " - "Raising exception. Column ID - {}, Data Type - {}, Data Length - {}", - i, dataType, dataLen); - ThrowStdException("SQLGetData returned an unexpected negative data length"); + } else { + LOG("Error retrieving data for column - {}, data type - {}, SQLGetData return " + "code - {}. Returning NULL value instead", + i, dataType, ret); + row.append(py::none()); } - } else { - LOG("Error retrieving data for column - {}, data type - {}, SQLGetData return " - "code - {}. Returning NULL value instead", - i, dataType, ret); - row.append(py::none()); } break; } case SQL_WCHAR: case SQL_WVARCHAR: - case SQL_WLONGVARCHAR: { - // TODO: revisit - HandleZeroColumnSizeAtFetch(columnSize); - uint64_t fetchBufferSize = columnSize + 1 /* null-termination */; - std::vector dataBuffer(fetchBufferSize); - SQLLEN dataLen; - ret = SQLGetData_ptr(hStmt, i, SQL_C_WCHAR, dataBuffer.data(), - dataBuffer.size() * sizeof(SQLWCHAR), &dataLen); - - if (SQL_SUCCEEDED(ret)) { - // TODO: Refactor these if's across other switches to avoid code duplication - if (dataLen > 0) { - uint64_t numCharsInData = dataLen / sizeof(SQLWCHAR); - if (numCharsInData < dataBuffer.size()) { - // SQLGetData will null-terminate the data + case SQL_WLONGVARCHAR: { + if (columnSize == SQL_NO_TOTAL || columnSize > 4000) { + LOG("Streaming LOB for column {} (NVARCHAR)", i); + row.append(FetchLobColumnData(hStmt, i, SQL_C_WCHAR, true, false)); + } else { + uint64_t fetchBufferSize = (columnSize + 1) * sizeof(SQLWCHAR); // +1 for null terminator + std::vector dataBuffer(columnSize + 1); + SQLLEN dataLen; + ret = SQLGetData_ptr(hStmt, i, SQL_C_WCHAR, dataBuffer.data(), fetchBufferSize, &dataLen); + if (SQL_SUCCEEDED(ret)) { + if (dataLen > 0) { + uint64_t numCharsInData = dataLen / sizeof(SQLWCHAR); + if (numCharsInData < dataBuffer.size()) { #if defined(__APPLE__) || defined(__linux__) - auto raw_bytes = reinterpret_cast(dataBuffer.data()); - size_t actualBufferSize = dataBuffer.size() * sizeof(SQLWCHAR); - if (dataLen < 0 || static_cast(dataLen) > actualBufferSize) { - LOG("Error: py::bytes creation request exceeds buffer size. dataLen={} buffer={}", - dataLen, actualBufferSize); - ThrowStdException("Invalid buffer length for py::bytes"); - } - py::bytes py_bytes(raw_bytes, dataLen); - py::str decoded = py_bytes.attr("decode")("utf-16-le"); - row.append(decoded); + const SQLWCHAR* sqlwBuf = reinterpret_cast(dataBuffer.data()); + std::wstring wstr = SQLWCHARToWString(sqlwBuf, numCharsInData); + std::string utf8str = WideToUTF8(wstr); + row.append(py::str(utf8str)); #else - row.append(std::wstring(dataBuffer.data())); + std::wstring wstr(reinterpret_cast(dataBuffer.data())); + row.append(py::cast(wstr)); #endif - } else { - // In this case, buffer size is smaller, and data to be retrieved is longer - // TODO: Revisit - std::ostringstream oss; - oss << "Buffer length for fetch (" << dataBuffer.size()-1 << ") is smaller, & data " - << "to be retrieved is longer (" << numCharsInData << "). ColumnID - " - << i << ", datatype - " << dataType; - ThrowStdException(oss.str()); + LOG("Appended NVARCHAR string of length {} to result row", numCharsInData); + } else { + // Buffer too small, fallback to streaming + LOG("NVARCHAR column {} data truncated, using streaming LOB", i); + row.append(FetchLobColumnData(hStmt, i, SQL_C_WCHAR, true, false)); + } + } else if (dataLen == SQL_NULL_DATA) { + LOG("Column {} is NULL (CHAR)", i); + row.append(py::none()); + } else if (dataLen == 0) { + row.append(py::str("")); + } else if (dataLen == SQL_NO_TOTAL) { + LOG("SQLGetData couldn't determine the length of the NVARCHAR data. Returning NULL. Column ID - {}", i); + row.append(py::none()); + } else if (dataLen < 0) { + LOG("SQLGetData returned an unexpected negative data length. " + "Raising exception. Column ID - {}, Data Type - {}, Data Length - {}", + i, dataType, dataLen); + ThrowStdException("SQLGetData returned an unexpected negative data length"); } - } else if (dataLen == SQL_NULL_DATA) { - row.append(py::none()); - } else if (dataLen == 0) { - // Handle zero-length (non-NULL) data - row.append(py::str("")); - } else if (dataLen < 0) { - // This is unexpected - LOG("SQLGetData returned an unexpected negative data length. " - "Raising exception. Column ID - {}, Data Type - {}, Data Length - {}", - i, dataType, dataLen); - ThrowStdException("SQLGetData returned an unexpected negative data length"); + } else { + LOG("Error retrieving data for column {} (NVARCHAR), SQLGetData return code {}", i, ret); + row.append(py::none()); } - } else { - LOG("Error retrieving data for column - {}, data type - {}, SQLGetData return " - "code - {}. Returning NULL value instead", - i, dataType, ret); - row.append(py::none()); - } + } break; } case SQL_INTEGER: { @@ -2316,7 +2412,7 @@ SQLRETURN SQLBindColums(SQLHSTMT hStmt, ColumnBuffers& buffers, py::list& column // Fetch rows in batches // TODO: Move to anonymous namespace, since it is not used outside this file SQLRETURN FetchBatchData(SQLHSTMT hStmt, ColumnBuffers& buffers, py::list& columnNames, - py::list& rows, SQLUSMALLINT numCols, SQLULEN& numRowsFetched) { + py::list& rows, SQLUSMALLINT numCols, SQLULEN& numRowsFetched, const std::vector& lobColumns) { LOG("Fetching data in batches"); SQLRETURN ret = SQLFetchScroll_ptr(hStmt, SQL_FETCH_NEXT, 0); if (ret == SQL_NO_DATA) { @@ -2376,25 +2472,19 @@ SQLRETURN FetchBatchData(SQLHSTMT hStmt, ColumnBuffers& buffers, py::list& colum case SQL_CHAR: case SQL_VARCHAR: case SQL_LONGVARCHAR: { - // TODO: variable length data needs special handling, this logic wont suffice SQLULEN columnSize = columnMeta["ColumnSize"].cast(); HandleZeroColumnSizeAtFetch(columnSize); uint64_t fetchBufferSize = columnSize + 1 /*null-terminator*/; uint64_t numCharsInData = dataLen / sizeof(SQLCHAR); + bool isLob = std::find(lobColumns.begin(), lobColumns.end(), col) != lobColumns.end(); // fetchBufferSize includes null-terminator, numCharsInData doesn't. Hence '<' - if (numCharsInData < fetchBufferSize) { + if (!isLob && numCharsInData < fetchBufferSize) { // SQLFetch will nullterminate the data row.append(std::string( reinterpret_cast(&buffers.charBuffers[col - 1][i * fetchBufferSize]), numCharsInData)); } else { - // In this case, buffer size is smaller, and data to be retrieved is longer - // TODO: Revisit - std::ostringstream oss; - oss << "Buffer length for fetch (" << columnSize << ") is smaller, & data " - << "to be retrieved is longer (" << numCharsInData << "). ColumnID - " - << col << ", datatype - " << dataType; - ThrowStdException(oss.str()); + row.append(FetchLobColumnData(hStmt, col, SQL_C_CHAR, false, false)); } break; } @@ -2406,8 +2496,9 @@ SQLRETURN FetchBatchData(SQLHSTMT hStmt, ColumnBuffers& buffers, py::list& colum HandleZeroColumnSizeAtFetch(columnSize); uint64_t fetchBufferSize = columnSize + 1 /*null-terminator*/; uint64_t numCharsInData = dataLen / sizeof(SQLWCHAR); + bool isLob = std::find(lobColumns.begin(), lobColumns.end(), col) != lobColumns.end(); // fetchBufferSize includes null-terminator, numCharsInData doesn't. Hence '<' - if (numCharsInData < fetchBufferSize) { + if (!isLob && numCharsInData < fetchBufferSize) { // SQLFetch will nullterminate the data #if defined(__APPLE__) || defined(__linux__) // Use unix-specific conversion to handle the wchar_t/SQLWCHAR size difference @@ -2421,13 +2512,7 @@ SQLRETURN FetchBatchData(SQLHSTMT hStmt, ColumnBuffers& buffers, py::list& colum numCharsInData)); #endif } else { - // In this case, buffer size is smaller, and data to be retrieved is longer - // TODO: Revisit - std::ostringstream oss; - oss << "Buffer length for fetch (" << columnSize << ") is smaller, & data " - << "to be retrieved is longer (" << numCharsInData << "). ColumnID - " - << col << ", datatype - " << dataType; - ThrowStdException(oss.str()); + row.append(FetchLobColumnData(hStmt, col, SQL_C_WCHAR, true, false)); } break; } @@ -2513,21 +2598,15 @@ SQLRETURN FetchBatchData(SQLHSTMT hStmt, ColumnBuffers& buffers, py::list& colum case SQL_BINARY: case SQL_VARBINARY: case SQL_LONGVARBINARY: { - // TODO: variable length data needs special handling, this logic wont suffice SQLULEN columnSize = columnMeta["ColumnSize"].cast(); HandleZeroColumnSizeAtFetch(columnSize); - if (static_cast(dataLen) <= columnSize) { + bool isLob = std::find(lobColumns.begin(), lobColumns.end(), col) != lobColumns.end(); + if (!isLob && static_cast(dataLen) <= columnSize) { row.append(py::bytes(reinterpret_cast( &buffers.charBuffers[col - 1][i * columnSize]), dataLen)); } else { - // In this case, buffer size is smaller, and data to be retrieved is longer - // TODO: Revisit - std::ostringstream oss; - oss << "Buffer length for fetch (" << columnSize << ") is smaller, & data " - << "to be retrieved is longer (" << dataLen << "). ColumnID - " - << col << ", datatype - " << dataType; - ThrowStdException(oss.str()); + row.append(FetchLobColumnData(hStmt, col, SQL_C_BINARY, false, true)); } break; } @@ -2656,6 +2735,35 @@ SQLRETURN FetchMany_wrap(SqlHandlePtr StatementHandle, py::list& rows, int fetch return ret; } + std::vector lobColumns; + for (SQLSMALLINT i = 0; i < numCols; i++) { + auto colMeta = columnNames[i].cast(); + SQLSMALLINT dataType = colMeta["DataType"].cast(); + SQLULEN columnSize = colMeta["ColumnSize"].cast(); + + if ((dataType == SQL_WVARCHAR || dataType == SQL_WLONGVARCHAR || + dataType == SQL_VARCHAR || dataType == SQL_LONGVARCHAR || + dataType == SQL_VARBINARY || dataType == SQL_LONGVARBINARY) && + (columnSize == 0 || columnSize == SQL_NO_TOTAL || columnSize > SQL_MAX_LOB_SIZE)) { + lobColumns.push_back(i + 1); // 1-based + } + } + + // If we have LOBs → fall back to row-by-row fetch + SQLGetData_wrap + if (!lobColumns.empty()) { + LOG("LOB columns detected → using per-row SQLGetData path"); + while (true) { + ret = SQLFetch_ptr(hStmt); + if (ret == SQL_NO_DATA) break; + if (!SQL_SUCCEEDED(ret)) return ret; + + py::list row; + SQLGetData_wrap(StatementHandle, numCols, row); // <-- streams LOBs correctly + rows.append(row); + } + return SQL_SUCCESS; + } + // Initialize column buffers ColumnBuffers buffers(numCols, fetchSize); @@ -2670,7 +2778,7 @@ SQLRETURN FetchMany_wrap(SqlHandlePtr StatementHandle, py::list& rows, int fetch SQLSetStmtAttr_ptr(hStmt, SQL_ATTR_ROW_ARRAY_SIZE, (SQLPOINTER)(intptr_t)fetchSize, 0); SQLSetStmtAttr_ptr(hStmt, SQL_ATTR_ROWS_FETCHED_PTR, &numRowsFetched, 0); - ret = FetchBatchData(hStmt, buffers, columnNames, rows, numCols, numRowsFetched); + ret = FetchBatchData(hStmt, buffers, columnNames, rows, numCols, numRowsFetched, lobColumns); if (!SQL_SUCCEEDED(ret) && ret != SQL_NO_DATA) { LOG("Error when fetching data"); return ret; @@ -2749,6 +2857,35 @@ SQLRETURN FetchAll_wrap(SqlHandlePtr StatementHandle, py::list& rows) { } LOG("Fetching data in batch sizes of {}", fetchSize); + std::vector lobColumns; + for (SQLSMALLINT i = 0; i < numCols; i++) { + auto colMeta = columnNames[i].cast(); + SQLSMALLINT dataType = colMeta["DataType"].cast(); + SQLULEN columnSize = colMeta["ColumnSize"].cast(); + + if ((dataType == SQL_WVARCHAR || dataType == SQL_WLONGVARCHAR || + dataType == SQL_VARCHAR || dataType == SQL_LONGVARCHAR || + dataType == SQL_VARBINARY || dataType == SQL_LONGVARBINARY) && + (columnSize == 0 || columnSize == SQL_NO_TOTAL || columnSize > SQL_MAX_LOB_SIZE)) { + lobColumns.push_back(i + 1); // 1-based + } + } + + // If we have LOBs → fall back to row-by-row fetch + SQLGetData_wrap + if (!lobColumns.empty()) { + LOG("LOB columns detected → using per-row SQLGetData path"); + while (true) { + ret = SQLFetch_ptr(hStmt); + if (ret == SQL_NO_DATA) break; + if (!SQL_SUCCEEDED(ret)) return ret; + + py::list row; + SQLGetData_wrap(StatementHandle, numCols, row); // <-- streams LOBs correctly + rows.append(row); + } + return SQL_SUCCESS; + } + ColumnBuffers buffers(numCols, fetchSize); // Bind columns @@ -2763,7 +2900,7 @@ SQLRETURN FetchAll_wrap(SqlHandlePtr StatementHandle, py::list& rows) { SQLSetStmtAttr_ptr(hStmt, SQL_ATTR_ROWS_FETCHED_PTR, &numRowsFetched, 0); while (ret != SQL_NO_DATA) { - ret = FetchBatchData(hStmt, buffers, columnNames, rows, numCols, numRowsFetched); + ret = FetchBatchData(hStmt, buffers, columnNames, rows, numCols, numRowsFetched, lobColumns); if (!SQL_SUCCEEDED(ret) && ret != SQL_NO_DATA) { LOG("Error when fetching data"); return ret; diff --git a/mssql_python/pybind/ddbc_bindings.h b/mssql_python/pybind/ddbc_bindings.h index 521a007b..fe4e8400 100644 --- a/mssql_python/pybind/ddbc_bindings.h +++ b/mssql_python/pybind/ddbc_bindings.h @@ -38,6 +38,18 @@ inline std::vector WStringToSQLWCHAR(const std::wstring& str) { result.push_back(0); return result; } + +inline std::wstring SQLWCHARToWString(const SQLWCHAR* sqlwStr, size_t length = SQL_NTS) { + if (!sqlwStr) return std::wstring(); + + if (length == SQL_NTS) { + size_t i = 0; + while (sqlwStr[i] != 0) ++i; + length = i; + } + return std::wstring(reinterpret_cast(sqlwStr), length); +} + #endif #if defined(__APPLE__) || defined(__linux__) @@ -60,7 +72,6 @@ inline bool IsValidUnicodeScalar(uint32_t cp) { inline std::wstring SQLWCHARToWString(const SQLWCHAR* sqlwStr, size_t length = SQL_NTS) { if (!sqlwStr) return std::wstring(); - if (length == SQL_NTS) { size_t i = 0; while (sqlwStr[i] != 0) ++i; @@ -68,29 +79,28 @@ inline std::wstring SQLWCHARToWString(const SQLWCHAR* sqlwStr, size_t length = S } std::wstring result; result.reserve(length); - if constexpr (sizeof(SQLWCHAR) == 2) { - // Decode UTF-16 to UTF-32 (with surrogate pair handling) - for (size_t i = 0; i < length; ++i) { + for (size_t i = 0; i < length; ) { // Use a manual increment to handle skipping uint16_t wc = static_cast(sqlwStr[i]); - // Check if this is a high surrogate (U+D800–U+DBFF) - if (wc >= UNICODE_SURROGATE_HIGH_START && wc <= UNICODE_SURROGATE_HIGH_END && i + 1 < length) { + // Check for high surrogate and valid low surrogate + if (wc >= UNICODE_SURROGATE_HIGH_START && wc <= UNICODE_SURROGATE_HIGH_END && (i + 1 < length)) { uint16_t low = static_cast(sqlwStr[i + 1]); - // Check if the next code unit is a low surrogate (U+DC00–U+DFFF) if (low >= UNICODE_SURROGATE_LOW_START && low <= UNICODE_SURROGATE_LOW_END) { - // Combine surrogate pair into a single code point + // Combine into a single code point uint32_t cp = (((wc - UNICODE_SURROGATE_HIGH_START) << 10) | (low - UNICODE_SURROGATE_LOW_START)) + 0x10000; result.push_back(static_cast(cp)); - ++i; // Skip the low surrogate + i += 2; // Move past both surrogates continue; } } - // If valid scalar then append, else append replacement char (U+FFFD) + // If we reach here, it's not a valid surrogate pair or is a BMP character. + // Check if it's a valid scalar and append, otherwise append replacement char. if (IsValidUnicodeScalar(wc)) { result.push_back(static_cast(wc)); } else { result.push_back(static_cast(UNICODE_REPLACEMENT_CHAR)); } + ++i; // Move to the next code unit } } else { // SQLWCHAR is UTF-32, so just copy with validation @@ -346,6 +356,7 @@ ErrorInfo SQLCheckError_Wrap(SQLSMALLINT handleType, SqlHandlePtr handle, SQLRET inline std::string WideToUTF8(const std::wstring& wstr) { if (wstr.empty()) return {}; + #if defined(_WIN32) int size_needed = WideCharToMultiByte(CP_UTF8, 0, wstr.data(), static_cast(wstr.size()), nullptr, 0, nullptr, nullptr); if (size_needed == 0) return {}; @@ -354,8 +365,34 @@ inline std::string WideToUTF8(const std::wstring& wstr) { if (converted == 0) return {}; return result; #else - std::wstring_convert> converter; - return converter.to_bytes(wstr); + // Manual UTF-32 to UTF-8 conversion for macOS/Linux + std::string utf8_string; + utf8_string.reserve(wstr.size() * 4); // Reserve enough space for worst case (4 bytes per character) + + for (wchar_t wc : wstr) { + uint32_t code_point = static_cast(wc); + + if (code_point <= 0x7F) { + // 1-byte UTF-8 sequence for ASCII characters + utf8_string += static_cast(code_point); + } else if (code_point <= 0x7FF) { + // 2-byte UTF-8 sequence + utf8_string += static_cast(0xC0 | ((code_point >> 6) & 0x1F)); + utf8_string += static_cast(0x80 | (code_point & 0x3F)); + } else if (code_point <= 0xFFFF) { + // 3-byte UTF-8 sequence + utf8_string += static_cast(0xE0 | ((code_point >> 12) & 0x0F)); + utf8_string += static_cast(0x80 | ((code_point >> 6) & 0x3F)); + utf8_string += static_cast(0x80 | (code_point & 0x3F)); + } else if (code_point <= 0x10FFFF) { + // 4-byte UTF-8 sequence for characters like emojis (e.g., U+1F604) + utf8_string += static_cast(0xF0 | ((code_point >> 18) & 0x07)); + utf8_string += static_cast(0x80 | ((code_point >> 12) & 0x3F)); + utf8_string += static_cast(0x80 | ((code_point >> 6) & 0x3F)); + utf8_string += static_cast(0x80 | (code_point & 0x3F)); + } + } + return utf8_string; #endif } diff --git a/tests/test_004_cursor.py b/tests/test_004_cursor.py index fe5cd4a1..1f6d6630 100644 --- a/tests/test_004_cursor.py +++ b/tests/test_004_cursor.py @@ -523,60 +523,6 @@ def test_varbinary_full_capacity(cursor, db_connection): cursor.execute("DROP TABLE #pytest_varbinary_test") db_connection.commit() -def test_varchar_max(cursor, db_connection): - """Test SQL_VARCHAR with MAX length""" - try: - cursor.execute("CREATE TABLE #pytest_varchar_test (varchar_column VARCHAR(MAX))") - db_connection.commit() - cursor.execute("INSERT INTO #pytest_varchar_test (varchar_column) VALUES (?), (?)", ["ABCDEFGHI", None]) - db_connection.commit() - expectedRows = 2 - # fetchone test - cursor.execute("SELECT varchar_column FROM #pytest_varchar_test") - rows = [] - for i in range(0, expectedRows): - rows.append(cursor.fetchone()) - assert cursor.fetchone() == None, "varchar_column is expected to have only {} rows".format(expectedRows) - assert rows[0] == ["ABCDEFGHI"], "SQL_VARCHAR parsing failed for fetchone - row 0" - assert rows[1] == [None], "SQL_VARCHAR parsing failed for fetchone - row 1" - # fetchall test - cursor.execute("SELECT varchar_column FROM #pytest_varchar_test") - rows = cursor.fetchall() - assert rows[0] == ["ABCDEFGHI"], "SQL_VARCHAR parsing failed for fetchall - row 0" - assert rows[1] == [None], "SQL_VARCHAR parsing failed for fetchall - row 1" - except Exception as e: - pytest.fail(f"SQL_VARCHAR parsing test failed: {e}") - finally: - cursor.execute("DROP TABLE #pytest_varchar_test") - db_connection.commit() - -def test_wvarchar_max(cursor, db_connection): - """Test SQL_WVARCHAR with MAX length""" - try: - cursor.execute("CREATE TABLE #pytest_wvarchar_test (wvarchar_column NVARCHAR(MAX))") - db_connection.commit() - cursor.execute("INSERT INTO #pytest_wvarchar_test (wvarchar_column) VALUES (?), (?)", ["!@#$%^&*()_+", None]) - db_connection.commit() - expectedRows = 2 - # fetchone test - cursor.execute("SELECT wvarchar_column FROM #pytest_wvarchar_test") - rows = [] - for i in range(0, expectedRows): - rows.append(cursor.fetchone()) - assert cursor.fetchone() == None, "wvarchar_column is expected to have only {} rows".format(expectedRows) - assert rows[0] == ["!@#$%^&*()_+"], "SQL_WVARCHAR parsing failed for fetchone - row 0" - assert rows[1] == [None], "SQL_WVARCHAR parsing failed for fetchone - row 1" - # fetchall test - cursor.execute("SELECT wvarchar_column FROM #pytest_wvarchar_test") - rows = cursor.fetchall() - assert rows[0] == ["!@#$%^&*()_+"], "SQL_WVARCHAR parsing failed for fetchall - row 0" - assert rows[1] == [None], "SQL_WVARCHAR parsing failed for fetchall - row 1" - except Exception as e: - pytest.fail(f"SQL_WVARCHAR parsing test failed: {e}") - finally: - cursor.execute("DROP TABLE #pytest_wvarchar_test") - db_connection.commit() - def test_varbinary_max(cursor, db_connection): """Test SQL_VARBINARY with MAX length""" try: @@ -5680,186 +5626,44 @@ def test_emoji_round_trip(cursor, db_connection): except Exception as e: pytest.fail(f"Error for input {repr(text)}: {e}") -def test_varchar_max_insert_non_lob(cursor, db_connection): - """Test small VARCHAR(MAX) insert (non-LOB path).""" - try: - cursor.execute("CREATE TABLE #pytest_varchar_nonlob (col VARCHAR(MAX))") - db_connection.commit() - - small_str = "Hello, world!" # small, non-LOB - cursor.execute( - "INSERT INTO #pytest_varchar_nonlob (col) VALUES (?)", - [small_str] - ) - db_connection.commit() - - empty_str = "" - cursor.execute( - "INSERT INTO #pytest_varchar_nonlob (col) VALUES (?)", - [empty_str] - ) - db_connection.commit() - - # None value - cursor.execute( - "INSERT INTO #pytest_varchar_nonlob (col) VALUES (?)", - [None] - ) - db_connection.commit() - - # Fetch commented for now - # cursor.execute("SELECT col FROM #pytest_varchar_nonlob") - # rows = cursor.fetchall() - # assert rows == [[small_str], [empty_str], [None]] - - finally: - pass - - -def test_varchar_max_insert_lob(cursor, db_connection): - """Test large VARCHAR(MAX) insert (LOB path).""" - try: - cursor.execute("CREATE TABLE #pytest_varchar_lob (col VARCHAR(MAX))") - db_connection.commit() - - large_str = "A" * 100_000 # > 8k to trigger LOB - cursor.execute( - "INSERT INTO #pytest_varchar_lob (col) VALUES (?)", - [large_str] - ) - db_connection.commit() - - # Fetch commented for now - # cursor.execute("SELECT col FROM #pytest_varchar_lob") - # rows = cursor.fetchall() - # assert rows == [[large_str]] - - finally: - pass - - -def test_nvarchar_max_insert_non_lob(cursor, db_connection): - """Test small NVARCHAR(MAX) insert (non-LOB path).""" - try: - cursor.execute("CREATE TABLE #pytest_nvarchar_nonlob (col NVARCHAR(MAX))") - db_connection.commit() - - small_str = "Unicode ✨ test" - cursor.execute( - "INSERT INTO #pytest_nvarchar_nonlob (col) VALUES (?)", - [small_str] - ) - db_connection.commit() - - empty_str = "" - cursor.execute( - "INSERT INTO #pytest_nvarchar_nonlob (col) VALUES (?)", - [empty_str] - ) - db_connection.commit() - - cursor.execute( - "INSERT INTO #pytest_nvarchar_nonlob (col) VALUES (?)", - [None] - ) - db_connection.commit() - - # Fetch commented for now - # cursor.execute("SELECT col FROM #pytest_nvarchar_nonlob") - # rows = cursor.fetchall() - # assert rows == [[small_str], [empty_str], [None]] - - finally: - pass - - -def test_nvarchar_max_insert_lob(cursor, db_connection): - """Test large NVARCHAR(MAX) insert (LOB path).""" +def test_varcharmax_transaction_rollback(cursor, db_connection): + """Test that inserting a large VARCHAR(MAX) within a transaction that is rolled back + does not persist the data, ensuring transactional integrity.""" try: - cursor.execute("CREATE TABLE #pytest_nvarchar_lob (col NVARCHAR(MAX))") + cursor.execute("DROP TABLE IF EXISTS #pytest_varcharmax") + cursor.execute("CREATE TABLE #pytest_varcharmax (col VARCHAR(MAX))") db_connection.commit() - - large_str = "📝" * 50_000 # each emoji = 2 UTF-16 code units, total > 100k bytes - cursor.execute( - "INSERT INTO #pytest_nvarchar_lob (col) VALUES (?)", - [large_str] - ) - db_connection.commit() - - # Fetch commented for now - # cursor.execute("SELECT col FROM #pytest_nvarchar_lob") - # rows = cursor.fetchall() - # assert rows == [[large_str]] + db_connection.autocommit = False + rollback_str = "ROLLBACK" * 2000 + cursor.execute("INSERT INTO #pytest_varcharmax VALUES (?)", [rollback_str]) + db_connection.rollback() + cursor.execute("SELECT COUNT(*) FROM #pytest_varcharmax WHERE col = ?", [rollback_str]) + assert cursor.fetchone()[0] == 0 finally: - pass - -def test_nvarchar_max_boundary(cursor, db_connection): - """Test NVARCHAR(MAX) at LOB boundary sizes.""" - try: - cursor.execute("DROP TABLE IF EXISTS #pytest_nvarchar_boundary") - cursor.execute("CREATE TABLE #pytest_nvarchar_boundary (col NVARCHAR(MAX))") - db_connection.commit() - - # 4k BMP chars = 8k bytes - cursor.execute("INSERT INTO #pytest_nvarchar_boundary (col) VALUES (?)", ["A" * 4096]) - # 4k emojis = 8k UTF-16 code units (16k bytes) - cursor.execute("INSERT INTO #pytest_nvarchar_boundary (col) VALUES (?)", ["📝" * 4096]) + db_connection.autocommit = True # reset state + cursor.execute("DROP TABLE IF EXISTS #pytest_varcharmax") db_connection.commit() - - # Fetch commented for now - # cursor.execute("SELECT col FROM #pytest_nvarchar_boundary") - # rows = cursor.fetchall() - # assert rows == [["A" * 4096], ["📝" * 4096]] - finally: - pass - -def test_nvarchar_max_chunk_edge(cursor, db_connection): - """Test NVARCHAR(MAX) insert slightly larger than a chunk.""" +def test_nvarcharmax_transaction_rollback(cursor, db_connection): + """Test that inserting a large NVARCHAR(MAX) within a transaction that is rolled back + does not persist the data, ensuring transactional integrity.""" try: - cursor.execute("DROP TABLE IF EXISTS #pytest_nvarchar_chunk") - cursor.execute("CREATE TABLE #pytest_nvarchar_chunk (col NVARCHAR(MAX))") + cursor.execute("DROP TABLE IF EXISTS #pytest_nvarcharmax") + cursor.execute("CREATE TABLE #pytest_nvarcharmax (col NVARCHAR(MAX))") db_connection.commit() - chunk_size = 8192 # bytes - test_str = "📝" * ((chunk_size // 4) + 3) # slightly > 1 chunk - cursor.execute("INSERT INTO #pytest_nvarchar_chunk (col) VALUES (?)", [test_str]) - db_connection.commit() - - # Fetch commented for now - # cursor.execute("SELECT col FROM #pytest_nvarchar_chunk") - # row = cursor.fetchone() - # assert row[0] == test_str + db_connection.autocommit = False + rollback_str = "ROLLBACK" * 2000 + cursor.execute("INSERT INTO #pytest_nvarcharmax VALUES (?)", [rollback_str]) + db_connection.rollback() + cursor.execute("SELECT COUNT(*) FROM #pytest_nvarcharmax WHERE col = ?", [rollback_str]) + assert cursor.fetchone()[0] == 0 finally: - pass - -def test_empty_string_chunk(cursor, db_connection): - """Test inserting empty strings into VARCHAR(MAX) and NVARCHAR(MAX).""" - try: - cursor.execute("DROP TABLE IF EXISTS #pytest_empty_string") - cursor.execute(""" - CREATE TABLE #pytest_empty_string ( - varchar_col VARCHAR(MAX), - nvarchar_col NVARCHAR(MAX) - ) - """) - db_connection.commit() - - empty_varchar = "" - empty_nvarchar = "" - cursor.execute( - "INSERT INTO #pytest_empty_string (varchar_col, nvarchar_col) VALUES (?, ?)", - [empty_varchar, empty_nvarchar] - ) + db_connection.autocommit = True + cursor.execute("DROP TABLE IF EXISTS #pytest_nvarcharmax") db_connection.commit() - cursor.execute("SELECT LEN(varchar_col), LEN(nvarchar_col) FROM #pytest_empty_string") - row = tuple(int(x) for x in cursor.fetchone()) - assert row == (0, 0), f"Expected lengths (0,0), got {row}" - finally: - cursor.execute("DROP TABLE IF EXISTS #pytest_empty_string") - db_connection.commit() def test_empty_char_single_and_batch_fetch(cursor, db_connection): """Test that empty CHAR data is handled correctly in both single and batch fetch""" @@ -6547,7 +6351,259 @@ def test_only_null_and_empty_binary(cursor, db_connection): finally: drop_table_if_exists(cursor, "#pytest_null_empty_binary") db_connection.commit() + +# ---------------------- VARCHAR(MAX) ---------------------- + +def test_varcharmax_short_fetch(cursor, db_connection): + """Small VARCHAR(MAX), fetchone/fetchall/fetchmany.""" + try: + cursor.execute("DROP TABLE IF EXISTS #pytest_varcharmax") + cursor.execute("CREATE TABLE #pytest_varcharmax (col VARCHAR(MAX))") + db_connection.commit() + + values = ["hello", "world"] + for val in values: + cursor.execute("INSERT INTO #pytest_varcharmax VALUES (?)", [val]) + db_connection.commit() + + # fetchone + cursor.execute("SELECT col FROM #pytest_varcharmax ORDER BY col") + row1 = cursor.fetchone()[0] + row2 = cursor.fetchone()[0] + assert {row1, row2} == set(values) + assert cursor.fetchone() is None + + # fetchall + cursor.execute("SELECT col FROM #pytest_varcharmax ORDER BY col") + all_rows = [r[0] for r in cursor.fetchall()] + assert set(all_rows) == set(values) + + # fetchmany + cursor.execute("SELECT col FROM #pytest_varcharmax ORDER BY col") + many = [r[0] for r in cursor.fetchmany(1)] + assert many[0] in values + finally: + cursor.execute("DROP TABLE IF EXISTS #pytest_varcharmax") + db_connection.commit() + + +def test_varcharmax_empty_string(cursor, db_connection): + """Empty string in VARCHAR(MAX).""" + try: + cursor.execute("CREATE TABLE #pytest_varcharmax (col VARCHAR(MAX))") + db_connection.commit() + cursor.execute("INSERT INTO #pytest_varcharmax VALUES (?)", [""]) + db_connection.commit() + + cursor.execute("SELECT col FROM #pytest_varcharmax") + assert cursor.fetchone()[0] == "" + finally: + cursor.execute("DROP TABLE #pytest_varcharmax") + db_connection.commit() + + +def test_varcharmax_null(cursor, db_connection): + """NULL in VARCHAR(MAX).""" + try: + cursor.execute("CREATE TABLE #pytest_varcharmax (col VARCHAR(MAX))") + db_connection.commit() + cursor.execute("INSERT INTO #pytest_varcharmax VALUES (?)", [None]) + db_connection.commit() + + cursor.execute("SELECT col FROM #pytest_varcharmax") + assert cursor.fetchone()[0] is None + finally: + cursor.execute("DROP TABLE #pytest_varcharmax") + db_connection.commit() + + +def test_varcharmax_boundary(cursor, db_connection): + """Boundary at 8000 (inline limit).""" + try: + boundary_str = "X" * 8000 + cursor.execute("CREATE TABLE #pytest_varcharmax (col VARCHAR(MAX))") + db_connection.commit() + cursor.execute("INSERT INTO #pytest_varcharmax VALUES (?)", [boundary_str]) + db_connection.commit() + + cursor.execute("SELECT col FROM #pytest_varcharmax") + assert cursor.fetchone()[0] == boundary_str + finally: + cursor.execute("DROP TABLE #pytest_varcharmax") + db_connection.commit() + + +def test_varcharmax_streaming(cursor, db_connection): + """Streaming fetch > 8k with all fetch modes.""" + try: + values = ["Y" * 8100, "Z" * 10000] + cursor.execute("CREATE TABLE #pytest_varcharmax (col VARCHAR(MAX))") + db_connection.commit() + for v in values: + cursor.execute("INSERT INTO #pytest_varcharmax VALUES (?)", [v]) + db_connection.commit() + + # --- fetchall --- + cursor.execute("SELECT col FROM #pytest_varcharmax ORDER BY LEN(col)") + rows = [r[0] for r in cursor.fetchall()] + assert rows == sorted(values, key=len) + + # --- fetchone --- + cursor.execute("SELECT col FROM #pytest_varcharmax ORDER BY LEN(col)") + r1 = cursor.fetchone()[0] + r2 = cursor.fetchone()[0] + assert {r1, r2} == set(values) + assert cursor.fetchone() is None + + # --- fetchmany --- + cursor.execute("SELECT col FROM #pytest_varcharmax ORDER BY LEN(col)") + batch = [r[0] for r in cursor.fetchmany(1)] + assert batch[0] in values + finally: + cursor.execute("DROP TABLE #pytest_varcharmax") + db_connection.commit() + + +def test_varcharmax_large(cursor, db_connection): + """Very large VARCHAR(MAX).""" + try: + large_str = "L" * 100_000 + cursor.execute("CREATE TABLE #pytest_varcharmax (col VARCHAR(MAX))") + db_connection.commit() + cursor.execute("INSERT INTO #pytest_varcharmax VALUES (?)", [large_str]) + db_connection.commit() + + cursor.execute("SELECT col FROM #pytest_varcharmax") + assert cursor.fetchone()[0] == large_str + finally: + cursor.execute("DROP TABLE #pytest_varcharmax") + db_connection.commit() + + +# ---------------------- NVARCHAR(MAX) ---------------------- + +def test_nvarcharmax_short_fetch(cursor, db_connection): + """Small NVARCHAR(MAX), unicode, fetch modes.""" + try: + values = ["hello", "world_ß"] + cursor.execute("CREATE TABLE #pytest_nvarcharmax (col NVARCHAR(MAX))") + db_connection.commit() + for v in values: + cursor.execute("INSERT INTO #pytest_nvarcharmax VALUES (?)", [v]) + db_connection.commit() + + # fetchone + cursor.execute("SELECT col FROM #pytest_nvarcharmax ORDER BY col") + r1 = cursor.fetchone()[0] + r2 = cursor.fetchone()[0] + assert {r1, r2} == set(values) + assert cursor.fetchone() is None + + # fetchall + cursor.execute("SELECT col FROM #pytest_nvarcharmax ORDER BY col") + all_rows = [r[0] for r in cursor.fetchall()] + assert set(all_rows) == set(values) + + # fetchmany + cursor.execute("SELECT col FROM #pytest_nvarcharmax ORDER BY col") + many = [r[0] for r in cursor.fetchmany(1)] + assert many[0] in values + finally: + cursor.execute("DROP TABLE #pytest_nvarcharmax") + db_connection.commit() + + +def test_nvarcharmax_empty_string(cursor, db_connection): + """Empty string in NVARCHAR(MAX).""" + try: + cursor.execute("CREATE TABLE #pytest_nvarcharmax (col NVARCHAR(MAX))") + db_connection.commit() + cursor.execute("INSERT INTO #pytest_nvarcharmax VALUES (?)", [""]) + db_connection.commit() + + cursor.execute("SELECT col FROM #pytest_nvarcharmax") + assert cursor.fetchone()[0] == "" + finally: + cursor.execute("DROP TABLE #pytest_nvarcharmax") + db_connection.commit() + + +def test_nvarcharmax_null(cursor, db_connection): + """NULL in NVARCHAR(MAX).""" + try: + cursor.execute("CREATE TABLE #pytest_nvarcharmax (col NVARCHAR(MAX))") + db_connection.commit() + cursor.execute("INSERT INTO #pytest_nvarcharmax VALUES (?)", [None]) + db_connection.commit() + + cursor.execute("SELECT col FROM #pytest_nvarcharmax") + assert cursor.fetchone()[0] is None + finally: + cursor.execute("DROP TABLE #pytest_nvarcharmax") + db_connection.commit() + + +def test_nvarcharmax_boundary(cursor, db_connection): + """Boundary at 4000 characters (inline limit).""" + try: + boundary_str = "X" * 4000 + cursor.execute("CREATE TABLE #pytest_nvarcharmax (col NVARCHAR(MAX))") + db_connection.commit() + cursor.execute("INSERT INTO #pytest_nvarcharmax VALUES (?)", [boundary_str]) + db_connection.commit() + + cursor.execute("SELECT col FROM #pytest_nvarcharmax") + assert cursor.fetchone()[0] == boundary_str + finally: + cursor.execute("DROP TABLE #pytest_nvarcharmax") + db_connection.commit() + + +def test_nvarcharmax_streaming(cursor, db_connection): + """Streaming fetch > 4k unicode with all fetch modes.""" + try: + values = ["Ω" * 4100, "漢" * 5000] + cursor.execute("CREATE TABLE #pytest_nvarcharmax (col NVARCHAR(MAX))") + db_connection.commit() + for v in values: + cursor.execute("INSERT INTO #pytest_nvarcharmax VALUES (?)", [v]) + db_connection.commit() + # --- fetchall --- + cursor.execute("SELECT col FROM #pytest_nvarcharmax ORDER BY LEN(col)") + rows = [r[0] for r in cursor.fetchall()] + assert rows == sorted(values, key=len) + + # --- fetchone --- + cursor.execute("SELECT col FROM #pytest_nvarcharmax ORDER BY LEN(col)") + r1 = cursor.fetchone()[0] + r2 = cursor.fetchone()[0] + assert {r1, r2} == set(values) + assert cursor.fetchone() is None + + # --- fetchmany --- + cursor.execute("SELECT col FROM #pytest_nvarcharmax ORDER BY LEN(col)") + batch = [r[0] for r in cursor.fetchmany(1)] + assert batch[0] in values + finally: + cursor.execute("DROP TABLE #pytest_nvarcharmax") + db_connection.commit() + + +def test_nvarcharmax_large(cursor, db_connection): + """Very large NVARCHAR(MAX).""" + try: + large_str = "漢" * 50_000 + cursor.execute("CREATE TABLE #pytest_nvarcharmax (col NVARCHAR(MAX))") + db_connection.commit() + cursor.execute("INSERT INTO #pytest_nvarcharmax VALUES (?)", [large_str]) + db_connection.commit() + + cursor.execute("SELECT col FROM #pytest_nvarcharmax") + assert cursor.fetchone()[0] == large_str + finally: + cursor.execute("DROP TABLE #pytest_nvarcharmax") + db_connection.commit() def test_money_smallmoney_insert_fetch(cursor, db_connection): """Test inserting and retrieving valid MONEY and SMALLMONEY values including boundaries and typical data""" @@ -6776,4 +6832,4 @@ def test_close(db_connection): except Exception as e: pytest.fail(f"Cursor close test failed: {e}") finally: - cursor = db_connection.cursor() \ No newline at end of file + cursor = db_connection.cursor()