Skip to content

Commit

Permalink
Wrap errors in dataclass/attrs post-init methods
Browse files Browse the repository at this point in the history
When decoding/converting to a dataclass/attrs type with a
`__post_init__`/`__attrs_post_init__` method, we now wrap all
`ValueError`/`TypeError` exceptions in a `ValidationError`. This mirrors
the behavior used in the new `Struct.__post_init__` support, and helps
provide more uniform error handling support.
  • Loading branch information
jcrist committed Jul 6, 2023
1 parent d4ce822 commit 524de6b
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 13 deletions.
11 changes: 8 additions & 3 deletions msgspec/_core.c
Original file line number Diff line number Diff line change
Expand Up @@ -4746,7 +4746,6 @@ ms_error_with_path(const char *msg, PathNode *path) {
return NULL;
}

/* TODO */
static MS_NOINLINE void
ms_maybe_wrap_validation_error(PathNode *path) {
PyObject *exc_type, *exc, *tb;
Expand Down Expand Up @@ -7972,7 +7971,10 @@ DataclassInfo_post_decode(DataclassInfo *self, PyObject *obj, PathNode *path) {
}
if (self->post_init != NULL) {
PyObject *res = CALL_ONE_ARG(self->post_init, obj);
if (res == NULL) return -1;
if (res == NULL) {
ms_maybe_wrap_validation_error(path);
return -1;
}
Py_DECREF(res);
}
return 0;
Expand Down Expand Up @@ -19260,7 +19262,10 @@ convert_object_to_dataclass(
}
if (info->post_init != NULL) {
PyObject *res = CALL_ONE_ARG(info->post_init, out);
if (res == NULL) goto error;
if (res == NULL) {
ms_maybe_wrap_validation_error(path);
goto error;
}
Py_DECREF(res);
}
Py_LeaveRecursiveCall();
Expand Down
32 changes: 24 additions & 8 deletions tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -2617,16 +2617,24 @@ def __post_init__(self):
assert res.a == 1
assert called

def test_decode_dataclass_post_init_errors(self, proto):
@pytest.mark.parametrize("exc_class", [ValueError, TypeError, OSError])
def test_decode_dataclass_post_init_errors(self, proto, exc_class):
@dataclass
class Example:
a: int

def __post_init__(self):
raise ValueError("Oh no!")
raise exc_class("Oh no!")

with pytest.raises(ValueError, match="Oh no!"):
proto.decode(proto.encode({"a": 1}), type=Example)
expected = (
ValidationError if exc_class in (ValueError, TypeError) else exc_class
)

with pytest.raises(expected, match="Oh no!") as rec:
proto.decode(proto.encode([{"a": 1}]), type=List[Example])

if expected is ValidationError:
assert "- at `$[0]`" in str(rec.value)

def test_decode_dataclass_not_object(self, proto):
@dataclass
Expand Down Expand Up @@ -2811,16 +2819,24 @@ def __attrs_post_init__(self):
assert res.a == 1
assert called

def test_decode_attrs_post_init_errors(self, proto):
@pytest.mark.parametrize("exc_class", [ValueError, TypeError, OSError])
def test_decode_attrs_post_init_errors(self, proto, exc_class):
@attrs.define
class Example:
a: int

def __attrs_post_init__(self):
raise ValueError("Oh no!")
raise exc_class("Oh no!")

with pytest.raises(ValueError, match="Oh no!"):
proto.decode(proto.encode({"a": 1}), type=Example)
expected = (
ValidationError if exc_class in (ValueError, TypeError) else exc_class
)

with pytest.raises(expected, match="Oh no!") as rec:
proto.decode(proto.encode([{"a": 1}]), type=List[Example])

if expected is ValidationError:
assert "- at `$[0]`" in str(rec.value)

def test_decode_attrs_pre_init(self, proto):
called = False
Expand Down
4 changes: 2 additions & 2 deletions tests/test_convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -1344,7 +1344,7 @@ def __post_init__(self):

msg = mapcls(a=1)

with pytest.raises(ValueError, match="Oh no!"):
with pytest.raises(ValidationError, match="Oh no!"):
convert(msg, Example, from_attributes=from_attributes)

@mapcls_and_from_attributes
Expand Down Expand Up @@ -1529,7 +1529,7 @@ class Example:
def __attrs_post_init__(self):
raise ValueError("Oh no!")

with pytest.raises(ValueError, match="Oh no!"):
with pytest.raises(ValidationError, match="Oh no!"):
convert(mapcls(a=1), Example, from_attributes=from_attributes)

def test_attrs_to_attrs(self):
Expand Down

0 comments on commit 524de6b

Please sign in to comment.