From 4bb814c47395e0a2fd1434c0e36da509205d5455 Mon Sep 17 00:00:00 2001 From: Warm Beer Date: Tue, 13 Sep 2022 13:18:46 +0200 Subject: [PATCH] feat: auto detect `db_driver` from `connect_url` --- src/config.rs | 3 -- src/databases/database.rs | 106 ++++++++++++++++++-------------------- src/errors.rs | 3 +- src/main.rs | 6 ++- tests/databases/mod.rs | 10 ++-- tests/databases/mysql.rs | 3 +- tests/databases/sqlite.rs | 3 +- 7 files changed, 67 insertions(+), 67 deletions(-) diff --git a/src/config.rs b/src/config.rs index 89078f2d..cb3b5546 100644 --- a/src/config.rs +++ b/src/config.rs @@ -3,7 +3,6 @@ use config::{ConfigError, Config, File}; use std::path::Path; use serde::{Serialize, Deserialize}; use tokio::sync::RwLock; -use crate::databases::database::DatabaseDriver; #[derive(Debug, Clone, Serialize, Deserialize)] pub struct Website { @@ -50,7 +49,6 @@ pub struct Auth { #[derive(Debug, Clone, Serialize, Deserialize)] pub struct Database { - pub db_driver: DatabaseDriver, pub connect_url: String, pub torrent_info_update_interval: u64, } @@ -105,7 +103,6 @@ impl Configuration { secret_key: "MaxVerstappenWC2021".to_string() }, database: Database { - db_driver: DatabaseDriver::Sqlite3, connect_url: "sqlite://data.db?mode=rwc".to_string(), torrent_info_update_interval: 3600 }, diff --git a/src/databases/database.rs b/src/databases/database.rs index c22f8202..856092d5 100644 --- a/src/databases/database.rs +++ b/src/databases/database.rs @@ -1,6 +1,7 @@ use async_trait::async_trait; use chrono::{NaiveDateTime}; use serde::{Serialize, Deserialize}; + use crate::databases::mysql::MysqlDatabase; use crate::databases::sqlite::SqliteDatabase; use crate::models::response::{TorrentsResponse}; @@ -9,18 +10,21 @@ use crate::models::torrent_file::{DbTorrentInfo, Torrent, TorrentFile}; use crate::models::tracker_key::TrackerKey; use crate::models::user::{User, UserAuthentication, UserCompact, UserProfile}; +/// Database drivers. #[derive(PartialEq, Debug, Clone, Serialize, Deserialize)] pub enum DatabaseDriver { Sqlite3, Mysql } +/// Compact representation of torrent. #[derive(Debug, Serialize, sqlx::FromRow)] pub struct TorrentCompact { pub torrent_id: i64, pub info_hash: String, } +/// Torrent category. #[derive(Debug, Serialize, Deserialize, sqlx::FromRow)] pub struct Category { pub category_id: i64, @@ -28,6 +32,7 @@ pub struct Category { pub num_torrents: i64 } +/// Sorting options for torrents. #[derive(Clone, Copy, Debug, Deserialize)] pub enum Sorting { UploadedAsc, @@ -42,9 +47,11 @@ pub enum Sorting { SizeDesc, } +/// Database errors. #[derive(Debug)] pub enum DatabaseError { Error, + UnrecognizedDatabaseDriver, // when the db path does not start with sqlite or mysql UsernameTaken, EmailTaken, UserNotFound, @@ -55,127 +62,116 @@ pub enum DatabaseError { TorrentTitleAlreadyExists, } -pub async fn connect_database(db_driver: &DatabaseDriver, db_path: &str) -> Box { - // match &db_path.chars().collect::>() as &[char] { - // ['s', 'q', 'l', 'i', 't', 'e', ..] => { - // let db = SqliteDatabase::new(db_path).await; - // Ok(Box::new(db)) - // } - // ['m', 'y', 's', 'q', 'l', ..] => { - // let db = MysqlDatabase::new(db_path).await; - // Ok(Box::new(db)) - // } - // _ => { - // Err(()) - // } - // } - - match db_driver { - DatabaseDriver::Sqlite3 => { +/// Connect to a database. +pub async fn connect_database(db_path: &str) -> Result, DatabaseError> { + match &db_path.chars().collect::>() as &[char] { + ['s', 'q', 'l', 'i', 't', 'e', ..] => { let db = SqliteDatabase::new(db_path).await; - Box::new(db) + Ok(Box::new(db)) } - DatabaseDriver::Mysql => { + ['m', 'y', 's', 'q', 'l', ..] => { let db = MysqlDatabase::new(db_path).await; - Box::new(db) + Ok(Box::new(db)) + } + _ => { + Err(DatabaseError::UnrecognizedDatabaseDriver) } } } +/// Trait for database implementations. #[async_trait] pub trait Database: Sync + Send { - // return current database driver + /// Return current database driver. fn get_database_driver(&self) -> DatabaseDriver; - // add new user and get the newly inserted user_id + /// Add new user and return the newly inserted `user_id`. async fn insert_user_and_get_id(&self, username: &str, email: &str, password: &str) -> Result; - // get user profile by user_id + /// Get `User` from `user_id`. async fn get_user_from_id(&self, user_id: i64) -> Result; - // get user authentication by user_id + /// Get `UserAuthentication` from `user_id`. async fn get_user_authentication_from_id(&self, user_id: i64) -> Result; - // get user profile by username + /// Get `UserProfile` from `username`. async fn get_user_profile_from_username(&self, username: &str) -> Result; - // get user compact by user_id + /// Get `UserCompact` from `user_id`. async fn get_user_compact_from_id(&self, user_id: i64) -> Result; - // todo: change to get all tracker keys of user, no matter if they are still valid - // get a user's tracker key + /// Get a user's `TrackerKey`. async fn get_user_tracker_key(&self, user_id: i64) -> Option; - // count users + /// Get total user count. async fn count_users(&self) -> Result; - // todo: make DateTime struct for the date_expiry - // ban user + /// Ban user with `user_id`, `reason` and `date_expiry`. async fn ban_user(&self, user_id: i64, reason: &str, date_expiry: NaiveDateTime) -> Result<(), DatabaseError>; - // give a user administrator rights + /// Grant a user the administrator role. async fn grant_admin_role(&self, user_id: i64) -> Result<(), DatabaseError>; - // verify email + /// Verify a user's email with `user_id`. async fn verify_email(&self, user_id: i64) -> Result<(), DatabaseError>; - // create a new tracker key for a certain user + /// Link a `TrackerKey` to a certain user with `user_id`. async fn add_tracker_key(&self, user_id: i64, tracker_key: &TrackerKey) -> Result<(), DatabaseError>; - // delete user + /// Delete user and all related user data with `user_id`. async fn delete_user(&self, user_id: i64) -> Result<(), DatabaseError>; - // add new category + /// Add a new category and return `category_id`. async fn insert_category_and_get_id(&self, category_name: &str) -> Result; - // get category by id - async fn get_category_from_id(&self, id: i64) -> Result; + /// Get `Category` from `category_id`. + async fn get_category_from_id(&self, category_id: i64) -> Result; - // get category by name - async fn get_category_from_name(&self, category: &str) -> Result; + /// Get `Category` from `category_name`. + async fn get_category_from_name(&self, category_name: &str) -> Result; - // get all categories + /// Get all categories as `Vec`. async fn get_categories(&self) -> Result, DatabaseError>; - // delete category + /// Delete category with `category_name`. async fn delete_category(&self, category_name: &str) -> Result<(), DatabaseError>; - // get results of a torrent search in a paginated and sorted form + /// Get results of a torrent search in a paginated and sorted form as `TorrentsResponse` from `search`, `categories`, `sort`, `offset` and `page_size`. async fn get_torrents_search_sorted_paginated(&self, search: &Option, categories: &Option>, sort: &Sorting, offset: u64, page_size: u8) -> Result; - // add new torrent and get the newly inserted torrent_id + /// Add new torrent and return the newly inserted `torrent_id` with `torrent`, `uploader_id`, `category_id`, `title` and `description`. async fn insert_torrent_and_get_id(&self, torrent: &Torrent, uploader_id: i64, category_id: i64, title: &str, description: &str) -> Result; - // get torrent by id + /// Get `Torrent` from `torrent_id`. async fn get_torrent_from_id(&self, torrent_id: i64) -> Result; - // get torrent info by id + /// Get torrent's info as `DbTorrentInfo` from `torrent_id`. async fn get_torrent_info_from_id(&self, torrent_id: i64) -> Result; - // get torrent files by id + /// Get all torrent's files as `Vec` from `torrent_id`. async fn get_torrent_files_from_id(&self, torrent_id: i64) -> Result, DatabaseError>; - // get torrent announce urls by id + /// Get all torrent's announce urls as `Vec>` from `torrent_id`. async fn get_torrent_announce_urls_from_id(&self, torrent_id: i64) -> Result>, DatabaseError>; - // get torrent listing by id + /// Get `TorrentListing` from `torrent_id`. async fn get_torrent_listing_from_id(&self, torrent_id: i64) -> Result; - // get all torrents (torrent_id + info_hash) + /// Get all torrents as `Vec`. async fn get_all_torrents_compact(&self) -> Result, DatabaseError>; - // update a torrent's title + /// Update a torrent's title with `torrent_id` and `title`. async fn update_torrent_title(&self, torrent_id: i64, title: &str) -> Result<(), DatabaseError>; - // update a torrent's description + /// Update a torrent's description with `torrent_id` and `description`. async fn update_torrent_description(&self, torrent_id: i64, description: &str) -> Result<(), DatabaseError>; - // update the seeders and leechers info for a particular torrent + /// Update the seeders and leechers info for a torrent with `torrent_id`, `tracker_url`, `seeders` and `leechers`. async fn update_tracker_info(&self, torrent_id: i64, tracker_url: &str, seeders: i64, leechers: i64) -> Result<(), DatabaseError>; - // delete a torrent + /// Delete a torrent with `torrent_id`. async fn delete_torrent(&self, torrent_id: i64) -> Result<(), DatabaseError>; - // DELETES ALL DATABASE ROWS, ONLY CALL THIS IF YOU KNOW WHAT YOU'RE DOING! + /// DELETES ALL DATABASE ROWS, ONLY CALL THIS IF YOU KNOW WHAT YOU'RE DOING! async fn delete_all_database_rows(&self) -> Result<(), DatabaseError>; } diff --git a/src/errors.rs b/src/errors.rs index eef4f851..675ca05f 100644 --- a/src/errors.rs +++ b/src/errors.rs @@ -222,7 +222,8 @@ impl From for ServiceError { DatabaseError::CategoryNotFound => ServiceError::InvalidCategory, DatabaseError::TorrentNotFound => ServiceError::TorrentNotFound, DatabaseError::TorrentAlreadyExists => ServiceError::InfoHashAlreadyExists, - DatabaseError::TorrentTitleAlreadyExists => ServiceError::TorrentTitleAlreadyExists + DatabaseError::TorrentTitleAlreadyExists => ServiceError::TorrentTitleAlreadyExists, + DatabaseError::UnrecognizedDatabaseDriver => ServiceError::InternalServerError, } } } diff --git a/src/main.rs b/src/main.rs index 06304307..7fc5d0f6 100644 --- a/src/main.rs +++ b/src/main.rs @@ -20,7 +20,11 @@ async fn main() -> std::io::Result<()> { let settings = cfg.settings.read().await; - let database = Arc::new(connect_database(&settings.database.db_driver, &settings.database.connect_url).await); + let database = Arc::new(connect_database(&settings.database.connect_url) + .await + .expect("Database error.") + ); + let auth = Arc::new(AuthorizationService::new(cfg.clone(), database.clone())); let tracker_service = Arc::new(TrackerService::new(cfg.clone(), database.clone())); let mailer_service = Arc::new(MailerService::new(cfg.clone()).await); diff --git a/tests/databases/mod.rs b/tests/databases/mod.rs index 66f90e92..adf8bc52 100644 --- a/tests/databases/mod.rs +++ b/tests/databases/mod.rs @@ -1,5 +1,5 @@ use std::future::Future; -use torrust_index_backend::databases::database::{connect_database, Database, DatabaseDriver}; +use torrust_index_backend::databases::database::{connect_database, Database}; mod mysql; mod tests; @@ -19,8 +19,12 @@ async fn run_test<'a, T, F>(db_fn: T, db: &'a Box) } // runs all tests -pub async fn run_tests(db_driver: DatabaseDriver, db_path: &str) { - let db = connect_database(&db_driver, db_path).await; +pub async fn run_tests(db_path: &str) { + let db_res = connect_database(db_path).await; + + assert!(db_res.is_ok()); + + let db = db_res.unwrap(); run_test(tests::it_can_add_a_user, &db).await; run_test(tests::it_can_add_a_torrent_category, &db).await; diff --git a/tests/databases/mysql.rs b/tests/databases/mysql.rs index d64ac1b3..c0f78429 100644 --- a/tests/databases/mysql.rs +++ b/tests/databases/mysql.rs @@ -1,11 +1,10 @@ -use torrust_index_backend::databases::database::{DatabaseDriver}; use crate::databases::{run_tests}; const DATABASE_URL: &str = "mysql://root:password@localhost:3306/torrust-index_test"; #[tokio::test] async fn run_mysql_tests() { - run_tests(DatabaseDriver::Mysql, DATABASE_URL).await; + run_tests(DATABASE_URL).await; } diff --git a/tests/databases/sqlite.rs b/tests/databases/sqlite.rs index 7aab5b1d..940d7e6b 100644 --- a/tests/databases/sqlite.rs +++ b/tests/databases/sqlite.rs @@ -1,11 +1,10 @@ -use torrust_index_backend::databases::database::{DatabaseDriver}; use crate::databases::{run_tests}; const DATABASE_URL: &str = "sqlite::memory:"; #[tokio::test] async fn run_sqlite_tests() { - run_tests(DatabaseDriver::Sqlite3, DATABASE_URL).await; + run_tests(DATABASE_URL).await; }