diff --git a/mssql_python/cursor.py b/mssql_python/cursor.py index 5bdeaed9..88152aa2 100644 --- a/mssql_python/cursor.py +++ b/mssql_python/cursor.py @@ -231,7 +231,7 @@ def _map_sql_type(self, param, parameters_list, i): """ if param is None: return ( - ddbc_sql_const.SQL_VARCHAR.value, # TODO: Add SQLDescribeParam to get correct type + ddbc_sql_const.SQL_VARCHAR.value, ddbc_sql_const.SQL_C_DEFAULT.value, 1, 0, diff --git a/mssql_python/pybind/ddbc_bindings.cpp b/mssql_python/pybind/ddbc_bindings.cpp index e78dea3f..69df4d49 100644 --- a/mssql_python/pybind/ddbc_bindings.cpp +++ b/mssql_python/pybind/ddbc_bindings.cpp @@ -140,6 +140,8 @@ SQLParamDataFunc SQLParamData_ptr = nullptr; SQLPutDataFunc SQLPutData_ptr = nullptr; SQLTablesFunc SQLTables_ptr = nullptr; +SQLDescribeParamFunc SQLDescribeParam_ptr = nullptr; + namespace { const char* GetSqlCTypeAsString(const SQLSMALLINT cType) { @@ -212,12 +214,12 @@ std::string DescribeChar(unsigned char ch) { // Given a list of parameters and their ParamInfo, calls SQLBindParameter on each of them with // appropriate arguments SQLRETURN BindParameters(SQLHANDLE hStmt, const py::list& params, - const std::vector& paramInfos, + std::vector& paramInfos, std::vector>& paramBuffers) { LOG("Starting parameter binding. Number of parameters: {}", params.size()); for (int paramIndex = 0; paramIndex < params.size(); paramIndex++) { const auto& param = params[paramIndex]; - const ParamInfo& paramInfo = paramInfos[paramIndex]; + ParamInfo& paramInfo = paramInfos[paramIndex]; LOG("Binding parameter {} - C Type: {}, SQL Type: {}", paramIndex, paramInfo.paramCType, paramInfo.paramSQLType); void* dataPtr = nullptr; SQLLEN bufferLength = 0; @@ -283,11 +285,37 @@ SQLRETURN BindParameters(SQLHANDLE hStmt, const py::list& params, if (!py::isinstance(param)) { ThrowStdException(MakeParamMismatchErrorStr(paramInfo.paramCType, paramIndex)); } - // TODO: This wont work for None values added to BINARY/VARBINARY columns. None values - // of binary columns need to have C type = SQL_C_BINARY & SQL type = SQL_BINARY + SQLSMALLINT sqlType = paramInfo.paramSQLType; + SQLULEN columnSize = paramInfo.columnSize; + SQLSMALLINT decimalDigits = paramInfo.decimalDigits; + if (sqlType == SQL_UNKNOWN_TYPE) { + SQLSMALLINT describedType; + SQLULEN describedSize; + SQLSMALLINT describedDigits; + SQLSMALLINT nullable; + RETCODE rc = SQLDescribeParam_ptr( + hStmt, + static_cast(paramIndex + 1), + &describedType, + &describedSize, + &describedDigits, + &nullable + ); + if (!SQL_SUCCEEDED(rc)) { + LOG("SQLDescribeParam failed for parameter {} with error code {}", paramIndex, rc); + return rc; + } + sqlType = describedType; + columnSize = describedSize; + decimalDigits = describedDigits; + } dataPtr = nullptr; strLenOrIndPtr = AllocateParamBuffer(paramBuffers); *strLenOrIndPtr = SQL_NULL_DATA; + bufferLength = 0; + paramInfo.paramSQLType = sqlType; + paramInfo.columnSize = columnSize; + paramInfo.decimalDigits = decimalDigits; break; } case SQL_C_STINYINT: @@ -767,6 +795,8 @@ DriverHandle LoadDriverOrThrowException() { SQLPutData_ptr = GetFunctionPointer(handle, "SQLPutData"); SQLTables_ptr = GetFunctionPointer(handle, "SQLTablesW"); + SQLDescribeParam_ptr = GetFunctionPointer(handle, "SQLDescribeParam"); + bool success = SQLAllocHandle_ptr && SQLSetEnvAttr_ptr && SQLSetConnectAttr_ptr && SQLSetStmtAttr_ptr && SQLGetConnectAttr_ptr && SQLDriverConnect_ptr && @@ -777,7 +807,8 @@ DriverHandle LoadDriverOrThrowException() { SQLDescribeCol_ptr && SQLMoreResults_ptr && SQLColAttribute_ptr && SQLEndTran_ptr && SQLDisconnect_ptr && SQLFreeHandle_ptr && SQLFreeStmt_ptr && SQLGetDiagRec_ptr && SQLParamData_ptr && - SQLPutData_ptr && SQLTables_ptr; + SQLPutData_ptr && SQLTables_ptr && + SQLDescribeParam_ptr; if (!success) { ThrowStdException("Failed to load required function pointers from driver."); @@ -1072,7 +1103,7 @@ SQLRETURN SQLTables_wrap(SqlHandlePtr StatementHandle, // be prepared in a previous call. SQLRETURN SQLExecute_wrap(const SqlHandlePtr statementHandle, const std::wstring& query /* TODO: Use SQLTCHAR? */, - const py::list& params, const std::vector& paramInfos, + const py::list& params, std::vector& paramInfos, py::list& isStmtPrepared, const bool usePrepare = true) { LOG("Execute SQL Query - {}", query.c_str()); if (!SQLPrepare_ptr) { diff --git a/mssql_python/pybind/ddbc_bindings.h b/mssql_python/pybind/ddbc_bindings.h index 2ae13459..521a007b 100644 --- a/mssql_python/pybind/ddbc_bindings.h +++ b/mssql_python/pybind/ddbc_bindings.h @@ -211,6 +211,8 @@ typedef SQLRETURN (SQL_API* SQLFreeStmtFunc)(SQLHSTMT, SQLUSMALLINT); typedef SQLRETURN (SQL_API* SQLGetDiagRecFunc)(SQLSMALLINT, SQLHANDLE, SQLSMALLINT, SQLWCHAR*, SQLINTEGER*, SQLWCHAR*, SQLSMALLINT, SQLSMALLINT*); +typedef SQLRETURN (SQL_API* SQLDescribeParamFunc)(SQLHSTMT, SQLUSMALLINT, SQLSMALLINT*, SQLULEN*, SQLSMALLINT*, SQLSMALLINT*); + // DAE APIs typedef SQLRETURN (SQL_API* SQLParamDataFunc)(SQLHSTMT, SQLPOINTER*); typedef SQLRETURN (SQL_API* SQLPutDataFunc)(SQLHSTMT, SQLPOINTER, SQLLEN); @@ -257,6 +259,8 @@ extern SQLFreeStmtFunc SQLFreeStmt_ptr; // Diagnostic APIs extern SQLGetDiagRecFunc SQLGetDiagRec_ptr; +extern SQLDescribeParamFunc SQLDescribeParam_ptr; + // DAE APIs extern SQLParamDataFunc SQLParamData_ptr; extern SQLPutDataFunc SQLPutData_ptr;