Skip to content

Commit

Permalink
Support #[derive(FromRequest)] on enums (#1009)
Browse files Browse the repository at this point in the history
* Support `#[from_request(via(...))]` on enums

* Check `#[from_request]` on variants

* check for non enum/struct and clean up

* changelog

* changelog

* remove needless feature

* changelog ref
  • Loading branch information
davidpdrsn authored May 8, 2022
1 parent a3a32f4 commit 852e548
Show file tree
Hide file tree
Showing 14 changed files with 215 additions and 25 deletions.
2 changes: 2 additions & 0 deletions axum-macros/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

- **fixed:** `Option` and `Result` are now supported in typed path route handler parameters ([#1001])
- **fixed:** Support wildcards in typed paths ([#1003])
- **added:** Support `#[derive(FromRequest)]` on enums using `#[from_request(via(OtherExtractor))]` ([#1009])

[#1001]: https://github.com/tokio-rs/axum/pull/1001
[#1003]: https://github.com/tokio-rs/axum/pull/1003
[#1009]: https://github.com/tokio-rs/axum/pull/1009

# 0.2.0 (31. March, 2022)

Expand Down
146 changes: 123 additions & 23 deletions axum-macros/src/from_request.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,24 +3,77 @@ use self::attr::{
RejectionDeriveOptOuts,
};
use heck::ToUpperCamelCase;
use proc_macro2::TokenStream;
use proc_macro2::{Span, TokenStream};
use quote::{format_ident, quote, quote_spanned};
use syn::{punctuated::Punctuated, spanned::Spanned, Token};

mod attr;

const GENERICS_ERROR: &str = "`#[derive(FromRequest)] doesn't support generics";
pub(crate) fn expand(item: syn::Item) -> syn::Result<TokenStream> {
match item {
syn::Item::Struct(item) => {
let syn::ItemStruct {
attrs,
ident,
generics,
fields,
semi_token: _,
vis,
struct_token: _,
} = item;

error_on_generics(generics)?;

match parse_container_attrs(&attrs)? {
FromRequestContainerAttr::Via(path) => {
impl_struct_by_extracting_all_at_once(ident, fields, path)
}
FromRequestContainerAttr::RejectionDerive(_, opt_outs) => {
impl_struct_by_extracting_each_field(ident, fields, vis, opt_outs)
}
FromRequestContainerAttr::None => impl_struct_by_extracting_each_field(
ident,
fields,
vis,
RejectionDeriveOptOuts::default(),
),
}
}
syn::Item::Enum(item) => {
let syn::ItemEnum {
attrs,
vis: _,
enum_token: _,
ident,
generics,
brace_token: _,
variants,
} = item;

error_on_generics(generics)?;

match parse_container_attrs(&attrs)? {
FromRequestContainerAttr::Via(path) => {
impl_enum_by_extracting_all_at_once(ident, variants, path)
}
FromRequestContainerAttr::RejectionDerive(rejection_derive, _) => {
Err(syn::Error::new_spanned(
rejection_derive,
"cannot use `rejection_derive` on enums",
))
}
FromRequestContainerAttr::None => Err(syn::Error::new(
Span::call_site(),
"missing `#[from_request(via(...))]`",
)),
}
}
_ => Err(syn::Error::new_spanned(item, "expected `struct` or `enum`")),
}
}

pub(crate) fn expand(item: syn::ItemStruct) -> syn::Result<TokenStream> {
let syn::ItemStruct {
attrs,
ident,
generics,
fields,
semi_token: _,
vis,
struct_token: _,
} = item;
fn error_on_generics(generics: syn::Generics) -> syn::Result<()> {
const GENERICS_ERROR: &str = "`#[derive(FromRequest)] doesn't support generics";

if !generics.params.is_empty() {
return Err(syn::Error::new_spanned(generics, GENERICS_ERROR));
Expand All @@ -30,18 +83,10 @@ pub(crate) fn expand(item: syn::ItemStruct) -> syn::Result<TokenStream> {
return Err(syn::Error::new_spanned(where_clause, GENERICS_ERROR));
}

match parse_container_attrs(&attrs)? {
FromRequestContainerAttr::Via(path) => impl_by_extracting_all_at_once(ident, fields, path),
FromRequestContainerAttr::RejectionDerive(opt_outs) => {
impl_by_extracting_each_field(ident, fields, vis, opt_outs)
}
FromRequestContainerAttr::None => {
impl_by_extracting_each_field(ident, fields, vis, RejectionDeriveOptOuts::default())
}
}
Ok(())
}

fn impl_by_extracting_each_field(
fn impl_struct_by_extracting_each_field(
ident: syn::Ident,
fields: syn::Fields,
vis: syn::Visibility,
Expand Down Expand Up @@ -413,7 +458,7 @@ fn rejection_variant_name(field: &syn::Field) -> syn::Result<syn::Ident> {
}
}

fn impl_by_extracting_all_at_once(
fn impl_struct_by_extracting_all_at_once(
ident: syn::Ident,
fields: syn::Fields,
path: syn::Path,
Expand Down Expand Up @@ -459,6 +504,61 @@ fn impl_by_extracting_all_at_once(
})
}

fn impl_enum_by_extracting_all_at_once(
ident: syn::Ident,
variants: Punctuated<syn::Variant, Token![,]>,
path: syn::Path,
) -> syn::Result<TokenStream> {
for variant in variants {
let FromRequestFieldAttr { via } = parse_field_attrs(&variant.attrs)?;
if let Some((via, _)) = via {
return Err(syn::Error::new_spanned(
via,
"`#[from_request(via(...))]` cannot be used on variants",
));
}

let fields = match variant.fields {
syn::Fields::Named(fields) => fields.named.into_iter(),
syn::Fields::Unnamed(fields) => fields.unnamed.into_iter(),
syn::Fields::Unit => Punctuated::<_, Token![,]>::new().into_iter(),
};

for field in fields {
let FromRequestFieldAttr { via } = parse_field_attrs(&field.attrs)?;
if let Some((via, _)) = via {
return Err(syn::Error::new_spanned(
via,
"`#[from_request(via(...))]` cannot be used inside variants",
));
}
}
}

let path_span = path.span();

Ok(quote_spanned! {path_span=>
#[::axum::async_trait]
#[automatically_derived]
impl<B> ::axum::extract::FromRequest<B> for #ident
where
B: ::axum::body::HttpBody + ::std::marker::Send + 'static,
B::Data: ::std::marker::Send,
B::Error: ::std::convert::Into<::axum::BoxError>,
{
type Rejection = <#path<Self> as ::axum::extract::FromRequest<B>>::Rejection;

async fn from_request(
req: &mut ::axum::extract::RequestParts<B>,
) -> ::std::result::Result<Self, Self::Rejection> {
::axum::extract::FromRequest::<B>::from_request(req)
.await
.map(|#path(inner)| inner)
}
}
})
}

#[test]
fn ui() {
#[rustversion::stable]
Expand Down
6 changes: 4 additions & 2 deletions axum-macros/src/from_request/attr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ pub(crate) struct FromRequestFieldAttr {

pub(crate) enum FromRequestContainerAttr {
Via(syn::Path),
RejectionDerive(RejectionDeriveOptOuts),
RejectionDerive(kw::rejection_derive, RejectionDeriveOptOuts),
None,
}

Expand Down Expand Up @@ -91,7 +91,9 @@ pub(crate) fn parse_container_attrs(
}
}
(Some((_, _, path)), None) => Ok(FromRequestContainerAttr::Via(path)),
(None, Some((_, _, opt_outs))) => Ok(FromRequestContainerAttr::RejectionDerive(opt_outs)),
(None, Some((_, rejection_derive, opt_outs))) => Ok(
FromRequestContainerAttr::RejectionDerive(rejection_derive, opt_outs),
),
(None, None) => Ok(FromRequestContainerAttr::None),
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
use axum_macros::FromRequest;

#[derive(FromRequest, Clone)]
#[from_request(via(axum::Extension))]
enum Extractor {
Foo {
#[from_request(via(axum::Extension))]
foo: (),
}
}

fn main() {}
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
error: `#[from_request(via(...))]` cannot be used inside variants
--> tests/from_request/fail/enum_from_request_ident_in_variant.rs:7:24
|
7 | #[from_request(via(axum::Extension))]
| ^^^
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
use axum_macros::FromRequest;

#[derive(FromRequest, Clone)]
#[from_request(via(axum::Extension))]
enum Extractor {
#[from_request(via(axum::Extension))]
Foo,
}

fn main() {}
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
error: `#[from_request(via(...))]` cannot be used on variants
--> tests/from_request/fail/enum_from_request_on_variant.rs:6:20
|
6 | #[from_request(via(axum::Extension))]
| ^^^
6 changes: 6 additions & 0 deletions axum-macros/tests/from_request/fail/enum_no_via.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
use axum_macros::FromRequest;

#[derive(FromRequest, Clone)]
enum Extractor {}

fn main() {}
7 changes: 7 additions & 0 deletions axum-macros/tests/from_request/fail/enum_no_via.stderr
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
error: missing `#[from_request(via(...))]`
--> tests/from_request/fail/enum_no_via.rs:3:10
|
3 | #[derive(FromRequest, Clone)]
| ^^^^^^^^^^^
|
= note: this error originates in the derive macro `FromRequest` (in Nightly builds, run with -Z macro-backtrace for more info)
7 changes: 7 additions & 0 deletions axum-macros/tests/from_request/fail/enum_rejection_derive.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
use axum_macros::FromRequest;

#[derive(FromRequest, Clone)]
#[from_request(rejection_derive(!Error))]
enum Extractor {}

fn main() {}
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
error: cannot use `rejection_derive` on enums
--> tests/from_request/fail/enum_rejection_derive.rs:4:16
|
4 | #[from_request(rejection_derive(!Error))]
| ^^^^^^^^^^^^^^^^
6 changes: 6 additions & 0 deletions axum-macros/tests/from_request/fail/not_enum_or_struct.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
use axum_macros::FromRequest;

#[derive(FromRequest)]
union Extractor {}

fn main() {}
11 changes: 11 additions & 0 deletions axum-macros/tests/from_request/fail/not_enum_or_struct.stderr
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
error: expected `struct` or `enum`
--> tests/from_request/fail/not_enum_or_struct.rs:4:1
|
4 | union Extractor {}
| ^^^^^^^^^^^^^^^^^^

error: unions cannot have zero fields
--> tests/from_request/fail/not_enum_or_struct.rs:4:1
|
4 | union Extractor {}
| ^^^^^^^^^^^^^^^^^^
12 changes: 12 additions & 0 deletions axum-macros/tests/from_request/pass/enum_via.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
use axum::{body::Body, routing::get, Extension, Router};
use axum_macros::FromRequest;

#[derive(FromRequest, Clone)]
#[from_request(via(Extension))]
enum Extractor {}

async fn foo(_: Extractor) {}

fn main() {
Router::<Body>::new().route("/", get(foo));
}

0 comments on commit 852e548

Please sign in to comment.