Skip to content

Commit 3cb7c25

Browse files
EthanSteinbergpre-commit-ci[bot]Skylion007
authored andcommitted
Fix functional.h bug + introduce test to verify that it is fixed (#4254)
* Illustrate bug in functional.h * style: pre-commit fixes * Make functional casting more robust / add workaround * Make function_record* casting even more robust * See if this fixes PyPy issue * It still fails on PyPy sadly * Do not make new CTOR just yet * Fix test * Add name to ensure correctness * style: pre-commit fixes * Clean up tests + remove ifdef guards * Add comments * Improve comments, error handling, and safety * Fix compile error * Fix magic logic * Extract helper function * Fix func signature * move to local internals * style: pre-commit fixes * Switch to simpler design * style: pre-commit fixes * Move to function_record * style: pre-commit fixes * Switch to internals, update tests and docs * Fix lint * Oops, forgot to resolve last comment * Fix typo * Update in response to comments * Implement suggestion to improve test * Update comment * Simple fixes Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Aaron Gokaslan <aaronGokaslan@gmail.com>
1 parent 5b092ed commit 3cb7c25

File tree

5 files changed

+124
-12
lines changed

5 files changed

+124
-12
lines changed

include/pybind11/detail/internals.h

+31
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,8 @@ using ExceptionTranslator = void (*)(std::exception_ptr);
4444

4545
PYBIND11_NAMESPACE_BEGIN(detail)
4646

47+
constexpr const char *internals_function_record_capsule_name = "pybind11_function_record_capsule";
48+
4749
// Forward declarations
4850
inline PyTypeObject *make_static_property_type();
4951
inline PyTypeObject *make_default_metaclass();
@@ -183,6 +185,16 @@ struct internals {
183185
# endif // PYBIND11_INTERNALS_VERSION > 4
184186
// Unused if PYBIND11_SIMPLE_GIL_MANAGEMENT is defined:
185187
PyInterpreterState *istate = nullptr;
188+
189+
# if PYBIND11_INTERNALS_VERSION > 4
190+
// Note that we have to use a std::string to allocate memory to ensure a unique address
191+
// We want unique addresses since we use pointer equality to compare function records
192+
std::string function_record_capsule_name = internals_function_record_capsule_name;
193+
# endif
194+
195+
internals() = default;
196+
internals(const internals &other) = delete;
197+
internals &operator=(const internals &other) = delete;
186198
~internals() {
187199
# if PYBIND11_INTERNALS_VERSION > 4
188200
PYBIND11_TLS_FREE(loader_life_support_tls_key);
@@ -559,6 +571,25 @@ const char *c_str(Args &&...args) {
559571
return strings.front().c_str();
560572
}
561573

574+
inline const char *get_function_record_capsule_name() {
575+
#if PYBIND11_INTERNALS_VERSION > 4
576+
return get_internals().function_record_capsule_name.c_str();
577+
#else
578+
return nullptr;
579+
#endif
580+
}
581+
582+
// Determine whether or not the following capsule contains a pybind11 function record.
583+
// Note that we use `internals` to make sure that only ABI compatible records are touched.
584+
//
585+
// This check is currently used in two places:
586+
// - An important optimization in functional.h to avoid overhead in C++ -> Python -> C++
587+
// - The sibling feature of cpp_function to allow overloads
588+
inline bool is_function_record_capsule(const capsule &cap) {
589+
// Pointer equality as we rely on internals() to ensure unique pointers
590+
return cap.name() == get_function_record_capsule_name();
591+
}
592+
562593
PYBIND11_NAMESPACE_END(detail)
563594

564595
/// Returns a named pointer that is shared among all extension modules (using the same

include/pybind11/functional.h

+9-2
Original file line numberDiff line numberDiff line change
@@ -48,9 +48,16 @@ struct type_caster<std::function<Return(Args...)>> {
4848
*/
4949
if (auto cfunc = func.cpp_function()) {
5050
auto *cfunc_self = PyCFunction_GET_SELF(cfunc.ptr());
51-
if (isinstance<capsule>(cfunc_self)) {
51+
if (cfunc_self == nullptr) {
52+
PyErr_Clear();
53+
} else if (isinstance<capsule>(cfunc_self)) {
5254
auto c = reinterpret_borrow<capsule>(cfunc_self);
53-
auto *rec = (function_record *) c;
55+
56+
function_record *rec = nullptr;
57+
// Check that we can safely reinterpret the capsule into a function_record
58+
if (detail::is_function_record_capsule(c)) {
59+
rec = c.get_pointer<function_record>();
60+
}
5461

5562
while (rec != nullptr) {
5663
if (rec->is_stateless

include/pybind11/pybind11.h

+34-10
Original file line numberDiff line numberDiff line change
@@ -469,13 +469,20 @@ class cpp_function : public function {
469469
if (rec->sibling) {
470470
if (PyCFunction_Check(rec->sibling.ptr())) {
471471
auto *self = PyCFunction_GET_SELF(rec->sibling.ptr());
472-
capsule rec_capsule = isinstance<capsule>(self) ? reinterpret_borrow<capsule>(self)
473-
: capsule(self);
474-
chain = (detail::function_record *) rec_capsule;
475-
/* Never append a method to an overload chain of a parent class;
476-
instead, hide the parent's overloads in this case */
477-
if (!chain->scope.is(rec->scope)) {
472+
if (!isinstance<capsule>(self)) {
478473
chain = nullptr;
474+
} else {
475+
auto rec_capsule = reinterpret_borrow<capsule>(self);
476+
if (detail::is_function_record_capsule(rec_capsule)) {
477+
chain = rec_capsule.get_pointer<detail::function_record>();
478+
/* Never append a method to an overload chain of a parent class;
479+
instead, hide the parent's overloads in this case */
480+
if (!chain->scope.is(rec->scope)) {
481+
chain = nullptr;
482+
}
483+
} else {
484+
chain = nullptr;
485+
}
479486
}
480487
}
481488
// Don't trigger for things like the default __init__, which are wrapper_descriptors
@@ -497,6 +504,7 @@ class cpp_function : public function {
497504

498505
capsule rec_capsule(unique_rec.release(),
499506
[](void *ptr) { destruct((detail::function_record *) ptr); });
507+
rec_capsule.set_name(detail::get_function_record_capsule_name());
500508
guarded_strdup.release();
501509

502510
object scope_module;
@@ -662,10 +670,13 @@ class cpp_function : public function {
662670
/// Main dispatch logic for calls to functions bound using pybind11
663671
static PyObject *dispatcher(PyObject *self, PyObject *args_in, PyObject *kwargs_in) {
664672
using namespace detail;
673+
assert(isinstance<capsule>(self));
665674

666675
/* Iterator over the list of potentially admissible overloads */
667-
const function_record *overloads = (function_record *) PyCapsule_GetPointer(self, nullptr),
676+
const function_record *overloads = reinterpret_cast<function_record *>(
677+
PyCapsule_GetPointer(self, get_function_record_capsule_name())),
668678
*it = overloads;
679+
assert(overloads != nullptr);
669680

670681
/* Need to know how many arguments + keyword arguments there are to pick the right
671682
overload */
@@ -2126,9 +2137,22 @@ class class_ : public detail::generic_type {
21262137

21272138
static detail::function_record *get_function_record(handle h) {
21282139
h = detail::get_function(h);
2129-
return h ? (detail::function_record *) reinterpret_borrow<capsule>(
2130-
PyCFunction_GET_SELF(h.ptr()))
2131-
: nullptr;
2140+
if (!h) {
2141+
return nullptr;
2142+
}
2143+
2144+
handle func_self = PyCFunction_GET_SELF(h.ptr());
2145+
if (!func_self) {
2146+
throw error_already_set();
2147+
}
2148+
if (!isinstance<capsule>(func_self)) {
2149+
return nullptr;
2150+
}
2151+
auto cap = reinterpret_borrow<capsule>(func_self);
2152+
if (!detail::is_function_record_capsule(cap)) {
2153+
return nullptr;
2154+
}
2155+
return cap.get_pointer<detail::function_record>();
21322156
}
21332157
};
21342158

tests/test_callbacks.cpp

+37
Original file line numberDiff line numberDiff line change
@@ -240,4 +240,41 @@ TEST_SUBMODULE(callbacks, m) {
240240
f();
241241
}
242242
});
243+
244+
auto *custom_def = []() {
245+
static PyMethodDef def;
246+
def.ml_name = "example_name";
247+
def.ml_doc = "Example doc";
248+
def.ml_meth = [](PyObject *, PyObject *args) -> PyObject * {
249+
if (PyTuple_Size(args) != 1) {
250+
throw std::runtime_error("Invalid number of arguments for example_name");
251+
}
252+
PyObject *first = PyTuple_GetItem(args, 0);
253+
if (!PyLong_Check(first)) {
254+
throw std::runtime_error("Invalid argument to example_name");
255+
}
256+
auto result = py::cast(PyLong_AsLong(first) * 9);
257+
return result.release().ptr();
258+
};
259+
def.ml_flags = METH_VARARGS;
260+
return &def;
261+
}();
262+
263+
// rec_capsule with name that has the same value (but not pointer) as our internal one
264+
// This capsule should be detected by our code as foreign and not inspected as the pointers
265+
// shouldn't match
266+
constexpr const char *rec_capsule_name
267+
= pybind11::detail::internals_function_record_capsule_name;
268+
py::capsule rec_capsule(std::malloc(1), [](void *data) { std::free(data); });
269+
rec_capsule.set_name(rec_capsule_name);
270+
m.add_object("custom_function", PyCFunction_New(custom_def, rec_capsule.ptr()));
271+
272+
// This test requires a new ABI version to pass
273+
#if PYBIND11_INTERNALS_VERSION > 4
274+
// rec_capsule with nullptr name
275+
py::capsule rec_capsule2(std::malloc(1), [](void *data) { std::free(data); });
276+
m.add_object("custom_function2", PyCFunction_New(custom_def, rec_capsule2.ptr()));
277+
#else
278+
m.add_object("custom_function2", py::none());
279+
#endif
243280
}

tests/test_callbacks.py

+13
Original file line numberDiff line numberDiff line change
@@ -193,3 +193,16 @@ def test_callback_num_times():
193193
if len(rates) > 1:
194194
print("Min Mean Max")
195195
print(f"{min(rates):6.3f} {sum(rates) / len(rates):6.3f} {max(rates):6.3f}")
196+
197+
198+
def test_custom_func():
199+
assert m.custom_function(4) == 36
200+
assert m.roundtrip(m.custom_function)(4) == 36
201+
202+
203+
@pytest.mark.skipif(
204+
m.custom_function2 is None, reason="Current PYBIND11_INTERNALS_VERSION too low"
205+
)
206+
def test_custom_func2():
207+
assert m.custom_function2(3) == 27
208+
assert m.roundtrip(m.custom_function2)(3) == 27

0 commit comments

Comments
 (0)