diff --git a/garde/tests/rules/byte_length.rs b/garde/tests/rules/byte_length.rs index 0257928..139259a 100644 --- a/garde/tests/rules/byte_length.rs +++ b/garde/tests/rules/byte_length.rs @@ -1,7 +1,10 @@ use super::util; + +const UWU: usize = 101; + #[derive(Debug, garde::Validate)] struct Test<'a> { - #[garde(byte_length(min = 10, max = 100))] + #[garde(byte_length(min = 10, max = UWU - 1))] field: &'a str, #[garde(inner(length(min = 10, max = 100)))] diff --git a/garde_derive/src/check.rs b/garde_derive/src/check.rs index a5d25a5..5fff8fd 100644 --- a/garde_derive/src/check.rs +++ b/garde_derive/src/check.rs @@ -5,7 +5,7 @@ use syn::parse_quote; use syn::spanned::Spanned; use crate::model; -use crate::util::MaybeFoldError; +use crate::util::{default_ctx_name, MaybeFoldError}; pub fn check(input: model::Input) -> syn::Result { let model::Input { @@ -25,7 +25,7 @@ pub fn check(input: model::Input) -> syn::Result { Ok(v) => v, Err(e) => { error.maybe_fold(e); - parse_quote!(()) + (parse_quote!(()), default_ctx_name()) } }; @@ -92,7 +92,7 @@ fn check_attrs(attrs: &[(Span, model::Attr)]) -> syn::Result<()> { } } -fn get_context(attrs: &[(Span, model::Attr)]) -> syn::Result { +fn get_context(attrs: &[(Span, model::Attr)]) -> syn::Result<(syn::Type, syn::Ident)> { #![allow(clippy::single_match)] let error = None; @@ -100,7 +100,7 @@ fn get_context(attrs: &[(Span, model::Attr)]) -> syn::Result { for (_, attr) in attrs { match attr { - model::Attr::Context(ty) => context = Some(ty), + model::Attr::Context(ty, ident) => context = Some((ty, ident)), _ => {} } } @@ -110,8 +110,8 @@ fn get_context(attrs: &[(Span, model::Attr)]) -> syn::Result { } match context { - Some(v) => Ok((**v).clone()), - None => Ok(parse_quote!(())), + Some((ty, id)) => Ok(((**ty).clone(), (*id).clone())), + None => Ok((parse_quote!(()), default_ctx_name())), } } @@ -122,7 +122,7 @@ fn get_options(attrs: &[(Span, model::Attr)]) -> model::Options { for (_, attr) in attrs { match attr { - model::Attr::Context(_) => {} + model::Attr::Context(..) => {} model::Attr::AllowUnvalidated => options.allow_unvalidated = true, } } @@ -307,8 +307,8 @@ fn check_rule( IpV6 => apply!(rule_set, IpV6(), span), CreditCard => apply!(rule_set, CreditCard(), span), PhoneNumber => apply!(rule_set, PhoneNumber(), span), - Length(v) => apply!(rule_set, Length(check_range(v)?), span), - ByteLength(v) => apply!(rule_set, ByteLength(check_range(v)?), span), + Length(v) => apply!(rule_set, Length(check_range_generic(v)?), span), + ByteLength(v) => apply!(rule_set, ByteLength(check_range_generic(v)?), span), Range(v) => apply!(rule_set, Range(check_range_not_ord(v)?), span), Contains(v) => apply!(rule_set, Contains(v), span), Prefix(v) => apply!(rule_set, Prefix(v), span), @@ -339,6 +339,63 @@ trait CheckRange: Sized { fn check_range(self) -> syn::Result>; } +fn check_range_generic( + range: model::Range>, +) -> syn::Result>> +where + L: PartialOrd, +{ + macro_rules! map_validate_range { + ($value:expr, $wrapper:expr) => {{ + match $value { + model::ValidateRange::GreaterThan(v) => { + model::ValidateRange::GreaterThan($wrapper(v)) + } + model::ValidateRange::LowerThan(v) => model::ValidateRange::LowerThan($wrapper(v)), + model::ValidateRange::Between(v1, v2) => { + model::ValidateRange::Between($wrapper(v1), $wrapper(v2)) + } + } + }}; + } + + let range = match (range.span, range.min, range.max) { + (span, Some(model::Either::Left(min)), Some(model::Either::Left(max))) => { + map_validate_range!( + check_range(model::Range { + span, + min: Some(min), + max: Some(max) + })?, + model::Either::Left + ) + } + (span, Some(model::Either::Left(min)), None) => { + map_validate_range!( + check_range(model::Range { + span, + min: Some(min), + max: None, + })?, + model::Either::Left + ) + } + (span, None, Some(model::Either::Left(max))) => { + map_validate_range!( + check_range(model::Range { + span, + min: None, + max: Some(max), + })?, + model::Either::Left + ) + } + (span, min, max) => check_range_not_ord(model::Range { span, min, max })?, + }; + + Ok(range) +} + fn check_range(range: model::Range) -> syn::Result> where T: PartialOrd, diff --git a/garde_derive/src/emit.rs b/garde_derive/src/emit.rs index 9e9c3e7..36a807e 100644 --- a/garde_derive/src/emit.rs +++ b/garde_derive/src/emit.rs @@ -13,7 +13,7 @@ pub fn emit(input: model::Validate) -> TokenStream2 { impl ToTokens for model::Validate { fn to_tokens(&self, tokens: &mut TokenStream2) { let ident = &self.ident; - let context_ty = &self.context; + let (context_ty, context_ident) = &self.context; let (impl_generics, ty_generics, where_clause) = self.generics.split_for_impl(); let kind = &self.kind; @@ -22,7 +22,9 @@ impl ToTokens for model::Validate { type Context = #context_ty ; #[allow(clippy::needless_borrow)] - fn validate(&self, __garde_user_ctx: &Self::Context) -> ::core::result::Result<(), ::garde::error::Errors> { + fn validate(&self, #context_ident: &Self::Context) -> ::core::result::Result<(), ::garde::error::Errors> { + let __garde_user_ctx = &#context_ident; + ( #kind ) diff --git a/garde_derive/src/model.rs b/garde_derive/src/model.rs index a637e9f..a26d1cd 100644 --- a/garde_derive/src/model.rs +++ b/garde_derive/src/model.rs @@ -12,7 +12,7 @@ pub struct Input { #[repr(u8)] pub enum Attr { - Context(Box), + Context(Box, Ident), AllowUnvalidated, } @@ -27,7 +27,7 @@ impl Attr { pub fn name(&self) -> &'static str { match self { - Attr::Context(_) => "context", + Attr::Context(..) => "context", Attr::AllowUnvalidated => "allow_unvalidated", } } @@ -100,8 +100,8 @@ pub enum RawRuleKind { IpV6, CreditCard, PhoneNumber, - Length(Range), - ByteLength(Range), + Length(Range>), + ByteLength(Range>), Range(Range), Contains(Expr), Prefix(Expr), @@ -111,6 +111,24 @@ pub enum RawRuleKind { Inner(List), } +pub enum Either { + Left(L), + Right(R), +} + +impl quote::ToTokens for Either +where + L: quote::ToTokens, + R: quote::ToTokens, +{ + fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) { + match self { + Self::Left(left) => left.to_tokens(tokens), + Self::Right(right) => right.to_tokens(tokens), + } + } +} + pub enum Pattern { Lit(Str), Expr(Expr), @@ -135,7 +153,7 @@ pub struct List { pub struct Validate { pub ident: Ident, pub generics: Generics, - pub context: Type, + pub context: (Type, Ident), pub kind: ValidateKind, pub options: Options, } @@ -211,8 +229,8 @@ pub enum ValidateRule { IpV6, CreditCard, PhoneNumber, - Length(ValidateRange), - ByteLength(ValidateRange), + Length(ValidateRange>), + ByteLength(ValidateRange>), Range(ValidateRange), Contains(Expr), Prefix(Expr), diff --git a/garde_derive/src/syntax.rs b/garde_derive/src/syntax.rs index 66ded19..19d9d58 100644 --- a/garde_derive/src/syntax.rs +++ b/garde_derive/src/syntax.rs @@ -5,11 +5,12 @@ use syn::ext::IdentExt; use syn::parse::Parse; use syn::punctuated::Punctuated; use syn::spanned::Spanned; +use syn::token::As; use syn::{DeriveInput, Token, Type}; use crate::model; use crate::model::List; -use crate::util::MaybeFoldError; +use crate::util::{default_ctx_name, MaybeFoldError}; pub fn parse(input: DeriveInput) -> syn::Result { let mut error = None; @@ -90,7 +91,13 @@ impl Parse for model::Attr { let content; syn::parenthesized!(content in input); let ty = content.parse::()?; - Ok(model::Attr::Context(Box::new(ty))) + let ident = if content.parse::().is_ok() { + content.parse()? + } else { + default_ctx_name() + }; + + Ok(model::Attr::Context(Box::new(ty), ident)) } "allow_unvalidated" => Ok(model::Attr::AllowUnvalidated), _ => Err(syn::Error::new(ident.span(), "unrecognized attribute")), @@ -383,7 +390,11 @@ where } } - Ok(model::Range { span, min, max }) + if let Some(error) = error { + Err(error) + } else { + Ok(model::Range { span, min, max }) + } } } @@ -404,6 +415,18 @@ trait FromExpr: Sized { fn from_expr(v: syn::Expr) -> syn::Result; } +impl FromExpr for model::Either +where + L: FromExpr, + R: FromExpr, +{ + fn from_expr(v: syn::Expr) -> syn::Result { + L::from_expr(v.clone()) + .map(model::Either::Left) + .or_else(|_| R::from_expr(v).map(model::Either::Right)) + } +} + impl FromExpr for syn::Expr { fn from_expr(v: syn::Expr) -> syn::Result { Ok(v) diff --git a/garde_derive/src/util.rs b/garde_derive/src/util.rs index 1868a6d..830db26 100644 --- a/garde_derive/src/util.rs +++ b/garde_derive/src/util.rs @@ -12,3 +12,7 @@ impl MaybeFoldError for Option { } } } + +pub fn default_ctx_name() -> syn::Ident { + syn::Ident::new("__garde_user_ctx", proc_macro2::Span::call_site()) +}