Skip to content

bpo-42064: Pass module state to sqlite3 UDF callbacks #27456

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 58 additions & 31 deletions Modules/_sqlite/connection.c
Original file line number Diff line number Diff line change
Expand Up @@ -612,8 +612,10 @@ set_sqlite_error(sqlite3_context *context, const char *msg)
else {
sqlite3_result_error(context, msg, -1);
}
pysqlite_state *state = pysqlite_get_state(NULL);
if (state->enable_callback_tracebacks) {
callback_context *ctx = (callback_context *)sqlite3_user_data(context);
assert(ctx != NULL);
assert(ctx->state != NULL);
if (ctx->state->enable_callback_tracebacks) {
PyErr_Print();
}
else {
Expand All @@ -625,19 +627,18 @@ static void
_pysqlite_func_callback(sqlite3_context *context, int argc, sqlite3_value **argv)
{
PyObject* args;
PyObject* py_func;
PyObject* py_retval = NULL;
int ok;

PyGILState_STATE threadstate;

threadstate = PyGILState_Ensure();

py_func = (PyObject*)sqlite3_user_data(context);

args = _pysqlite_build_py_params(context, argc, argv);
if (args) {
py_retval = PyObject_CallObject(py_func, args);
callback_context *ctx = (callback_context *)sqlite3_user_data(context);
assert(ctx != NULL);
py_retval = PyObject_CallObject(ctx->callable, args);
Py_DECREF(args);
}

Expand All @@ -657,20 +658,19 @@ static void _pysqlite_step_callback(sqlite3_context *context, int argc, sqlite3_
{
PyObject* args;
PyObject* function_result = NULL;
PyObject* aggregate_class;
PyObject** aggregate_instance;
PyObject* stepmethod = NULL;

PyGILState_STATE threadstate;

threadstate = PyGILState_Ensure();

aggregate_class = (PyObject*)sqlite3_user_data(context);

aggregate_instance = (PyObject**)sqlite3_aggregate_context(context, sizeof(PyObject*));

if (*aggregate_instance == NULL) {
*aggregate_instance = _PyObject_CallNoArg(aggregate_class);
callback_context *ctx = (callback_context *)sqlite3_user_data(context);
assert(ctx != NULL);
*aggregate_instance = _PyObject_CallNoArg(ctx->callable);
if (!*aggregate_instance) {
set_sqlite_error(context,
"user-defined aggregate's '__init__' method raised error");
Expand Down Expand Up @@ -784,14 +784,35 @@ static void _pysqlite_drop_unused_cursor_references(pysqlite_Connection* self)
Py_SETREF(self->cursors, new_list);
}

static void _destructor(void* args)
static callback_context *
create_callback_context(pysqlite_state *state, PyObject *callable)
{
// This function may be called without the GIL held, so we need to ensure
// that we destroy 'args' with the GIL
PyGILState_STATE gstate;
gstate = PyGILState_Ensure();
Py_DECREF((PyObject*)args);
PyGILState_STATE gstate = PyGILState_Ensure();
callback_context *ctx = PyMem_Malloc(sizeof(callback_context));
if (ctx != NULL) {
ctx->callable = Py_NewRef(callable);
ctx->state = state;
}
PyGILState_Release(gstate);
return ctx;
}

static void
free_callback_context(callback_context *ctx)
{
if (ctx != NULL) {
// This function may be called without the GIL held, so we need to
// ensure that we destroy 'ctx' with the GIL held.
PyGILState_STATE gstate = PyGILState_Ensure();
Py_DECREF(ctx->callable);
PyMem_Free(ctx);
PyGILState_Release(gstate);
}
}

static void _destructor(void* args)
{
free_callback_context((callback_context *)args);
}

/*[clinic input]
Expand Down Expand Up @@ -833,11 +854,11 @@ pysqlite_connection_create_function_impl(pysqlite_Connection *self,
flags |= SQLITE_DETERMINISTIC;
#endif
}
rc = sqlite3_create_function_v2(self->db,
name,
narg,
flags,
(void*)Py_NewRef(func),
callback_context *ctx = create_callback_context(self->state, func);
if (ctx == NULL) {
return NULL;
}
rc = sqlite3_create_function_v2(self->db, name, narg, flags, ctx,
_pysqlite_func_callback,
NULL,
NULL,
Expand Down Expand Up @@ -873,11 +894,12 @@ pysqlite_connection_create_aggregate_impl(pysqlite_Connection *self,
return NULL;
}

rc = sqlite3_create_function_v2(self->db,
name,
n_arg,
SQLITE_UTF8,
(void*)Py_NewRef(aggregate_class),
callback_context *ctx = create_callback_context(self->state,
aggregate_class);
if (ctx == NULL) {
return NULL;
}
rc = sqlite3_create_function_v2(self->db, name, n_arg, SQLITE_UTF8, ctx,
0,
&_pysqlite_step_callback,
&_pysqlite_final_callback,
Expand Down Expand Up @@ -1439,7 +1461,6 @@ pysqlite_collation_callback(
int text1_length, const void* text1_data,
int text2_length, const void* text2_data)
{
PyObject* callback = (PyObject*)context;
PyObject* string1 = 0;
PyObject* string2 = 0;
PyGILState_STATE gilstate;
Expand All @@ -1459,8 +1480,10 @@ pysqlite_collation_callback(
goto finally; /* failed to allocate strings */
}

callback_context *ctx = (callback_context *)context;
assert(ctx != NULL);
PyObject *args[] = { string1, string2 }; // Borrowed refs.
retval = PyObject_Vectorcall(callback, args, 2, NULL);
retval = PyObject_Vectorcall(ctx->callable, args, 2, NULL);
if (retval == NULL) {
/* execution failed */
goto finally;
Expand Down Expand Up @@ -1690,6 +1713,7 @@ pysqlite_connection_create_collation_impl(pysqlite_Connection *self,
return NULL;
}

callback_context *ctx = NULL;
int rc;
int flags = SQLITE_UTF8;
if (callable == Py_None) {
Expand All @@ -1701,8 +1725,11 @@ pysqlite_connection_create_collation_impl(pysqlite_Connection *self,
PyErr_SetString(PyExc_TypeError, "parameter must be callable");
return NULL;
}
rc = sqlite3_create_collation_v2(self->db, name, flags,
Py_NewRef(callable),
ctx = create_callback_context(self->state, callable);
if (ctx == NULL) {
return NULL;
}
rc = sqlite3_create_collation_v2(self->db, name, flags, ctx,
&pysqlite_collation_callback,
&_destructor);
}
Expand All @@ -1713,7 +1740,7 @@ pysqlite_connection_create_collation_impl(pysqlite_Connection *self,
* the context before returning.
*/
if (callable != Py_None) {
Py_DECREF(callable);
free_callback_context(ctx);
}
_pysqlite_seterror(self->state, self->db);
return NULL;
Expand Down
6 changes: 6 additions & 0 deletions Modules/_sqlite/connection.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,12 @@

#include "sqlite3.h"

typedef struct _callback_context
{
PyObject *callable;
pysqlite_state *state;
} callback_context;

typedef struct
{
PyObject_HEAD
Expand Down