Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: refcount bug involving trampoline functions with PyObject * return type. #5156

Merged
merged 9 commits into from
Jun 11, 2024
15 changes: 13 additions & 2 deletions include/pybind11/cast.h
Original file line number Diff line number Diff line change
Expand Up @@ -1339,13 +1339,24 @@ enable_if_t<!cast_is_temporary_value_reference<T>::value, T> cast_ref(object &&,
// static_assert, even though if it's in dead code, so we provide a "trampoline" to pybind11::cast
// that only does anything in cases where pybind11::cast is valid.
template <typename T>
enable_if_t<cast_is_temporary_value_reference<T>::value, T> cast_safe(object &&) {
enable_if_t<cast_is_temporary_value_reference<T>::value
&& !detail::is_same_ignoring_cvref<T, PyObject *>::value,
T>
cast_safe(object &&) {
pybind11_fail("Internal error: cast_safe fallback invoked");
}
template <typename T>
enable_if_t<std::is_void<T>::value, void> cast_safe(object &&) {}
template <typename T>
enable_if_t<detail::none_of<cast_is_temporary_value_reference<T>, std::is_void<T>>::value, T>
enable_if_t<detail::is_same_ignoring_cvref<T, PyObject *>::value, PyObject *>
cast_safe(object &&o) {
return o.release().ptr();
}
template <typename T>
enable_if_t<detail::none_of<cast_is_temporary_value_reference<T>,
detail::is_same_ignoring_cvref<T, PyObject *>,
std::is_void<T>>::value,
T>
cast_safe(object &&o) {
return pybind11::cast<T>(std::move(o));
}
Expand Down
6 changes: 5 additions & 1 deletion include/pybind11/pybind11.h
Original file line number Diff line number Diff line change
Expand Up @@ -2868,10 +2868,14 @@ function get_override(const T *this_ptr, const char *name) {
= pybind11::get_override(static_cast<const cname *>(this), name); \
if (override) { \
auto o = override(__VA_ARGS__); \
if (pybind11::detail::cast_is_temporary_value_reference<ret_type>::value) { \
PYBIND11_WARNING_PUSH \
PYBIND11_WARNING_DISABLE_MSVC(4127) \
if (pybind11::detail::cast_is_temporary_value_reference<ret_type>::value \
&& !pybind11::detail::is_same_ignoring_cvref<ret_type, PyObject *>::value) { \
static pybind11::detail::override_caster_t<ret_type> caster; \
return pybind11::detail::cast_ref<ret_type>(std::move(o), caster); \
} \
PYBIND11_WARNING_POP \
return pybind11::detail::cast_safe<ret_type>(std::move(o)); \
} \
} while (false)
Expand Down
41 changes: 39 additions & 2 deletions tests/test_type_caster_pyobject_ptr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@
#include "pybind11_tests.h"

#include <cstddef>
#include <string>
#include <vector>

namespace {
namespace test_type_caster_pyobject_ptr {

std::vector<PyObject *> make_vector_pyobject_ptr(const py::object &ValueHolder) {
std::vector<PyObject *> vec_obj;
Expand All @@ -18,9 +19,39 @@ std::vector<PyObject *> make_vector_pyobject_ptr(const py::object &ValueHolder)
return vec_obj;
}

} // namespace
struct WithPyObjectPtrReturn {
#if defined(__clang_major__) && __clang_major__ < 4
WithPyObjectPtrReturn() = default;
WithPyObjectPtrReturn(const WithPyObjectPtrReturn &) = default;
#endif
virtual ~WithPyObjectPtrReturn() = default;
virtual PyObject *return_pyobject_ptr() const = 0;
};

struct WithPyObjectPtrReturnTrampoline : WithPyObjectPtrReturn {
PyObject *return_pyobject_ptr() const override {
PYBIND11_OVERRIDE_PURE(PyObject *, WithPyObjectPtrReturn, return_pyobject_ptr,
/* no arguments */);
}
};

std::string call_return_pyobject_ptr(const WithPyObjectPtrReturn *base_class_ptr) {
PyObject *returned_obj = base_class_ptr->return_pyobject_ptr();
#if !defined(PYPY_VERSION) // It is not worth the trouble doing something special for PyPy.
if (Py_REFCNT(returned_obj) != 1) {
py::pybind11_fail(__FILE__ ":" PYBIND11_TOSTRING(__LINE__));
}
#endif
auto ret_val = py::repr(returned_obj).cast<std::string>();
Py_DECREF(returned_obj);
return ret_val;
}

} // namespace test_type_caster_pyobject_ptr

TEST_SUBMODULE(type_caster_pyobject_ptr, m) {
using namespace test_type_caster_pyobject_ptr;

m.def("cast_from_pyobject_ptr", []() {
PyObject *ptr = PyLong_FromLongLong(6758L);
return py::cast(ptr, py::return_value_policy::take_ownership);
Expand Down Expand Up @@ -127,4 +158,10 @@ TEST_SUBMODULE(type_caster_pyobject_ptr, m) {
(void) py::cast(*ptr);
}
#endif

py::class_<WithPyObjectPtrReturn, WithPyObjectPtrReturnTrampoline>(m, "WithPyObjectPtrReturn")
.def(py::init<>())
.def("return_pyobject_ptr", &WithPyObjectPtrReturn::return_pyobject_ptr);

m.def("call_return_pyobject_ptr", call_return_pyobject_ptr);
}
16 changes: 16 additions & 0 deletions tests/test_type_caster_pyobject_ptr.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,3 +102,19 @@ def test_return_list_pyobject_ptr_reference():
def test_type_caster_name_via_incompatible_function_arguments_type_error():
with pytest.raises(TypeError, match=r"1\. \(arg0: object, arg1: int\) -> None"):
m.pass_pyobject_ptr_and_int(ValueHolder(101), ValueHolder(202))


def test_trampoline_with_pyobject_ptr_return():
class Drvd(m.WithPyObjectPtrReturn):
def return_pyobject_ptr(self):
return ["11", "22", "33"]

# Basic health check: First make sure this works as expected.
d = Drvd()
assert d.return_pyobject_ptr() == ["11", "22", "33"]

while True:
# This failed before PR #5156: AddressSanitizer: heap-use-after-free ... in Py_DECREF
d_repr = m.call_return_pyobject_ptr(d)
assert d_repr == repr(["11", "22", "33"])
break # Comment out for manual leak checking.
Loading