Skip to content

Commit

Permalink
Implicit conversions to bool + np.bool_ conversion (#925)
Browse files Browse the repository at this point in the history
This adds support for implicit conversions to bool from Python types
with `__bool__` (Python 3) or `__nonzero__` (Python 2) attributes, and
adds direct (i.e. non-converting) support for numpy bools.
  • Loading branch information
aldanor authored and jagerman committed Jul 23, 2017
1 parent a03408c commit e07f758
Show file tree
Hide file tree
Showing 4 changed files with 92 additions and 2 deletions.
30 changes: 28 additions & 2 deletions include/pybind11/cast.h
Original file line number Diff line number Diff line change
Expand Up @@ -1049,11 +1049,37 @@ template <> class type_caster<std::nullptr_t> : public void_caster<std::nullptr_

template <> class type_caster<bool> {
public:
bool load(handle src, bool) {
bool load(handle src, bool convert) {
if (!src) return false;
else if (src.ptr() == Py_True) { value = true; return true; }
else if (src.ptr() == Py_False) { value = false; return true; }
else return false;
else if (convert || !strcmp("numpy.bool_", Py_TYPE(src.ptr())->tp_name)) {
// (allow non-implicit conversion for numpy booleans)

Py_ssize_t res = -1;
if (src.is_none()) {
res = 0; // None is implicitly converted to False
}
#if defined(PYPY_VERSION)
// On PyPy, check that "__bool__" (or "__nonzero__" on Python 2.7) attr exists
else if (hasattr(src, PYBIND11_BOOL_ATTR)) {
res = PyObject_IsTrue(src.ptr());
}
#else
// Alternate approach for CPython: this does the same as the above, but optimized
// using the CPython API so as to avoid an unneeded attribute lookup.
else if (auto tp_as_number = src.ptr()->ob_type->tp_as_number) {
if (PYBIND11_NB_BOOL(tp_as_number)) {
res = (*PYBIND11_NB_BOOL(tp_as_number))(src.ptr());
}
}
#endif
if (res == 0 || res == 1) {
value = (bool) res;
return true;
}
}
return false;
}
static handle cast(bool src, return_value_policy /* policy */, handle /* parent */) {
return handle(src ? Py_True : Py_False).inc_ref();
Expand Down
5 changes: 5 additions & 0 deletions include/pybind11/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -152,8 +152,11 @@
#define PYBIND11_SLICE_OBJECT PyObject
#define PYBIND11_FROM_STRING PyUnicode_FromString
#define PYBIND11_STR_TYPE ::pybind11::str
#define PYBIND11_BOOL_ATTR "__bool__"
#define PYBIND11_NB_BOOL(ptr) ((ptr)->nb_bool)
#define PYBIND11_PLUGIN_IMPL(name) \
extern "C" PYBIND11_EXPORT PyObject *PyInit_##name()

#else
#define PYBIND11_INSTANCE_METHOD_NEW(ptr, class_) PyMethod_New(ptr, nullptr, class_)
#define PYBIND11_INSTANCE_METHOD_CHECK PyMethod_Check
Expand All @@ -171,6 +174,8 @@
#define PYBIND11_SLICE_OBJECT PySliceObject
#define PYBIND11_FROM_STRING PyString_FromString
#define PYBIND11_STR_TYPE ::pybind11::bytes
#define PYBIND11_BOOL_ATTR "__nonzero__"
#define PYBIND11_NB_BOOL(ptr) ((ptr)->nb_nonzero)
#define PYBIND11_PLUGIN_IMPL(name) \
static PyObject *pybind11_init_wrapper(); \
extern "C" PYBIND11_EXPORT void init##name() { \
Expand Down
4 changes: 4 additions & 0 deletions tests/test_builtin_casters.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,10 @@ TEST_SUBMODULE(builtin_casters, m) {
m.def("load_nullptr_t", [](std::nullptr_t) {}); // not useful, but it should still compile
m.def("cast_nullptr_t", []() { return std::nullptr_t{}; });

// test_bool_caster
m.def("bool_passthrough", [](bool arg) { return arg; });
m.def("bool_passthrough_noconvert", [](bool arg) { return arg; }, py::arg().noconvert());

// test_reference_wrapper
m.def("refwrap_builtin", [](std::reference_wrapper<int> p) { return 10 * p.get(); });
m.def("refwrap_usertype", [](std::reference_wrapper<UserType> p) { return p.get().value(); });
Expand Down
55 changes: 55 additions & 0 deletions tests/test_builtin_casters.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,3 +265,58 @@ def test_complex_cast():
"""std::complex casts"""
assert m.complex_cast(1) == "1.0"
assert m.complex_cast(2j) == "(0.0, 2.0)"


def test_bool_caster():
"""Test bool caster implicit conversions."""
convert, noconvert = m.bool_passthrough, m.bool_passthrough_noconvert

def require_implicit(v):
pytest.raises(TypeError, noconvert, v)

def cant_convert(v):
pytest.raises(TypeError, convert, v)

# straight up bool
assert convert(True) is True
assert convert(False) is False
assert noconvert(True) is True
assert noconvert(False) is False

# None requires implicit conversion
require_implicit(None)
assert convert(None) is False

class A(object):
def __init__(self, x):
self.x = x

def __nonzero__(self):
return self.x

def __bool__(self):
return self.x

class B(object):
pass

# Arbitrary objects are not accepted
cant_convert(object())
cant_convert(B())

# Objects with __nonzero__ / __bool__ defined can be converted
require_implicit(A(True))
assert convert(A(True)) is True
assert convert(A(False)) is False


@pytest.requires_numpy
def test_numpy_bool():
import numpy as np
convert, noconvert = m.bool_passthrough, m.bool_passthrough_noconvert

# np.bool_ is not considered implicit
assert convert(np.bool_(True)) is True
assert convert(np.bool_(False)) is False
assert noconvert(np.bool_(True)) is True
assert noconvert(np.bool_(False)) is False

0 comments on commit e07f758

Please sign in to comment.