Skip to content

Commit

Permalink
Let operator ? only wrap the error in Box if needed
Browse files Browse the repository at this point in the history
This leads to fewer needless allocations, and makes the operator `?`
usable in an no-alloc context.
  • Loading branch information
Kijewski committed Dec 21, 2024
1 parent bedc317 commit fe36573
Show file tree
Hide file tree
Showing 6 changed files with 216 additions and 27 deletions.
59 changes: 55 additions & 4 deletions rinja/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use alloc::boxed::Box;
use core::convert::Infallible;
use core::error::Error as StdError;
use core::fmt;
use core::marker::PhantomData;
#[cfg(feature = "std")]
use std::io;

Expand Down Expand Up @@ -121,7 +122,7 @@ impl From<Box<dyn StdError + Send + Sync>> for Error {
impl From<io::Error> for Error {
#[inline]
fn from(err: io::Error) -> Self {
from_from_io_error(err, MAX_ERROR_UNWRAP_COUNT)
error_from_io_error(err, MAX_ERROR_UNWRAP_COUNT)
}
}

Expand All @@ -145,7 +146,7 @@ fn error_from_stderror(err: Box<dyn StdError + Send + Sync>, unwraps: usize) ->
},
#[cfg(feature = "std")]
ErrorKind::Io => match err.downcast() {
Ok(err) => from_from_io_error(*err, unwraps),
Ok(err) => error_from_io_error(*err, unwraps),
Err(_) => Error::Fmt, // unreachable
},
ErrorKind::Rinja => match err.downcast() {
Expand All @@ -156,7 +157,7 @@ fn error_from_stderror(err: Box<dyn StdError + Send + Sync>, unwraps: usize) ->
}

#[cfg(feature = "std")]
fn from_from_io_error(err: io::Error, unwraps: usize) -> Error {
fn error_from_io_error(err: io::Error, unwraps: usize) -> Error {
let Some(inner) = err.get_ref() else {
return Error::custom(err);
};
Expand All @@ -182,7 +183,7 @@ fn from_from_io_error(err: io::Error, unwraps: usize) -> Error {
None => Error::Fmt, // unreachable
},
ErrorKind::Io => match err.downcast() {
Ok(inner) => from_from_io_error(inner, unwraps),
Ok(inner) => error_from_io_error(inner, unwraps),
Err(_) => Error::Fmt, // unreachable
},
}
Expand Down Expand Up @@ -244,3 +245,53 @@ const _: () = {
trait AssertSendSyncStatic: Send + Sync + 'static {}
impl AssertSendSyncStatic for Error {}
};

/// Helper trait to convert a custom `?` call into a [`crate::Result`]
pub trait ResultConverter {
/// Okay Value type of the output
type Value;
/// Input type
type Input;

/// Consume an interior mutable `self`, and turn it into a [`crate::Result`]
fn rinja_conv_result(self, result: Self::Input) -> Result<Self::Value, Error>;
}

/// Helper marker to be used with [`ResultConverter`]
#[derive(Debug, Clone, Copy)]
pub struct ErrorMarker<T>(PhantomData<Result<T>>);

impl<T> ErrorMarker<T> {
/// Get marker for a [`Result`] type
#[inline]
pub fn of(_: &T) -> Self {
Self(PhantomData)
}
}

#[cfg(feature = "alloc")]
impl<T, E> ResultConverter for &ErrorMarker<Result<T, E>>
where
E: Into<Box<dyn StdError + Send + Sync>>,
{
type Value = T;
type Input = Result<T, E>;

#[inline]
fn rinja_conv_result(self, result: Self::Input) -> Result<Self::Value, Error> {
result.map_err(Error::custom)
}
}

impl<T, E> ResultConverter for &&ErrorMarker<Result<T, E>>
where
E: Into<Error>,
{
type Value = T;
type Input = Result<T, E>;

#[inline]
fn rinja_conv_result(self, result: Self::Input) -> Result<Self::Value, Error> {
result.map_err(Into::into)
}
}
10 changes: 1 addition & 9 deletions rinja/src/helpers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ use core::iter::{Enumerate, Peekable};
use core::ops::Deref;
use core::pin::Pin;

pub use crate::error::{ErrorMarker, ResultConverter};
use crate::filters::FastWritable;

pub struct TemplateLoop<I>
Expand Down Expand Up @@ -267,12 +268,3 @@ impl<L: FastWritable, R: FastWritable> FastWritable for Concat<L, R> {
self.1.write_into(dest)
}
}

#[inline]
#[cfg(feature = "alloc")]
pub fn map_try<T, E>(result: Result<T, E>) -> Result<T, crate::Error>
where
E: Into<alloc::boxed::Box<dyn std::error::Error + Send + Sync>>,
{
result.map_err(crate::Error::custom)
}
12 changes: 3 additions & 9 deletions rinja_derive/src/generator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,7 @@ impl<'a, 'h> Generator<'a, 'h> {
RinjaW: rinja::helpers::core::fmt::Write + ?rinja::helpers::core::marker::Sized\
{\
use rinja::filters::{AutoEscape as _, WriteWritable as _};\
use rinja::helpers::ResultConverter as _;
use rinja::helpers::core::fmt::Write as _;",
);

Expand Down Expand Up @@ -1496,16 +1497,9 @@ impl<'a, 'h> Generator<'a, 'h> {
buf: &mut Buffer,
expr: &WithSpan<'_, Expr<'_>>,
) -> Result<DisplayWrap, CompileError> {
if !cfg!(feature = "alloc") {
return Err(ctx.generate_error(
"the `?` operator requires the `alloc` feature to be enabled",
expr.span(),
));
}

buf.write("rinja::helpers::map_try(");
buf.write("match (");
self.visit_expr(ctx, buf, expr)?;
buf.write(")?");
buf.write(") { res => (&&rinja::helpers::ErrorMarker::of(&res)).rinja_conv_result(res)? }");
Ok(DisplayWrap::Unwrapped)
}

Expand Down
1 change: 1 addition & 0 deletions rinja_derive/src/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ fn compare(jinja: &str, expected: &str, fields: &[(&str, &str)], size_hint: usiz
RinjaW: rinja::helpers::core::fmt::Write + ?rinja::helpers::core::marker::Sized,
{
use rinja::filters::{AutoEscape as _, WriteWritable as _};
use rinja::helpers::ResultConverter as _;
use rinja::helpers::core::fmt::Write as _;
#expected
rinja::Result::Ok(())
Expand Down
1 change: 1 addition & 0 deletions testing/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ core = { package = "intentionally-empty", version = "1.0.0" }
[dev-dependencies]
rinja = { path = "../rinja", version = "0.3.5", features = ["code-in-doc", "serde_json"] }

assert_matches = "1.5.0"
criterion = "0.5"
phf = { version = "0.11", features = ["macros" ] }
trybuild = "1.0.100"
Expand Down
160 changes: 155 additions & 5 deletions testing/tests/try.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
use std::{fmt, io};

use assert_matches::assert_matches;
use rinja::Template;

#[test]
Expand All @@ -15,7 +18,7 @@ fn test_int_parser() {
}

let template = IntParserTemplate { s: "💯" };
assert!(matches!(template.render(), Err(rinja::Error::Custom(_))));
assert_matches!(template.render(), Err(rinja::Error::Custom(_)));
assert_eq!(
format!("{}", &template.render().unwrap_err()),
"invalid digit found in string"
Expand All @@ -34,17 +37,17 @@ fn fail_fmt() {
}

impl FailFmt {
fn value(&self) -> Result<&'static str, std::fmt::Error> {
fn value(&self) -> Result<&'static str, fmt::Error> {
if let Some(inner) = self.inner {
Ok(inner)
} else {
Err(std::fmt::Error)
Err(fmt::Error)
}
}
}

let template = FailFmt { inner: None };
assert!(matches!(template.render(), Err(rinja::Error::Custom(_))));
assert_matches!(template.render(), Err(rinja::Error::Fmt));
assert_eq!(
format!("{}", &template.render().unwrap_err()),
format!("{}", std::fmt::Error)
Expand Down Expand Up @@ -75,9 +78,156 @@ fn fail_str() {
}

let template = FailStr { value: false };
assert!(matches!(template.render(), Err(rinja::Error::Custom(_))));
assert_matches!(template.render(), Err(rinja::Error::Custom(_)));
assert_eq!(format!("{}", &template.render().unwrap_err()), "FAIL");

let template = FailStr { value: true };
assert_eq!(template.render().unwrap(), "hello world");
}

#[test]
fn error_conversion_from_fmt() {
#[derive(Template)]
#[template(source = "{{ value()? }}", ext = "txt")]
struct ResultTemplate {
succeed: bool,
}

impl ResultTemplate {
fn value(&self) -> Result<&'static str, fmt::Error> {
match self.succeed {
true => Ok("hello"),
false => Err(fmt::Error),
}
}
}

assert_matches!(
ResultTemplate { succeed: true }.render().as_deref(),
Ok("hello")
);
assert_matches!(
ResultTemplate { succeed: false }.render().as_deref(),
Err(rinja::Error::Fmt)
);
}

#[test]
fn error_conversion_from_rinja_custom() {
#[derive(Template)]
#[template(source = "{{ value()? }}", ext = "txt")]
struct ResultTemplate {
succeed: bool,
}

impl ResultTemplate {
fn value(&self) -> Result<&'static str, rinja::Error> {
match self.succeed {
true => Ok("hello"),
false => Err(rinja::Error::custom(CustomError)),
}
}
}

#[derive(Debug)]
struct CustomError;

impl fmt::Display for CustomError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str("custom")
}
}

impl std::error::Error for CustomError {}

assert_matches!(
ResultTemplate { succeed: true }.render().as_deref(),
Ok("hello")
);

let err = match (ResultTemplate { succeed: false }.render().unwrap_err()) {
rinja::Error::Custom(err) => err,
err => panic!("Expected Error::Custom(_), got {err:#?}"),
};
assert!(err.is::<CustomError>());
}

#[test]
fn error_conversion_from_custom() {
#[derive(Template)]
#[template(source = "{{ value()? }}", ext = "txt")]
struct ResultTemplate {
succeed: bool,
}

impl ResultTemplate {
fn value(&self) -> Result<&'static str, CustomError> {
match self.succeed {
true => Ok("hello"),
false => Err(CustomError),
}
}
}

#[derive(Debug)]
struct CustomError;

impl fmt::Display for CustomError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str("custom")
}
}

impl std::error::Error for CustomError {}

assert_matches!(
ResultTemplate { succeed: true }.render().as_deref(),
Ok("hello")
);

let err = match (ResultTemplate { succeed: false }.render().unwrap_err()) {
rinja::Error::Custom(err) => err,
err => panic!("Expected Error::Custom(_), got {err:#?}"),
};
assert!(err.is::<CustomError>());
}

#[test]
fn error_conversion_from_wrapped_in_io() {
#[derive(Template)]
#[template(source = "{{ value()? }}", ext = "txt")]
struct ResultTemplate {
succeed: bool,
}

impl ResultTemplate {
fn value(&self) -> Result<&'static str, io::Error> {
match self.succeed {
true => Ok("hello"),
false => Err(io::Error::new(io::ErrorKind::InvalidData, CustomError)),
}
}
}

#[derive(Debug)]
struct CustomError;

impl fmt::Display for CustomError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str("custom")
}
}

impl std::error::Error for CustomError {}

assert_matches!(
ResultTemplate { succeed: true }.render().as_deref(),
Ok("hello")
);

let err = match (ResultTemplate { succeed: false }.render().unwrap_err()) {
rinja::Error::Custom(err) => err,
err => panic!("Expected Error::Custom(_), got {err:#?}"),
};
assert!(err.is::<CustomError>());
}

0 comments on commit fe36573

Please sign in to comment.