diff --git a/mssql_python/cursor.py b/mssql_python/cursor.py index a6f5bb64..5bdeaed9 100644 --- a/mssql_python/cursor.py +++ b/mssql_python/cursor.py @@ -18,6 +18,8 @@ from mssql_python.exceptions import InterfaceError, NotSupportedError, ProgrammingError from .row import Row +# Constants for string handling +MAX_INLINE_CHAR = 4000 # NVARCHAR/VARCHAR inline limit; this triggers NVARCHAR(MAX)/VARCHAR(MAX) + DAE class Cursor: """ @@ -233,10 +235,11 @@ def _map_sql_type(self, param, parameters_list, i): ddbc_sql_const.SQL_C_DEFAULT.value, 1, 0, + False, ) if isinstance(param, bool): - return ddbc_sql_const.SQL_BIT.value, ddbc_sql_const.SQL_C_BIT.value, 1, 0 + return ddbc_sql_const.SQL_BIT.value, ddbc_sql_const.SQL_C_BIT.value, 1, 0, False if isinstance(param, int): if 0 <= param <= 255: @@ -245,6 +248,7 @@ def _map_sql_type(self, param, parameters_list, i): ddbc_sql_const.SQL_C_TINYINT.value, 3, 0, + False, ) if -32768 <= param <= 32767: return ( @@ -252,6 +256,7 @@ def _map_sql_type(self, param, parameters_list, i): ddbc_sql_const.SQL_C_SHORT.value, 5, 0, + False, ) if -2147483648 <= param <= 2147483647: return ( @@ -259,12 +264,14 @@ def _map_sql_type(self, param, parameters_list, i): ddbc_sql_const.SQL_C_LONG.value, 10, 0, + False, ) return ( ddbc_sql_const.SQL_BIGINT.value, ddbc_sql_const.SQL_C_SBIGINT.value, 19, 0, + False, ) if isinstance(param, float): @@ -273,6 +280,7 @@ def _map_sql_type(self, param, parameters_list, i): ddbc_sql_const.SQL_C_DOUBLE.value, 15, 0, + False, ) if isinstance(param, decimal.Decimal): @@ -284,6 +292,7 @@ def _map_sql_type(self, param, parameters_list, i): ddbc_sql_const.SQL_C_NUMERIC.value, parameters_list[i].precision, parameters_list[i].scale, + False, ) if isinstance(param, str): @@ -297,6 +306,7 @@ def _map_sql_type(self, param, parameters_list, i): ddbc_sql_const.SQL_C_WCHAR.value, len(param), 0, + False, ) # Attempt to parse as date, datetime, datetime2, timestamp, smalldatetime or time @@ -309,6 +319,7 @@ def _map_sql_type(self, param, parameters_list, i): ddbc_sql_const.SQL_C_TYPE_DATE.value, 10, 0, + False, ) if self._parse_datetime(param): parameters_list[i] = self._parse_datetime(param) @@ -317,6 +328,7 @@ def _map_sql_type(self, param, parameters_list, i): ddbc_sql_const.SQL_C_TYPE_TIMESTAMP.value, 26, 6, + False, ) if self._parse_time(param): parameters_list[i] = self._parse_time(param) @@ -325,25 +337,26 @@ def _map_sql_type(self, param, parameters_list, i): ddbc_sql_const.SQL_C_TYPE_TIME.value, 8, 0, + False, ) # String mapping logic here is_unicode = self._is_unicode_string(param) - # TODO: revisit - if len(param) > 4000: # Long strings + if len(param) > MAX_INLINE_CHAR: # Long strings if is_unicode: - utf16_len = len(param.encode("utf-16-le")) // 2 return ( ddbc_sql_const.SQL_WLONGVARCHAR.value, ddbc_sql_const.SQL_C_WCHAR.value, - utf16_len, + len(param), 0, + True, ) return ( ddbc_sql_const.SQL_LONGVARCHAR.value, ddbc_sql_const.SQL_C_CHAR.value, len(param), 0, + True, ) if is_unicode: # Short Unicode strings utf16_len = len(param.encode("utf-16-le")) // 2 @@ -352,12 +365,14 @@ def _map_sql_type(self, param, parameters_list, i): ddbc_sql_const.SQL_C_WCHAR.value, utf16_len, 0, + False, ) return ( ddbc_sql_const.SQL_VARCHAR.value, ddbc_sql_const.SQL_C_CHAR.value, len(param), 0, + False, ) if isinstance(param, bytes): @@ -367,12 +382,14 @@ def _map_sql_type(self, param, parameters_list, i): ddbc_sql_const.SQL_C_BINARY.value, len(param), 0, + False, ) return ( ddbc_sql_const.SQL_BINARY.value, ddbc_sql_const.SQL_C_BINARY.value, len(param), 0, + False, ) if isinstance(param, bytearray): @@ -382,12 +399,14 @@ def _map_sql_type(self, param, parameters_list, i): ddbc_sql_const.SQL_C_BINARY.value, len(param), 0, + True, ) return ( ddbc_sql_const.SQL_BINARY.value, ddbc_sql_const.SQL_C_BINARY.value, len(param), 0, + False, ) if isinstance(param, datetime.datetime): @@ -396,6 +415,7 @@ def _map_sql_type(self, param, parameters_list, i): ddbc_sql_const.SQL_C_TYPE_TIMESTAMP.value, 26, 6, + False, ) if isinstance(param, datetime.date): @@ -404,6 +424,7 @@ def _map_sql_type(self, param, parameters_list, i): ddbc_sql_const.SQL_C_TYPE_DATE.value, 10, 0, + False, ) if isinstance(param, datetime.time): @@ -412,14 +433,11 @@ def _map_sql_type(self, param, parameters_list, i): ddbc_sql_const.SQL_C_TYPE_TIME.value, 8, 0, + False, ) - return ( - ddbc_sql_const.SQL_VARCHAR.value, - ddbc_sql_const.SQL_C_CHAR.value, - len(str(param)), - 0, - ) + # For safety: unknown/unhandled Python types should not silently go to SQL + raise TypeError("Unsupported parameter type: The driver cannot safely convert it to a SQL type.") def _initialize_cursor(self) -> None: """ @@ -495,7 +513,7 @@ def _create_parameter_types_list(self, parameter, param_info, parameters_list, i paraminfo. """ paraminfo = param_info() - sql_type, c_type, column_size, decimal_digits = self._map_sql_type( + sql_type, c_type, column_size, decimal_digits, is_dae = self._map_sql_type( parameter, parameters_list, i ) paraminfo.paramCType = c_type @@ -503,6 +521,11 @@ def _create_parameter_types_list(self, parameter, param_info, parameters_list, i paraminfo.inputOutputType = ddbc_sql_const.SQL_PARAM_INPUT.value paraminfo.columnSize = column_size paraminfo.decimalDigits = decimal_digits + paraminfo.isDAE = is_dae + + if is_dae: + paraminfo.dataPtr = parameter # Will be converted to py::object* in C++ + return paraminfo def _initialize_description(self): @@ -762,9 +785,16 @@ def execute( self.is_stmt_prepared, use_prepare, ) - + # Check return code + try: + # Check for errors but don't raise exceptions for info/warning messages - check_error(ddbc_sql_const.SQL_HANDLE_STMT.value, self.hstmt, ret) + check_error(ddbc_sql_const.SQL_HANDLE_STMT.value, self.hstmt, ret) + except Exception as e: + log('warning', "Execute failed, resetting cursor: %s", e) + self._reset_cursor() + raise + # Capture any diagnostic messages (SQL_SUCCESS_WITH_INFO, etc.) if self.hstmt: diff --git a/mssql_python/pybind/ddbc_bindings.cpp b/mssql_python/pybind/ddbc_bindings.cpp index d0a20dbd..8a88688a 100644 --- a/mssql_python/pybind/ddbc_bindings.cpp +++ b/mssql_python/pybind/ddbc_bindings.cpp @@ -30,7 +30,7 @@ #ifndef ARCHITECTURE #define ARCHITECTURE "win64" // Default to win64 if not defined during compilation #endif - +#define DAE_CHUNK_SIZE 8192 //------------------------------------------------------------------------------------------------- // Class definitions //------------------------------------------------------------------------------------------------- @@ -43,9 +43,9 @@ struct ParamInfo { SQLSMALLINT paramSQLType; SQLULEN columnSize; SQLSMALLINT decimalDigits; - // TODO: Reuse python buffer for large data using Python buffer protocol - // Stores pointer to the python object that holds parameter value - // py::object* dataPtr; + SQLLEN strLenOrInd = 0; // Required for DAE + bool isDAE = false; // Indicates if we need to stream + py::object dataPtr; }; // Mirrors the SQL_NUMERIC_STRUCT. But redefined to replace val char array @@ -134,6 +134,10 @@ SQLFreeStmtFunc SQLFreeStmt_ptr = nullptr; // Diagnostic APIs SQLGetDiagRecFunc SQLGetDiagRec_ptr = nullptr; + +// DAE APIs +SQLParamDataFunc SQLParamData_ptr = nullptr; +SQLPutDataFunc SQLPutData_ptr = nullptr; SQLTablesFunc SQLTables_ptr = nullptr; namespace { @@ -245,57 +249,26 @@ SQLRETURN BindParameters(SQLHANDLE hStmt, const py::list& params, !py::isinstance(param)) { ThrowStdException(MakeParamMismatchErrorStr(paramInfo.paramCType, paramIndex)); } - std::wstring* strParam = - AllocateParamBuffer(paramBuffers, param.cast()); - if (strParam->size() > 4096 /* TODO: Fix max length */) { - ThrowStdException( - "Streaming parameters is not yet supported. Parameter size" - " must be less than 8192 bytes"); - } - - // Log detailed parameter information - LOG("SQL_C_WCHAR Parameter[{}]: Length={}, Content='{}'", - paramIndex, - strParam->size(), - (strParam->size() <= 100 - ? WideToUTF8(std::wstring(strParam->begin(), strParam->end())) - : WideToUTF8(std::wstring(strParam->begin(), strParam->begin() + 100)) + "...")); - - // Log each character's code point for debugging - if (strParam->size() <= 20) { - for (size_t i = 0; i < strParam->size(); i++) { - unsigned char ch = static_cast((*strParam)[i]); - LOG(" char[{}] = {} ({})", i, static_cast(ch), DescribeChar(ch)); - } + if (paramInfo.isDAE) { + // deferred execution + LOG("Parameter[{}] is marked for DAE streaming", paramIndex); + dataPtr = const_cast(reinterpret_cast(¶mInfos[paramIndex])); + strLenOrIndPtr = AllocateParamBuffer(paramBuffers); + *strLenOrIndPtr = SQL_LEN_DATA_AT_EXEC(0); + bufferLength = 0; + } else { + // Normal small-string case + std::wstring* strParam = + AllocateParamBuffer(paramBuffers, param.cast()); + LOG("SQL_C_WCHAR Parameter[{}]: Length={}, isDAE={}", paramIndex, strParam->size(), paramInfo.isDAE); + std::vector* sqlwcharBuffer = + AllocateParamBuffer>(paramBuffers, WStringToSQLWCHAR(*strParam)); + dataPtr = sqlwcharBuffer->data(); + bufferLength = sqlwcharBuffer->size() * sizeof(SQLWCHAR); + strLenOrIndPtr = AllocateParamBuffer(paramBuffers); + *strLenOrIndPtr = SQL_NTS; + } -#if defined(__APPLE__) || defined(__linux__) - // On macOS/Linux, we need special handling for wide characters - // Create a properly encoded SQLWCHAR buffer for the parameter - std::vector* sqlwcharBuffer = - AllocateParamBuffer>(paramBuffers); - - // Reserve space and convert from wstring to SQLWCHAR array - std::vector utf16 = WStringToSQLWCHAR(*strParam); - if (utf16.size() < strParam->size()) { - LOG("Warning: UTF-16 encoding shrank string? input={} output={}", - strParam->size(), utf16.size()); - } - if (utf16.size() > strParam->size() * 2 + 1) { - LOG("Warning: UTF-16 expansion unusually large: input={} output={}", - strParam->size(), utf16.size()); - } - *sqlwcharBuffer = std::move(utf16); - // Use the SQLWCHAR buffer instead of the wstring directly - dataPtr = sqlwcharBuffer->data(); - bufferLength = sqlwcharBuffer->size() * sizeof(SQLWCHAR); - LOG("macOS: Created SQLWCHAR buffer for parameter with size: {} bytes", bufferLength); -#else - // On Windows, wchar_t and SQLWCHAR are the same size, so direct cast works - dataPtr = const_cast(static_cast(strParam->c_str())); - bufferLength = (strParam->size() + 1 /* null terminator */) * sizeof(wchar_t); -#endif - strLenOrIndPtr = AllocateParamBuffer(paramBuffers); - *strLenOrIndPtr = SQL_NTS; break; } case SQL_C_BIT: { @@ -791,6 +764,9 @@ DriverHandle LoadDriverOrThrowException() { SQLFreeStmt_ptr = GetFunctionPointer(handle, "SQLFreeStmt"); SQLGetDiagRec_ptr = GetFunctionPointer(handle, "SQLGetDiagRecW"); + + SQLParamData_ptr = GetFunctionPointer(handle, "SQLParamData"); + SQLPutData_ptr = GetFunctionPointer(handle, "SQLPutData"); SQLTables_ptr = GetFunctionPointer(handle, "SQLTablesW"); bool success = @@ -802,7 +778,8 @@ DriverHandle LoadDriverOrThrowException() { SQLGetData_ptr && SQLNumResultCols_ptr && SQLBindCol_ptr && SQLDescribeCol_ptr && SQLMoreResults_ptr && SQLColAttribute_ptr && SQLEndTran_ptr && SQLDisconnect_ptr && SQLFreeHandle_ptr && - SQLFreeStmt_ptr && SQLGetDiagRec_ptr && SQLTables_ptr; + SQLFreeStmt_ptr && SQLGetDiagRec_ptr && SQLParamData_ptr && + SQLPutData_ptr && SQLTables_ptr; if (!success) { ThrowStdException("Failed to load required function pointers from driver."); @@ -1176,16 +1153,63 @@ SQLRETURN SQLExecute_wrap(const SqlHandlePtr statementHandle, } rc = SQLExecute_ptr(hStmt); + if (rc == SQL_NEED_DATA) { + LOG("Beginning SQLParamData/SQLPutData loop for DAE."); + SQLPOINTER paramToken = nullptr; + while ((rc = SQLParamData_ptr(hStmt, ¶mToken)) == SQL_NEED_DATA) { + // Finding the paramInfo that matches the returned token + const ParamInfo* matchedInfo = nullptr; + for (auto& info : paramInfos) { + if (reinterpret_cast(const_cast(&info)) == paramToken) { + matchedInfo = &info; + break; + } + } + if (!matchedInfo) { + ThrowStdException("Unrecognized paramToken returned by SQLParamData"); + } + const py::object& pyObj = matchedInfo->dataPtr; + if (pyObj.is_none()) { + SQLPutData_ptr(hStmt, nullptr, 0); + continue; + } + if (py::isinstance(pyObj)) { + std::wstring wstr = pyObj.cast(); +#if defined(__APPLE__) || defined(__linux__) + auto utf16Buf = WStringToSQLWCHAR(wstr); + const char* dataPtr = reinterpret_cast(utf16Buf.data()); + size_t totalBytes = (utf16Buf.size() - 1) * sizeof(SQLWCHAR); +#else + const char* dataPtr = reinterpret_cast(wstr.data()); + size_t totalBytes = wstr.size() * sizeof(wchar_t); +#endif + const size_t chunkSize = DAE_CHUNK_SIZE; + for (size_t offset = 0; offset < totalBytes; offset += chunkSize) { + size_t len = std::min(chunkSize, totalBytes - offset); + rc = SQLPutData_ptr(hStmt, (SQLPOINTER)(dataPtr + offset), static_cast(len)); + if (!SQL_SUCCEEDED(rc)) { + LOG("SQLPutData failed at offset {} of {}", offset, totalBytes); + return rc; + } + } + } else { + ThrowStdException("DAE only supported for str or bytes"); + } + } + if (!SQL_SUCCEEDED(rc)) { + LOG("SQLParamData final rc: {}", rc); + return rc; + } + LOG("DAE complete, SQLExecute resumed internally."); + } if (!SQL_SUCCEEDED(rc) && rc != SQL_NO_DATA) { LOG("DDBCSQLExecute: Error during execution of the statement"); return rc; } - // TODO: Handle huge input parameters by checking rc == SQL_NEED_DATA // Unbind the bound buffers for all parameters coz the buffers' memory will // be freed when this function exits (parambuffers goes out of scope) rc = SQLFreeStmt_ptr(hStmt, SQL_RESET_PARAMS); - return rc; } } @@ -2731,8 +2755,11 @@ PYBIND11_MODULE(ddbc_bindings, m) { .def_readwrite("paramCType", &ParamInfo::paramCType) .def_readwrite("paramSQLType", &ParamInfo::paramSQLType) .def_readwrite("columnSize", &ParamInfo::columnSize) - .def_readwrite("decimalDigits", &ParamInfo::decimalDigits); - + .def_readwrite("decimalDigits", &ParamInfo::decimalDigits) + .def_readwrite("strLenOrInd", &ParamInfo::strLenOrInd) + .def_readwrite("dataPtr", &ParamInfo::dataPtr) + .def_readwrite("isDAE", &ParamInfo::isDAE); + // Define numeric data class py::class_(m, "NumericData") .def(py::init<>()) diff --git a/mssql_python/pybind/ddbc_bindings.h b/mssql_python/pybind/ddbc_bindings.h index f28f610c..2ae13459 100644 --- a/mssql_python/pybind/ddbc_bindings.h +++ b/mssql_python/pybind/ddbc_bindings.h @@ -32,6 +32,14 @@ using namespace pybind11::literals; #include #include +#if defined(_WIN32) +inline std::vector WStringToSQLWCHAR(const std::wstring& str) { + std::vector result(str.begin(), str.end()); + result.push_back(0); + return result; +} +#endif + #if defined(__APPLE__) || defined(__linux__) #include @@ -203,6 +211,9 @@ typedef SQLRETURN (SQL_API* SQLFreeStmtFunc)(SQLHSTMT, SQLUSMALLINT); typedef SQLRETURN (SQL_API* SQLGetDiagRecFunc)(SQLSMALLINT, SQLHANDLE, SQLSMALLINT, SQLWCHAR*, SQLINTEGER*, SQLWCHAR*, SQLSMALLINT, SQLSMALLINT*); +// DAE APIs +typedef SQLRETURN (SQL_API* SQLParamDataFunc)(SQLHSTMT, SQLPOINTER*); +typedef SQLRETURN (SQL_API* SQLPutDataFunc)(SQLHSTMT, SQLPOINTER, SQLLEN); //------------------------------------------------------------------------------------------------- // Extern function pointer declarations (defined in ddbc_bindings.cpp) //------------------------------------------------------------------------------------------------- @@ -246,6 +257,10 @@ extern SQLFreeStmtFunc SQLFreeStmt_ptr; // Diagnostic APIs extern SQLGetDiagRecFunc SQLGetDiagRec_ptr; +// DAE APIs +extern SQLParamDataFunc SQLParamData_ptr; +extern SQLPutDataFunc SQLPutData_ptr; + // Logging utility template void LOG(const std::string& formatString, Args&&... args);