Skip to content

Commit

Permalink
Add automatic type recognition for axum (#668)
Browse files Browse the repository at this point in the history
Add automatic request body recognition and improve existing path params
recognition for axum.
  • Loading branch information
juhaku authored Jul 9, 2023
1 parent d008ff4 commit 41d8f58
Show file tree
Hide file tree
Showing 5 changed files with 196 additions and 34 deletions.
2 changes: 1 addition & 1 deletion utoipa-gen/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ decimal = []
rocket_extras = ["regex", "syn/extra-traits"]
non_strict_integers = []
uuid = ["dep:uuid", "utoipa/uuid"]
axum_extras = ["syn/extra-traits"]
axum_extras = ["regex", "syn/extra-traits"]
time = []
smallvec = []
repr = []
Expand Down
19 changes: 18 additions & 1 deletion utoipa-gen/src/ext.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,19 @@ pub struct ValueArgument<'a> {
pub type_tree: Option<TypeTree<'a>>,
}

#[cfg(feature = "actix_extras")]
impl<'v> From<(MacroArg, TypeTree<'v>)> for ValueArgument<'v> {
fn from((macro_arg, primitive_arg): (MacroArg, TypeTree<'v>)) -> Self {
Self {
name: match macro_arg {
MacroArg::Path(path) => Some(Cow::Owned(path.name)),
},
type_tree: Some(primitive_arg),
argument_in: ArgumentIn::Path,
}
}
}

#[cfg_attr(
not(any(
feature = "actix_extras",
Expand Down Expand Up @@ -272,7 +285,11 @@ pub struct PathOperations;
)))]
impl ArgumentResolver for PathOperations {}

#[cfg(not(any(feature = "actix_extras", feature = "rocket_extras")))]
#[cfg(not(any(
feature = "actix_extras",
feature = "rocket_extras",
feature = "axum_extras"
)))]
impl PathResolver for PathOperations {}

#[cfg(not(any(feature = "actix_extras", feature = "rocket_extras")))]
Expand Down
120 changes: 90 additions & 30 deletions utoipa-gen/src/ext/axum.rs
Original file line number Diff line number Diff line change
@@ -1,20 +1,21 @@
use std::borrow::Cow;

use regex::Captures;
use syn::{punctuated::Punctuated, token::Comma};

use crate::component::TypeTree;
use crate::component::{TypeTree, ValueType};

use super::{
fn_arg::{self, FnArg, FnArgType},
ArgumentResolver, PathOperations, ValueArgument,
ArgValue, ArgumentResolver, MacroArg, MacroPath, PathOperations, PathResolver, ValueArgument,
};

// axum framework is only able to resolve handler function arguments.
// `PathResolver` and `PathOperationResolver` is not supported in axum.
impl ArgumentResolver for PathOperations {
fn resolve_arguments(
args: &'_ Punctuated<syn::FnArg, Comma>,
_: Option<Vec<super::MacroArg>>, // ignored, cannot be provided
macro_args: Option<Vec<super::MacroArg>>,
) -> (
Option<Vec<super::ValueArgument<'_>>>,
Option<Vec<super::IntoParamsType<'_>>>,
Expand All @@ -23,50 +24,83 @@ impl ArgumentResolver for PathOperations {
let (into_params_args, value_args): (Vec<FnArg>, Vec<FnArg>) =
fn_arg::get_fn_args(args).partition(fn_arg::is_into_params);

// TODO value args resolve request body
let (value_args, body) = split_value_args_and_request_body(value_args);

(
Some(get_value_arguments(value_args).collect()),
Some(
value_args
.zip(macro_args.unwrap_or_default().into_iter())
.map(|(value_arg, macro_arg)| ValueArgument {
name: match macro_arg {
MacroArg::Path(path) => Some(Cow::Owned(path.name)),
},
argument_in: value_arg.argument_in,
type_tree: value_arg.type_tree,
})
.collect(),
),
Some(
into_params_args
.into_iter()
.flat_map(fn_arg::with_parameter_in)
.map(Into::into)
.collect(),
),
None,
body.into_iter().next().map(Into::into),
)
}
}

fn get_value_arguments(value_args: Vec<FnArg>) -> impl Iterator<Item = super::ValueArgument<'_>> {
value_args
fn split_value_args_and_request_body(
value_args: Vec<FnArg>,
) -> (
impl Iterator<Item = super::ValueArgument<'_>>,
impl Iterator<Item = TypeTree<'_>>,
) {
let (path_args, body_types): (Vec<FnArg>, Vec<FnArg>) = value_args
.into_iter()
.filter(|arg| arg.ty.is("Path"))
.flat_map(|path_arg| match path_arg.arg_type {
FnArgType::Single(name) => path_arg
.ty
.children
.expect("Path argument must have children")
.into_iter()
.map(|ty| to_value_argument(Some(Cow::Owned(name.to_string())), ty))
.collect::<Vec<_>>(),
FnArgType::Destructed(tuple) => tuple
.iter()
.zip(
path_arg
.ty
.children
.expect("Path argument must have children")
.filter(|arg| {
arg.ty.is("Path") || arg.ty.is("Json") || arg.ty.is("Form") || arg.ty.is("Bytes")
})
.partition(|arg| arg.ty.is("Path"));

(
path_args
.into_iter()
.filter(|arg| arg.ty.is("Path"))
.flat_map(|path_arg| {
match (
path_arg.arg_type,
path_arg.ty.children.expect("Path must have children"),
) {
(FnArgType::Single(name), path_children) => path_children
.into_iter()
.flat_map(|child| {
.flat_map(|ty| match ty.value_type {
ValueType::Tuple => ty
.children
.expect("ValueType::Tuple will always have children")
.into_iter()
.map(|ty| to_value_argument(None, ty))
.collect(),
ValueType::Primitive => {
vec![to_value_argument(Some(Cow::Owned(name.to_string())), ty)]
}
ValueType::Object | ValueType::Value => unreachable!("Cannot get here"),
})
.collect::<Vec<_>>(),
(FnArgType::Destructed(tuple), path_children) => tuple
.iter()
.zip(path_children.into_iter().flat_map(|child| {
child
.children
.expect("ValueType::Tuple will always have children")
}),
)
.map(|(name, ty)| to_value_argument(Some(Cow::Owned(name.to_string())), ty))
.collect::<Vec<_>>(),
})
}))
.map(|(name, ty)| to_value_argument(Some(Cow::Owned(name.to_string())), ty))
.collect::<Vec<_>>(),
}
}),
body_types.into_iter().map(|body| body.ty),
)
}

fn to_value_argument<'a>(name: Option<Cow<'a, str>>, ty: TypeTree<'a>) -> ValueArgument<'a> {
Expand All @@ -76,3 +110,29 @@ fn to_value_argument<'a>(name: Option<Cow<'a, str>>, ty: TypeTree<'a>) -> ValueA
argument_in: super::ArgumentIn::Path,
}
}

impl PathResolver for PathOperations {
fn resolve_path(path: &Option<String>) -> Option<MacroPath> {
path.as_ref().map(|path| {
let regex = regex::Regex::new(r"\{[a-zA-Z0-9][^{}]*}").unwrap();

let mut args = Vec::<MacroArg>::with_capacity(regex.find_iter(path).count());
MacroPath {
path: regex
.replace_all(path, |captures: &Captures| {
let capture = &captures[0];
let original_name = String::from(capture);

args.push(MacroArg::Path(ArgValue {
name: String::from(&capture[1..capture.len() - 1]),
original_name,
}));
// otherwise return the capture itself
capture.to_string()
})
.to_string(),
args,
}
})
}
}
1 change: 0 additions & 1 deletion utoipa-gen/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1311,7 +1311,6 @@ pub fn path(attr: TokenStream, item: TokenStream) -> TokenStream {
}

let mut resolved_operation = PathOperations::resolve_operation(&ast_fn);

let resolved_path = PathOperations::resolve_path(
&resolved_operation
.as_mut()
Expand Down
88 changes: 87 additions & 1 deletion utoipa-gen/tests/path_derive_axum_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use std::sync::{Arc, Mutex};
use assert_json_diff::{assert_json_eq, assert_json_matches, CompareMode, Config, NumericMode};
use axum::{
extract::{Path, Query},
Extension,
Extension, Json,
};
use serde::Deserialize;
use serde_json::json;
Expand Down Expand Up @@ -371,3 +371,89 @@ fn derive_path_query_params_with_named_struct_destructed() {
])
)
}

#[test]
fn path_with_path_query_body_resolved() {
#[derive(utoipa::ToSchema, serde::Serialize, serde::Deserialize)]
struct Item(String);

#[allow(unused)]
struct Error;

#[derive(serde::Serialize, serde::Deserialize, IntoParams)]
struct Filter {
age: i32,
status: String,
}

#[utoipa::path(path = "/item/{id}/{name}", post)]
#[allow(unused)]
async fn post_item(
_path: Path<(i32, String)>,
_query: Query<Filter>,
_body: Json<Item>,
) -> Result<Json<Item>, Error> {
Ok(Json(Item(String::new())))
}

#[derive(utoipa::OpenApi)]
#[openapi(paths(post_item))]
struct Doc;

let doc = serde_json::to_value(Doc::openapi()).unwrap();
let operation = doc.pointer("/paths/~1item~1{id}~1{name}/post").unwrap();

assert_json_eq!(
&operation.pointer("/parameters").unwrap(),
json!([
{
"in": "path",
"name": "id",
"required": true,
"schema": {
"format": "int32",
"type": "integer"
}
},
{
"in": "path",
"name": "name",
"required": true,
"schema": {
"type": "string"
}
},
{
"in": "query",
"name": "age",
"required": true,
"schema": {
"format": "int32",
"type": "integer"
}
},
{
"in": "query",
"name": "status",
"required": true,
"schema": {
"type": "string"
}
}
])
);
assert_json_eq!(
&operation.pointer("/requestBody"),
json!({
"description": "",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/Item"
}
}
},
"required": true,
})
)
}

0 comments on commit 41d8f58

Please sign in to comment.