From f3c13a3769a490b012d25196e9564f95131ce145 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mateusz=20Ma=C4=87kowski?= Date: Thu, 1 Aug 2024 22:53:55 +0200 Subject: [PATCH] feat: add form support and better routing; implement TODO app example --- Cargo.toml | 5 + examples/hello-world/src/main.rs | 8 +- examples/todo-list/Cargo.toml | 12 + examples/todo-list/src/main.rs | 92 +++++++ examples/todo-list/templates/index.html | 28 ++ flareon-macros/Cargo.toml | 15 +- flareon-macros/src/form.rs | 310 +++++++++++++++++++++ flareon-macros/src/lib.rs | 38 ++- flareon-macros/tests/compile_tests.rs | 5 + flareon-macros/tests/ui/derive_form.rs | 16 ++ flareon/Cargo.toml | 4 + flareon/src/error.rs | 30 +++ flareon/src/forms.rs | 342 ++++++++++++++++++++++++ flareon/src/lib.rs | 149 ++++------- flareon/src/prelude.rs | 6 +- flareon/src/private.rs | 5 + flareon/src/request.rs | 99 +++++++ flareon/src/router.rs | 180 +++++++++++++ flareon/src/router/path.rs | 332 +++++++++++++++++++++++ 19 files changed, 1563 insertions(+), 113 deletions(-) create mode 100644 examples/todo-list/Cargo.toml create mode 100644 examples/todo-list/src/main.rs create mode 100644 examples/todo-list/templates/index.html create mode 100644 flareon-macros/src/form.rs create mode 100644 flareon-macros/tests/compile_tests.rs create mode 100644 flareon-macros/tests/ui/derive_form.rs create mode 100644 flareon/src/error.rs create mode 100644 flareon/src/forms.rs create mode 100644 flareon/src/private.rs create mode 100644 flareon/src/request.rs create mode 100644 flareon/src/router.rs create mode 100644 flareon/src/router/path.rs diff --git a/Cargo.toml b/Cargo.toml index c6fe397..73a8ae3 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -7,6 +7,7 @@ members = [ "flareon-orm", # Examples "examples/hello-world", + "examples/todo-list", ] resolver = "2" @@ -15,6 +16,7 @@ edition = "2021" license = "MIT OR Apache-2.0" [workspace.dependencies] +askama = "0.12.1" async-trait = "0.1.80" axum = "0.7.5" bytes = "1.6.1" @@ -22,6 +24,9 @@ chrono = { version = "0.4.38", features = ["serde"] } clap = { version = "4.5.8", features = ["derive", "env"] } derive_builder = "0.20.0" env_logger = "0.11.3" +flareon = { path = "flareon" } +flareon_macros = { path = "flareon-macros" } +form_urlencoded = "1.2.1" indexmap = "2.2.6" itertools = "0.13.0" log = "0.4.22" diff --git a/examples/hello-world/src/main.rs b/examples/hello-world/src/main.rs index c0d5667..540822f 100644 --- a/examples/hello-world/src/main.rs +++ b/examples/hello-world/src/main.rs @@ -1,10 +1,10 @@ use std::sync::Arc; -use flareon::prelude::{ - Body, Error, FlareonApp, FlareonProject, Request, Response, Route, StatusCode, -}; +use flareon::prelude::{Body, Error, FlareonApp, FlareonProject, Response, StatusCode}; +use flareon::request::Request; +use flareon::router::Route; -fn return_hello(_request: Request) -> Result { +async fn return_hello(_request: Request) -> Result { Ok(Response::new_html( StatusCode::OK, Body::fixed("

Hello Flareon!

".as_bytes().to_vec()), diff --git a/examples/todo-list/Cargo.toml b/examples/todo-list/Cargo.toml new file mode 100644 index 0000000..8db67d6 --- /dev/null +++ b/examples/todo-list/Cargo.toml @@ -0,0 +1,12 @@ +[package] +name = "example-todo-list" +version = "0.1.0" +publish = false +description = "TODO List - Flareon example." +edition = "2021" + +[dependencies] +askama = "0.12.1" +flareon = { path = "../../flareon" } +tokio = { version = "1.38.0", features = ["macros", "rt-multi-thread"] } +env_logger = "0.11.5" diff --git a/examples/todo-list/src/main.rs b/examples/todo-list/src/main.rs new file mode 100644 index 0000000..8649f08 --- /dev/null +++ b/examples/todo-list/src/main.rs @@ -0,0 +1,92 @@ +use std::sync::Arc; + +use askama::Template; +use flareon::forms::Form; +use flareon::prelude::{Body, Error, FlareonApp, FlareonProject, Response, Route, StatusCode}; +use flareon::request::Request; +use flareon::reverse; +use tokio::sync::RwLock; + +#[derive(Debug, Clone)] +struct TodoItem { + title: String, +} + +#[derive(Debug, Template)] +#[template(path = "index.html")] +struct IndexTemplate<'a> { + request: &'a Request, + todo_items: Vec, +} + +static TODOS: RwLock> = RwLock::const_new(Vec::new()); + +async fn index(request: Request) -> Result { + let todo_items = (*TODOS.read().await).clone(); + let index_template = IndexTemplate { + request: &request, + todo_items, + }; + let rendered = index_template.render().unwrap(); + + Ok(Response::new_html( + StatusCode::OK, + Body::fixed(rendered.as_bytes().to_vec()), + )) +} + +#[derive(Debug, Form)] +struct TodoForm { + #[form(opt(max_length = 100))] + title: String, +} + +async fn add_todo(mut request: Request) -> Result { + let todo_form = TodoForm::from_request(&mut request).await.unwrap(); + + { + let mut todos = TODOS.write().await; + todos.push(TodoItem { + title: todo_form.title, + }); + } + + Ok(reverse!(request, "index")) +} + +async fn remove_todo(request: Request) -> Result { + let todo_id = request.path_param("todo_id").expect("todo_id not found"); + let todo_id = todo_id.parse::().expect("todo_id is not a number"); + + { + let mut todos = TODOS.write().await; + todos.remove(todo_id); + } + + Ok(reverse!(request, "index")) +} + +#[tokio::main] +async fn main() { + env_logger::init(); + + let todo_app = FlareonApp::builder() + .urls([ + Route::with_handler_and_name("/", Arc::new(Box::new(index)), "index"), + Route::with_handler_and_name("/todos/add", Arc::new(Box::new(add_todo)), "add-todo"), + Route::with_handler_and_name( + "/todos/:todo_id/remove", + Arc::new(Box::new(remove_todo)), + "remove-todo", + ), + ]) + .build() + .unwrap(); + + let todo_project = FlareonProject::builder() + .register_app_with_views(todo_app, "") + .build() + .unwrap(); + + flareon::run(todo_project, "127.0.0.1:8000").await.unwrap(); +} diff --git a/examples/todo-list/templates/index.html b/examples/todo-list/templates/index.html new file mode 100644 index 0000000..1a973dd --- /dev/null +++ b/examples/todo-list/templates/index.html @@ -0,0 +1,28 @@ +{% let request = request %} + + + + + + + TODO List + + +

TODO List

+
+ + +
+
    + {% for todo in todo_items %} +
  • + {% let todo_id = loop.index0 %} +
    todo_id) }}" method="post"> + {{ todo.title }} + +
    +
  • + {% endfor %} +
+ + diff --git a/flareon-macros/Cargo.toml b/flareon-macros/Cargo.toml index e33eb64..f36566d 100644 --- a/flareon-macros/Cargo.toml +++ b/flareon-macros/Cargo.toml @@ -1,5 +1,5 @@ [package] -name = "flareon-macros" +name = "flareon_macros" version = "0.1.0" edition.workspace = true license.workspace = true @@ -8,4 +8,17 @@ description = "Modern web framework focused on speed and ease of use - macros." [lib] proc-macro = true +[[test]] +name = "tests" +path = "tests/compile_tests.rs" + [dependencies] +darling = "0.20.10" +proc-macro-crate = "3.1.0" +proc-macro2 = "1.0.86" +quote = "1.0.36" +syn = { version = "2.0.74", features = ["full"] } + +[dev-dependencies] +flareon.workspace = true +trybuild = { version = "1.0.99", features = ["diff"] } diff --git a/flareon-macros/src/form.rs b/flareon-macros/src/form.rs new file mode 100644 index 0000000..8194ddd --- /dev/null +++ b/flareon-macros/src/form.rs @@ -0,0 +1,310 @@ +use std::collections::HashMap; + +use darling::{FromDeriveInput, FromField}; +use proc_macro2::{Ident, TokenStream}; +use quote::{format_ident, quote, ToTokens, TokenStreamExt}; + +use crate::flareon_ident; + +pub fn form_for_struct(ast: syn::DeriveInput) -> TokenStream { + let opts = match FormOpts::from_derive_input(&ast) { + Ok(val) => val, + Err(err) => { + return err.write_errors(); + } + }; + + let mut builder = opts.as_form_derive_builder(); + for field in opts.fields() { + builder.push_field(field); + } + + quote!(#builder) +} + +#[derive(Debug, FromDeriveInput)] +#[darling(forward_attrs(allow, doc, cfg), supports(struct_named))] +pub struct FormOpts { + ident: syn::Ident, + data: darling::ast::Data, +} + +impl FormOpts { + fn fields(&self) -> Vec<&Field> { + self.data + .as_ref() + .take_struct() + .expect("Only structs are supported") + .fields + } + + fn field_count(&self) -> usize { + self.fields().len() + } + + fn as_form_derive_builder(&self) -> FormDeriveBuilder { + FormDeriveBuilder { + name: self.ident.clone(), + context_struct_name: format_ident!("{}Context", self.ident), + context_struct_errors_name: format_ident!("{}ContextErrors", self.ident), + context_struct_field_iterator_name: format_ident!("{}ContextFieldIterator", self.ident), + fields_as_struct_fields: Vec::with_capacity(self.field_count()), + fields_as_struct_fields_new: Vec::with_capacity(self.field_count()), + fields_as_context_from_request: Vec::with_capacity(self.field_count()), + fields_as_from_context: Vec::with_capacity(self.field_count()), + fields_as_errors: Vec::with_capacity(self.field_count()), + fields_as_get_errors: Vec::with_capacity(self.field_count()), + fields_as_get_errors_mut: Vec::with_capacity(self.field_count()), + fields_as_iterator_next: Vec::with_capacity(self.field_count()), + } + } +} + +#[derive(Debug, Clone, FromField)] +#[darling(attributes(form))] +pub struct Field { + ident: Option, + ty: syn::Type, + opt: Option>, +} + +#[derive(Debug)] +struct FormDeriveBuilder { + name: Ident, + context_struct_name: Ident, + context_struct_errors_name: Ident, + context_struct_field_iterator_name: Ident, + fields_as_struct_fields: Vec, + fields_as_struct_fields_new: Vec, + fields_as_context_from_request: Vec, + fields_as_from_context: Vec, + fields_as_errors: Vec, + fields_as_get_errors: Vec, + fields_as_get_errors_mut: Vec, + fields_as_iterator_next: Vec, +} + +impl ToTokens for FormDeriveBuilder { + fn to_tokens(&self, tokens: &mut TokenStream) { + tokens.append_all(self.build_form_impl()); + tokens.append_all(self.build_form_context_impl()); + tokens.append_all(self.build_errors_struct()); + tokens.append_all(self.build_context_field_iterator_impl()); + } +} + +impl FormDeriveBuilder { + fn push_field(&mut self, field: &Field) { + let crate_ident = flareon_ident(); + let name = field.ident.as_ref().unwrap(); + let ty = &field.ty; + let index = self.fields_as_struct_fields.len(); + let opt = &field.opt; + + self.fields_as_struct_fields + .push(quote!(#name: <#ty as #crate_ident::forms::AsFormField>::Type)); + + self.fields_as_struct_fields_new.push({ + let custom_options_setters: Vec<_> = if let Some(opt) = opt { + opt.iter() + .map(|(key, value)| quote!(custom_options.#key = Some(#value))) + .collect() + } else { + Vec::new() + }; + quote!(#name: { + let options = #crate_ident::forms::FormFieldOptions { + id: stringify!(#name).to_owned(), + }; + type Field = <#ty as #crate_ident::forms::AsFormField>::Type; + type CustomOptions = ::CustomOptions; + let mut custom_options: CustomOptions = ::core::default::Default::default(); + #( #custom_options_setters; )* + ::with_options(options, custom_options) + }) + }); + + self.fields_as_context_from_request + .push(quote!(stringify!(#name) => { + #crate_ident::forms::FormField::set_value(&mut self.#name, value) + })); + + self.fields_as_from_context.push(quote!(#name: <#ty as #crate_ident::forms::AsFormField>::clean_value(&context.#name).unwrap())); + + self.fields_as_errors + .push(quote!(#name: Vec<#crate_ident::forms::FormFieldValidationError>)); + + self.fields_as_get_errors + .push(quote!(stringify!(#name) => self.__errors.#name.as_slice())); + + self.fields_as_get_errors_mut + .push(quote!(stringify!(#name) => self.__errors.#name.as_mut())); + + self.fields_as_iterator_next.push( + quote!(#index => Some(&self.context.#name as &'a dyn #crate_ident::forms::DynFormField)), + ); + } + + fn build_form_impl(&self) -> TokenStream { + let crate_ident = flareon_ident(); + let name = &self.name; + let context_struct_name = &self.context_struct_name; + let fields_as_from_context = &self.fields_as_from_context; + + quote! { + #[#crate_ident::private::async_trait] + #[automatically_derived] + impl #crate_ident::forms::Form for #name { + type Context = #context_struct_name; + + async fn from_request( + request: &mut #crate_ident::request::Request + ) -> Result> { + let mut context = ::build_context(request).await?; + + Ok(Self { + #( #fields_as_from_context, )* + }) + } + } + } + } + + fn build_form_context_impl(&self) -> TokenStream { + let crate_ident = flareon_ident(); + + let context_struct_name = &self.context_struct_name; + let context_struct_errors_name = &self.context_struct_errors_name; + let context_struct_field_iterator_name = &self.context_struct_field_iterator_name; + + let fields_as_struct_fields = &self.fields_as_struct_fields; + let fields_as_struct_fields_new = &self.fields_as_struct_fields_new; + let fields_as_context_from_request = &self.fields_as_context_from_request; + let fields_as_get_errors = &self.fields_as_get_errors; + let fields_as_get_errors_mut = &self.fields_as_get_errors_mut; + + quote! { + #[derive(::core::fmt::Debug)] + struct #context_struct_name { + __errors: #context_struct_errors_name, + #( #fields_as_struct_fields, )* + } + + #[automatically_derived] + impl #crate_ident::forms::FormContext for #context_struct_name { + fn new() -> Self { + Self { + __errors: ::core::default::Default::default(), + #( #fields_as_struct_fields_new, )* + } + } + + fn fields(&self) -> impl Iterator + '_ { + #context_struct_field_iterator_name { + context: self, + index: 0, + } + } + + fn set_value( + &mut self, + field_id: &str, + value: ::std::borrow::Cow, + ) -> Result<(), #crate_ident::forms::FormFieldValidationError> { + match field_id { + #( #fields_as_context_from_request, )* + _ => {} + } + Ok(()) + } + + fn get_errors( + &self, + target: #crate_ident::forms::FormErrorTarget + ) -> &[#crate_ident::forms::FormFieldValidationError] { + match target { + #crate_ident::forms::FormErrorTarget::Field(field_id) => { + match field_id { + #( #fields_as_get_errors, )* + _ => { + panic!("Unknown field name passed to get_errors: `{}`", field_id); + } + } + } + #crate_ident::forms::FormErrorTarget::Form => { + self.__errors.__form.as_slice() + } + } + } + + fn get_errors_mut( + &mut self, + target: #crate_ident::forms::FormErrorTarget + ) -> &mut Vec<#crate_ident::forms::FormFieldValidationError> { + match target { + #crate_ident::forms::FormErrorTarget::Field(field_id) => { + match field_id { + #( #fields_as_get_errors_mut, )* + _ => { + panic!("Unknown field name passed to get_errors_mut: `{}`", field_id); + } + } + } + #crate_ident::forms::FormErrorTarget::Form => { + self.__errors.__form.as_mut() + } + } + } + } + } + } + + fn build_errors_struct(&self) -> TokenStream { + let crate_ident = flareon_ident(); + let context_struct_errors_name = &self.context_struct_errors_name; + let fields_as_errors = &self.fields_as_errors; + + quote! { + #[derive(::core::fmt::Debug, ::core::default::Default)] + struct #context_struct_errors_name { + __form: Vec<#crate_ident::forms::FormFieldValidationError>, + #( #fields_as_errors, )* + } + } + } + + fn build_context_field_iterator_impl(&self) -> TokenStream { + let crate_ident = flareon_ident(); + let context_struct_name = &self.context_struct_name; + let context_struct_field_iterator_name = &self.context_struct_field_iterator_name; + let fields_as_iterator_next = &self.fields_as_iterator_next; + + quote! { + #[derive(::core::fmt::Debug)] + struct #context_struct_field_iterator_name<'a> { + context: &'a #context_struct_name, + index: usize, + } + + #[automatically_derived] + impl<'a> Iterator for #context_struct_field_iterator_name<'a> { + type Item = &'a dyn #crate_ident::forms::DynFormField; + + fn next(&mut self) -> Option { + let result = match self.index { + #( #fields_as_iterator_next, )* + _ => None, + }; + + if result.is_some() { + self.index += 1; + } else { + self.index = 0; + } + + result + } + } + } + } +} diff --git a/flareon-macros/src/lib.rs b/flareon-macros/src/lib.rs index 6136834..12c9241 100644 --- a/flareon-macros/src/lib.rs +++ b/flareon-macros/src/lib.rs @@ -1,6 +1,38 @@ +mod form; + use proc_macro::TokenStream; +use proc_macro_crate::crate_name; +use quote::quote; +use syn::parse_macro_input; + +use crate::form::form_for_struct; + +/// Derive the [`Form`] trait for a struct. +/// +/// This macro will generate an implementation of the [`Form`] trait for the +/// given named struct. Note that all the fields of the struct **must** +/// implement the [`AsFormField`] trait. +/// +/// [`Form`]: trait.Form.html +/// [`AsFormField`]: trait.AsFormField.html +#[proc_macro_derive(Form, attributes(form))] +pub fn derive_form(input: TokenStream) -> TokenStream { + let ast = parse_macro_input!(input as syn::DeriveInput); + let token_stream = form_for_struct(ast); + println!("{}", token_stream); + token_stream.into() +} -#[proc_macro] -pub fn flareon(_input: TokenStream) -> TokenStream { - unimplemented!() +pub(crate) fn flareon_ident() -> proc_macro2::TokenStream { + let flareon_crate = crate_name("flareon").expect("flareon is not present in `Cargo.toml`"); + match flareon_crate { + proc_macro_crate::FoundCrate::Itself => { + quote! { ::flareon } + } + proc_macro_crate::FoundCrate::Name(name) => { + let ident = syn::Ident::new(&name, proc_macro2::Span::call_site()); + quote! { ::#ident } + } + } + .into() } diff --git a/flareon-macros/tests/compile_tests.rs b/flareon-macros/tests/compile_tests.rs new file mode 100644 index 0000000..d067819 --- /dev/null +++ b/flareon-macros/tests/compile_tests.rs @@ -0,0 +1,5 @@ +#[test] +fn test_derive_form() { + let t = trybuild::TestCases::new(); + t.pass("tests/ui/derive_form.rs"); +} diff --git a/flareon-macros/tests/ui/derive_form.rs b/flareon-macros/tests/ui/derive_form.rs new file mode 100644 index 0000000..f8a49c7 --- /dev/null +++ b/flareon-macros/tests/ui/derive_form.rs @@ -0,0 +1,16 @@ +use flareon::forms::Form; +use flareon::request::Request; + +#[derive(Debug, Form)] +struct MyForm { + name: String, + name2: std::string::String, +} + +#[allow(unused)] +async fn test_endpoint(mut request: Request) { + let form = MyForm::from_request(&mut request).await.unwrap(); + println!("name = {}, name2 = {}", form.name, form.name2); +} + +fn main() {} diff --git a/flareon/Cargo.toml b/flareon/Cargo.toml index 25d8064..cd1df40 100644 --- a/flareon/Cargo.toml +++ b/flareon/Cargo.toml @@ -6,11 +6,15 @@ license.workspace = true description = "Modern web framework focused on speed and ease of use." [dependencies] +askama.workspace = true async-trait.workspace = true axum.workspace = true bytes.workspace = true derive_builder.workspace = true +flareon_macros.workspace = true +form_urlencoded.workspace = true indexmap.workspace = true log.workspace = true +regex.workspace = true thiserror.workspace = true tokio.workspace = true diff --git a/flareon/src/error.rs b/flareon/src/error.rs new file mode 100644 index 0000000..c08f41e --- /dev/null +++ b/flareon/src/error.rs @@ -0,0 +1,30 @@ +use thiserror::Error; + +#[derive(Debug, Error)] +#[non_exhaustive] +pub enum Error { + #[error("Could not retrieve request body: {source}")] + ReadRequestBody { + #[from] + source: axum::Error, + }, + #[error("Invalid content type; expected {expected}, found {actual}")] + InvalidContentType { + expected: &'static str, + actual: String, + }, + #[error("Could not create a response object: {0}")] + ResponseBuilder(#[from] axum::http::Error), + #[error("Failed to reverse route `{view_name}` due to view not existing")] + NoViewToReverse { view_name: String }, + #[error("Failed to reverse route: {0}")] + ReverseError(#[from] crate::router::path::ReverseError), + #[error("Failed to render template: {0}")] + TemplateRender(#[from] askama::Error), +} + +impl From for askama::Error { + fn from(value: Error) -> Self { + askama::Error::Custom(Box::new(value)) + } +} diff --git a/flareon/src/forms.rs b/flareon/src/forms.rs new file mode 100644 index 0000000..1e0aa8e --- /dev/null +++ b/flareon/src/forms.rs @@ -0,0 +1,342 @@ +use std::borrow::Cow; + +use async_trait::async_trait; +pub use flareon_macros::Form; +use thiserror::Error; + +use crate::request::Request; + +/// Error occurred while processing a form. +#[derive(Debug, Error)] +pub enum FormError { + /// An error occurred while processing the request, before validating the + /// form data. + #[error("Request error: {error}")] + RequestError { + #[from] + error: crate::Error, + }, + /// The form failed to validate. + #[error("The form failed to validate")] + ValidationError { context: T::Context }, +} + +const FORM_FIELD_REQUIRED: &str = "This field is required."; + +/// An error that can occur when validating a form field. +#[derive(Debug, Error)] +#[error("{message}")] +pub struct FormFieldValidationError { + message: Cow<'static, str>, +} + +#[derive(Debug)] +pub enum FormErrorTarget<'a> { + Field(&'a str), + Form, +} + +impl FormFieldValidationError { + /// Creates a new `FormFieldValidationError` from a `String`. + #[must_use] + pub const fn from_string(message: String) -> Self { + Self { + message: Cow::Owned(message), + } + } + + /// Creates a new `FormFieldValidationError` from a static string. + #[must_use] + pub const fn from_static(message: &'static str) -> Self { + Self { + message: Cow::Borrowed(message), + } + } +} + +/// A trait for types that can be used as forms. +/// +/// This trait is used to define a type that can be used as a form. It provides +/// a way to create a form from a request, build a context from the request, and +/// validate the form. +/// +/// # Deriving +/// +/// This trait can, and should be derived using the [`Form`](derive@Form) derive +/// macro. This macro generates the implementation of the trait for the type, +/// including the implementation of the [`FormContext`] trait for the context +/// type. +/// +/// ```rust +/// use flareon::forms::Form; +/// +/// #[derive(Form)] +/// struct MyForm { +/// #[form(opt(max_length = 100))] +/// name: String, +/// } +/// ``` +#[async_trait] +pub trait Form: Sized { + /// The context type associated with the form. + type Context: FormContext; + + /// Creates a form from a request. + async fn from_request(request: &mut Request) -> Result>; + + /// Builds the context for the form from a request. + async fn build_context(request: &mut Request) -> Result> { + let form_data = request + .form_data() + .await + .map_err(|error| FormError::RequestError { error })?; + + let mut context = Self::Context::new(); + let mut has_errors = false; + + for (field_id, value) in Request::query_pairs(&form_data) { + let field_id = field_id.as_ref(); + + if let Err(err) = context.set_value(field_id, value) { + context.add_error(FormErrorTarget::Field(field_id), err); + has_errors = true; + } + } + + if has_errors { + Err(FormError::ValidationError { context }) + } else { + Ok(context) + } + } +} + +/// A trait for form contexts. +/// +/// A form context is used to store the state of a form, such as the values of +/// the fields and any errors that occur during validation. This trait is used +/// to define the interface for a form context, which is used to interact with +/// the form fields and errors. +/// +/// This trait is typically not implemented directly; instead, its +/// implementations are generated automatically through the +/// [`Form`](derive@Form) derive macro. +pub trait FormContext: Sized { + /// Creates a new form context without any initial form data. + fn new() -> Self; + + /// Returns an iterator over the fields in the form. + fn fields(&self) -> impl Iterator + '_; + + /// Sets the value of a form field. + fn set_value( + &mut self, + field_id: &str, + value: Cow, + ) -> Result<(), FormFieldValidationError>; + + /// Adds a validation error to the form context. + fn add_error(&mut self, target: FormErrorTarget, error: FormFieldValidationError) { + self.get_errors_mut(target).push(error); + } + + /// Returns the validation errors for a target in the form context. + fn get_errors(&self, target: FormErrorTarget) -> &[FormFieldValidationError]; + + /// Returns a mutable reference to the validation errors for a target in the + /// form context. + fn get_errors_mut(&mut self, target: FormErrorTarget) -> &mut Vec; +} + +/// Generic options valid for all types of form fields. +#[derive(Debug)] +pub struct FormFieldOptions { + pub id: String, +} + +/// A form field. +/// +/// This trait is used to define a type of field that can be used in a form. It +/// is used to render the field in an HTML form, set the value of the field, and +/// validate it. Typically, the implementors of this trait are used indirectly +/// through the [`Form`] trait and field types that implement [`AsFormField`]. +pub trait FormField: Sized { + /// Custom options for the form field, unique for each field type. + type CustomOptions: Default; + + /// Creates a new form field with the given options. + fn with_options(options: FormFieldOptions, custom_options: Self::CustomOptions) -> Self; + + /// Returns the generic options for the form field. + fn options(&self) -> &FormFieldOptions; + + /// Returns the ID of the form field. + fn id(&self) -> &str { + &self.options().id + } + + /// Sets the string value of the form field. + /// + /// This method should convert the value to the appropriate type for the + /// field, such as a number for a number field. + fn set_value(&mut self, value: Cow); + + /// Renders the form field as an HTML string. + fn render(&self) -> String; +} + +/// A version of [`FormField`] that can be used in a dynamic context. +/// +/// This trait is used to allow a form field to be used in a dynamic context, +/// such as when using Form field iterator. It provides access to the field's +/// options, value, and rendering, among others. +/// +/// This trait is implemented for all types that implement [`FormField`]. +pub trait DynFormField { + fn dyn_options(&self) -> &FormFieldOptions; + + fn dyn_id(&self) -> &str; + + fn dyn_set_value(&mut self, value: Cow); + + fn dyn_render(&self) -> String; +} + +impl DynFormField for T { + fn dyn_options(&self) -> &FormFieldOptions { + FormField::options(self) + } + + fn dyn_id(&self) -> &str { + FormField::id(self) + } + + fn dyn_set_value(&mut self, value: Cow) { + FormField::set_value(self, value) + } + + fn dyn_render(&self) -> String { + FormField::render(self) + } +} + +/// A trait for types that can be used as form fields. +/// +/// This trait uses [`FormField`] to define a type that can be used as a form +/// field. It provides a way to clean the value of the field, which is used to +/// validate the field's value before converting to the final type. +pub trait AsFormField { + type Type: FormField; + + fn clean_value(field: &Self::Type) -> Result + where + Self: Sized; +} + +/// A form field for a string. +#[derive(Debug)] +pub struct CharField { + options: FormFieldOptions, + custom_options: CharFieldOptions, + value: Option, +} + +/// Custom options for a `CharField`. +#[derive(Debug, Default)] +pub struct CharFieldOptions { + /// The maximum length of the field. Used to set the `maxlength` attribute + /// in the HTML input element. + pub max_length: Option, +} + +impl CharFieldOptions { + /// Sets the maximum length for the `CharField`. + pub fn set_max_length(&mut self, max_length: u32) { + self.max_length = Some(max_length); + } +} + +impl FormField for CharField { + type CustomOptions = CharFieldOptions; + + fn with_options(options: FormFieldOptions, custom_options: Self::CustomOptions) -> Self { + Self { + options, + custom_options, + value: None, + } + } + + fn options(&self) -> &FormFieldOptions { + &self.options + } + + fn set_value(&mut self, value: Cow) { + self.value = Some(value.into_owned()); + } + + fn render(&self) -> String { + let mut tag = HtmlTag::input("text"); + tag.attr("name", self.id()); + if let Some(max_length) = self.custom_options.max_length { + tag.attr("maxlength", &max_length.to_string()); + } + tag.render() + } +} + +impl AsFormField for String { + type Type = CharField; + + fn clean_value(field: &Self::Type) -> Result { + if let Some(value) = &field.value { + Ok(value.clone()) + } else { + Err(FormFieldValidationError::from_static(FORM_FIELD_REQUIRED)) + } + } +} + +/// A helper struct for rendering HTML tags. +#[derive(Debug)] +struct HtmlTag { + tag: String, + attributes: Vec<(String, String)>, +} + +impl HtmlTag { + #[must_use] + fn new(tag: &str) -> Self { + Self { + tag: tag.to_string(), + attributes: Vec::new(), + } + } + + #[must_use] + fn input(input_type: &str) -> Self { + let mut input = Self::new("input"); + input.attr("type", input_type); + input + } + + fn attr(&mut self, key: &str, value: &str) -> &mut Self { + if self.attributes.iter().any(|(k, _)| k == key) { + panic!("Attribute already exists: {}", key); + } + self.attributes.push((key.to_string(), value.to_string())); + self + } + + #[must_use] + fn render(&self) -> String { + let mut result = format!("<{} ", self.tag); + + for (key, value) in &self.attributes { + result.push_str(&format!("{}=\"{}\" ", key, value)); + } + + result.push_str(" />"); + result + } +} diff --git a/flareon/src/lib.rs b/flareon/src/lib.rs index 1b41d28..9bba5b2 100644 --- a/flareon/src/lib.rs +++ b/flareon/src/lib.rs @@ -1,6 +1,15 @@ +extern crate self as flareon; + +mod error; +pub mod forms; pub mod prelude; +#[doc(hidden)] +pub mod private; +pub mod request; +pub mod router; use std::fmt::{Debug, Formatter}; +use std::future::Future; use std::io::Read; use std::sync::Arc; @@ -8,60 +17,29 @@ use async_trait::async_trait; use axum::handler::HandlerWithoutStateExt; use bytes::Bytes; use derive_builder::Builder; +pub use error::Error; use indexmap::IndexMap; use log::info; -use thiserror::Error; +use request::Request; +use router::{Route, Router}; + +pub type Result = std::result::Result; pub type StatusCode = axum::http::StatusCode; #[async_trait] pub trait RequestHandler { - async fn handle(&self, request: Request) -> Result; -} - -#[derive(Clone, Debug)] -pub struct Router { - urls: Vec, -} - -impl Router { - #[must_use] - pub fn with_urls>>(urls: T) -> Self { - Self { urls: urls.into() } - } - - async fn route(&self, request: Request, request_path: &str) -> Result { - for route in &self.urls { - if request_path.starts_with(&route.url) { - let request_path = &request_path[route.url.len()..]; - match &route.view { - RouteInner::Handler(handler) => return handler.handle(request).await, - RouteInner::Router(router) => { - return Box::pin(router.route(request, request_path)).await - } - } - } - } - - unimplemented!("404 handler is not implemented yet") - } -} - -#[async_trait] -impl RequestHandler for Router { - async fn handle(&self, request: Request) -> Result { - let path = request.uri().path().to_owned(); - self.route(request, &path).await - } + async fn handle(&self, request: Request) -> Result; } #[async_trait] -impl RequestHandler for T +impl RequestHandler for T where - T: Fn(Request) -> Result + Send + Sync, + T: Fn(Request) -> R + Clone + Send + Sync + 'static, + R: for<'a> Future> + Send, { - async fn handle(&self, request: Request) -> Result { - self(request) + async fn handle(&self, request: Request) -> Result { + self(request).await } } @@ -100,50 +78,6 @@ impl FlareonAppBuilder { } } -#[derive(Clone)] -pub struct Route { - url: String, - view: RouteInner, -} - -impl Route { - #[must_use] - pub fn with_handler>( - url: T, - view: Arc>, - ) -> Self { - Self { - url: url.into(), - view: RouteInner::Handler(view), - } - } - - #[must_use] - pub fn with_router>(url: T, router: Router) -> Self { - Self { - url: url.into(), - view: RouteInner::Router(router), - } - } -} - -#[derive(Clone)] -enum RouteInner { - Handler(Arc>), - Router(Router), -} - -impl Debug for Route { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - match &self.view { - RouteInner::Handler(_) => f.debug_tuple("Handler").field(&"handler(...)").finish(), - RouteInner::Router(router) => f.debug_tuple("Router").field(router).finish(), - } - } -} - -pub type Request = axum::extract::Request; - type HeadersMap = IndexMap; #[derive(Debug)] @@ -155,6 +89,8 @@ pub struct Response { const CONTENT_TYPE_HEADER: &str = "Content-Type"; const HTML_CONTENT_TYPE: &str = "text/html"; +const FORM_CONTENT_TYPE: &str = "application/x-www-form-urlencoded"; +const LOCATION_HEADER: &str = "Location"; impl Response { #[must_use] @@ -166,6 +102,17 @@ impl Response { } } + #[must_use] + pub fn new_redirect>(location: T) -> Self { + let mut headers = HeadersMap::new(); + headers.insert(LOCATION_HEADER.to_owned(), location.into()); + Self { + status: StatusCode::SEE_OTHER, + headers, + body: Body::empty(), + } + } + #[must_use] fn html_headers() -> HeadersMap { let mut headers = HeadersMap::new(); @@ -200,12 +147,6 @@ impl Body { } } -#[derive(Debug, thiserror::Error)] -pub enum Error { - #[error("Could not create a response object: {0}")] - ResponseBuilder(#[from] axum::http::Error), -} - #[derive(Clone, Debug)] pub struct FlareonProject { apps: Vec, @@ -230,15 +171,13 @@ impl FlareonProjectBuilder { #[must_use] pub fn register_app_with_views(&mut self, app: FlareonApp, url_prefix: &str) -> &mut Self { let new = self; - new.urls.push(Route::with_handler( - url_prefix, - Arc::new(Box::new(app.router.clone())), - )); + new.urls + .push(Route::with_router(url_prefix, app.router.clone())); new.apps.push(app); new } - pub fn build(&self) -> Result { + pub fn build(&self) -> Result { Ok(FlareonProject { apps: self.apps.clone(), router: Router::with_urls(self.urls.clone()), @@ -257,17 +196,23 @@ impl FlareonProject { pub fn builder() -> FlareonProjectBuilder { FlareonProjectBuilder::default() } + + #[must_use] + pub fn router(&self) -> &Router { + &self.router + } } -pub async fn run(mut project: FlareonProject, address_str: &str) -> Result<(), Error> { +pub async fn run(mut project: FlareonProject, address_str: &str) -> Result<()> { for app in &mut project.apps { info!("Initializing app: {:?}", app); } + let project = Arc::new(project); let listener = tokio::net::TcpListener::bind(address_str).await.unwrap(); let handler = |request: axum::extract::Request| async move { - pass_to_axum(&project, request) + pass_to_axum(&project, Request::new(request, project.clone())) .await .unwrap_or_else(handle_response_error) }; @@ -279,9 +224,9 @@ pub async fn run(mut project: FlareonProject, address_str: &str) -> Result<(), E } async fn pass_to_axum( - project: &FlareonProject, - request: axum::extract::Request, -) -> Result { + project: &Arc, + request: Request, +) -> Result { let response = project.router.handle(request).await?; let mut builder = axum::http::Response::builder().status(response.status); diff --git a/flareon/src/prelude.rs b/flareon/src/prelude.rs index 6c2ec5f..8b8afc4 100644 --- a/flareon/src/prelude.rs +++ b/flareon/src/prelude.rs @@ -1,3 +1,3 @@ -pub use crate::{ - Body, Error, FlareonApp, FlareonProject, Request, RequestHandler, Response, Route, StatusCode, -}; +pub use crate::request::Request; +pub use crate::router::Route; +pub use crate::{Body, Error, FlareonApp, FlareonProject, RequestHandler, Response, StatusCode}; diff --git a/flareon/src/private.rs b/flareon/src/private.rs new file mode 100644 index 0000000..623d168 --- /dev/null +++ b/flareon/src/private.rs @@ -0,0 +1,5 @@ +/// Re-exports of some of the Flareon dependencies that are used in the macros. +/// +/// This is to avoid the need to add them as dependencies to the crate that uses +/// the macros. +pub use async_trait::async_trait; diff --git a/flareon/src/request.rs b/flareon/src/request.rs new file mode 100644 index 0000000..b846d1f --- /dev/null +++ b/flareon/src/request.rs @@ -0,0 +1,99 @@ +use std::borrow::Cow; +use std::sync::Arc; + +use bytes::Bytes; +use indexmap::IndexMap; + +use crate::{Error, FlareonProject, FORM_CONTENT_TYPE}; + +#[derive(Debug)] +pub struct Request { + inner: axum::extract::Request, + project: Arc, + pub(crate) path_params: IndexMap, +} + +impl Request { + #[must_use] + pub fn new(inner: axum::extract::Request, project: Arc) -> Self { + Self { + inner, + project, + path_params: IndexMap::new(), + } + } + + #[must_use] + pub fn inner(&self) -> &axum::extract::Request { + &self.inner + } + + #[must_use] + pub fn project(&self) -> &FlareonProject { + &self.project + } + + #[must_use] + pub fn uri(&self) -> &axum::http::Uri { + self.inner.uri() + } + + #[must_use] + pub fn method(&self) -> &axum::http::Method { + self.inner.method() + } + + #[must_use] + pub fn headers(&self) -> &axum::http::HeaderMap { + self.inner.headers() + } + + #[must_use] + pub fn content_type(&self) -> Option<&axum::http::HeaderValue> { + self.inner.headers().get(axum::http::header::CONTENT_TYPE) + } + + pub async fn form_data(&mut self) -> Result { + if self.method() == axum::http::Method::GET { + if let Some(query) = self.inner.uri().query() { + return Ok(Bytes::copy_from_slice(query.as_bytes())); + } + + Ok(Bytes::new()) + } else { + self.expect_content_type(FORM_CONTENT_TYPE)?; + + let body = std::mem::take(self.inner.body_mut()); + let bytes = axum::body::to_bytes(body, usize::MAX) + .await + .map_err(|err| Error::ReadRequestBody { source: err })?; + + Ok(bytes) + } + } + + fn expect_content_type(&mut self, expected: &'static str) -> Result<(), Error> { + let content_type = self + .content_type() + .map(|value| String::from_utf8_lossy(value.as_bytes())) + .unwrap_or("".into()); + if self.content_type() == Some(&axum::http::HeaderValue::from_static(expected)) { + Ok(()) + } else { + Err(Error::InvalidContentType { + expected, + actual: content_type.into_owned(), + }) + } + } + + #[must_use] + pub fn query_pairs(bytes: &Bytes) -> impl Iterator, Cow)> { + form_urlencoded::parse(bytes.as_ref()) + } + + #[must_use] + pub fn path_param(&self, name: &str) -> Option<&str> { + self.path_params.get(name).map(|s| s.as_str()) + } +} diff --git a/flareon/src/router.rs b/flareon/src/router.rs new file mode 100644 index 0000000..5f03b46 --- /dev/null +++ b/flareon/src/router.rs @@ -0,0 +1,180 @@ +use std::collections::HashMap; +use std::fmt::{Debug, Formatter}; +use std::sync::Arc; + +use axum::http::StatusCode; +use bytes::Bytes; +use log::debug; + +use crate::request::Request; +use crate::router::path::{PathMatcher, ReverseParamMap}; +use crate::{Body, Error, RequestHandler, Response, Result}; + +pub mod path; + +#[derive(Clone, Debug)] +pub struct Router { + urls: Vec, + names: HashMap>, +} + +impl Router { + #[must_use] + pub fn with_urls>>(urls: T) -> Self { + let urls = urls.into(); + let mut names = HashMap::new(); + + for url in &urls { + if let Some(name) = &url.name { + names.insert(name.clone(), url.url.clone()); + } + } + + Self { urls, names } + } + + async fn route(&self, mut request: Request, request_path: &str) -> Result { + debug!("Routing request to {}", request_path); + + for route in &self.urls { + if let Some(matches) = route.url.capture(request_path) { + let matches_fully = matches.matches_fully(); + for param in matches.params { + request + .path_params + .insert(param.name.to_owned(), param.value); + } + + match &route.view { + RouteInner::Handler(handler) => { + if matches_fully { + return handler.handle(request).await; + } + } + RouteInner::Router(router) => { + return Box::pin(router.route(request, matches.remaining_path)).await + } + } + } + } + + debug!("Not found: {}", request_path); + Ok(handle_not_found()) + } + + pub async fn handle(&self, request: Request) -> Result { + let path = request.uri().path().to_owned(); + self.route(request, &path).await + } + + /// Get a URL for a view by name. + /// + /// Instead of using this method directly, consider using the + /// [`reverse!`](crate::reverse) macro which provides much more + /// ergonomic way to call this. + pub fn reverse(&self, name: &str, params: &ReverseParamMap) -> Result { + self.reverse_option(name, params)? + .ok_or_else(|| Error::NoViewToReverse { + view_name: name.to_owned(), + }) + } + + pub fn reverse_option(&self, name: &str, params: &ReverseParamMap) -> Result> { + let url = self.names.get(name).map(|matcher| matcher.reverse(params)); + if let Some(url) = url { + return Ok(Some(url?)); + } + + for route in &self.urls { + if let RouteInner::Router(router) = &route.view { + if let Some(url) = router.reverse_option(name, params)? { + return Ok(Some(route.url.reverse(params)? + &url)); + } + } + } + Ok(None) + } +} + +#[derive(Debug, Clone)] +pub struct Route { + url: Arc, + view: RouteInner, + name: Option, +} + +impl Route { + #[must_use] + pub fn with_handler(url: &str, view: Arc>) -> Self { + Self { + url: Arc::new(PathMatcher::new(url)), + view: RouteInner::Handler(view), + name: None, + } + } + + #[must_use] + pub fn with_handler_and_name>( + url: &str, + view: Arc>, + name: T, + ) -> Self { + Self { + url: Arc::new(PathMatcher::new(url)), + view: RouteInner::Handler(view), + name: Some(name.into()), + } + } + + #[must_use] + pub fn with_router(url: &str, router: Router) -> Self { + Self { + url: Arc::new(PathMatcher::new(url)), + view: RouteInner::Router(router), + name: None, + } + } +} + +fn handle_not_found() -> Response { + Response::new_html( + StatusCode::NOT_FOUND, + Body::Fixed(Bytes::from("404 Not Found")), + ) +} + +#[derive(Clone)] +enum RouteInner { + Handler(Arc>), + Router(Router), +} + +impl Debug for RouteInner { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match &self { + RouteInner::Handler(_) => f.debug_tuple("Handler").field(&"handler(...)").finish(), + RouteInner::Router(router) => f.debug_tuple("Router").field(router).finish(), + } + } +} + +#[macro_export] +macro_rules! reverse { + ($request:expr, $view_name:literal $(, $($key:expr => $value:expr),* )?) => { + ::flareon::Response::new_redirect($crate::reverse_str!( + $request, + $view_name, + $( $($key => $value),* )? + )) + }; +} + +#[macro_export] +macro_rules! reverse_str { + ( $request:expr, $view_name:literal $(, $($key:expr => $value:expr),* )? ) => { + $request + .project() + .router() + .reverse($view_name, &$crate::reverse_param_map!($( $($key => $value),* )?))? + }; +} diff --git a/flareon/src/router/path.rs b/flareon/src/router/path.rs new file mode 100644 index 0000000..05dd199 --- /dev/null +++ b/flareon/src/router/path.rs @@ -0,0 +1,332 @@ +use std::collections::HashMap; +use std::fmt::Display; + +use log::debug; +use regex::Regex; +use thiserror::Error; + +#[derive(Debug, Clone)] +pub(super) struct PathMatcher { + parts: Vec, +} + +impl PathMatcher { + #[must_use] + pub fn new>(path_pattern: T) -> Self { + let path_pattern = path_pattern.into(); + + let mut last_end = 0; + let mut parts = Vec::new(); + let param_regex = Regex::new(":([^/]+)").expect("Invalid regex"); + for capture in param_regex.captures_iter(&path_pattern) { + let full_match = capture.get(0).expect("Could not get regex match"); + let start = full_match.start(); + if start > last_end { + parts.push(PathPart::Literal(path_pattern[last_end..start].to_string())); + } + + let name = capture + .get(1) + .expect("Could not get regex capture") + .as_str() + .to_owned(); + if !Self::is_param_name_valid(&name) { + panic!("Invalid parameter name: `{}`", name); + } + parts.push(PathPart::Param { name }); + last_end = start + full_match.len(); + } + if last_end < path_pattern.len() { + parts.push(PathPart::Literal(path_pattern[last_end..].to_string())); + } + + Self { parts } + } + + fn is_param_name_valid(name: &str) -> bool { + if name.is_empty() { + return false; + } + let first_char = name.chars().next().expect("Empty string"); + if !first_char.is_alphabetic() && first_char != '_' { + return false; + } + for ch in name.chars() { + if !ch.is_alphanumeric() && ch != '_' { + return false; + } + } + true + } + + #[must_use] + pub fn capture<'matcher, 'path>( + &'matcher self, + path: &'path str, + ) -> Option> { + debug!("Matching path `{}` against pattern `{}`", path, self); + + let mut current_path = path; + let mut params = Vec::with_capacity(self.param_len()); + for part in &self.parts { + match part { + PathPart::Literal(s) => { + if !current_path.starts_with(s) { + return None; + } + current_path = ¤t_path[s.len()..]; + } + PathPart::Param { name } => { + let next_slash = current_path.find('/'); + let value = if let Some(next_slash) = next_slash { + ¤t_path[..next_slash] + } else { + current_path + }; + if value.is_empty() { + return None; + } + params.push(PathParam::new(name, value)); + current_path = ¤t_path[value.len()..]; + } + } + } + + Some(CaptureResult::new(params, current_path)) + } + + pub fn reverse(&self, params: &ReverseParamMap) -> Result { + let mut result = String::new(); + + for part in &self.parts { + match part { + PathPart::Literal(s) => result.push_str(s), + PathPart::Param { name } => { + let value = params + .get(name) + .ok_or_else(|| ReverseError::MissingParam(name.clone()))?; + result.push_str(value); + } + } + } + + Ok(result) + } + + #[must_use] + fn param_len(&self) -> usize { + self.parts + .iter() + .map(|part| match part { + PathPart::Literal(..) => 0, + PathPart::Param { .. } => 1, + }) + .sum() + } +} + +#[derive(Debug)] +pub struct ReverseParamMap { + params: HashMap, +} + +impl Default for ReverseParamMap { + fn default() -> Self { + Self::new() + } +} + +impl ReverseParamMap { + #[must_use] + pub fn new() -> Self { + Self { + params: HashMap::new(), + } + } + + pub fn insert(&mut self, key: K, value: V) { + self.params.insert(key.to_string(), value.to_string()); + } + + #[must_use] + fn get(&self, key: &str) -> Option<&String> { + self.params.get(key) + } +} + +#[macro_export] +macro_rules! reverse_param_map { + ( $($key:expr => $value:expr),* ) => {{ + let mut map = $crate::router::path::ReverseParamMap::new(); + $( + map.insert($key, $value); + )* + map + }}; +} + +#[derive(Debug, Error)] +pub enum ReverseError { + #[error("Missing parameter for reverse: `{0}`")] + MissingParam(String), +} + +impl Display for PathMatcher { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + for part in &self.parts { + write!(f, "{}", part)?; + } + Ok(()) + } +} + +#[derive(Debug, PartialEq, Eq)] +pub(super) struct CaptureResult<'matcher, 'path> { + pub(super) params: Vec>, + pub(super) remaining_path: &'path str, +} + +impl<'matcher, 'path> CaptureResult<'matcher, 'path> { + #[must_use] + fn new(params: Vec>, remaining_path: &'path str) -> Self { + Self { + params, + remaining_path, + } + } + + #[must_use] + pub fn matches_fully(&self) -> bool { + self.remaining_path.is_empty() + } +} + +#[derive(Debug, Clone)] +enum PathPart { + Literal(String), + Param { name: String }, +} + +impl Display for PathPart { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + PathPart::Literal(s) => write!(f, "{}", s), + PathPart::Param { name } => write!(f, ":{}", name), + } + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub(super) struct PathParam<'a> { + pub(super) name: &'a str, + pub(super) value: String, +} + +impl<'a> PathParam<'a> { + #[must_use] + pub fn new(name: &'a str, value: &str) -> Self { + Self { + name, + value: value.to_string(), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_path_parser_no_params() { + let path_parser = PathMatcher::new("/users"); + assert_eq!( + path_parser.capture("/users"), + Some(CaptureResult::new(vec![], "")) + ); + assert_eq!(path_parser.capture("/test"), None); + } + + #[test] + fn test_path_parser_single_param() { + let path_parser = PathMatcher::new("/users/:id"); + assert_eq!( + path_parser.capture("/users/123"), + Some(CaptureResult::new(vec![PathParam::new("id", "123")], "")) + ); + assert_eq!( + path_parser.capture("/users/123/"), + Some(CaptureResult::new(vec![PathParam::new("id", "123")], "/")) + ); + assert_eq!( + path_parser.capture("/users/123/abc"), + Some(CaptureResult::new( + vec![PathParam::new("id", "123")], + "/abc" + )) + ); + assert_eq!(path_parser.capture("/users/"), None); + } + + #[test] + fn test_path_parser_multiple_params() { + let path_parser = PathMatcher::new("/users/:id/posts/:post_id"); + assert_eq!( + path_parser.capture("/users/123/posts/456"), + Some(CaptureResult::new( + vec![ + PathParam::new("id", "123"), + PathParam::new("post_id", "456"), + ], + "" + )) + ); + assert_eq!( + path_parser.capture("/users/123/posts/456/abc"), + Some(CaptureResult::new( + vec![ + PathParam::new("id", "123"), + PathParam::new("post_id", "456"), + ], + "/abc" + )) + ); + } + + #[test] + fn reverse_with_valid_params() { + let path_parser = PathMatcher::new("/users/:id/posts/:post_id"); + let mut params = ReverseParamMap::new(); + params.insert("id", "123"); + params.insert("post_id", "456"); + assert_eq!( + path_parser.reverse(¶ms).unwrap(), + "/users/123/posts/456" + ); + } + + #[test] + fn reverse_with_missing_param() { + let path_parser = PathMatcher::new("/users/:id/posts/:post_id"); + let mut params = ReverseParamMap::new(); + params.insert("id", "123"); + let result = path_parser.reverse(¶ms); + assert!(result.is_err()); + assert_eq!( + result.unwrap_err().to_string(), + "Missing parameter for reverse: `post_id`" + ); + } + + #[test] + fn reverse_with_extra_param() { + let path_parser = PathMatcher::new("/users/:id/posts/:post_id"); + let mut params = ReverseParamMap::new(); + params.insert("id", "123"); + params.insert("post_id", "456"); + params.insert("extra", "789"); + assert_eq!( + path_parser.reverse(¶ms).unwrap(), + "/users/123/posts/456" + ); + } +}