Skip to content

Commit 66e0f80

Browse files
committed
First approximation.
1 parent 88b019a commit 66e0f80

File tree

2 files changed

+125
-0
lines changed

2 files changed

+125
-0
lines changed

Diff for: tests/test_native_enum.cpp

+80
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
// #include <pybind11/native_enum.h>
2+
3+
#include "pybind11_tests.h"
4+
5+
namespace test_native_enum {
6+
7+
// https://en.cppreference.com/w/cpp/language/enum
8+
9+
// enum that takes 16 bits
10+
enum smallenum : std::int16_t { a, b, c };
11+
12+
// color may be red (value 0), yellow (value 1), green (value 20), or blue (value 21)
13+
enum color { red, yellow, green = 20, blue };
14+
15+
// altitude may be altitude::high or altitude::low
16+
enum class altitude : char {
17+
high = 'h',
18+
low = 'l', // trailing comma only allowed after CWG518
19+
};
20+
21+
// the constant d is 0, the constant e is 1, the constant f is 3
22+
enum { d, e, f = e + 2 };
23+
24+
int pass_color(color e) { return static_cast<int>(e); }
25+
color return_color(int i) { return static_cast<color>(i); }
26+
27+
py::handle wrap_color(py::module_ m) {
28+
auto enum_module = py::module_::import("enum");
29+
auto int_enum = enum_module.attr("IntEnum");
30+
using u_t = std::underlying_type<color>::type;
31+
auto members = py::make_tuple(py::make_tuple("red", static_cast<u_t>(color::red)),
32+
py::make_tuple("yellow", static_cast<u_t>(color::yellow)),
33+
py::make_tuple("green", static_cast<u_t>(color::green)),
34+
py::make_tuple("blue", static_cast<u_t>(color::blue)));
35+
auto int_enum_color = int_enum("color", members);
36+
int_enum_color.attr("__module__") = m;
37+
m.attr("color") = int_enum_color;
38+
return int_enum_color.release(); // Intentionally leak Python reference.
39+
}
40+
41+
} // namespace test_native_enum
42+
43+
PYBIND11_NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
44+
PYBIND11_NAMESPACE_BEGIN(detail)
45+
46+
using namespace test_native_enum;
47+
48+
template <>
49+
struct type_caster<color> {
50+
static handle native_type;
51+
52+
static handle cast(const color &src, return_value_policy /* policy */, handle /* parent */) {
53+
auto u_v = static_cast<std::underlying_type<color>::type>(src);
54+
return native_type(u_v).release();
55+
}
56+
57+
bool load(handle src, bool /* convert */) {
58+
if (!isinstance(src, native_type)) {
59+
return false;
60+
}
61+
value = static_cast<color>(py::cast<std::underlying_type<color>::type>(src.attr("value")));
62+
return true;
63+
}
64+
65+
PYBIND11_TYPE_CASTER(color, const_name("<enum 'color'>"));
66+
};
67+
68+
handle type_caster<color>::native_type = nullptr;
69+
70+
PYBIND11_NAMESPACE_END(detail)
71+
PYBIND11_NAMESPACE_END(PYBIND11_NAMESPACE)
72+
73+
TEST_SUBMODULE(native_enum, m) {
74+
using namespace test_native_enum;
75+
76+
py::detail::type_caster<color>::native_type = wrap_color(m);
77+
78+
m.def("pass_color", pass_color);
79+
m.def("return_color", return_color);
80+
}

Diff for: tests/test_native_enum.py

+45
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
import enum
2+
3+
import pytest
4+
5+
from pybind11_tests import native_enum as m
6+
7+
COLOR_MEMBERS = (
8+
("red", 0),
9+
("yellow", 1),
10+
("green", 20),
11+
("blue", 21),
12+
)
13+
14+
15+
def test_enum_color_type():
16+
assert isinstance(m.color, enum.EnumMeta)
17+
18+
19+
@pytest.mark.parametrize("name,value", COLOR_MEMBERS)
20+
def test_enum_color_members(name, value):
21+
assert m.color[name] == value
22+
23+
24+
@pytest.mark.parametrize("name,value", COLOR_MEMBERS)
25+
def test_pass_color_success(name, value):
26+
assert m.pass_color(m.color[name]) == value
27+
28+
29+
def test_pass_color_fail():
30+
with pytest.raises(TypeError) as excinfo:
31+
m.pass_color(None)
32+
assert "<enum 'color'>" in str(excinfo.value)
33+
34+
35+
@pytest.mark.parametrize("name,value", COLOR_MEMBERS)
36+
def test_return_color_success(name, value):
37+
assert m.return_color(value) == m.color[name]
38+
39+
40+
def test_return_color_fail():
41+
with pytest.raises(ValueError) as excinfo_direct:
42+
m.color(2)
43+
with pytest.raises(ValueError) as excinfo_cast:
44+
m.return_color(2)
45+
assert str(excinfo_cast.value) == str(excinfo_direct.value)

0 commit comments

Comments
 (0)