diff --git a/axum-debug/Cargo.toml b/axum-debug/Cargo.toml index b61093c2e3..f2c174b339 100644 --- a/axum-debug/Cargo.toml +++ b/axum-debug/Cargo.toml @@ -15,6 +15,7 @@ version = "0.1.0" proc-macro = true [dependencies] +proc-macro-crate = "1.1.0" proc-macro2 = "1.0" quote = "1.0" syn = { version = "1.0", features = ["full"] } diff --git a/axum-debug/src/lib.rs b/axum-debug/src/lib.rs index d0fdba77e5..d66ff61cf9 100644 --- a/axum-debug/src/lib.rs +++ b/axum-debug/src/lib.rs @@ -254,11 +254,13 @@ pub fn debug_handler(_attr: TokenStream, input: TokenStream) -> TokenStream { #[cfg(debug_assertions)] mod debug { use proc_macro2::TokenStream; - use quote::{format_ident, quote_spanned}; + use proc_macro_crate::FoundCrate; + use quote::{format_ident, quote, quote_spanned}; use syn::{parse_macro_input, spanned::Spanned, FnArg, Ident, ItemFn, ReturnType, Signature}; pub(crate) fn apply_debug_handler(input: proc_macro::TokenStream) -> proc_macro::TokenStream { let function = parse_macro_input!(input as ItemFn); + let axum = import_axum(); let vis = &function.vis; let sig = &function.sig; @@ -284,8 +286,8 @@ mod debug { } let check_trait = check_trait_code(sig, &generics); - let check_return = check_return_code(sig, &generics); - let check_params = match check_params_code(sig, &generics) { + let check_return = check_return_code(sig, &generics, &axum); + let check_params = match check_params_code(sig, &generics, &axum) { Ok(tokens) => tokens, Err(err) => return err.into_compile_error().into(), }; @@ -336,7 +338,7 @@ mod debug { } } - fn check_trait_code(sig: &Signature, generics: &[Ident]) -> proc_macro2::TokenStream { + fn check_trait_code(sig: &Signature, generics: &[Ident]) -> TokenStream { let ident = &sig.ident; let span = ident.span(); @@ -353,7 +355,7 @@ mod debug { } } - fn check_return_code(sig: &Signature, generics: &[Ident]) -> proc_macro2::TokenStream { + fn check_return_code(sig: &Signature, generics: &[Ident], axum: &TokenStream) -> TokenStream { let span = match &sig.output { ReturnType::Default => sig.output.span(), ReturnType::Type(_, ty) => ty.span(), @@ -368,7 +370,7 @@ mod debug { where F: ::std::ops::FnOnce(#(#generics),*) -> Fut, Fut: ::std::future::Future, - Res: ::axum::response::IntoResponse, + Res: #axum::response::IntoResponse, {} } } @@ -377,6 +379,7 @@ mod debug { fn check_params_code( sig: &Signature, generics: &[Ident], + axum: &TokenStream, ) -> Result, syn::Error> { let ident = &sig.ident; generics @@ -402,7 +405,7 @@ mod debug { where F: ::std::ops::FnOnce(#(#generics),*) -> Fut, Fut: ::std::future::Future, - #generic: ::axum::extract::FromRequest + Send, + #generic: #axum::extract::FromRequest + Send, {} } }; @@ -411,6 +414,22 @@ mod debug { }) .collect() } + + fn import_axum() -> TokenStream { + match proc_macro_crate::crate_name("axum") { + Ok(FoundCrate::Name(name)) => { + // Use renamed crate name if axum was renamed via `Cargo.toml` + let name = format_ident!("{}", name); + quote! { ::#name } + } + // No match arm for `Ok(FoundCrate::Itself)` because it is pretty much pointless: + // https://github.com/bkchr/proc-macro-crate/issues/11 + _ => { + // Fall back to plain `axum` + quote! { ::axum } + } + } + } } #[test]