diff --git a/.gitignore b/.gitignore index 6985cf1..42901ef 100644 --- a/.gitignore +++ b/.gitignore @@ -12,3 +12,7 @@ Cargo.lock # MSVC Windows builds of rustc generate these, which store debugging information *.pdb + +# Test databases +*.db +*.sqlite3 diff --git a/Cargo.toml b/Cargo.toml index 73a8ae3..88087f6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -4,7 +4,6 @@ members = [ "flareon-admin", "flareon-auth", "flareon-macros", - "flareon-orm", # Examples "examples/hello-world", "examples/todo-list", @@ -22,17 +21,22 @@ axum = "0.7.5" bytes = "1.6.1" chrono = { version = "0.4.38", features = ["serde"] } clap = { version = "4.5.8", features = ["derive", "env"] } +convert_case = "0.6.0" derive_builder = "0.20.0" +derive_more = { version = "1.0.0", features = ["full"] } env_logger = "0.11.3" flareon = { path = "flareon" } flareon_macros = { path = "flareon-macros" } +flareon_orm = { path = "flareon-orm" } form_urlencoded = "1.2.1" indexmap = "2.2.6" itertools = "0.13.0" log = "0.4.22" regex = "1.10.5" +sea-query = "0.32.0-rc.1" +sea-query-binder = { version = "0.7.0-rc.1", features = ["sqlx-any", "runtime-tokio"] } serde = "1.0.203" slug = "0.1.5" -tokio = { version = "1.38.0", features = ["macros", "rt-multi-thread"] } -tower = "0.4.13" +sqlx = { version = "0.8.0", features = ["runtime-tokio", "sqlite"] } thiserror = "1.0.61" +tokio = { version = "1.38.0", features = ["macros", "rt-multi-thread"] } diff --git a/examples/todo-list/src/main.rs b/examples/todo-list/src/main.rs index 8649f08..c51aede 100644 --- a/examples/todo-list/src/main.rs +++ b/examples/todo-list/src/main.rs @@ -1,14 +1,18 @@ use std::sync::Arc; use askama::Template; +use flareon::db::query::ExprEq; +use flareon::db::{model, Database, Model}; use flareon::forms::Form; use flareon::prelude::{Body, Error, FlareonApp, FlareonProject, Response, Route, StatusCode}; use flareon::request::Request; use flareon::reverse; -use tokio::sync::RwLock; +use tokio::sync::OnceCell; #[derive(Debug, Clone)] +#[model] struct TodoItem { + id: i32, title: String, } @@ -19,10 +23,12 @@ struct IndexTemplate<'a> { todo_items: Vec, } -static TODOS: RwLock> = RwLock::const_new(Vec::new()); +static DB: OnceCell = OnceCell::const_new(); async fn index(request: Request) -> Result { - let todo_items = (*TODOS.read().await).clone(); + let db = DB.get().unwrap(); + + let todo_items = TodoItem::objects().all(db).await.unwrap(); let index_template = IndexTemplate { request: &request, todo_items, @@ -45,10 +51,14 @@ async fn add_todo(mut request: Request) -> Result { let todo_form = TodoForm::from_request(&mut request).await.unwrap(); { - let mut todos = TODOS.write().await; - todos.push(TodoItem { + let db = DB.get().unwrap(); + TodoItem { + id: 0, title: todo_form.title, - }); + } + .save(db) + .await + .unwrap(); } Ok(reverse!(request, "index")) @@ -56,11 +66,15 @@ async fn add_todo(mut request: Request) -> Result { async fn remove_todo(request: Request) -> Result { let todo_id = request.path_param("todo_id").expect("todo_id not found"); - let todo_id = todo_id.parse::().expect("todo_id is not a number"); + let todo_id = todo_id.parse::().expect("todo_id is not a number"); { - let mut todos = TODOS.write().await; - todos.remove(todo_id); + let db = DB.get().unwrap(); + TodoItem::objects() + .filter(::Fields::ID.eq(todo_id)) + .delete(db) + .await + .unwrap(); } Ok(reverse!(request, "index")) @@ -70,6 +84,19 @@ async fn remove_todo(request: Request) -> Result { async fn main() { env_logger::init(); + let db = DB + .get_or_init(|| async { Database::new("sqlite::memory:").await.unwrap() }) + .await; + db.execute( + r" + CREATE TABLE todo_item ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + title TEXT NOT NULL + );", + ) + .await + .unwrap(); + let todo_app = FlareonApp::builder() .urls([ Route::with_handler_and_name("/", Arc::new(Box::new(index)), "index"), diff --git a/examples/todo-list/templates/index.html b/examples/todo-list/templates/index.html index 1a973dd..f179d26 100644 --- a/examples/todo-list/templates/index.html +++ b/examples/todo-list/templates/index.html @@ -16,7 +16,7 @@

TODO List

    {% for todo in todo_items %}
  • - {% let todo_id = loop.index0 %} + {% let todo_id = todo.id %}
    todo_id) }}" method="post"> {{ todo.title }} diff --git a/flareon-admin/src/lib.rs b/flareon-admin/src/lib.rs index b93cf3f..53f8cad 100644 --- a/flareon-admin/src/lib.rs +++ b/flareon-admin/src/lib.rs @@ -1,3 +1,4 @@ +#[must_use] pub fn add(left: u64, right: u64) -> u64 { left + right } diff --git a/flareon-auth/src/lib.rs b/flareon-auth/src/lib.rs index b93cf3f..53f8cad 100644 --- a/flareon-auth/src/lib.rs +++ b/flareon-auth/src/lib.rs @@ -1,3 +1,4 @@ +#[must_use] pub fn add(left: u64, right: u64) -> u64 { left + right } diff --git a/flareon-macros/Cargo.toml b/flareon-macros/Cargo.toml index f36566d..01ed193 100644 --- a/flareon-macros/Cargo.toml +++ b/flareon-macros/Cargo.toml @@ -13,6 +13,7 @@ name = "tests" path = "tests/compile_tests.rs" [dependencies] +convert_case.workspace = true darling = "0.20.10" proc-macro-crate = "3.1.0" proc-macro2 = "1.0.86" diff --git a/flareon-macros/src/lib.rs b/flareon-macros/src/lib.rs index f332dc1..d4afb20 100644 --- a/flareon-macros/src/lib.rs +++ b/flareon-macros/src/lib.rs @@ -1,11 +1,15 @@ mod form; +mod model; +use darling::ast::NestedMeta; +use darling::Error; use proc_macro::TokenStream; use proc_macro_crate::crate_name; use quote::quote; use syn::parse_macro_input; use crate::form::impl_form_for_struct; +use crate::model::impl_model_for_struct; /// Derive the [`Form`] trait for a struct. /// @@ -22,6 +26,19 @@ pub fn derive_form(input: TokenStream) -> TokenStream { token_stream.into() } +#[proc_macro_attribute] +pub fn model(args: TokenStream, input: TokenStream) -> TokenStream { + let attr_args = match NestedMeta::parse_meta_list(args.into()) { + Ok(v) => v, + Err(e) => { + return TokenStream::from(Error::from(e).write_errors()); + } + }; + let ast = parse_macro_input!(input as syn::DeriveInput); + let token_stream = impl_model_for_struct(attr_args, ast); + token_stream.into() +} + pub(crate) fn flareon_ident() -> proc_macro2::TokenStream { let flareon_crate = crate_name("flareon").expect("flareon is not present in `Cargo.toml`"); match flareon_crate { diff --git a/flareon-macros/src/model.rs b/flareon-macros/src/model.rs new file mode 100644 index 0000000..a0887c3 --- /dev/null +++ b/flareon-macros/src/model.rs @@ -0,0 +1,178 @@ +use convert_case::{Case, Casing}; +use darling::ast::NestedMeta; +use darling::{FromDeriveInput, FromField}; +use proc_macro2::{Ident, TokenStream}; +use quote::{format_ident, quote, ToTokens, TokenStreamExt}; + +use crate::flareon_ident; + +pub fn impl_model_for_struct(_args: Vec, ast: syn::DeriveInput) -> TokenStream { + let opts = match ModelOpts::from_derive_input(&ast) { + Ok(val) => val, + Err(err) => { + return err.write_errors(); + } + }; + + let mut builder = opts.as_model_builder(); + for field in opts.fields() { + builder.push_field(field); + } + + quote!(#ast #builder) +} + +#[derive(Debug, FromDeriveInput)] +#[darling(forward_attrs(allow, doc, cfg), supports(struct_named))] +struct ModelOpts { + ident: syn::Ident, + data: darling::ast::Data, +} + +impl ModelOpts { + fn fields(&self) -> Vec<&Field> { + self.data + .as_ref() + .take_struct() + .expect("Only structs are supported") + .fields + } + + fn field_count(&self) -> usize { + self.fields().len() + } + + fn as_model_builder(&self) -> ModelBuilder { + let table_name = self.ident.to_string().to_case(Case::Snake); + + ModelBuilder { + name: self.ident.clone(), + table_name, + fields_struct_name: format_ident!("{}Fields", self.ident), + fields_as_columns: Vec::with_capacity(self.field_count()), + fields_as_from_db: Vec::with_capacity(self.field_count()), + fields_as_get_values: Vec::with_capacity(self.field_count()), + fields_as_field_refs: Vec::with_capacity(self.field_count()), + } + } +} + +#[derive(Debug, Clone, FromField)] +#[darling(attributes(form))] +struct Field { + ident: Option, + ty: syn::Type, +} + +#[derive(Debug)] +struct ModelBuilder { + name: Ident, + table_name: String, + fields_struct_name: Ident, + fields_as_columns: Vec, + fields_as_from_db: Vec, + fields_as_get_values: Vec, + fields_as_field_refs: Vec, +} + +impl ToTokens for ModelBuilder { + fn to_tokens(&self, tokens: &mut TokenStream) { + tokens.append_all(self.build_model_impl()); + tokens.append_all(self.build_fields_struct()); + } +} + +impl ModelBuilder { + fn push_field(&mut self, field: &Field) { + let orm_ident = orm_ident(); + + let name = field.ident.as_ref().unwrap(); + let const_name = format_ident!("{}", name.to_string().to_case(Case::UpperSnake)); + let ty = &field.ty; + let index = self.fields_as_columns.len(); + + let column_name = name.to_string().to_case(Case::Snake); + let is_auto = column_name == "id"; + + { + let mut field_as_column = quote!(#orm_ident::Column::new( + #orm_ident::Identifier::new(#column_name) + )); + if is_auto { + field_as_column.append_all(quote!(.auto(true))); + } + self.fields_as_columns.push(field_as_column); + } + + self.fields_as_from_db.push(quote!( + #name: db_row.get::<#ty>(#index)? + )); + + self.fields_as_get_values.push(quote!( + #index => &self.#name as &dyn #orm_ident::ValueRef + )); + + self.fields_as_field_refs.push(quote!( + pub const #const_name: #orm_ident::query::FieldRef<#ty> = + #orm_ident::query::FieldRef::<#ty>::new(#orm_ident::Identifier::new(#column_name)); + )); + } + + fn build_model_impl(&self) -> TokenStream { + let orm_ident = orm_ident(); + + let name = &self.name; + let table_name = &self.table_name; + let fields_struct_name = &self.fields_struct_name; + let fields_as_columns = &self.fields_as_columns; + let fields_as_from_db = &self.fields_as_from_db; + let fields_as_get_values = &self.fields_as_get_values; + + quote! { + #[automatically_derived] + impl #orm_ident::Model for #name { + type Fields = #fields_struct_name; + + const COLUMNS: &'static [#orm_ident::Column] = &[ + #(#fields_as_columns,)* + ]; + const TABLE_NAME: #orm_ident::Identifier = #orm_ident::Identifier::new(#table_name); + + fn from_db(db_row: #orm_ident::Row) -> #orm_ident::Result { + Ok(Self { + #(#fields_as_from_db,)* + }) + } + + fn get_values(&self, columns: &[usize]) -> Vec<&dyn #orm_ident::ValueRef> { + columns + .iter() + .map(|&column| match column { + #(#fields_as_get_values,)* + _ => panic!("Unknown column index: {}", column), + }) + .collect() + } + } + } + } + + fn build_fields_struct(&self) -> TokenStream { + let fields_struct_name = &self.fields_struct_name; + let fields_as_field_refs = &self.fields_as_field_refs; + + quote! { + #[derive(::core::fmt::Debug)] + pub struct #fields_struct_name; + + impl #fields_struct_name { + #(#fields_as_field_refs)* + } + } + } +} + +fn orm_ident() -> TokenStream { + let crate_ident = flareon_ident(); + quote! { #crate_ident::db } +} diff --git a/flareon-macros/tests/compile_tests.rs b/flareon-macros/tests/compile_tests.rs index d067819..30db393 100644 --- a/flareon-macros/tests/compile_tests.rs +++ b/flareon-macros/tests/compile_tests.rs @@ -3,3 +3,9 @@ fn test_derive_form() { let t = trybuild::TestCases::new(); t.pass("tests/ui/derive_form.rs"); } + +#[test] +fn test_attr_model() { + let t = trybuild::TestCases::new(); + t.pass("tests/ui/attr_model.rs"); +} diff --git a/flareon-macros/tests/ui/attr_model.rs b/flareon-macros/tests/ui/attr_model.rs new file mode 100644 index 0000000..0b7f457 --- /dev/null +++ b/flareon-macros/tests/ui/attr_model.rs @@ -0,0 +1,14 @@ +use flareon::db::{model, Model}; + +#[derive(Debug)] +#[model] +struct MyModel { + id: i32, + name: std::string::String, + description: String, + visits: i32, +} + +fn main() { + println!("{:?}", MyModel::TABLE_NAME); +} diff --git a/flareon-orm/Cargo.toml b/flareon-orm/Cargo.toml deleted file mode 100644 index 3b85638..0000000 --- a/flareon-orm/Cargo.toml +++ /dev/null @@ -1,8 +0,0 @@ -[package] -name = "flareon-orm" -version = "0.1.0" -edition.workspace = true -license.workspace = true -description = "Modern web framework focused on speed and ease of use - ORM." - -[dependencies] diff --git a/flareon-orm/src/lib.rs b/flareon-orm/src/lib.rs deleted file mode 100644 index b93cf3f..0000000 --- a/flareon-orm/src/lib.rs +++ /dev/null @@ -1,14 +0,0 @@ -pub fn add(left: u64, right: u64) -> u64 { - left + right -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn it_works() { - let result = add(2, 2); - assert_eq!(result, 4); - } -} diff --git a/flareon/Cargo.toml b/flareon/Cargo.toml index cd1df40..9da559a 100644 --- a/flareon/Cargo.toml +++ b/flareon/Cargo.toml @@ -11,10 +11,14 @@ async-trait.workspace = true axum.workspace = true bytes.workspace = true derive_builder.workspace = true +derive_more.workspace = true flareon_macros.workspace = true form_urlencoded.workspace = true indexmap.workspace = true log.workspace = true regex.workspace = true +sea-query-binder.workspace = true +sea-query.workspace = true +sqlx.workspace = true thiserror.workspace = true tokio.workspace = true diff --git a/flareon/src/db.rs b/flareon/src/db.rs new file mode 100644 index 0000000..3d1717b --- /dev/null +++ b/flareon/src/db.rs @@ -0,0 +1,409 @@ +mod fields; +pub mod query; + +use std::fmt::Write; + +use async_trait::async_trait; +use derive_more::{Debug, Deref, Display}; +pub use flareon_macros::model; +use log::debug; +use query::Query; +use sea_query::{Iden, QueryBuilder, SchemaBuilder, SimpleExpr, SqliteQueryBuilder}; +use sea_query_binder::SqlxBinder; +use sqlx::any::AnyPoolOptions; +use sqlx::pool::PoolConnection; +use sqlx::{AnyPool, Type}; +use thiserror::Error; + +/// An error that can occur when interacting with the database. +#[derive(Debug, Error)] +#[non_exhaustive] +pub enum DatabaseError { + #[error("Database engine error: {0}")] + DatabaseEngineError(#[from] sqlx::Error), + #[error("Error when building query: {0}")] + QueryBuildingError(#[from] sea_query::error::Error), +} + +pub type Result = std::result::Result; + +/// A model trait for database models. +/// +/// This trait is used to define a model that can be stored in a database. +/// It is used to define the structure of the model, the table name, and the +/// columns. +/// +/// # Deriving +/// +/// This trait can, and should be derived using the [`model`] attribute macro. +/// This macro generates the implementation of the trait for the type, including +/// the implementation of the `Fields` helper struct. +/// +/// ```rust +/// use flareon::db::model; +/// +/// #[model] +/// struct MyModel { +/// id: i32, +/// name: String, +/// } +/// ``` +#[async_trait] +pub trait Model: Sized + Send { + /// A helper structure for the fields of the model. + /// + /// This structure should a constant [`FieldRef`](query::FieldRef) instance + /// for each field in the model. Note that the names of the fields + /// should be written in UPPER_SNAKE_CASE, just like other constants in + /// Rust. + type Fields; + + /// The name of the table in the database. + const TABLE_NAME: Identifier; + + /// The columns of the model. + const COLUMNS: &'static [Column]; + + /// Creates a model instance from a database row. + fn from_db(db_row: Row) -> Result; + + /// Gets the values of the model for the given columns. + fn get_values(&self, columns: &[usize]) -> Vec<&dyn ValueRef>; + + /// Returns a query for all objects of this model. + #[must_use] + fn objects() -> Query { + Query::new() + } + + /// Saves the model to the database. + async fn save(&mut self, db: &Database) -> Result<()> { + db.insert(self).await?; + Ok(()) + } +} + +/// An identifier structure that holds a table or column name as static string. +#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Deref, Display)] +pub struct Identifier(pub &'static str); + +impl Identifier { + /// Creates a new identifier from a static string. + #[must_use] + pub const fn new(s: &'static str) -> Self { + Self(s) + } + + /// Returns the inner string of the identifier. + #[must_use] + pub fn as_str(&self) -> &str { + self.0 + } +} + +impl From<&'static str> for Identifier { + fn from(s: &'static str) -> Self { + Self(s) + } +} + +impl Iden for Identifier { + fn unquoted(&self, s: &mut dyn Write) { + s.write_str(self.as_str()).unwrap(); + } +} + +/// A column structure that holds the name of the column and some additional +/// schema information. +#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub struct Column { + name: Identifier, + auto_value: bool, + null: bool, +} + +impl Column { + #[must_use] + pub const fn new(name: Identifier) -> Self { + Self { + name, + auto_value: false, + null: false, + } + } + + #[must_use] + pub const fn auto(mut self, auto_value: bool) -> Self { + self.auto_value = auto_value; + self + } + + #[must_use] + pub const fn null(mut self, null: bool) -> Self { + self.null = null; + self + } +} + +/// A row structure that holds the data of a single row retrieved from the +/// database. +#[derive(Debug)] +pub struct Row { + #[debug("...")] + inner: sqlx::any::AnyRow, +} + +impl Row { + #[must_use] + fn new(inner: sqlx::any::AnyRow) -> Self { + Self { inner } + } + + pub fn get<'r, T: FromDbValue<'r>>(&'r self, index: usize) -> Result { + Ok(T::from_sqlx(sqlx::Row::try_get::( + &self.inner, + index, + )?)) + } +} + +/// A trait for converting a database value to a Rust value. +pub trait FromDbValue<'r> { + type SqlxType: sqlx::decode::Decode<'r, sqlx::any::Any> + Type; + + fn from_sqlx(value: Self::SqlxType) -> Self; +} + +/// A trait for converting a Rust value to a database value. +pub trait ValueRef: Send + Sync { + fn as_sea_query_value(&self) -> sea_query::Value; +} + +/// A database connection structure that holds the connection to the database. +/// +/// It is used to execute queries and interact with the database. The connection +/// is established when the structure is created and closed when +/// [`Self::close()`] is called. +#[derive(Debug)] +pub struct Database { + _url: String, + db_connection: AnyPool, + #[debug("...")] + query_builder: Box, + #[debug("...")] + _schema_builder: Box, +} + +impl Database { + pub async fn new>(url: T) -> Result { + sqlx::any::install_default_drivers(); + + let url = url.into(); + + let db_conn = AnyPoolOptions::new() + .max_connections(1) + .connect(&url) + .await?; + let (query_builder, schema_builder): ( + Box, + Box, + ) = { + if url.starts_with("sqlite:") { + (Box::new(SqliteQueryBuilder), Box::new(SqliteQueryBuilder)) + } else { + todo!("Other databases are not supported yet"); + } + }; + + Ok(Self { + _url: url, + db_connection: db_conn, + query_builder, + _schema_builder: schema_builder, + }) + } + + pub async fn close(self) -> Result<()> { + self.db_connection.close().await; + Ok(()) + } + + pub async fn execute(&self, query: &str) -> Result { + let sqlx_query = sqlx::query(query); + + self.execute_sqlx(sqlx_query).await + } + + async fn execute_sqlx<'a, A>( + &self, + sqlx_query: sqlx::query::Query<'a, sqlx::any::Any, A>, + ) -> Result + where + A: 'a + sqlx::IntoArguments<'a, sqlx::any::Any>, + { + let result = sqlx_query.execute(&mut *self.conn().await?).await?; + let result = QueryResult { + rows_affected: RowsNum(result.rows_affected()), + }; + + debug!("Rows affected: {}", result.rows_affected.0); + Ok(result) + } + + async fn conn(&self) -> Result> { + Ok(self.db_connection.acquire().await?) + } + + pub async fn insert(&self, data: &mut T) -> Result<()> { + let non_auto_column_identifiers = T::COLUMNS + .iter() + .filter_map(|column| { + if column.auto_value { + None + } else { + Some(Identifier::from(column.name.as_str())) + } + }) + .collect::>(); + let value_indices = T::COLUMNS + .iter() + .enumerate() + .filter_map(|(i, column)| if column.auto_value { None } else { Some(i) }) + .collect::>(); + let values = data.get_values(&value_indices); + + let (sql, values) = sea_query::Query::insert() + .into_table(T::TABLE_NAME) + .columns(non_auto_column_identifiers) + .values( + values + .into_iter() + .map(|value| SimpleExpr::Value(value.as_sea_query_value())) + .collect::>(), + )? + .returning_col(Identifier::new("id")) + .build_any_sqlx(self.query_builder.as_ref()); + + debug!("Insert query: {}", sql); + + let row = sqlx::query_with(&sql, values) + .fetch_one(&mut *self.conn().await?) + .await?; + let id: i64 = sqlx::Row::try_get(&row, 0)?; + debug!("Inserted row with id: {}", id); + + Ok(()) + } + + pub async fn query(&self, query: &Query) -> Result> { + let columns_to_get: Vec<_> = T::COLUMNS.iter().map(|column| column.name).collect(); + let mut select = sea_query::Query::select(); + select.columns(columns_to_get).from(T::TABLE_NAME); + query.modify_statement(&mut select); + let (sql, values) = select.build_any_sqlx(self.query_builder.as_ref()); + + debug!("Select query: {}", sql); + + let rows: Vec = sqlx::query_with(&sql, values) + .fetch_all(&mut *self.conn().await?) + .await? + .into_iter() + .map(|row| T::from_db(Row::new(row))) + .collect::>()?; + + Ok(rows) + } + + pub async fn delete(&self, query: &Query) -> Result { + let mut delete = sea_query::Query::delete(); + delete.from_table(T::TABLE_NAME); + query.modify_statement(&mut delete); + let (sql, values) = delete.build_any_sqlx(self.query_builder.as_ref()); + + debug!("Delete query: {}", sql); + + self.execute_sqlx(sqlx::query_with(&sql, values)).await + } +} + +/// Result of a query execution. +#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub struct QueryResult { + rows_affected: RowsNum, +} + +impl QueryResult { + /// Returns the number of rows affected by the query. + #[must_use] + pub fn rows_affected(&self) -> RowsNum { + self.rows_affected + } +} + +/// A structure that holds the number of rows. +#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Deref, Display)] +pub struct RowsNum(pub u64); + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_identifier() { + let id = Identifier::new("test"); + assert_eq!(id.as_str(), "test"); + } + + #[test] + fn test_column() { + let column = Column::new(Identifier::new("test")); + assert_eq!(column.name.as_str(), "test"); + assert!(!column.auto_value); + assert!(!column.null); + } + + #[derive(std::fmt::Debug, PartialEq)] + #[model] + struct TestModel { + id: i32, + name: String, + } + + #[tokio::test] + async fn test_model_crud() { + let db = test_db().await; + + db.execute( + r" + CREATE TABLE test_model ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + name TEXT NOT NULL + );", + ) + .await + .unwrap(); + + assert_eq!(TestModel::objects().all(&db).await.unwrap(), vec![]); + + let mut model = TestModel { + id: 0, + name: "test".to_owned(), + }; + model.save(&db).await.unwrap(); + + assert_eq!(TestModel::objects().all(&db).await.unwrap().len(), 1); + + use crate::db::query::ExprEq; + TestModel::objects() + .filter(::Fields::ID.eq(1)) + .delete(&db) + .await + .unwrap(); + + assert_eq!(TestModel::objects().all(&db).await.unwrap(), vec![]); + } + + async fn test_db() -> Database { + Database::new("sqlite::memory:").await.unwrap() + } +} diff --git a/flareon/src/db/fields.rs b/flareon/src/db/fields.rs new file mode 100644 index 0000000..13b6e27 --- /dev/null +++ b/flareon/src/db/fields.rs @@ -0,0 +1,31 @@ +use sea_query::Value; + +use crate::db::{FromDbValue, ValueRef}; + +impl FromDbValue<'_> for i32 { + type SqlxType = i32; + + fn from_sqlx(value: Self::SqlxType) -> Self { + value + } +} + +impl ValueRef for i32 { + fn as_sea_query_value(&self) -> Value { + (*self).into() + } +} + +impl ValueRef for String { + fn as_sea_query_value(&self) -> Value { + self.into() + } +} + +impl FromDbValue<'_> for String { + type SqlxType = String; + + fn from_sqlx(value: Self::SqlxType) -> Self { + value + } +} diff --git a/flareon/src/db/query.rs b/flareon/src/db/query.rs new file mode 100644 index 0000000..6dc5c81 --- /dev/null +++ b/flareon/src/db/query.rs @@ -0,0 +1,102 @@ +use std::marker::PhantomData; + +use derive_more::Debug; +use sea_query::IntoColumnRef; + +use crate::db; +use crate::db::{Database, FromDbValue, Identifier, Model, QueryResult, ValueRef}; + +#[derive(Debug)] +pub struct Query { + filter: Option, + phantom_data: PhantomData, +} + +impl Default for Query { + fn default() -> Self { + Self::new() + } +} + +impl Query { + #[must_use] + pub fn new() -> Self { + Self { + filter: None, + phantom_data: PhantomData, + } + } + + pub fn filter(&mut self, filter: Expr) -> &mut Self { + self.filter = Some(filter); + self + } + + pub async fn all(&self, db: &Database) -> db::Result> { + db.query(self).await + } + + pub async fn delete(&self, db: &Database) -> db::Result { + db.delete(self).await + } + + pub(super) fn modify_statement(&self, statement: &mut S) { + if let Some(filter) = &self.filter { + statement.and_where(filter.as_sea_query_expr()); + } + } +} + +#[derive(Debug)] +pub enum Expr { + Column(Identifier), + Value(#[debug("{}", _0.as_sea_query_value())] Box), + Eq(Box, Box), +} + +impl Expr { + #[must_use] + pub fn value(value: T) -> Self { + Self::Value(Box::new(value)) + } + + #[must_use] + pub fn eq(lhs: Self, rhs: Self) -> Self { + Self::Eq(Box::new(lhs), Box::new(rhs)) + } + + #[must_use] + pub fn as_sea_query_expr(&self) -> sea_query::SimpleExpr { + match self { + Self::Column(identifier) => identifier.into_column_ref().into(), + Self::Eq(lhs, rhs) => lhs.as_sea_query_expr().eq(rhs.as_sea_query_expr()), + Self::Value(value) => value.as_sea_query_value().into(), + } + } +} + +#[derive(Debug)] +pub struct FieldRef { + identifier: Identifier, + phantom_data: PhantomData, +} + +impl<'a, T: FromDbValue<'a> + ValueRef> FieldRef { + #[must_use] + pub const fn new(identifier: Identifier) -> Self { + Self { + identifier, + phantom_data: PhantomData, + } + } +} + +pub trait ExprEq { + fn eq(self, other: T) -> Expr; +} + +impl ExprEq for FieldRef { + fn eq(self, other: T) -> Expr { + Expr::eq(Expr::Column(self.identifier), Expr::value(other)) + } +} diff --git a/flareon/src/error.rs b/flareon/src/error.rs index c08f41e..c214e7f 100644 --- a/flareon/src/error.rs +++ b/flareon/src/error.rs @@ -21,6 +21,8 @@ pub enum Error { ReverseError(#[from] crate::router::path::ReverseError), #[error("Failed to render template: {0}")] TemplateRender(#[from] askama::Error), + #[error("Database error: {0}")] + DatabaseError(#[from] crate::db::DatabaseError), } impl From for askama::Error { diff --git a/flareon/src/forms.rs b/flareon/src/forms.rs index 585d319..10377bd 100644 --- a/flareon/src/forms.rs +++ b/flareon/src/forms.rs @@ -242,7 +242,7 @@ pub struct CharField { } /// Custom options for a `CharField`. -#[derive(Debug, Default)] +#[derive(Debug, Default, Copy, Clone)] pub struct CharFieldOptions { /// The maximum length of the field. Used to set the `maxlength` attribute /// in the HTML input element. diff --git a/flareon/src/lib.rs b/flareon/src/lib.rs index 9bba5b2..ff0010b 100644 --- a/flareon/src/lib.rs +++ b/flareon/src/lib.rs @@ -1,5 +1,18 @@ +#![warn( + missing_debug_implementations, + missing_copy_implementations, + trivial_casts, + trivial_numeric_casts, + unreachable_pub, + unsafe_code, + unstable_features, + unused_import_braces, + unused_qualifications +)] + extern crate self as flareon; +pub mod db; mod error; pub mod forms; pub mod prelude; @@ -23,7 +36,7 @@ use log::info; use request::Request; use router::{Route, Router}; -pub type Result = std::result::Result; +pub type Result = std::result::Result; pub type StatusCode = axum::http::StatusCode; diff --git a/flareon/src/request.rs b/flareon/src/request.rs index f672572..a4b08ef 100644 --- a/flareon/src/request.rs +++ b/flareon/src/request.rs @@ -92,6 +92,6 @@ impl Request { #[must_use] pub fn path_param(&self, name: &str) -> Option<&str> { - self.path_params.get(name).map(std::string::String::as_str) + self.path_params.get(name).map(String::as_str) } } diff --git a/flareon/src/router/path.rs b/flareon/src/router/path.rs index 0b65da7..e57c3f6 100644 --- a/flareon/src/router/path.rs +++ b/flareon/src/router/path.rs @@ -12,7 +12,7 @@ pub(super) struct PathMatcher { impl PathMatcher { #[must_use] - pub fn new>(path_pattern: T) -> Self { + pub(crate) fn new>(path_pattern: T) -> Self { let path_pattern = path_pattern.into(); let mut last_end = 0; @@ -61,7 +61,7 @@ impl PathMatcher { } #[must_use] - pub fn capture<'matcher, 'path>( + pub(crate) fn capture<'matcher, 'path>( &'matcher self, path: &'path str, ) -> Option> { @@ -96,7 +96,7 @@ impl PathMatcher { Some(CaptureResult::new(params, current_path)) } - pub fn reverse(&self, params: &ReverseParamMap) -> Result { + pub(crate) fn reverse(&self, params: &ReverseParamMap) -> Result { let mut result = String::new(); for part in &self.parts { @@ -197,7 +197,7 @@ impl<'matcher, 'path> CaptureResult<'matcher, 'path> { } #[must_use] - pub fn matches_fully(&self) -> bool { + pub(crate) fn matches_fully(&self) -> bool { self.remaining_path.is_empty() } } @@ -225,7 +225,7 @@ pub(super) struct PathParam<'a> { impl<'a> PathParam<'a> { #[must_use] - pub fn new(name: &'a str, value: &str) -> Self { + pub(crate) fn new(name: &'a str, value: &str) -> Self { Self { name, value: value.to_string(),