Skip to content

Support frozenset, tuple as dict keys #3886

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 22 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
51563bc
Support frozenset, tuple as dict keys
ecatmur Apr 19, 2022
d0f9f2e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 19, 2022
038904a
Fix for C++11 mode
ecatmur Apr 19, 2022
27986dd
Merge branch 'frozenset' of https://github.com/ecatmur/pybind11 into …
ecatmur Apr 19, 2022
a56f91c
protect non-const methods
ecatmur Apr 19, 2022
cd09b3a
formatting
ecatmur Apr 19, 2022
6c045a5
Revert "protect non-const methods"
ecatmur Apr 19, 2022
bb123d1
Revert "Revert "protect non-const methods""
ecatmur Apr 19, 2022
6b55983
Move add() and clear() to set
ecatmur Apr 19, 2022
3eb1a8b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 19, 2022
1b7a941
Only use const type_caster for class types
ecatmur Apr 21, 2022
0ebef3b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 21, 2022
fcaff44
More tests for tuple -> list, frozenset -> set
ecatmur Apr 21, 2022
2096750
Merge branch 'frozenset' of https://github.com/ecatmur/pybind11 into …
ecatmur Apr 21, 2022
e828031
Add frozenset, and allow it cast to std::set
ecatmur Apr 24, 2022
f2db7bb
Rename set_base to any_set to match Python C API
ecatmur Apr 24, 2022
ef92aa5
PR: static_cast, anyset
ecatmur May 1, 2022
736f293
Add tests for frozenset
ecatmur May 1, 2022
faf8a51
Remove frozenset default ctor, add tests
ecatmur May 1, 2022
05b6147
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 1, 2022
26a29f4
Merge remote-tracking branch 'upstream/master' into frozenset-core
ecatmur May 1, 2022
0fb3a4f
Merge branch 'frozenset-core' into frozenset
ecatmur May 1, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions include/pybind11/cast.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,9 @@ class type_caster<std::reference_wrapper<type>> {
explicit operator std::reference_wrapper<type>() { return cast_op<type &>(subcaster); }
};

template <typename type>
class type_caster<const type> : public type_caster<type> {};

#define PYBIND11_TYPE_CASTER(type, py_name) \
protected: \
type value; \
Expand Down Expand Up @@ -907,6 +910,12 @@ struct handle_type_name<kwargs> {

template <typename type>
struct pyobject_caster {
template <typename T = type, enable_if_t<std::is_same<T, handle>::value, int> = 0>
pyobject_caster() : value() {}

template <typename T = type, enable_if_t<std::is_base_of<object, T>::value, int> = 0>
pyobject_caster() : value(reinterpret_steal<type>(handle())) {}

template <typename T = type, enable_if_t<std::is_same<T, handle>::value, int> = 0>
bool load(handle src, bool /* convert */) {
value = src;
Expand Down
30 changes: 21 additions & 9 deletions include/pybind11/pytypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -1784,25 +1784,37 @@ class kwargs : public dict {
PYBIND11_OBJECT_DEFAULT(kwargs, dict, PyDict_Check)
};

class set : public object {
class anyset : public object {
protected:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't like having add as a public function on frozen-sets. We can still call the C-API directly on the pointer for casters, but other code shouldn't be able to add objects to frozen sets.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

moved non-const methods to protected

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They would still need to be public for normal set?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ecatmur Why not just move them to set subclass and use the raw C-API to interact with the frozen set.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I get it now. Fixed.

PYBIND11_OBJECT(anyset, object, PyAnySet_Check)

public:
size_t size() const { return static_cast<size_t>(PySet_Size(m_ptr)); }
bool empty() const { return size() == 0; }
template <typename T>
bool contains(T &&val) const {
return PySet_Contains(m_ptr, detail::object_or_cast(std::forward<T>(val)).ptr()) == 1;
}
};

class set : public anyset {
public:
PYBIND11_OBJECT_CVT(set, object, PySet_Check, PySet_New)
set() : object(PySet_New(nullptr), stolen_t{}) {
PYBIND11_OBJECT_CVT(set, anyset, PySet_Check, PySet_New)
set() : anyset(PySet_New(nullptr), stolen_t{}) {
if (!m_ptr) {
pybind11_fail("Could not allocate set object!");
}
}
size_t size() const { return (size_t) PySet_Size(m_ptr); }
bool empty() const { return size() == 0; }
template <typename T>
bool add(T &&val) /* py-non-const */ {
return PySet_Add(m_ptr, detail::object_or_cast(std::forward<T>(val)).ptr()) == 0;
}
void clear() /* py-non-const */ { PySet_Clear(m_ptr); }
template <typename T>
bool contains(T &&val) const {
return PySet_Contains(m_ptr, detail::object_or_cast(std::forward<T>(val)).ptr()) == 1;
}
};

class frozenset : public anyset {
public:
PYBIND11_OBJECT_CVT(frozenset, anyset, PyFrozenSet_Check, PyFrozenSet_New)
};

class function : public object {
Expand Down
84 changes: 65 additions & 19 deletions include/pybind11/stl.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,23 +49,29 @@ forwarded_type<T, U> forward_like(U &&u) {
return std::forward<detail::forwarded_type<T, U>>(std::forward<U>(u));
}

template <typename Type, typename Key>
template <typename Key>
using make_key_caster = type_caster<conditional_t<std::is_class<intrinsic_t<Key>>::value,
const intrinsic_t<Key>,
intrinsic_t<Key>>>;

template <typename Type, typename Key, bool Const = false>
struct set_caster {
using type = Type;
using key_conv = make_caster<Key>;
using key_conv = make_key_caster<Key>;

bool load(handle src, bool convert) {
if (!isinstance<pybind11::set>(src)) {
if (!isinstance<anyset>(src)) {
return false;
}
auto s = reinterpret_borrow<pybind11::set>(src);
auto s = reinterpret_borrow<anyset>(src);
value.clear();
for (auto entry : s) {
key_conv conv;
if (!conv.load(entry, convert)) {
return false;
}
value.insert(cast_op<Key &&>(std::move(conv)));
value.insert(
std::move(conv).operator typename key_conv::template cast_op_type<Key &&>());
}
return true;
}
Expand All @@ -75,23 +81,27 @@ struct set_caster {
if (!std::is_lvalue_reference<T>::value) {
policy = return_value_policy_override<Key>::policy(policy);
}
pybind11::set s;
typename std::conditional<Const, frozenset, set>::type s;
for (auto &&value : src) {
auto value_ = reinterpret_steal<object>(
key_conv::cast(forward_like<T>(value), policy, parent));
if (!value_ || !s.add(std::move(value_))) {
// pybind11::frozenset doesn't have add() for safety, so call PySet_Add directly.
if (!value_
|| PySet_Add(s.ptr(), detail::object_or_cast(std::move(value_)).ptr()) != 0) {
return handle();
}
}
return s.release();
}

PYBIND11_TYPE_CASTER(type, const_name("Set[") + key_conv::name + const_name("]"));
PYBIND11_TYPE_CASTER(type,
const_name<Const>("FrozenSet[", "Set[") + key_conv::name
+ const_name("]"));
};

template <typename Type, typename Key, typename Value>
struct map_caster {
using key_conv = make_caster<Key>;
using key_conv = make_key_caster<Key>;
using value_conv = make_caster<Value>;

bool load(handle src, bool convert) {
Expand All @@ -106,7 +116,9 @@ struct map_caster {
if (!kconv.load(it.first.ptr(), convert) || !vconv.load(it.second.ptr(), convert)) {
return false;
}
value.emplace(cast_op<Key &&>(std::move(kconv)), cast_op<Value &&>(std::move(vconv)));
value.emplace(
std::move(kconv).operator typename key_conv::template cast_op_type<Key &&>(),
cast_op<Value &&>(std::move(vconv)));
}
return true;
}
Expand Down Expand Up @@ -138,7 +150,7 @@ struct map_caster {
+ const_name("]"));
};

template <typename Type, typename Value>
template <typename Type, typename Value, bool Const = false>
struct list_caster {
using value_conv = make_caster<Value>;

Expand Down Expand Up @@ -174,32 +186,50 @@ struct list_caster {
if (!std::is_lvalue_reference<T>::value) {
policy = return_value_policy_override<Value>::policy(policy);
}
list l(src.size());
conditional_t<Const, tuple, list> l(src.size());
ssize_t index = 0;
for (auto &&value : src) {
auto value_ = reinterpret_steal<object>(
value_conv::cast(forward_like<T>(value), policy, parent));
if (!value_) {
return handle();
}
PyList_SET_ITEM(l.ptr(), index++, value_.release().ptr()); // steals a reference
if (Const) {
PyTuple_SET_ITEM(l.ptr(), index++, value_.release().ptr()); // steals a reference
} else {
PyList_SET_ITEM(l.ptr(), index++, value_.release().ptr()); // steals a reference
}
}
return l.release();
}

PYBIND11_TYPE_CASTER(Type, const_name("List[") + value_conv::name + const_name("]"));
PYBIND11_TYPE_CASTER(Type,
const_name<Const>("Tuple[", "List[") + value_conv::name
+ const_name<Const>(", ...]", "]"));
};

template <typename Type, typename Alloc>
struct type_caster<std::vector<Type, Alloc>> : list_caster<std::vector<Type, Alloc>, Type> {};

template <typename Type, typename Alloc>
struct type_caster<const std::vector<Type, Alloc>>
: list_caster<std::vector<Type, Alloc>, Type, true> {};

template <typename Type, typename Alloc>
struct type_caster<std::deque<Type, Alloc>> : list_caster<std::deque<Type, Alloc>, Type> {};

template <typename Type, typename Alloc>
struct type_caster<const std::deque<Type, Alloc>>
: list_caster<std::deque<Type, Alloc>, Type, true> {};

template <typename Type, typename Alloc>
struct type_caster<std::list<Type, Alloc>> : list_caster<std::list<Type, Alloc>, Type> {};

template <typename ArrayType, typename Value, bool Resizable, size_t Size = 0>
template <typename Type, typename Alloc>
struct type_caster<const std::list<Type, Alloc>>
: list_caster<std::list<Type, Alloc>, Type, true> {};

template <typename ArrayType, typename Value, bool Resizable, size_t Size = 0, bool Const = false>
struct array_caster {
using value_conv = make_caster<Value>;

Expand Down Expand Up @@ -238,42 +268,58 @@ struct array_caster {

template <typename T>
static handle cast(T &&src, return_value_policy policy, handle parent) {
list l(src.size());
conditional_t<Const, tuple, list> l(src.size());
ssize_t index = 0;
for (auto &&value : src) {
auto value_ = reinterpret_steal<object>(
value_conv::cast(forward_like<T>(value), policy, parent));
if (!value_) {
return handle();
}
PyList_SET_ITEM(l.ptr(), index++, value_.release().ptr()); // steals a reference
if (Const) {
PyTuple_SET_ITEM(l.ptr(), index++, value_.release().ptr()); // steals a reference
} else {
PyList_SET_ITEM(l.ptr(), index++, value_.release().ptr()); // steals a reference
}
}
return l.release();
}

PYBIND11_TYPE_CASTER(ArrayType,
const_name("List[") + value_conv::name
const_name<Const>("Tuple[", "List[") + value_conv::name
+ const_name<Resizable>(const_name(""),
const_name("[") + const_name<Size>()
+ const_name("]"))
+ const_name("]"));
+ const_name<Const>(", ...]", "]"));
};

template <typename Type, size_t Size>
struct type_caster<std::array<Type, Size>>
: array_caster<std::array<Type, Size>, Type, false, Size> {};

template <typename Type, size_t Size>
struct type_caster<const std::array<Type, Size>>
: array_caster<std::array<Type, Size>, Type, false, Size, true> {};

template <typename Type>
struct type_caster<std::valarray<Type>> : array_caster<std::valarray<Type>, Type, true> {};

template <typename Key, typename Compare, typename Alloc>
struct type_caster<std::set<Key, Compare, Alloc>>
: set_caster<std::set<Key, Compare, Alloc>, Key> {};

template <typename Key, typename Compare, typename Alloc>
struct type_caster<const std::set<Key, Compare, Alloc>>
: set_caster<std::set<Key, Compare, Alloc>, Key, true> {};

template <typename Key, typename Hash, typename Equal, typename Alloc>
struct type_caster<std::unordered_set<Key, Hash, Equal, Alloc>>
: set_caster<std::unordered_set<Key, Hash, Equal, Alloc>, Key> {};

template <typename Key, typename Hash, typename Equal, typename Alloc>
struct type_caster<const std::unordered_set<Key, Hash, Equal, Alloc>>
: set_caster<std::unordered_set<Key, Hash, Equal, Alloc>, Key, true> {};

template <typename Key, typename Value, typename Compare, typename Alloc>
struct type_caster<std::map<Key, Value, Compare, Alloc>>
: map_caster<std::map<Key, Value, Compare, Alloc>, Key, Value> {};
Expand Down
24 changes: 19 additions & 5 deletions tests/test_pytypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,22 +75,34 @@ TEST_SUBMODULE(pytypes, m) {
m.def("get_none", [] { return py::none(); });
m.def("print_none", [](const py::none &none) { py::print("none: {}"_s.format(none)); });

// test_set
// test_set, test_frozenset
m.def("get_set", []() {
py::set set;
set.add(py::str("key1"));
set.add("key2");
set.add(std::string("key3"));
return set;
});
m.def("print_set", [](const py::set &set) {
m.def("get_frozenset", []() {
py::set set;
set.add(py::str("key1"));
set.add("key2");
set.add(std::string("key3"));
return py::frozenset(set);
});
m.def("print_anyset", [](const py::anyset &set) {
for (auto item : set) {
py::print("key:", item);
}
});
m.def("set_contains",
[](const py::set &set, const py::object &key) { return set.contains(key); });
m.def("set_contains", [](const py::set &set, const char *key) { return set.contains(key); });
m.def("anyset_size", [](const py::anyset &set) { return set.size(); });
m.def("anyset_empty", [](const py::anyset &set) { return set.empty(); });
m.def("anyset_contains",
[](const py::anyset &set, const py::object &key) { return set.contains(key); });
m.def("anyset_contains",
[](const py::anyset &set, const char *key) { return set.contains(key); });
m.def("set_add", [](py::set &set, const py::object &key) { set.add(key); });
m.def("set_clear", [](py::set &set) { set.clear(); });

// test_dict
m.def("get_dict", []() { return py::dict("key"_a = "value"); });
Expand Down Expand Up @@ -310,6 +322,7 @@ TEST_SUBMODULE(pytypes, m) {
"list"_a = py::list(d["list"]),
"dict"_a = py::dict(d["dict"]),
"set"_a = py::set(d["set"]),
"frozenset"_a = py::frozenset(d["frozenset"]),
"memoryview"_a = py::memoryview(d["memoryview"]));
});

Expand All @@ -325,6 +338,7 @@ TEST_SUBMODULE(pytypes, m) {
"list"_a = d["list"].cast<py::list>(),
"dict"_a = d["dict"].cast<py::dict>(),
"set"_a = d["set"].cast<py::set>(),
"frozenset"_a = d["frozenset"].cast<py::frozenset>(),
"memoryview"_a = d["memoryview"].cast<py::memoryview>());
});

Expand Down
Loading