Skip to content

Commit

Permalink
Changing pybind11::str to only hold PyUnicodeObject (NOT also `by…
Browse files Browse the repository at this point in the history
…tes`).

The corresponding behavior changes are captured by changes in the tests. A significant effort was made to keep the test diffs minimal but also comprehensive and easy to read.

Note: Unlike PR #2256 (dropped), this PR only changes exactly one behavior. The two other behavior changes discussed under PR #2256 are avoided here (1. disabling implicit decoding from bytes to unicode; 2. list_caster behavior change). Based on this PR, those can be easily implemented if and when desired.
  • Loading branch information
Ralf W. Grosse-Kunstleve committed Aug 11, 2020
1 parent 161ad4e commit 60beaf3
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 12 deletions.
18 changes: 18 additions & 0 deletions include/pybind11/cast.h
Original file line number Diff line number Diff line change
Expand Up @@ -1629,7 +1629,25 @@ struct pyobject_caster {
template <typename T = type, enable_if_t<std::is_same<T, handle>::value, int> = 0>
bool load(handle src, bool /* convert */) { value = src; return static_cast<bool>(value); }

#ifdef PYBIND11_DISABLE_IMPLICIT_STR_FROM_BYTES
template <typename T = type, enable_if_t<std::is_base_of<object, T>::value, int> = 0>
#else
template <typename T = type, enable_if_t<std::is_same<T, str>::value, int> = 0>
bool load(handle src, bool /* convert */) {
if (isinstance<bytes>(src)) {
PyObject *str_from_bytes = PyUnicode_FromEncodedObject(src.ptr(), "utf-8", nullptr);
if (!str_from_bytes) throw error_already_set();
value = reinterpret_steal<type>(str_from_bytes);
return true;
}
if (!isinstance<type>(src))
return false;
value = reinterpret_borrow<type>(src);
return true;
}

template <typename T = type, enable_if_t<std::is_base_of<object, T>::value && !std::is_same<T, str>::value, int> = 0>
#endif
bool load(handle src, bool /* convert */) {
if (!isinstance<type>(src))
return false;
Expand Down
4 changes: 1 addition & 3 deletions include/pybind11/pytypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -738,8 +738,6 @@ inline bool PyIterable_Check(PyObject *obj) {
inline bool PyNone_Check(PyObject *o) { return o == Py_None; }
inline bool PyEllipsis_Check(PyObject *o) { return o == Py_Ellipsis; }

inline bool PyUnicode_Check_Permissive(PyObject *o) { return PyUnicode_Check(o) || PYBIND11_BYTES_CHECK(o); }

inline bool PyStaticMethod_Check(PyObject *o) { return o->ob_type == &PyStaticMethod_Type; }

class kwargs_proxy : public handle {
Expand Down Expand Up @@ -885,7 +883,7 @@ class bytes;

class str : public object {
public:
PYBIND11_OBJECT_CVT(str, object, detail::PyUnicode_Check_Permissive, raw_str)
PYBIND11_OBJECT_CVT(str, object, PyUnicode_Check, raw_str)

str(const char *c, size_t n)
: object(PyUnicode_FromStringAndSize(c, (ssize_t) n), stolen_t{}) {
Expand Down
2 changes: 1 addition & 1 deletion include/pybind11/stl.h
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ template <typename Type, typename Value> struct list_caster {
using value_conv = make_caster<Value>;

bool load(handle src, bool convert) {
if (!isinstance<sequence>(src) || isinstance<str>(src))
if (!isinstance<sequence>(src) || isinstance<bytes>(src) || isinstance<str>(src))
return false;
auto s = reinterpret_borrow<sequence>(src);
value.clear();
Expand Down
26 changes: 18 additions & 8 deletions tests/test_pytypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,13 +249,21 @@ def test_pybind11_str_raw_str():
valid_orig = u"DZ"
valid_utf8 = valid_orig.encode("utf-8")
valid_cvt = cvt(valid_utf8)
assert type(valid_cvt) == bytes # Probably surprising.
assert valid_cvt == b'\xc7\xb1'
assert type(valid_cvt) == type(u"") # Py2 unicode, Py3 str, flake8 compatible
if str is bytes:
assert valid_cvt == valid_orig
else:
assert valid_cvt == u"b'\\xc7\\xb1'"

malformed_utf8 = b'\x80'
malformed_cvt = cvt(malformed_utf8)
assert type(malformed_cvt) == bytes # Probably surprising.
assert malformed_cvt == b'\x80'
if str is bytes:
with pytest.raises(UnicodeDecodeError) as excinfo:
cvt(malformed_utf8)
assert "invalid start byte" in str(excinfo)
else:
malformed_cvt = cvt(malformed_utf8)
assert type(valid_cvt) == type(u"")
assert malformed_cvt == u"b'\\x80'"


def test_implicit_casting():
Expand Down Expand Up @@ -392,19 +400,21 @@ def test_isinstance_string_types():
assert not m.isinstance_pybind11_bytes(u"")

assert m.isinstance_pybind11_str(u"")
assert m.isinstance_pybind11_str(b"") # Probably surprising.
assert not m.isinstance_pybind11_str(b"")


def test_pass_bytes_or_unicode_to_string_types():
assert m.pass_to_pybind11_bytes(b"Bytes") == 5
with pytest.raises(TypeError):
m.pass_to_pybind11_bytes(u"Str") # NO implicit encode

assert m.pass_to_pybind11_str(b"Bytes") == 5
assert m.pass_to_pybind11_str(b"Bytes") == 5 # implicit decode
assert m.pass_to_pybind11_str(u"Str") == 3

assert m.pass_to_std_string(b"Bytes") == 5
assert m.pass_to_std_string(u"Str") == 3

malformed_utf8 = b"\x80"
assert m.pass_to_pybind11_str(malformed_utf8) == 1 # NO decoding error
with pytest.raises(UnicodeDecodeError) as excinfo:
m.pass_to_pybind11_str(malformed_utf8)
assert 'invalid start byte' in str(excinfo.value)

0 comments on commit 60beaf3

Please sign in to comment.