diff --git a/include/pybind11/cast.h b/include/pybind11/cast.h index e8128710e2..bb705ab2d4 100644 --- a/include/pybind11/cast.h +++ b/include/pybind11/cast.h @@ -79,6 +79,9 @@ class type_caster> { explicit operator std::reference_wrapper() { return cast_op(subcaster); } }; +template +class type_caster : public type_caster {}; + #define PYBIND11_TYPE_CASTER(type, py_name) \ protected: \ type value; \ @@ -907,6 +910,12 @@ struct handle_type_name { template struct pyobject_caster { + template ::value, int> = 0> + pyobject_caster() : value() {} + + template ::value, int> = 0> + pyobject_caster() : value(reinterpret_steal(handle())) {} + template ::value, int> = 0> bool load(handle src, bool /* convert */) { value = src; diff --git a/include/pybind11/pytypes.h b/include/pybind11/pytypes.h index 324fa932f1..a79b0caadf 100644 --- a/include/pybind11/pytypes.h +++ b/include/pybind11/pytypes.h @@ -1784,25 +1784,37 @@ class kwargs : public dict { PYBIND11_OBJECT_DEFAULT(kwargs, dict, PyDict_Check) }; -class set : public object { +class anyset : public object { +protected: + PYBIND11_OBJECT(anyset, object, PyAnySet_Check) + +public: + size_t size() const { return static_cast(PySet_Size(m_ptr)); } + bool empty() const { return size() == 0; } + template + bool contains(T &&val) const { + return PySet_Contains(m_ptr, detail::object_or_cast(std::forward(val)).ptr()) == 1; + } +}; + +class set : public anyset { public: - PYBIND11_OBJECT_CVT(set, object, PySet_Check, PySet_New) - set() : object(PySet_New(nullptr), stolen_t{}) { + PYBIND11_OBJECT_CVT(set, anyset, PySet_Check, PySet_New) + set() : anyset(PySet_New(nullptr), stolen_t{}) { if (!m_ptr) { pybind11_fail("Could not allocate set object!"); } } - size_t size() const { return (size_t) PySet_Size(m_ptr); } - bool empty() const { return size() == 0; } template bool add(T &&val) /* py-non-const */ { return PySet_Add(m_ptr, detail::object_or_cast(std::forward(val)).ptr()) == 0; } void clear() /* py-non-const */ { PySet_Clear(m_ptr); } - template - bool contains(T &&val) const { - return PySet_Contains(m_ptr, detail::object_or_cast(std::forward(val)).ptr()) == 1; - } +}; + +class frozenset : public anyset { +public: + PYBIND11_OBJECT_CVT(frozenset, anyset, PyFrozenSet_Check, PyFrozenSet_New) }; class function : public object { diff --git a/include/pybind11/stl.h b/include/pybind11/stl.h index 51b57a92ba..547e5c0e8f 100644 --- a/include/pybind11/stl.h +++ b/include/pybind11/stl.h @@ -49,23 +49,29 @@ forwarded_type forward_like(U &&u) { return std::forward>(std::forward(u)); } -template +template +using make_key_caster = type_caster>::value, + const intrinsic_t, + intrinsic_t>>; + +template struct set_caster { using type = Type; - using key_conv = make_caster; + using key_conv = make_key_caster; bool load(handle src, bool convert) { - if (!isinstance(src)) { + if (!isinstance(src)) { return false; } - auto s = reinterpret_borrow(src); + auto s = reinterpret_borrow(src); value.clear(); for (auto entry : s) { key_conv conv; if (!conv.load(entry, convert)) { return false; } - value.insert(cast_op(std::move(conv))); + value.insert( + std::move(conv).operator typename key_conv::template cast_op_type()); } return true; } @@ -75,23 +81,27 @@ struct set_caster { if (!std::is_lvalue_reference::value) { policy = return_value_policy_override::policy(policy); } - pybind11::set s; + typename std::conditional::type s; for (auto &&value : src) { auto value_ = reinterpret_steal( key_conv::cast(forward_like(value), policy, parent)); - if (!value_ || !s.add(std::move(value_))) { + // pybind11::frozenset doesn't have add() for safety, so call PySet_Add directly. + if (!value_ + || PySet_Add(s.ptr(), detail::object_or_cast(std::move(value_)).ptr()) != 0) { return handle(); } } return s.release(); } - PYBIND11_TYPE_CASTER(type, const_name("Set[") + key_conv::name + const_name("]")); + PYBIND11_TYPE_CASTER(type, + const_name("FrozenSet[", "Set[") + key_conv::name + + const_name("]")); }; template struct map_caster { - using key_conv = make_caster; + using key_conv = make_key_caster; using value_conv = make_caster; bool load(handle src, bool convert) { @@ -106,7 +116,9 @@ struct map_caster { if (!kconv.load(it.first.ptr(), convert) || !vconv.load(it.second.ptr(), convert)) { return false; } - value.emplace(cast_op(std::move(kconv)), cast_op(std::move(vconv))); + value.emplace( + std::move(kconv).operator typename key_conv::template cast_op_type(), + cast_op(std::move(vconv))); } return true; } @@ -138,7 +150,7 @@ struct map_caster { + const_name("]")); }; -template +template struct list_caster { using value_conv = make_caster; @@ -174,7 +186,7 @@ struct list_caster { if (!std::is_lvalue_reference::value) { policy = return_value_policy_override::policy(policy); } - list l(src.size()); + conditional_t l(src.size()); ssize_t index = 0; for (auto &&value : src) { auto value_ = reinterpret_steal( @@ -182,24 +194,42 @@ struct list_caster { if (!value_) { return handle(); } - PyList_SET_ITEM(l.ptr(), index++, value_.release().ptr()); // steals a reference + if (Const) { + PyTuple_SET_ITEM(l.ptr(), index++, value_.release().ptr()); // steals a reference + } else { + PyList_SET_ITEM(l.ptr(), index++, value_.release().ptr()); // steals a reference + } } return l.release(); } - PYBIND11_TYPE_CASTER(Type, const_name("List[") + value_conv::name + const_name("]")); + PYBIND11_TYPE_CASTER(Type, + const_name("Tuple[", "List[") + value_conv::name + + const_name(", ...]", "]")); }; template struct type_caster> : list_caster, Type> {}; +template +struct type_caster> + : list_caster, Type, true> {}; + template struct type_caster> : list_caster, Type> {}; +template +struct type_caster> + : list_caster, Type, true> {}; + template struct type_caster> : list_caster, Type> {}; -template +template +struct type_caster> + : list_caster, Type, true> {}; + +template struct array_caster { using value_conv = make_caster; @@ -238,7 +268,7 @@ struct array_caster { template static handle cast(T &&src, return_value_policy policy, handle parent) { - list l(src.size()); + conditional_t l(src.size()); ssize_t index = 0; for (auto &&value : src) { auto value_ = reinterpret_steal( @@ -246,23 +276,31 @@ struct array_caster { if (!value_) { return handle(); } - PyList_SET_ITEM(l.ptr(), index++, value_.release().ptr()); // steals a reference + if (Const) { + PyTuple_SET_ITEM(l.ptr(), index++, value_.release().ptr()); // steals a reference + } else { + PyList_SET_ITEM(l.ptr(), index++, value_.release().ptr()); // steals a reference + } } return l.release(); } PYBIND11_TYPE_CASTER(ArrayType, - const_name("List[") + value_conv::name + const_name("Tuple[", "List[") + value_conv::name + const_name(const_name(""), const_name("[") + const_name() + const_name("]")) - + const_name("]")); + + const_name(", ...]", "]")); }; template struct type_caster> : array_caster, Type, false, Size> {}; +template +struct type_caster> + : array_caster, Type, false, Size, true> {}; + template struct type_caster> : array_caster, Type, true> {}; @@ -270,10 +308,18 @@ template struct type_caster> : set_caster, Key> {}; +template +struct type_caster> + : set_caster, Key, true> {}; + template struct type_caster> : set_caster, Key> {}; +template +struct type_caster> + : set_caster, Key, true> {}; + template struct type_caster> : map_caster, Key, Value> {}; diff --git a/tests/test_pytypes.cpp b/tests/test_pytypes.cpp index d1e9b81a73..8d296f655a 100644 --- a/tests/test_pytypes.cpp +++ b/tests/test_pytypes.cpp @@ -75,7 +75,7 @@ TEST_SUBMODULE(pytypes, m) { m.def("get_none", [] { return py::none(); }); m.def("print_none", [](const py::none &none) { py::print("none: {}"_s.format(none)); }); - // test_set + // test_set, test_frozenset m.def("get_set", []() { py::set set; set.add(py::str("key1")); @@ -83,14 +83,26 @@ TEST_SUBMODULE(pytypes, m) { set.add(std::string("key3")); return set; }); - m.def("print_set", [](const py::set &set) { + m.def("get_frozenset", []() { + py::set set; + set.add(py::str("key1")); + set.add("key2"); + set.add(std::string("key3")); + return py::frozenset(set); + }); + m.def("print_anyset", [](const py::anyset &set) { for (auto item : set) { py::print("key:", item); } }); - m.def("set_contains", - [](const py::set &set, const py::object &key) { return set.contains(key); }); - m.def("set_contains", [](const py::set &set, const char *key) { return set.contains(key); }); + m.def("anyset_size", [](const py::anyset &set) { return set.size(); }); + m.def("anyset_empty", [](const py::anyset &set) { return set.empty(); }); + m.def("anyset_contains", + [](const py::anyset &set, const py::object &key) { return set.contains(key); }); + m.def("anyset_contains", + [](const py::anyset &set, const char *key) { return set.contains(key); }); + m.def("set_add", [](py::set &set, const py::object &key) { set.add(key); }); + m.def("set_clear", [](py::set &set) { set.clear(); }); // test_dict m.def("get_dict", []() { return py::dict("key"_a = "value"); }); @@ -310,6 +322,7 @@ TEST_SUBMODULE(pytypes, m) { "list"_a = py::list(d["list"]), "dict"_a = py::dict(d["dict"]), "set"_a = py::set(d["set"]), + "frozenset"_a = py::frozenset(d["frozenset"]), "memoryview"_a = py::memoryview(d["memoryview"])); }); @@ -325,6 +338,7 @@ TEST_SUBMODULE(pytypes, m) { "list"_a = d["list"].cast(), "dict"_a = d["dict"].cast(), "set"_a = d["set"].cast(), + "frozenset"_a = d["frozenset"].cast(), "memoryview"_a = d["memoryview"].cast()); }); diff --git a/tests/test_pytypes.py b/tests/test_pytypes.py index 5c715ada6b..a6adfdddad 100644 --- a/tests/test_pytypes.py +++ b/tests/test_pytypes.py @@ -66,11 +66,12 @@ def test_none(capture, doc): def test_set(capture, doc): s = m.get_set() + assert isinstance(s, set) assert s == {"key1", "key2", "key3"} + s.add("key4") with capture: - s.add("key4") - m.print_set(s) + m.print_anyset(s) assert ( capture.unordered == """ @@ -81,12 +82,43 @@ def test_set(capture, doc): """ ) - assert not m.set_contains(set(), 42) - assert m.set_contains({42}, 42) - assert m.set_contains({"foo"}, "foo") + m.set_add(s, "key5") + assert m.anyset_size(s) == 5 - assert doc(m.get_list) == "get_list() -> list" - assert doc(m.print_list) == "print_list(arg0: list) -> None" + m.set_clear(s) + assert m.anyset_empty(s) + + assert not m.anyset_contains(set(), 42) + assert m.anyset_contains({42}, 42) + assert m.anyset_contains({"foo"}, "foo") + + assert doc(m.get_set) == "get_set() -> set" + assert doc(m.print_anyset) == "print_anyset(arg0: anyset) -> None" + + +def test_frozenset(capture, doc): + s = m.get_frozenset() + assert isinstance(s, frozenset) + assert s == frozenset({"key1", "key2", "key3"}) + + with capture: + m.print_anyset(s) + assert ( + capture.unordered + == """ + key: key1 + key: key2 + key: key3 + """ + ) + assert m.anyset_size(s) == 3 + assert not m.anyset_empty(s) + + assert not m.anyset_contains(frozenset(), 42) + assert m.anyset_contains(frozenset({42}), 42) + assert m.anyset_contains(frozenset({"foo"}), "foo") + + assert doc(m.get_frozenset) == "get_frozenset() -> frozenset" def test_dict(capture, doc): @@ -302,6 +334,7 @@ def test_constructors(): list: range(3), dict: [("two", 2), ("one", 1), ("three", 3)], set: [4, 4, 5, 6, 6, 6], + frozenset: [4, 4, 5, 6, 6, 6], memoryview: b"abc", } inputs = {k.__name__: v for k, v in data.items()} diff --git a/tests/test_stl.cpp b/tests/test_stl.cpp index b56a91953b..d818bcc6eb 100644 --- a/tests/test_stl.cpp +++ b/tests/test_stl.cpp @@ -248,6 +248,22 @@ TEST_SUBMODULE(stl, m) { return v; }); + // test_frozen_key + m.def("cast_set_map", []() { + return std::map, std::string>{{{"key1", "key2"}, "value"}}; + }); + m.def("load_set_map", [](const std::map, std::string> &map) { + return map.at({"key1", "key2"}) == "value" && map.at({"key3"}) == "value2"; + }); + m.def("cast_set_set", []() { return std::set>{{"key1", "key2"}}; }); + m.def("load_set_set", [](const std::set> &set) { + return (set.count({"key1", "key2"}) != 0u) && (set.count({"key3"}) != 0u); + }); + m.def("cast_vector_set", []() { return std::set>{{1, 2}}; }); + m.def("load_vector_set", [](const std::set> &set) { + return (set.count({1, 2}) != 0u) && (set.count({3}) != 0u); + }); + pybind11::enum_(m, "EnumType") .value("kSet", EnumType::kSet) .value("kUnset", EnumType::kUnset); diff --git a/tests/test_stl.py b/tests/test_stl.py index 975860b85a..11a445019b 100644 --- a/tests/test_stl.py +++ b/tests/test_stl.py @@ -73,6 +73,7 @@ def test_set(doc): assert s == {"key1", "key2"} s.add("key3") assert m.load_set(s) + assert m.load_set(frozenset(s)) assert doc(m.cast_set) == "cast_set() -> Set[str]" assert doc(m.load_set) == "load_set(arg0: Set[str]) -> bool" @@ -97,6 +98,38 @@ def test_recursive_casting(): assert z[0].value == 7 and z[1].value == 42 +def test_frozen_key(doc): + """Test that we special-case C++ key types to Python immutable containers, e.g.: + std::map, V> <-> dict[frozenset[K], V] + std::set> <-> set[frozenset[T]] + std::set> <-> set[tuple[T, ...]] + """ + s = m.cast_set_map() + assert s == {frozenset({"key1", "key2"}): "value"} + s[frozenset({"key3"})] = "value2" + assert m.load_set_map(s) + assert doc(m.cast_set_map) == "cast_set_map() -> Dict[FrozenSet[str], str]" + assert ( + doc(m.load_set_map) == "load_set_map(arg0: Dict[FrozenSet[str], str]) -> bool" + ) + + s = m.cast_set_set() + assert s == {frozenset({"key1", "key2"})} + s.add(frozenset({"key3"})) + assert m.load_set_set(s) + assert doc(m.cast_set_set) == "cast_set_set() -> Set[FrozenSet[str]]" + assert doc(m.load_set_set) == "load_set_set(arg0: Set[FrozenSet[str]]) -> bool" + + s = m.cast_vector_set() + assert s == {(1, 2)} + s.add((3,)) + assert m.load_vector_set(s) + assert doc(m.cast_vector_set) == "cast_vector_set() -> Set[Tuple[int, ...]]" + assert ( + doc(m.load_vector_set) == "load_vector_set(arg0: Set[Tuple[int, ...]]) -> bool" + ) + + def test_move_out_container(): """Properties use the `reference_internal` policy by default. If the underlying function returns an rvalue, the policy is automatically changed to `move` to avoid referencing