Skip to content

Commit

Permalink
Add argument name to error messages in conversion errors
Browse files Browse the repository at this point in the history
  • Loading branch information
Askaholic committed Oct 13, 2020
1 parent 968e4c8 commit 067c5d0
Show file tree
Hide file tree
Showing 4 changed files with 78 additions and 12 deletions.
10 changes: 7 additions & 3 deletions pyo3-derive-backend/src/pymethod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -475,10 +475,14 @@ fn impl_arg_param(

let ty = arg.ty;
let name = arg.name;
let transform_error = quote! {
|e| pyo3::derive_utils::argument_extraction_error(_py, stringify!(#name), e)
};

if spec.is_args(&name) {
return quote! {
let #arg_name = <#ty as pyo3::FromPyObject>::extract(_args.as_ref())?;
let #arg_name = <#ty as pyo3::FromPyObject>::extract(_args.as_ref())
.map_err(#transform_error)?;
};
} else if spec.is_kwargs(&name) {
return quote! {
Expand Down Expand Up @@ -518,15 +522,15 @@ fn impl_arg_param(

quote! {
let #mut_ _tmp: #target_ty = match #arg_value {
Some(_obj) => _obj.extract()?,
Some(_obj) => _obj.extract().map_err(#transform_error)?,
None => #default,
};
let #arg_name = #borrow_tmp;
}
} else {
quote! {
let #arg_name = match #arg_value {
Some(_obj) => _obj.extract()?,
Some(_obj) => _obj.extract().map_err(#transform_error)?,
None => #default,
};
}
Expand Down
12 changes: 12 additions & 0 deletions src/derive_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,18 @@ pub fn parse_fn_args<'p>(
Ok((args, kwargs))
}

/// Add the argument name to the error message of an error which occurred during argument extraction
pub fn argument_extraction_error(
py: Python,
arg_name: &str,
original_error: PyErr,
) -> PyErr {
PyErr::from_type(
original_error.ptype(py),
format!("argument '{}': {}", arg_name, original_error.instance(py)),
)
}

/// `Sync` wrapper of `ffi::PyModuleDef`.
#[doc(hidden)]
pub struct ModuleDef(UnsafeCell<ffi::PyModuleDef>);
Expand Down
55 changes: 55 additions & 0 deletions tests/test_pyfunction.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,61 @@ use pyo3::{raw_pycfunction, wrap_pyfunction};

mod common;

#[pyfunction]
fn conversion_error(str_arg: &str, int_arg: i64) {
println!("{:?} {:?}", str_arg, int_arg);
}

#[test]
fn test_conversion_error() {
let gil = Python::acquire_gil();
let py = gil.python();

let conversion_error = wrap_pyfunction!(conversion_error)(py).unwrap();
py_expect_exception!(
py,
conversion_error,
"conversion_error('100, -100)",
PyTypeError,
"argument 'str_arg': Can't convert 100 to PyString"
);
py_expect_exception!(
py,
conversion_error,
"conversion_error('a string', 'another string')",
PyTypeError,
"argument 'int_arg': 'str' object cannot be interpreted as an integer"
);
}

#[pyfunction]
#[text_signature = "(arg1, arg2)"]
fn conversion_error_signature(tuple_arg: (&str, f64), option_arg: Option<i64>) {
println!("{:?} {:?}", tuple_arg, option_arg);
}

#[test]
fn test_conversion_error_signature() {
let gil = Python::acquire_gil();
let py = gil.python();

let conversion_error_signature = wrap_pyfunction!(conversion_error_signature)(py).unwrap();
py_expect_exception!(
py,
conversion_error_signature,
"conversion_error_signature('a string', 'another string')",
PyTypeError,
"argument 'arg1': Can't convert 'a string' to PyTuple"
);
py_expect_exception!(
py,
conversion_error_signature,
"conversion_error_signature('100, -100)",
PyTypeError,
"argument 'arg2': Can't convert '-100' to Option<PyLong>"
);
}

#[pyfunction(arg = "true")]
fn optional_bool(arg: Option<bool>) -> String {
format!("{:?}", arg)
Expand Down
13 changes: 4 additions & 9 deletions tests/test_string.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
use pyo3::prelude::*;
use pyo3::py_run;
use pyo3::wrap_pyfunction;

mod common;
Expand All @@ -15,15 +14,11 @@ fn test_unicode_encode_error() {
let py = gil.python();

let take_str = wrap_pyfunction!(take_str)(py).unwrap();
py_run!(
py_expect_exception!(
py,
take_str,
r#"
try:
take_str('\ud800')
except UnicodeEncodeError as e:
error_msg = "'utf-8' codec can't encode character '\\ud800' in position 0: surrogates not allowed"
assert str(e) == error_msg
"#
"take_str('\\ud800')",
PyUnicodeEncodeError,
"argument '_s': 'utf-8' codec can't encode character '\\ud800' in position 0: surrogates not allowed"
);
}

0 comments on commit 067c5d0

Please sign in to comment.