Skip to content

Commit

Permalink
feat: [#426] add TSL support
Browse files Browse the repository at this point in the history
You can provide a certificate and certificate key files to run the API
with HTTPs.
  • Loading branch information
josecelano committed May 15, 2024
1 parent 969ffff commit 284d235
Show file tree
Hide file tree
Showing 5 changed files with 113 additions and 22 deletions.
10 changes: 7 additions & 3 deletions src/app.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,15 @@ use crate::services::torrent::{
use crate::services::user::{self, DbBannedUserList, DbUserProfileRepository, DbUserRepository};
use crate::services::{proxy, settings, torrent};
use crate::tracker::statistics_importer::StatisticsImporter;
use crate::web::api::server::signals::Halted;
use crate::web::api::server::v1::auth::Authentication;
use crate::web::api::Version;
use crate::{console, mailer, tracker, web};

pub struct Running {
pub api_socket_addr: SocketAddr,
pub api_server: Option<JoinHandle<std::result::Result<(), std::io::Error>>>,
pub api_server: JoinHandle<std::result::Result<(), std::io::Error>>,
pub api_server_halt_task: tokio::sync::oneshot::Sender<Halted>,
pub tracker_data_importer_handle: tokio::task::JoinHandle<()>,
}

Expand Down Expand Up @@ -56,6 +58,7 @@ pub async fn run(configuration: Configuration, api_version: &Version) -> Running
// From [net] config
let net_ip = "0.0.0.0".to_string();
let net_port = settings.net.port;
let opt_net_tsl = settings.net.tsl.clone();

// IMPORTANT: drop settings before starting server to avoid read locks that
// leads to requests hanging.
Expand Down Expand Up @@ -168,12 +171,13 @@ pub async fn run(configuration: Configuration, api_version: &Version) -> Running
);

// Start API server
let running_api = web::api::start(app_data, &net_ip, net_port, api_version).await;
let running_api = web::api::start(app_data, &net_ip, net_port, opt_net_tsl, api_version).await;

// Full running application
Running {
api_socket_addr: running_api.socket_addr,
api_server: running_api.api_server,
api_server: running_api.task,
api_server_halt_task: running_api.halt_task,
tracker_data_importer_handle: tracker_statistics_importer_handle,
}
}
4 changes: 3 additions & 1 deletion src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@ async fn main() -> Result<(), std::io::Error> {

let app = app::run(configuration, &api_version).await;

assert!(!app.api_server_halt_task.is_closed(), "Halt channel should be open");

match api_version {
Version::V1 => app.api_server.unwrap().await.expect("the API server was dropped"),
Version::V1 => app.api_server.await.expect("the API server was dropped"),
}
}
16 changes: 13 additions & 3 deletions src/web/api/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@ use std::sync::Arc;

use tokio::task::JoinHandle;

use self::server::signals::Halted;
use crate::common::AppData;
use crate::config::Tsl;
use crate::web::api;

/// API versions.
Expand All @@ -26,14 +28,22 @@ pub enum Version {
pub struct Running {
/// The socket address the API server is listening on.
pub socket_addr: SocketAddr,
/// The channel sender to send halt signal to the server.
pub halt_task: tokio::sync::oneshot::Sender<Halted>,
/// The handle for the running API server.
pub api_server: Option<JoinHandle<Result<(), std::io::Error>>>,
pub task: JoinHandle<Result<(), std::io::Error>>,
}

/// Starts the API server.
#[must_use]
pub async fn start(app_data: Arc<AppData>, net_ip: &str, net_port: u16, implementation: &Version) -> api::Running {
pub async fn start(
app_data: Arc<AppData>,
net_ip: &str,
net_port: u16,
opt_tsl: Option<Tsl>,
implementation: &Version,
) -> api::Running {
match implementation {
Version::V1 => server::start(app_data, net_ip, net_port).await,
Version::V1 => server::start(app_data, net_ip, net_port, opt_tsl).await,
}
}
103 changes: 89 additions & 14 deletions src/web/api/server/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,37 +3,48 @@ pub mod signals;
pub mod v1;

use std::net::SocketAddr;
use std::panic::Location;
use std::sync::Arc;

use axum_server::tls_rustls::RustlsConfig;
use axum_server::Handle;
use log::info;
use log::{error, info};
use thiserror::Error;
use tokio::sync::oneshot::{Receiver, Sender};
use torrust_index_located_error::LocatedError;
use v1::routes::router;

use self::signals::{Halted, Started};
use super::Running;
use crate::common::AppData;
use crate::config::Tsl;
use crate::web::api::server::custom_axum::TimeoutAcceptor;
use crate::web::api::server::signals::graceful_shutdown;

pub type DynError = Arc<dyn std::error::Error + Send + Sync>;

/// Starts the API server.
///
/// # Panics
///
/// Panics if the API server can't be started.
pub async fn start(app_data: Arc<AppData>, net_ip: &str, net_port: u16) -> Running {
pub async fn start(app_data: Arc<AppData>, net_ip: &str, net_port: u16, opt_tsl: Option<Tsl>) -> Running {
let config_socket_addr: SocketAddr = format!("{net_ip}:{net_port}")
.parse()
.expect("API server socket address to be valid.");

let opt_rust_tls_config = make_rust_tls(&opt_tsl)
.await
.map(|tls| tls.expect("it should have a valid net tls configuration"));

let (tx_start, rx) = tokio::sync::oneshot::channel::<Started>();
let (_tx_halt, rx_halt) = tokio::sync::oneshot::channel::<Halted>();
let (tx_halt, rx_halt) = tokio::sync::oneshot::channel::<Halted>();

// Run the API server
let join_handle = tokio::spawn(async move {
info!("Starting API server with net config: {} ...", config_socket_addr);

start_server(config_socket_addr, app_data.clone(), tx_start, rx_halt).await;
start_server(config_socket_addr, app_data.clone(), tx_start, rx_halt, opt_rust_tls_config).await;

info!("API server stopped");

Expand All @@ -42,13 +53,18 @@ pub async fn start(app_data: Arc<AppData>, net_ip: &str, net_port: u16) -> Runni

// Wait until the API server is running
let bound_addr = match rx.await {
Ok(msg) => msg.socket_addr,
Err(e) => panic!("API server start. The API server was dropped: {e}"),
Ok(started) => started.socket_addr,
Err(err) => {
let msg = format!("Unable to start API server: {err}");
error!("{}", msg);
panic!("{}", msg);
}
};

Running {
socket_addr: bound_addr,
api_server: Some(join_handle),
halt_task: tx_halt,
task: join_handle,
}
}

Expand All @@ -57,6 +73,7 @@ async fn start_server(
app_data: Arc<AppData>,
tx_start: Sender<Started>,
rx_halt: Receiver<Halted>,
rust_tls_config: Option<RustlsConfig>,
) {
let router = router(app_data);
let socket = std::net::TcpListener::bind(config_socket_addr).expect("Could not bind tcp_listener to address.");
Expand All @@ -70,16 +87,74 @@ async fn start_server(
format!("Shutting down API server on socket address: {address}"),
));

info!("API server listening on http://{}", address); // # DevSkim: ignore DS137138
let tls = rust_tls_config.clone();
let protocol = if tls.is_some() { "https" } else { "http" };

info!("API server listening on {}://{}", protocol, address); // # DevSkim: ignore DS137138

tx_start
.send(Started { socket_addr: address })
.expect("the API server should not be dropped");

custom_axum::from_tcp_with_timeouts(socket)
.handle(handle)
.acceptor(TimeoutAcceptor)
.serve(router.into_make_service_with_connect_info::<std::net::SocketAddr>())
.await
.expect("API server should be running");
match tls {
Some(tls) => custom_axum::from_tcp_rustls_with_timeouts(socket, tls)
.handle(handle)
.acceptor(TimeoutAcceptor)
.serve(router.into_make_service_with_connect_info::<std::net::SocketAddr>())
.await
.expect("API server should be running"),
None => custom_axum::from_tcp_with_timeouts(socket)
.handle(handle)
.acceptor(TimeoutAcceptor)
.serve(router.into_make_service_with_connect_info::<std::net::SocketAddr>())
.await
.expect("API server should be running"),
};
}

#[derive(Error, Debug)]
pub enum Error {
/// Enabled tls but missing config.
#[error("tls config missing")]
MissingTlsConfig { location: &'static Location<'static> },

/// Unable to parse tls Config.
#[error("bad tls config: {source}")]
BadTlsConfig {
source: LocatedError<'static, dyn std::error::Error + Send + Sync>,
ssl_cert_path: String,
ssl_key_path: String,
},
}

pub async fn make_rust_tls(tsl_config: &Option<Tsl>) -> Option<Result<RustlsConfig, Error>> {
match tsl_config {
Some(tsl) => {
if let (Some(cert), Some(key)) = (tsl.ssl_cert_path.clone(), tsl.ssl_key_path.clone()) {
info!("Using https. Cert path: {cert}.");
info!("Using https. Key path: {key}.");

let ssl_cert_path = cert.clone().to_string();
let ssl_key_path = key.clone().to_string();

Some(
RustlsConfig::from_pem_file(cert, key)
.await
.map_err(|err| Error::BadTlsConfig {
source: (Arc::new(err) as DynError).into(),
ssl_cert_path,
ssl_key_path,
}),
)
} else {
Some(Err(Error::MissingTlsConfig {
location: Location::caller(),
}))
}
}
None => {
info!("TLS not enabled");
None
}
}
}
2 changes: 1 addition & 1 deletion tests/environments/app_starter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ impl AppStarter {
.expect("the app starter should not be dropped");

match api_version {
Version::V1 => app.api_server.unwrap().await,
Version::V1 => app.api_server.await,
}
});

Expand Down

0 comments on commit 284d235

Please sign in to comment.