Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[smart_holder] Fix handling of const unique_ptr<T, D> & (do not disown). #5332

Merged
merged 15 commits into from
Aug 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 32 additions & 2 deletions include/pybind11/cast.h
Original file line number Diff line number Diff line change
Expand Up @@ -1067,8 +1067,14 @@ struct move_only_holder_caster<
+ clean_type_id(typeinfo->cpptype->name()) + ")");
}

template <typename>
using cast_op_type = std::unique_ptr<type, deleter>;
template <typename T_>
using cast_op_type
= conditional_t<std::is_same<typename std::remove_volatile<T_>::type,
const std::unique_ptr<type, deleter> &>::value
|| std::is_same<typename std::remove_volatile<T_>::type,
const std::unique_ptr<const type, deleter> &>::value,
const std::unique_ptr<type, deleter> &,
std::unique_ptr<type, deleter>>;

explicit operator std::unique_ptr<type, deleter>() {
if (typeinfo->holder_enum_v == detail::holder_enum_t::smart_holder) {
Expand All @@ -1077,6 +1083,28 @@ struct move_only_holder_caster<
pybind11_fail("Expected to be UNREACHABLE: " __FILE__ ":" PYBIND11_TOSTRING(__LINE__));
}

explicit operator const std::unique_ptr<type, deleter> &() {
if (typeinfo->holder_enum_v == detail::holder_enum_t::smart_holder) {
// Get shared_ptr to ensure that the Python object is not disowned elsewhere.
shared_ptr_storage = sh_load_helper.load_as_shared_ptr(value);
// Build a temporary unique_ptr that is meant to never expire.
unique_ptr_storage = std::shared_ptr<std::unique_ptr<type, deleter>>(
new std::unique_ptr<type, deleter>{
sh_load_helper.template load_as_const_unique_ptr<deleter>(
shared_ptr_storage.get())},
[](std::unique_ptr<type, deleter> *ptr) {
if (!ptr) {
pybind11_fail("FATAL: `const std::unique_ptr<T, D> &` was disowned "
"(EXPECT UNDEFINED BEHAVIOR).");
}
(void) ptr->release();
delete ptr;
});
return *unique_ptr_storage;
}
pybind11_fail("Expected to be UNREACHABLE: " __FILE__ ":" PYBIND11_TOSTRING(__LINE__));
}

bool try_implicit_casts(handle src, bool convert) {
for (auto &cast : typeinfo->implicit_casts) {
move_only_holder_caster sub_caster(*cast.first);
Expand All @@ -1097,6 +1125,8 @@ struct move_only_holder_caster<
static bool try_direct_conversions(handle) { return false; }

smart_holder_type_caster_support::load_helper<remove_cv_t<type>> sh_load_helper; // Const2Mutbl
std::shared_ptr<type> shared_ptr_storage; // Serves as a pseudo lock.
std::shared_ptr<std::unique_ptr<type, deleter>> unique_ptr_storage;
};

#endif // PYBIND11_HAS_INTERNALS_WITH_SMART_HOLDER_SUPPORT
Expand Down
6 changes: 4 additions & 2 deletions include/pybind11/detail/struct_smart_holder.h
Original file line number Diff line number Diff line change
Expand Up @@ -234,15 +234,17 @@ struct smart_holder {
// Caller is responsible for precondition: ensure_compatible_rtti_uqp_del<T, D>() must succeed.
template <typename T, typename D>
std::unique_ptr<D> extract_deleter(const char *context) const {
auto *gd = std::get_deleter<guarded_delete>(vptr);
const auto *gd = std::get_deleter<guarded_delete>(vptr);
if (gd && gd->use_del_fun) {
const auto &custom_deleter_ptr = gd->del_fun.template target<custom_deleter<T, D>>();
if (custom_deleter_ptr == nullptr) {
throw std::runtime_error(
std::string("smart_holder::extract_deleter() precondition failure (") + context
+ ").");
}
return std::unique_ptr<D>(new D(std::move(custom_deleter_ptr->deleter)));
static_assert(std::is_copy_constructible<D>::value,
"Required for compatibility with smart_holder functionality.");
return std::unique_ptr<D>(new D(custom_deleter_ptr->deleter));
}
return nullptr;
}
Expand Down
13 changes: 13 additions & 0 deletions include/pybind11/detail/type_caster_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -814,6 +814,19 @@ struct load_helper : value_and_holder_helper {

return result;
}

// This assumes load_as_shared_ptr succeeded(), and the returned shared_ptr is still alive.
// The returned unique_ptr is meant to never expire (the behavior is undefined otherwise).
template <typename D>
std::unique_ptr<T, D>
load_as_const_unique_ptr(T *raw_type_ptr, const char *context = "load_as_const_unique_ptr") {
if (!have_holder()) {
return unique_with_deleter<T, D>(nullptr, std::unique_ptr<D>());
}
holder().template ensure_compatible_rtti_uqp_del<T, D>(context);
return unique_with_deleter<T, D>(
raw_type_ptr, std::move(holder().template extract_deleter<T, D>(context)));
}
};

PYBIND11_NAMESPACE_END(smart_holder_type_caster_support)
Expand Down
14 changes: 14 additions & 0 deletions tests/test_class_sh_basic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,17 @@ std::string get_mtxt(atyp const &obj) { return obj.mtxt; }
std::ptrdiff_t get_ptr(atyp const &obj) { return reinterpret_cast<std::ptrdiff_t>(&obj); }

std::unique_ptr<atyp> unique_ptr_roundtrip(std::unique_ptr<atyp> obj) { return obj; }

std::string pass_unique_ptr_cref(const std::unique_ptr<atyp> &obj) { return obj->mtxt; }

const std::unique_ptr<atyp> &rtrn_unique_ptr_cref(const std::string &mtxt) {
static std::unique_ptr<atyp> obj{new atyp{"static_ctor_arg"}};
if (!mtxt.empty()) {
obj->mtxt = mtxt;
}
return obj;
}

const std::unique_ptr<atyp> &unique_ptr_cref_roundtrip(const std::unique_ptr<atyp> &obj) {
return obj;
}
Expand Down Expand Up @@ -217,6 +228,9 @@ TEST_SUBMODULE(class_sh_basic, m) {
m.def("get_ptr", get_ptr); // pass_cref

m.def("unique_ptr_roundtrip", unique_ptr_roundtrip); // pass_uqmp, rtrn_uqmp

m.def("pass_unique_ptr_cref", pass_unique_ptr_cref);
m.def("rtrn_unique_ptr_cref", rtrn_unique_ptr_cref);
m.def("unique_ptr_cref_roundtrip", unique_ptr_cref_roundtrip);

py::classh<SharedPtrStash>(m, "SharedPtrStash")
Expand Down
34 changes: 23 additions & 11 deletions tests/test_class_sh_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,19 +151,31 @@ def test_unique_ptr_roundtrip(num_round_trips=1000):
id_orig = id_rtrn


# This currently fails, because a unique_ptr is always loaded by value
# due to pybind11/detail/smart_holder_type_casters.h:689
# I think, we need to provide more cast operators.
@pytest.mark.skip()
def test_unique_ptr_cref_roundtrip():
def test_pass_unique_ptr_cref():
obj = m.atyp("ctor_arg")
assert re.match("ctor_arg(_MvCtor)*_MvCtor", m.get_mtxt(obj))
assert re.match("ctor_arg(_MvCtor)*_MvCtor", m.pass_unique_ptr_cref(obj))
assert re.match("ctor_arg(_MvCtor)*_MvCtor", m.get_mtxt(obj))


def test_rtrn_unique_ptr_cref():
obj0 = m.rtrn_unique_ptr_cref("")
assert m.get_mtxt(obj0) == "static_ctor_arg"
obj1 = m.rtrn_unique_ptr_cref("passed_mtxt_1")
assert m.get_mtxt(obj1) == "passed_mtxt_1"
assert m.get_mtxt(obj0) == "passed_mtxt_1"
assert obj0 is obj1


def test_unique_ptr_cref_roundtrip(num_round_trips=1000):
# Multiple roundtrips to stress-test implementation.
orig = m.atyp("passenger")
id_orig = id(orig)
mtxt_orig = m.get_mtxt(orig)

recycled = m.unique_ptr_cref_roundtrip(orig)
assert m.get_mtxt(orig) == mtxt_orig
assert m.get_mtxt(recycled) == mtxt_orig
assert id(recycled) == id_orig
recycled = orig
for _ in range(num_round_trips):
recycled = m.unique_ptr_cref_roundtrip(recycled)
assert recycled is orig
assert m.get_mtxt(recycled) == mtxt_orig


@pytest.mark.parametrize(
Expand Down
6 changes: 5 additions & 1 deletion tests/test_class_sh_trampoline_shared_from_this.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,10 @@ long pass_shared_ptr(const std::shared_ptr<Sft> &obj) {
return sft.use_count();
}

void pass_unique_ptr_cref(const std::unique_ptr<Sft> &) {
std::string pass_unique_ptr_cref(const std::unique_ptr<Sft> &obj) {
return obj ? obj->history : "<NULLPTR>";
}
void pass_unique_ptr_rref(std::unique_ptr<Sft> &&) {
throw std::runtime_error("Expected to not be reached.");
}

Expand Down Expand Up @@ -138,6 +141,7 @@ TEST_SUBMODULE(class_sh_trampoline_shared_from_this, m) {
m.def("use_count", use_count);
m.def("pass_shared_ptr", pass_shared_ptr);
m.def("pass_unique_ptr_cref", pass_unique_ptr_cref);
m.def("pass_unique_ptr_rref", pass_unique_ptr_rref);
m.def("make_pure_cpp_sft_raw_ptr", make_pure_cpp_sft_raw_ptr);
m.def("make_pure_cpp_sft_unq_ptr", make_pure_cpp_sft_unq_ptr);
m.def("make_pure_cpp_sft_shd_ptr", make_pure_cpp_sft_shd_ptr);
Expand Down
4 changes: 3 additions & 1 deletion tests/test_class_sh_trampoline_shared_from_this.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,8 +137,10 @@ def test_pass_released_shared_ptr_as_unique_ptr():
obj = PySft("PySft")
stash1 = m.SftSharedPtrStash(1)
stash1.Add(obj) # Releases shared_ptr to C++.
assert m.pass_unique_ptr_cref(obj) == "PySft_Stash1Add"
assert obj.history == "PySft_Stash1Add"
with pytest.raises(ValueError) as exc_info:
m.pass_unique_ptr_cref(obj)
m.pass_unique_ptr_rref(obj)
assert str(exc_info.value) == (
"Python instance is currently owned by a std::shared_ptr."
)
Expand Down