Skip to content

Commit

Permalink
Attach python lifetime to shared_ptr passed to C++
Browse files Browse the repository at this point in the history
- Reference cycles are possible as a result, but shared_ptr is already susceptible to this in C++
  • Loading branch information
virtuald committed Feb 1, 2021
1 parent 721834b commit 9955bcc
Show file tree
Hide file tree
Showing 3 changed files with 104 additions and 0 deletions.
31 changes: 31 additions & 0 deletions include/pybind11/cast.h
Original file line number Diff line number Diff line change
Expand Up @@ -1563,6 +1563,9 @@ struct copyable_holder_caster : public type_caster_base<type> {
throw cast_error("Unable to load a custom holder type from a default-holder instance");
}

// holders that are not std::shared_ptr
template <typename T = holder_type,
detail::enable_if_t<!is_shared_ptr<T>::value, int> = 0>
bool load_value(value_and_holder &&v_h) {
if (v_h.holder_constructed()) {
value = v_h.value_ptr();
Expand All @@ -1578,6 +1581,34 @@ struct copyable_holder_caster : public type_caster_base<type> {
}
}

// holders that are std::shared_ptr
template <typename T = holder_type,
detail::enable_if_t<is_shared_ptr<T>::value, int> = 0>
bool load_value(value_and_holder &&v_h) {
if (v_h.holder_constructed()) {
value = v_h.value_ptr();

// The shared_ptr is always given to C++ code, so we construct a new shared_ptr
// that is given a custom deleter. The custom deleter increments the python
// reference count to bind the python instance lifetime with the lifetime
// of the shared_ptr.
//
// This enables things like passing the last python reference of a subclass to a
// C++ function without the python reference dying.
//
// Reference cycles will cause a leak, but this is a limitation of shared_ptr
holder = holder_type((type*)value, [unused = reinterpret_borrow<object>((PyObject*)v_h.inst)](type*){});
return true;
} else {
throw cast_error("Unable to cast from non-held to held instance (T& to Holder<T>) "
#if defined(NDEBUG)
"(compile in debug mode for type information)");
#else
"of type '" + type_id<holder_type>() + "''");
#endif
}
}

template <typename T = holder_type, detail::enable_if_t<!std::is_constructible<T, const T &, type*>::value, int> = 0>
bool try_implicit_casts(handle, bool) { return false; }

Expand Down
29 changes: 29 additions & 0 deletions tests/test_smart_ptr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -397,4 +397,33 @@ TEST_SUBMODULE(smart_ptr, m) {
list.append(py::cast(e));
return list;
});

// For testing whether a python subclass of a C++ object dies when the
// last python reference is lost
struct SpBase {
// returns true if the base virtual function is called
virtual bool is_base_used() { return true; }
virtual ~SpBase() = default;
};

struct PySpBase : SpBase {
bool is_base_used() override { PYBIND11_OVERRIDE(bool, SpBase, is_base_used); }
};

struct SpBaseTester {
std::shared_ptr<SpBase> get_object() { return m_obj; }
void set_object(std::shared_ptr<SpBase> obj) { m_obj = obj; }
bool is_base_used() { return m_obj->is_base_used(); }
std::shared_ptr<SpBase> m_obj;
};

py::class_<SpBase, std::shared_ptr<SpBase>, PySpBase>(m, "SpBase")
.def(py::init<>())
.def("is_base_used", &SpBase::is_base_used);

py::class_<SpBaseTester>(m, "SpBaseTester")
.def(py::init<>())
.def("get_object", &SpBaseTester::get_object)
.def("set_object", &SpBaseTester::set_object)
.def("is_base_used", &SpBaseTester::is_base_used);
}
44 changes: 44 additions & 0 deletions tests/test_smart_ptr.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,3 +316,47 @@ def test_shared_ptr_gc():
pytest.gc_collect()
for i, v in enumerate(el.get()):
assert i == v.value()


def test_shared_ptr_cpp_arg():
import weakref

class PyChild(m.SpBase):
def is_base_used(self):
return False

tester = m.SpBaseTester()

obj = PyChild()
objref = weakref.ref(obj)

tester.set_object(obj)
del obj
pytest.gc_collect()

# python reference is still around since C++ has it now
assert objref() is not None
assert tester.is_base_used() == False
assert tester.get_object() is objref()


def test_shared_ptr_arg_identity():
import weakref

tester = m.SpBaseTester()

obj = m.SpBase()
objref = weakref.ref(obj)

tester.set_object(obj)
del obj
pytest.gc_collect()

# python reference is still around since C++ has it
assert objref() is not None
assert tester.get_object() is objref()

# python reference disappears once the C++ object releases it
tester.set_object(None)
pytest.gc_collect()
assert objref() is None

0 comments on commit 9955bcc

Please sign in to comment.