diff --git a/include/pybind11/cast.h b/include/pybind11/cast.h index 9a971704e4..af0f88c219 100644 --- a/include/pybind11/cast.h +++ b/include/pybind11/cast.h @@ -1639,6 +1639,27 @@ object object_api::call(Args &&...args) const { return operator()(std::forward(args)...); } +// Convert list -> tuple and set -> frozenset for use as keys in dict, set etc. +// https://mail.python.org/pipermail/python-dev/2005-October/057586.html +inline object freeze(object &&obj) { + if (isinstance(obj)) { + return tuple(std::move(obj)); + } + if (isinstance(obj)) { + return frozenset(std::move(obj)); + } + return std::move(obj); +} + +template +struct frozen_type_name { + static constexpr auto name = Caster::name; +}; +template +struct frozen_type_name> { + static constexpr auto name = Caster::frozen_name; +}; + PYBIND11_NAMESPACE_END(detail) template diff --git a/include/pybind11/stl.h b/include/pybind11/stl.h index 597bce61d5..9b5a6bf5d5 100644 --- a/include/pybind11/stl.h +++ b/include/pybind11/stl.h @@ -79,14 +79,17 @@ struct set_caster { for (auto &&value : src) { auto value_ = reinterpret_steal( key_conv::cast(forward_like(value), policy, parent)); - if (!value_ || !s.add(std::move(value_))) { + if (!value_ || !s.add(freeze(std::move(value_)))) { return handle(); } } return s.release(); } - PYBIND11_TYPE_CASTER(type, const_name("Set[") + key_conv::name + const_name("]")); + PYBIND11_TYPE_CASTER(type, + const_name("Set[") + frozen_type_name::name + const_name("]")); + static constexpr auto frozen_name + = const_name("FrozenSet[") + frozen_type_name::name + const_name("]"); }; template @@ -128,14 +131,14 @@ struct map_caster { if (!key || !value) { return handle(); } - d[key] = value; + d[freeze(std::move(key))] = std::move(value); } return d.release(); } PYBIND11_TYPE_CASTER(Type, - const_name("Dict[") + key_conv::name + const_name(", ") + value_conv::name - + const_name("]")); + const_name("Dict[") + frozen_type_name::name + const_name(", ") + + value_conv::name + const_name("]")); }; template @@ -188,6 +191,8 @@ struct list_caster { } PYBIND11_TYPE_CASTER(Type, const_name("List[") + value_conv::name + const_name("]")); + static constexpr auto frozen_name + = const_name("Tuple[") + value_conv::name + const_name(", ...]"); }; template @@ -257,6 +262,11 @@ struct array_caster { const_name("[") + const_name() + const_name("]")) + const_name("]")); + static constexpr auto frozen_name + = const_name("Tuple[") + value_conv::name + + const_name(const_name(", ..."), + const_name("[") + const_name() + const_name("]")) + + const_name("]"); }; template diff --git a/tests/test_stl.cpp b/tests/test_stl.cpp index 38d32fda93..5c51db9b52 100644 --- a/tests/test_stl.cpp +++ b/tests/test_stl.cpp @@ -249,6 +249,22 @@ TEST_SUBMODULE(stl, m) { return v; }); + // test_frozen_key + m.def("cast_set_map", []() { + return std::map, std::string>{{{"key1", "key2"}, "value"}}; + }); + m.def("load_set_map", [](const std::map, std::string> &map) { + return map.at({"key1", "key2"}) == "value" && map.at({"key3"}) == "value2"; + }); + m.def("cast_set_set", []() { return std::set>{{"key1", "key2"}}; }); + m.def("load_set_set", [](const std::set> &set) { + return (set.count({"key1", "key2"}) != 0u) && (set.count({"key3"}) != 0u); + }); + m.def("cast_vector_set", []() { return std::set>{{1, 2}}; }); + m.def("load_vector_set", [](const std::set> &set) { + return (set.count({1, 2}) != 0u) && (set.count({3}) != 0u); + }); + pybind11::enum_(m, "EnumType") .value("kSet", EnumType::kSet) .value("kUnset", EnumType::kUnset); diff --git a/tests/test_stl.py b/tests/test_stl.py index d30c382113..11a445019b 100644 --- a/tests/test_stl.py +++ b/tests/test_stl.py @@ -98,6 +98,38 @@ def test_recursive_casting(): assert z[0].value == 7 and z[1].value == 42 +def test_frozen_key(doc): + """Test that we special-case C++ key types to Python immutable containers, e.g.: + std::map, V> <-> dict[frozenset[K], V] + std::set> <-> set[frozenset[T]] + std::set> <-> set[tuple[T, ...]] + """ + s = m.cast_set_map() + assert s == {frozenset({"key1", "key2"}): "value"} + s[frozenset({"key3"})] = "value2" + assert m.load_set_map(s) + assert doc(m.cast_set_map) == "cast_set_map() -> Dict[FrozenSet[str], str]" + assert ( + doc(m.load_set_map) == "load_set_map(arg0: Dict[FrozenSet[str], str]) -> bool" + ) + + s = m.cast_set_set() + assert s == {frozenset({"key1", "key2"})} + s.add(frozenset({"key3"})) + assert m.load_set_set(s) + assert doc(m.cast_set_set) == "cast_set_set() -> Set[FrozenSet[str]]" + assert doc(m.load_set_set) == "load_set_set(arg0: Set[FrozenSet[str]]) -> bool" + + s = m.cast_vector_set() + assert s == {(1, 2)} + s.add((3,)) + assert m.load_vector_set(s) + assert doc(m.cast_vector_set) == "cast_vector_set() -> Set[Tuple[int, ...]]" + assert ( + doc(m.load_vector_set) == "load_vector_set(arg0: Set[Tuple[int, ...]]) -> bool" + ) + + def test_move_out_container(): """Properties use the `reference_internal` policy by default. If the underlying function returns an rvalue, the policy is automatically changed to `move` to avoid referencing