diff --git a/Lib/sqlite3/test/hooks.py b/Lib/sqlite3/test/hooks.py index 8c60bdcf5d70aa..97121eea1be5fb 100644 --- a/Lib/sqlite3/test/hooks.py +++ b/Lib/sqlite3/test/hooks.py @@ -20,6 +20,7 @@ # misrepresented as being the original software. # 3. This notice may not be removed or altered from any source distribution. +import contextlib import unittest import sqlite3 as sqlite @@ -200,6 +201,16 @@ def progress(): self.assertEqual(action, 0, "progress handler was not cleared") class TraceCallbackTests(unittest.TestCase): + @contextlib.contextmanager + def check_stmt_trace(self, cx, expected): + try: + traced = [] + cx.set_trace_callback(lambda stmt: traced.append(stmt)) + yield + finally: + self.assertEqual(traced, expected) + cx.set_trace_callback(None) + def test_trace_callback_used(self): """ Test that the trace callback is invoked once it is set. @@ -261,6 +272,21 @@ def trace(statement): cur.execute(queries[1]) self.assertEqual(traced_statements, queries) + def test_trace_expanded_sql(self): + expected = [ + "create table t(t)", + "BEGIN ", + "insert into t values(0)", + "insert into t values(1)", + "insert into t values(2)", + "COMMIT", + ] + cx = sqlite.connect(":memory:") + with self.check_stmt_trace(cx, expected): + with cx: + cx.execute("create table t(t)") + cx.executemany("insert into t values(?)", ((v,) for v in range(3))) + def suite(): tests = [ diff --git a/Misc/NEWS.d/next/Library/2021-09-08-16-21-03.bpo-45138.yghUrK.rst b/Misc/NEWS.d/next/Library/2021-09-08-16-21-03.bpo-45138.yghUrK.rst new file mode 100644 index 00000000000000..906ed4c4db43c8 --- /dev/null +++ b/Misc/NEWS.d/next/Library/2021-09-08-16-21-03.bpo-45138.yghUrK.rst @@ -0,0 +1,3 @@ +Fix a regression in the :mod:`sqlite3` trace callback where bound parameters +were not expanded in the passed statement string. The regression was introduced +in Python 3.10 by :issue:`40318`. Patch by Erlend E. Aasland. diff --git a/Modules/_sqlite/connection.c b/Modules/_sqlite/connection.c index c9c10b41398e26..68c5aee79ab15f 100644 --- a/Modules/_sqlite/connection.c +++ b/Modules/_sqlite/connection.c @@ -1050,33 +1050,65 @@ static int _progress_handler(void* user_arg) * may change in future releases. Callback implementations should return zero * to ensure future compatibility. */ -static int _trace_callback(unsigned int type, void* user_arg, void* prepared_statement, void* statement_string) +static int +_trace_callback(unsigned int type, void *callable, void *stmt, void *sql) #else -static void _trace_callback(void* user_arg, const char* statement_string) +static void +_trace_callback(void *callable, const char *sql) #endif { - PyObject *py_statement = NULL; - PyObject *ret = NULL; - - PyGILState_STATE gilstate; - #ifdef HAVE_TRACE_V2 if (type != SQLITE_TRACE_STMT) { return 0; } #endif - gilstate = PyGILState_Ensure(); - py_statement = PyUnicode_DecodeUTF8(statement_string, - strlen(statement_string), "replace"); + PyGILState_STATE gilstate = PyGILState_Ensure(); + + PyObject *py_statement = NULL; +#ifdef HAVE_TRACE_V2 + const char *expanded_sql = sqlite3_expanded_sql((sqlite3_stmt *)stmt); + if (expanded_sql == NULL) { + sqlite3 *db = sqlite3_db_handle((sqlite3_stmt *)stmt); + if (sqlite3_errcode(db) == SQLITE_NOMEM) { + (void)PyErr_NoMemory(); + goto exit; + } + + PyErr_SetString(pysqlite_DataError, + "Expanded SQL string exceeds the maximum string length"); + if (_pysqlite_enable_callback_tracebacks) { + PyErr_Print(); + } else { + PyErr_Clear(); + } + + // Fall back to unexpanded sql + py_statement = PyUnicode_FromString((const char *)sql); + } + else { + py_statement = PyUnicode_FromString(expanded_sql); + sqlite3_free((void *)expanded_sql); + } +#else + if (sql == NULL) { + PyErr_SetString(pysqlite_DataError, + "Expanded SQL string exceeds the maximum string length"); + if (_pysqlite_enable_callback_tracebacks) { + PyErr_Print(); + } else { + PyErr_Clear(); + } + goto exit; + } + py_statement = PyUnicode_FromString(sql); +#endif if (py_statement) { - ret = PyObject_CallOneArg((PyObject*)user_arg, py_statement); + PyObject *ret = PyObject_CallOneArg((PyObject *)callable, py_statement); Py_DECREF(py_statement); + Py_XDECREF(ret); } - - if (ret) { - Py_DECREF(ret); - } else { + if (PyErr_Occurred()) { if (_pysqlite_enable_callback_tracebacks) { PyErr_Print(); } else { @@ -1084,6 +1116,7 @@ static void _trace_callback(void* user_arg, const char* statement_string) } } +exit: PyGILState_Release(gilstate); #ifdef HAVE_TRACE_V2 return 0;