Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: auto detect db_driver from connect_url #70

Merged
merged 1 commit into from
Sep 13, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ use config::{ConfigError, Config, File};
use std::path::Path;
use serde::{Serialize, Deserialize};
use tokio::sync::RwLock;
use crate::databases::database::DatabaseDriver;

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Website {
Expand Down Expand Up @@ -50,7 +49,6 @@ pub struct Auth {

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Database {
pub db_driver: DatabaseDriver,
pub connect_url: String,
pub torrent_info_update_interval: u64,
}
Expand Down Expand Up @@ -105,7 +103,6 @@ impl Configuration {
secret_key: "MaxVerstappenWC2021".to_string()
},
database: Database {
db_driver: DatabaseDriver::Sqlite3,
connect_url: "sqlite://data.db?mode=rwc".to_string(),
torrent_info_update_interval: 3600
},
Expand Down
106 changes: 51 additions & 55 deletions src/databases/database.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use async_trait::async_trait;
use chrono::{NaiveDateTime};
use serde::{Serialize, Deserialize};

use crate::databases::mysql::MysqlDatabase;
use crate::databases::sqlite::SqliteDatabase;
use crate::models::response::{TorrentsResponse};
Expand All @@ -9,25 +10,29 @@ use crate::models::torrent_file::{DbTorrentInfo, Torrent, TorrentFile};
use crate::models::tracker_key::TrackerKey;
use crate::models::user::{User, UserAuthentication, UserCompact, UserProfile};

/// Database drivers.
#[derive(PartialEq, Debug, Clone, Serialize, Deserialize)]
pub enum DatabaseDriver {
Sqlite3,
Mysql
}

/// Compact representation of torrent.
#[derive(Debug, Serialize, sqlx::FromRow)]
pub struct TorrentCompact {
pub torrent_id: i64,
pub info_hash: String,
}

/// Torrent category.
#[derive(Debug, Serialize, Deserialize, sqlx::FromRow)]
pub struct Category {
pub category_id: i64,
pub name: String,
pub num_torrents: i64
}

/// Sorting options for torrents.
#[derive(Clone, Copy, Debug, Deserialize)]
pub enum Sorting {
UploadedAsc,
Expand All @@ -42,9 +47,11 @@ pub enum Sorting {
SizeDesc,
}

/// Database errors.
#[derive(Debug)]
pub enum DatabaseError {
Error,
UnrecognizedDatabaseDriver, // when the db path does not start with sqlite or mysql
UsernameTaken,
EmailTaken,
UserNotFound,
Expand All @@ -55,127 +62,116 @@ pub enum DatabaseError {
TorrentTitleAlreadyExists,
}

pub async fn connect_database(db_driver: &DatabaseDriver, db_path: &str) -> Box<dyn Database> {
// match &db_path.chars().collect::<Vec<char>>() as &[char] {
// ['s', 'q', 'l', 'i', 't', 'e', ..] => {
// let db = SqliteDatabase::new(db_path).await;
// Ok(Box::new(db))
// }
// ['m', 'y', 's', 'q', 'l', ..] => {
// let db = MysqlDatabase::new(db_path).await;
// Ok(Box::new(db))
// }
// _ => {
// Err(())
// }
// }

match db_driver {
DatabaseDriver::Sqlite3 => {
/// Connect to a database.
pub async fn connect_database(db_path: &str) -> Result<Box<dyn Database>, DatabaseError> {
match &db_path.chars().collect::<Vec<char>>() as &[char] {
['s', 'q', 'l', 'i', 't', 'e', ..] => {
let db = SqliteDatabase::new(db_path).await;
Box::new(db)
Ok(Box::new(db))
}
DatabaseDriver::Mysql => {
['m', 'y', 's', 'q', 'l', ..] => {
let db = MysqlDatabase::new(db_path).await;
Box::new(db)
Ok(Box::new(db))
}
_ => {
Err(DatabaseError::UnrecognizedDatabaseDriver)
}
}
}

/// Trait for database implementations.
#[async_trait]
pub trait Database: Sync + Send {
// return current database driver
/// Return current database driver.
fn get_database_driver(&self) -> DatabaseDriver;

// add new user and get the newly inserted user_id
/// Add new user and return the newly inserted `user_id`.
async fn insert_user_and_get_id(&self, username: &str, email: &str, password: &str) -> Result<i64, DatabaseError>;

// get user profile by user_id
/// Get `User` from `user_id`.
async fn get_user_from_id(&self, user_id: i64) -> Result<User, DatabaseError>;

// get user authentication by user_id
/// Get `UserAuthentication` from `user_id`.
async fn get_user_authentication_from_id(&self, user_id: i64) -> Result<UserAuthentication, DatabaseError>;

// get user profile by username
/// Get `UserProfile` from `username`.
async fn get_user_profile_from_username(&self, username: &str) -> Result<UserProfile, DatabaseError>;

// get user compact by user_id
/// Get `UserCompact` from `user_id`.
async fn get_user_compact_from_id(&self, user_id: i64) -> Result<UserCompact, DatabaseError>;

// todo: change to get all tracker keys of user, no matter if they are still valid
// get a user's tracker key
/// Get a user's `TrackerKey`.
async fn get_user_tracker_key(&self, user_id: i64) -> Option<TrackerKey>;

// count users
/// Get total user count.
async fn count_users(&self) -> Result<i64, DatabaseError>;

// todo: make DateTime struct for the date_expiry
// ban user
/// Ban user with `user_id`, `reason` and `date_expiry`.
async fn ban_user(&self, user_id: i64, reason: &str, date_expiry: NaiveDateTime) -> Result<(), DatabaseError>;

// give a user administrator rights
/// Grant a user the administrator role.
async fn grant_admin_role(&self, user_id: i64) -> Result<(), DatabaseError>;

// verify email
/// Verify a user's email with `user_id`.
async fn verify_email(&self, user_id: i64) -> Result<(), DatabaseError>;

// create a new tracker key for a certain user
/// Link a `TrackerKey` to a certain user with `user_id`.
async fn add_tracker_key(&self, user_id: i64, tracker_key: &TrackerKey) -> Result<(), DatabaseError>;

// delete user
/// Delete user and all related user data with `user_id`.
async fn delete_user(&self, user_id: i64) -> Result<(), DatabaseError>;

// add new category
/// Add a new category and return `category_id`.
async fn insert_category_and_get_id(&self, category_name: &str) -> Result<i64, DatabaseError>;

// get category by id
async fn get_category_from_id(&self, id: i64) -> Result<Category, DatabaseError>;
/// Get `Category` from `category_id`.
async fn get_category_from_id(&self, category_id: i64) -> Result<Category, DatabaseError>;

// get category by name
async fn get_category_from_name(&self, category: &str) -> Result<Category, DatabaseError>;
/// Get `Category` from `category_name`.
async fn get_category_from_name(&self, category_name: &str) -> Result<Category, DatabaseError>;

// get all categories
/// Get all categories as `Vec<Category>`.
async fn get_categories(&self) -> Result<Vec<Category>, DatabaseError>;

// delete category
/// Delete category with `category_name`.
async fn delete_category(&self, category_name: &str) -> Result<(), DatabaseError>;

// get results of a torrent search in a paginated and sorted form
/// Get results of a torrent search in a paginated and sorted form as `TorrentsResponse` from `search`, `categories`, `sort`, `offset` and `page_size`.
async fn get_torrents_search_sorted_paginated(&self, search: &Option<String>, categories: &Option<Vec<String>>, sort: &Sorting, offset: u64, page_size: u8) -> Result<TorrentsResponse, DatabaseError>;

// add new torrent and get the newly inserted torrent_id
/// Add new torrent and return the newly inserted `torrent_id` with `torrent`, `uploader_id`, `category_id`, `title` and `description`.
async fn insert_torrent_and_get_id(&self, torrent: &Torrent, uploader_id: i64, category_id: i64, title: &str, description: &str) -> Result<i64, DatabaseError>;

// get torrent by id
/// Get `Torrent` from `torrent_id`.
async fn get_torrent_from_id(&self, torrent_id: i64) -> Result<Torrent, DatabaseError>;

// get torrent info by id
/// Get torrent's info as `DbTorrentInfo` from `torrent_id`.
async fn get_torrent_info_from_id(&self, torrent_id: i64) -> Result<DbTorrentInfo, DatabaseError>;

// get torrent files by id
/// Get all torrent's files as `Vec<TorrentFile>` from `torrent_id`.
async fn get_torrent_files_from_id(&self, torrent_id: i64) -> Result<Vec<TorrentFile>, DatabaseError>;

// get torrent announce urls by id
/// Get all torrent's announce urls as `Vec<Vec<String>>` from `torrent_id`.
async fn get_torrent_announce_urls_from_id(&self, torrent_id: i64) -> Result<Vec<Vec<String>>, DatabaseError>;

// get torrent listing by id
/// Get `TorrentListing` from `torrent_id`.
async fn get_torrent_listing_from_id(&self, torrent_id: i64) -> Result<TorrentListing, DatabaseError>;

// get all torrents (torrent_id + info_hash)
/// Get all torrents as `Vec<TorrentCompact>`.
async fn get_all_torrents_compact(&self) -> Result<Vec<TorrentCompact>, DatabaseError>;

// update a torrent's title
/// Update a torrent's title with `torrent_id` and `title`.
async fn update_torrent_title(&self, torrent_id: i64, title: &str) -> Result<(), DatabaseError>;

// update a torrent's description
/// Update a torrent's description with `torrent_id` and `description`.
async fn update_torrent_description(&self, torrent_id: i64, description: &str) -> Result<(), DatabaseError>;

// update the seeders and leechers info for a particular torrent
/// Update the seeders and leechers info for a torrent with `torrent_id`, `tracker_url`, `seeders` and `leechers`.
async fn update_tracker_info(&self, torrent_id: i64, tracker_url: &str, seeders: i64, leechers: i64) -> Result<(), DatabaseError>;

// delete a torrent
/// Delete a torrent with `torrent_id`.
async fn delete_torrent(&self, torrent_id: i64) -> Result<(), DatabaseError>;

// DELETES ALL DATABASE ROWS, ONLY CALL THIS IF YOU KNOW WHAT YOU'RE DOING!
/// DELETES ALL DATABASE ROWS, ONLY CALL THIS IF YOU KNOW WHAT YOU'RE DOING!
async fn delete_all_database_rows(&self) -> Result<(), DatabaseError>;
}
3 changes: 2 additions & 1 deletion src/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,8 @@ impl From<DatabaseError> for ServiceError {
DatabaseError::CategoryNotFound => ServiceError::InvalidCategory,
DatabaseError::TorrentNotFound => ServiceError::TorrentNotFound,
DatabaseError::TorrentAlreadyExists => ServiceError::InfoHashAlreadyExists,
DatabaseError::TorrentTitleAlreadyExists => ServiceError::TorrentTitleAlreadyExists
DatabaseError::TorrentTitleAlreadyExists => ServiceError::TorrentTitleAlreadyExists,
DatabaseError::UnrecognizedDatabaseDriver => ServiceError::InternalServerError,
}
}
}
Expand Down
6 changes: 5 additions & 1 deletion src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,11 @@ async fn main() -> std::io::Result<()> {

let settings = cfg.settings.read().await;

let database = Arc::new(connect_database(&settings.database.db_driver, &settings.database.connect_url).await);
let database = Arc::new(connect_database(&settings.database.connect_url)
.await
.expect("Database error.")
);

let auth = Arc::new(AuthorizationService::new(cfg.clone(), database.clone()));
let tracker_service = Arc::new(TrackerService::new(cfg.clone(), database.clone()));
let mailer_service = Arc::new(MailerService::new(cfg.clone()).await);
Expand Down
10 changes: 7 additions & 3 deletions tests/databases/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use std::future::Future;
use torrust_index_backend::databases::database::{connect_database, Database, DatabaseDriver};
use torrust_index_backend::databases::database::{connect_database, Database};

mod mysql;
mod tests;
Expand All @@ -19,8 +19,12 @@ async fn run_test<'a, T, F>(db_fn: T, db: &'a Box<dyn Database>)
}

// runs all tests
pub async fn run_tests(db_driver: DatabaseDriver, db_path: &str) {
let db = connect_database(&db_driver, db_path).await;
pub async fn run_tests(db_path: &str) {
let db_res = connect_database(db_path).await;

assert!(db_res.is_ok());

let db = db_res.unwrap();

run_test(tests::it_can_add_a_user, &db).await;
run_test(tests::it_can_add_a_torrent_category, &db).await;
Expand Down
3 changes: 1 addition & 2 deletions tests/databases/mysql.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
use torrust_index_backend::databases::database::{DatabaseDriver};
use crate::databases::{run_tests};

const DATABASE_URL: &str = "mysql://root:password@localhost:3306/torrust-index_test";

#[tokio::test]
async fn run_mysql_tests() {
run_tests(DatabaseDriver::Mysql, DATABASE_URL).await;
run_tests(DATABASE_URL).await;
}


3 changes: 1 addition & 2 deletions tests/databases/sqlite.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
use torrust_index_backend::databases::database::{DatabaseDriver};
use crate::databases::{run_tests};

const DATABASE_URL: &str = "sqlite::memory:";

#[tokio::test]
async fn run_sqlite_tests() {
run_tests(DatabaseDriver::Sqlite3, DATABASE_URL).await;
run_tests(DATABASE_URL).await;
}