Skip to content

Commit

Permalink
protect nanobind state capsule using 'name' parameter
Browse files Browse the repository at this point in the history
  • Loading branch information
wjakob committed Oct 28, 2022
1 parent 620a4aa commit 42db2fd
Showing 1 changed file with 6 additions and 4 deletions.
10 changes: 6 additions & 4 deletions src/nb_internals.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -378,9 +378,10 @@ static void internals_make() {
if (!dict)
fail("nanobind::detail::internals_make(): PyInterpreterState_GetDict() failed!");

PyObject *capsule = PyCapsule_New(internals_p, nullptr, nullptr);
const char *internals_id = NB_INTERNALS_ID;
PyObject *capsule = PyCapsule_New(internals_p, internals_id, nullptr);
PyObject *nb_module = PyModule_NewObject(nb_name.ptr());
int rv = PyDict_SetItemString(dict, NB_INTERNALS_ID, capsule);
int rv = PyDict_SetItemString(dict, internals_id, capsule);
if (rv || !capsule || !nb_module)
fail("nanobind::detail::internals_make(): allocation failed!");
Py_DECREF(capsule);
Expand Down Expand Up @@ -488,10 +489,11 @@ static void internals_fetch() {
if (!dict)
fail("nanobind::detail::internals_fetch(): PyInterpreterState_GetDict() failed!");

PyObject *capsule = PyDict_GetItemString(dict, NB_INTERNALS_ID);
const char *internals_id = NB_INTERNALS_ID;
PyObject *capsule = PyDict_GetItemString(dict, internals_id);

if (capsule) {
internals_p = (nb_internals *) PyCapsule_GetPointer(capsule, nullptr);
internals_p = (nb_internals *) PyCapsule_GetPointer(capsule, internals_id);
if (!internals_p)
fail("nanobind::detail::internals_fetch(): capsule pointer is NULL!");
return;
Expand Down

0 comments on commit 42db2fd

Please sign in to comment.