diff --git a/Cargo.lock b/Cargo.lock index 94fbe7d8..4b833504 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4274,6 +4274,7 @@ dependencies = [ "form_urlencoded", "idna", "percent-encoding", + "serde", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index c22c3dd4..a65c2a74 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -80,7 +80,7 @@ tower-http = { version = "0", features = ["compression-full", "cors", "propagate trace = "0" tracing = "0" tracing-subscriber = { version = "0.3.18", features = ["json"] } -url = "2" +url = {version = "2", features = ["serde"] } uuid = { version = "1", features = ["v4"] } zerocopy = "0.7.33" diff --git a/packages/configuration/src/lib.rs b/packages/configuration/src/lib.rs index c8c91443..ca008a49 100644 --- a/packages/configuration/src/lib.rs +++ b/packages/configuration/src/lib.rs @@ -9,6 +9,7 @@ pub mod v1; use std::collections::HashMap; use std::env; use std::sync::Arc; +use std::time::Duration; use camino::Utf8PathBuf; use derive_more::Constructor; @@ -20,6 +21,10 @@ use torrust_tracker_located_error::{DynError, LocatedError}; /// The maximum number of returned peers for a torrent. pub const TORRENT_PEERS_LIMIT: usize = 74; +/// Default timeout for sending and receiving packets. And waiting for sockets +/// to be readable and writable. +pub const DEFAULT_TIMEOUT: Duration = Duration::from_secs(5); + // Environment variables /// The whole `tracker.toml` file content. It has priority over the config file. diff --git a/src/console/clients/checker/app.rs b/src/console/clients/checker/app.rs index e3bca231..9f9825d9 100644 --- a/src/console/clients/checker/app.rs +++ b/src/console/clients/checker/app.rs @@ -98,7 +98,7 @@ pub async fn run() -> Result> { console: console_printer, }; - Ok(service.run_checks().await) + service.run_checks().await.context("it should run the check tasks") } fn tracing_stdout_init(filter: LevelFilter) { diff --git a/src/console/clients/checker/checks/health.rs b/src/console/clients/checker/checks/health.rs index 47eec4cb..b1fb7914 100644 --- a/src/console/clients/checker/checks/health.rs +++ b/src/console/clients/checker/checks/health.rs @@ -1,49 +1,77 @@ +use std::sync::Arc; use std::time::Duration; -use reqwest::{Client as HttpClient, Url, Url as ServiceUrl}; +use anyhow::Result; +use hyper::StatusCode; +use reqwest::{Client as HttpClient, Response}; +use serde::Serialize; +use thiserror::Error; +use url::Url; -use super::structs::{CheckerOutput, Status}; -use crate::console::clients::checker::service::{CheckError, CheckResult}; +#[derive(Debug, Clone, Error, Serialize)] +#[serde(into = "String")] +pub enum Error { + #[error("Failed to Build a Http Client: {err:?}")] + ClientBuildingError { err: Arc }, + #[error("Heath check failed to get a response: {err:?}")] + ResponseError { err: Arc }, + #[error("Http check returned a non-success code: \"{code}\" with the response: \"{response:?}\"")] + UnsuccessfulResponse { code: StatusCode, response: Arc }, +} + +impl From for String { + fn from(value: Error) -> Self { + value.to_string() + } +} + +#[derive(Debug, Clone, Serialize)] +pub struct Checks { + url: Url, + result: Result, +} -#[allow(clippy::missing_panics_doc)] -pub async fn run(health_checks: &Vec, check_results: &mut Vec) -> Vec { - let mut health_checkers: Vec = Vec::new(); +pub async fn run(health_checks: Vec, timeout: Duration) -> Vec> { + let mut results = Vec::default(); - for health_check_url in health_checks { - let mut health_checker = CheckerOutput { - url: health_check_url.to_string(), - status: Status { - code: String::new(), - message: String::new(), - }, + tracing::debug!("Health checks ..."); + + for url in health_checks { + let result = match run_health_check(url.clone(), timeout).await { + Ok(response) => Ok(response.status().to_string()), + Err(err) => Err(err), }; - match run_health_check(health_check_url.clone()).await { - Ok(()) => { - check_results.push(Ok(())); - health_checker.status.code = "ok".to_string(); - } - Err(err) => { - check_results.push(Err(err)); - health_checker.status.code = "error".to_string(); - health_checker.status.message = "Health API is failing.".to_string(); - } + + let check = Checks { url, result }; + + if check.result.is_err() { + results.push(Err(check)); + } else { + results.push(Ok(check)); } - health_checkers.push(health_checker); } - health_checkers + + results } -async fn run_health_check(url: Url) -> Result<(), CheckError> { - let client = HttpClient::builder().timeout(Duration::from_secs(5)).build().unwrap(); +async fn run_health_check(url: Url, timeout: Duration) -> Result { + let client = HttpClient::builder() + .timeout(timeout) + .build() + .map_err(|e| Error::ClientBuildingError { err: e.into() })?; - match client.get(url.clone()).send().await { - Ok(response) => { - if response.status().is_success() { - Ok(()) - } else { - Err(CheckError::HealthCheckError { url }) - } - } - Err(_) => Err(CheckError::HealthCheckError { url }), + let response = client + .get(url.clone()) + .send() + .await + .map_err(|e| Error::ResponseError { err: e.into() })?; + + if response.status().is_success() { + Ok(response) + } else { + Err(Error::UnsuccessfulResponse { + code: response.status(), + response: response.into(), + }) } } diff --git a/src/console/clients/checker/checks/http.rs b/src/console/clients/checker/checks/http.rs index 57f8c301..bb285374 100644 --- a/src/console/clients/checker/checks/http.rs +++ b/src/console/clients/checker/checks/http.rs @@ -1,59 +1,63 @@ -use std::str::FromStr; +use std::str::FromStr as _; +use std::time::Duration; -use reqwest::Url as ServiceUrl; +use serde::Serialize; use torrust_tracker_primitives::info_hash::InfoHash; -use tracing::debug; use url::Url; -use super::structs::{CheckerOutput, Status}; -use crate::console::clients::checker::service::{CheckError, CheckResult}; -use crate::shared::bit_torrent::tracker::http::client::requests::announce::QueryBuilder; +use crate::console::clients::http::Error; use crate::shared::bit_torrent::tracker::http::client::responses::announce::Announce; use crate::shared::bit_torrent::tracker::http::client::responses::scrape; use crate::shared::bit_torrent::tracker::http::client::{requests, Client}; -#[allow(clippy::missing_panics_doc)] -pub async fn run(http_trackers: &Vec, check_results: &mut Vec) -> Vec { - let mut http_checkers: Vec = Vec::new(); - - for http_tracker in http_trackers { - let mut http_checker = CheckerOutput { - url: http_tracker.to_string(), - status: Status { - code: String::new(), - message: String::new(), - }, +#[derive(Debug, Clone, Serialize)] +pub struct Checks { + url: Url, + results: Vec<(Check, Result<(), Error>)>, +} + +#[derive(Debug, Clone, Serialize)] +pub enum Check { + Announce, + Scrape, +} + +pub async fn run(http_trackers: Vec, timeout: Duration) -> Vec> { + let mut results = Vec::default(); + + tracing::debug!("HTTP trackers ..."); + + for ref url in http_trackers { + let mut checks = Checks { + url: url.clone(), + results: Vec::default(), }; - match check_http_announce(http_tracker).await { - Ok(()) => { - check_results.push(Ok(())); - http_checker.status.code = "ok".to_string(); - } - Err(err) => { - check_results.push(Err(err)); - http_checker.status.code = "error".to_string(); - http_checker.status.message = "Announce is failing.".to_string(); - } + // Announce + { + let check = check_http_announce(url, timeout).await.map(|_| ()); + + checks.results.push((Check::Announce, check)); } - match check_http_scrape(http_tracker).await { - Ok(()) => { - check_results.push(Ok(())); - http_checker.status.code = "ok".to_string(); - } - Err(err) => { - check_results.push(Err(err)); - http_checker.status.code = "error".to_string(); - http_checker.status.message = "Scrape is failing.".to_string(); - } + // Scrape + { + let check = check_http_scrape(url, timeout).await.map(|_| ()); + + checks.results.push((Check::Scrape, check)); + } + + if checks.results.iter().any(|f| f.1.is_err()) { + results.push(Err(checks)); + } else { + results.push(Ok(checks)); } - http_checkers.push(http_checker); } - http_checkers + + results } -async fn check_http_announce(tracker_url: &Url) -> Result<(), CheckError> { +async fn check_http_announce(url: &Url, timeout: Duration) -> Result { let info_hash_str = "9c38422213e30bff212b30c360d26f9a02136422".to_string(); // # DevSkim: ignore DS173237 let info_hash = InfoHash::from_str(&info_hash_str).expect("a valid info-hash is required"); @@ -61,37 +65,28 @@ async fn check_http_announce(tracker_url: &Url) -> Result<(), CheckError> { // We should change the client to catch that error and return a `CheckError`. // Otherwise the checking process will stop. The idea is to process all checks // and return a final report. - let Ok(client) = Client::new(tracker_url.clone()) else { - return Err(CheckError::HttpError { - url: (tracker_url.to_owned()), - }); - }; - let Ok(response) = client - .announce(&QueryBuilder::with_default_values().with_info_hash(&info_hash).query()) + let client = Client::new(url.clone(), timeout).map_err(|err| Error::HttpClientError { err })?; + + let response = client + .announce( + &requests::announce::QueryBuilder::with_default_values() + .with_info_hash(&info_hash) + .query(), + ) .await - else { - return Err(CheckError::HttpError { - url: (tracker_url.to_owned()), - }); - }; - - if let Ok(body) = response.bytes().await { - if let Ok(_announce_response) = serde_bencode::from_bytes::(&body) { - Ok(()) - } else { - debug!("announce body {:#?}", body); - Err(CheckError::HttpError { - url: tracker_url.clone(), - }) - } - } else { - Err(CheckError::HttpError { - url: tracker_url.clone(), - }) - } + .map_err(|err| Error::HttpClientError { err })?; + + let response = response.bytes().await.map_err(|e| Error::ResponseError { err: e.into() })?; + + let response = serde_bencode::from_bytes::(&response).map_err(|e| Error::ParseBencodeError { + data: response, + err: e.into(), + })?; + + Ok(response) } -async fn check_http_scrape(url: &Url) -> Result<(), CheckError> { +async fn check_http_scrape(url: &Url, timeout: Duration) -> Result { let info_hashes: Vec = vec!["9c38422213e30bff212b30c360d26f9a02136422".to_string()]; // # DevSkim: ignore DS173237 let query = requests::scrape::Query::try_from(info_hashes).expect("a valid array of info-hashes is required"); @@ -100,21 +95,16 @@ async fn check_http_scrape(url: &Url) -> Result<(), CheckError> { // Otherwise the checking process will stop. The idea is to process all checks // and return a final report. - let Ok(client) = Client::new(url.clone()) else { - return Err(CheckError::HttpError { url: (url.to_owned()) }); - }; - let Ok(response) = client.scrape(&query).await else { - return Err(CheckError::HttpError { url: (url.to_owned()) }); - }; + let client = Client::new(url.clone(), timeout).map_err(|err| Error::HttpClientError { err })?; - if let Ok(body) = response.bytes().await { - if let Ok(_scrape_response) = scrape::Response::try_from_bencoded(&body) { - Ok(()) - } else { - debug!("scrape body {:#?}", body); - Err(CheckError::HttpError { url: url.clone() }) - } - } else { - Err(CheckError::HttpError { url: url.clone() }) - } + let response = client.scrape(&query).await.map_err(|err| Error::HttpClientError { err })?; + + let response = response.bytes().await.map_err(|e| Error::ResponseError { err: e.into() })?; + + let response = scrape::Response::try_from_bencoded(&response).map_err(|e| Error::BencodeParseError { + data: response, + err: e.into(), + })?; + + Ok(response) } diff --git a/src/console/clients/checker/checks/udp.rs b/src/console/clients/checker/checks/udp.rs index 072aa5ca..dd4d5e63 100644 --- a/src/console/clients/checker/checks/udp.rs +++ b/src/console/clients/checker/checks/udp.rs @@ -1,94 +1,98 @@ use std::net::SocketAddr; +use std::time::Duration; -use aquatic_udp_protocol::{Port, TransactionId}; +use aquatic_udp_protocol::TransactionId; use hex_literal::hex; +use serde::Serialize; use torrust_tracker_primitives::info_hash::InfoHash; -use tracing::debug; -use crate::console::clients::checker::checks::structs::{CheckerOutput, Status}; -use crate::console::clients::checker::service::{CheckError, CheckResult}; -use crate::console::clients::udp::checker; +use crate::console::clients::udp::checker::Client; +use crate::console::clients::udp::Error; -const ASSIGNED_BY_OS: u16 = 0; -const RANDOM_TRANSACTION_ID: i32 = -888_840_697; - -#[allow(clippy::missing_panics_doc)] -pub async fn run(udp_trackers: &Vec, check_results: &mut Vec) -> Vec { - let mut udp_checkers: Vec = Vec::new(); - - for udp_tracker in udp_trackers { - let mut checker_output = CheckerOutput { - url: udp_tracker.to_string(), - status: Status { - code: String::new(), - message: String::new(), - }, - }; +#[derive(Debug, Clone, Serialize)] +pub struct Checks { + remote_addr: SocketAddr, + results: Vec<(Check, Result<(), Error>)>, +} - debug!("UDP tracker: {:?}", udp_tracker); +#[derive(Debug, Clone, Serialize)] +pub enum Check { + Setup, + Connect, + Announce, + Scrape, +} - let transaction_id = TransactionId::new(RANDOM_TRANSACTION_ID); +#[allow(clippy::missing_panics_doc)] +pub async fn run(udp_trackers: Vec, timeout: Duration) -> Vec> { + let mut results = Vec::default(); - let mut client = checker::Client::default(); + tracing::debug!("UDP trackers ..."); - debug!("Bind and connect"); + let info_hash = InfoHash(hex!("9c38422213e30bff212b30c360d26f9a02136422")); // # DevSkim: ignore DS173237 - let Ok(bound_to) = client.bind_and_connect(ASSIGNED_BY_OS, udp_tracker).await else { - check_results.push(Err(CheckError::UdpError { - socket_addr: *udp_tracker, - })); - checker_output.status.code = "error".to_string(); - checker_output.status.message = "Can't connect to socket.".to_string(); - break; + for remote_addr in udp_trackers { + let mut checks = Checks { + remote_addr, + results: Vec::default(), }; - debug!("Send connection request"); - - let Ok(connection_id) = client.send_connection_request(transaction_id).await else { - check_results.push(Err(CheckError::UdpError { - socket_addr: *udp_tracker, - })); - checker_output.status.code = "error".to_string(); - checker_output.status.message = "Can't make tracker connection request.".to_string(); - break; + tracing::debug!("UDP tracker: {:?}", remote_addr); + + // Setup + let client = match Client::new(remote_addr, timeout).await { + Ok(client) => { + checks.results.push((Check::Setup, Ok(()))); + client + } + Err(err) => { + checks.results.push((Check::Setup, Err(err))); + results.push(Err(checks)); + break; + } }; - let info_hash = InfoHash(hex!("9c38422213e30bff212b30c360d26f9a02136422")); // # DevSkim: ignore DS173237 - - debug!("Send announce request"); + let transaction_id = TransactionId::new(1); + + // Connect Remote + let connection_id = match client.send_connection_request(transaction_id).await { + Ok(connection_id) => { + checks.results.push((Check::Connect, Ok(()))); + connection_id + } + Err(err) => { + checks.results.push((Check::Connect, Err(err))); + results.push(Err(checks)); + break; + } + }; - if (client - .send_announce_request(connection_id, transaction_id, info_hash, Port(bound_to.port().into())) - .await) - .is_ok() + // Announce { - check_results.push(Ok(())); - checker_output.status.code = "ok".to_string(); - } else { - let err = CheckError::UdpError { - socket_addr: *udp_tracker, - }; - check_results.push(Err(err)); - checker_output.status.code = "error".to_string(); - checker_output.status.message = "Announce is failing.".to_string(); + let check = client + .send_announce_request(transaction_id, connection_id, info_hash) + .await + .map(|_| ()); + + checks.results.push((Check::Announce, check)); } - debug!("Send scrape request"); + // Scrape + { + let check = client + .send_scrape_request(connection_id, transaction_id, &[info_hash]) + .await + .map(|_| ()); - let info_hashes = vec![InfoHash(hex!("9c38422213e30bff212b30c360d26f9a02136422"))]; // # DevSkim: ignore DS173237 + checks.results.push((Check::Announce, check)); + } - if (client.send_scrape_request(connection_id, transaction_id, info_hashes).await).is_ok() { - check_results.push(Ok(())); - checker_output.status.code = "ok".to_string(); + if checks.results.iter().any(|f| f.1.is_err()) { + results.push(Err(checks)); } else { - let err = CheckError::UdpError { - socket_addr: *udp_tracker, - }; - check_results.push(Err(err)); - checker_output.status.code = "error".to_string(); - checker_output.status.message = "Scrape is failing.".to_string(); + results.push(Ok(checks)); } - udp_checkers.push(checker_output); } - udp_checkers + + results } diff --git a/src/console/clients/checker/service.rs b/src/console/clients/checker/service.rs index 16483e92..acd312d8 100644 --- a/src/console/clients/checker/service.rs +++ b/src/console/clients/checker/service.rs @@ -1,9 +1,11 @@ -use std::net::SocketAddr; use std::sync::Arc; -use reqwest::Url; +use futures::FutureExt as _; +use serde::Serialize; +use tokio::task::{JoinError, JoinSet}; +use torrust_tracker_configuration::DEFAULT_TIMEOUT; -use super::checks::{self}; +use super::checks::{health, http, udp}; use super::config::Configuration; use super::console::Console; use crate::console::clients::checker::printer::Printer; @@ -13,33 +15,48 @@ pub struct Service { pub(crate) console: Console, } -pub type CheckResult = Result<(), CheckError>; - -#[derive(Debug)] -pub enum CheckError { - UdpError { socket_addr: SocketAddr }, - HttpError { url: Url }, - HealthCheckError { url: Url }, +#[derive(Debug, Clone, Serialize)] +pub enum CheckResult { + Udp(Result), + Http(Result), + Health(Result), } impl Service { /// # Errors /// - /// Will return OK is all checks pass or an array with the check errors. - #[allow(clippy::missing_panics_doc)] - pub async fn run_checks(&self) -> Vec { - let mut check_results = vec![]; - - let udp_checkers = checks::udp::run(&self.config.udp_trackers, &mut check_results).await; - - let http_checkers = checks::http::run(&self.config.http_trackers, &mut check_results).await; - - let health_checkers = checks::health::run(&self.config.health_checks, &mut check_results).await; - - let json_output = - serde_json::json!({ "udp_trackers": udp_checkers, "http_trackers": http_checkers, "health_checks": health_checkers }); - self.console.println(&serde_json::to_string_pretty(&json_output).unwrap()); - - check_results + /// It will return an error if some of the tests panic or otherwise fail to run. + /// On success it will return a vector of `Ok(())` of [`CheckResult`]. + /// + /// # Panics + /// + /// It would panic if `serde_json` produces invalid json for the `to_string_pretty` function. + pub async fn run_checks(self) -> Result, JoinError> { + tracing::info!("Running checks for trackers ..."); + + let mut check_results = Vec::default(); + + let mut checks = JoinSet::new(); + checks.spawn( + udp::run(self.config.udp_trackers.clone(), DEFAULT_TIMEOUT).map(|mut f| f.drain(..).map(CheckResult::Udp).collect()), + ); + checks.spawn( + http::run(self.config.http_trackers.clone(), DEFAULT_TIMEOUT) + .map(|mut f| f.drain(..).map(CheckResult::Http).collect()), + ); + checks.spawn( + health::run(self.config.health_checks.clone(), DEFAULT_TIMEOUT) + .map(|mut f| f.drain(..).map(CheckResult::Health).collect()), + ); + + while let Some(results) = checks.join_next().await { + check_results.append(&mut results?); + } + + let json_output = serde_json::json!(check_results); + self.console + .println(&serde_json::to_string_pretty(&json_output).expect("it should consume valid json")); + + Ok(check_results) } } diff --git a/src/console/clients/http/app.rs b/src/console/clients/http/app.rs index 8fc9db0c..a54db5f8 100644 --- a/src/console/clients/http/app.rs +++ b/src/console/clients/http/app.rs @@ -14,10 +14,12 @@ //! cargo run --bin http_tracker_client scrape http://127.0.0.1:7070 9c38422213e30bff212b30c360d26f9a02136422 | jq //! ``` use std::str::FromStr; +use std::time::Duration; use anyhow::Context; use clap::{Parser, Subcommand}; use reqwest::Url; +use torrust_tracker_configuration::DEFAULT_TIMEOUT; use torrust_tracker_primitives::info_hash::InfoHash; use crate::shared::bit_torrent::tracker::http::client::requests::announce::QueryBuilder; @@ -46,25 +48,25 @@ pub async fn run() -> anyhow::Result<()> { match args.command { Command::Announce { tracker_url, info_hash } => { - announce_command(tracker_url, info_hash).await?; + announce_command(tracker_url, info_hash, DEFAULT_TIMEOUT).await?; } Command::Scrape { tracker_url, info_hashes, } => { - scrape_command(&tracker_url, &info_hashes).await?; + scrape_command(&tracker_url, &info_hashes, DEFAULT_TIMEOUT).await?; } } Ok(()) } -async fn announce_command(tracker_url: String, info_hash: String) -> anyhow::Result<()> { +async fn announce_command(tracker_url: String, info_hash: String, timeout: Duration) -> anyhow::Result<()> { let base_url = Url::parse(&tracker_url).context("failed to parse HTTP tracker base URL")?; let info_hash = InfoHash::from_str(&info_hash).expect("Invalid infohash. Example infohash: `9c38422213e30bff212b30c360d26f9a02136422`"); - let response = Client::new(base_url)? + let response = Client::new(base_url, timeout)? .announce(&QueryBuilder::with_default_values().with_info_hash(&info_hash).query()) .await?; @@ -80,12 +82,12 @@ async fn announce_command(tracker_url: String, info_hash: String) -> anyhow::Res Ok(()) } -async fn scrape_command(tracker_url: &str, info_hashes: &[String]) -> anyhow::Result<()> { +async fn scrape_command(tracker_url: &str, info_hashes: &[String], timeout: Duration) -> anyhow::Result<()> { let base_url = Url::parse(tracker_url).context("failed to parse HTTP tracker base URL")?; let query = requests::scrape::Query::try_from(info_hashes).context("failed to parse infohashes")?; - let response = Client::new(base_url)?.scrape(&query).await?; + let response = Client::new(base_url, timeout)?.scrape(&query).await?; let body = response.bytes().await?; diff --git a/src/console/clients/http/mod.rs b/src/console/clients/http/mod.rs index 309be628..eaa71957 100644 --- a/src/console/clients/http/mod.rs +++ b/src/console/clients/http/mod.rs @@ -1 +1,36 @@ +use std::sync::Arc; + +use serde::Serialize; +use thiserror::Error; + +use crate::shared::bit_torrent::tracker::http::client::responses::scrape::BencodeParseError; + pub mod app; + +#[derive(Debug, Clone, Error, Serialize)] +#[serde(into = "String")] +pub enum Error { + #[error("Http request did not receive a response within the timeout: {err:?}")] + HttpClientError { + err: crate::shared::bit_torrent::tracker::http::client::Error, + }, + #[error("Http failed to get a response at all: {err:?}")] + ResponseError { err: Arc }, + #[error("Failed to deserialize the bencoded response data with the error: \"{err:?}\"")] + ParseBencodeError { + data: hyper::body::Bytes, + err: Arc, + }, + + #[error("Failed to deserialize the bencoded response data with the error: \"{err:?}\"")] + BencodeParseError { + data: hyper::body::Bytes, + err: Arc, + }, +} + +impl From for String { + fn from(value: Error) -> Self { + value.to_string() + } +} diff --git a/src/console/clients/udp/app.rs b/src/console/clients/udp/app.rs index 51d21b51..bcba3955 100644 --- a/src/console/clients/udp/app.rs +++ b/src/console/clients/udp/app.rs @@ -60,18 +60,19 @@ use std::net::{SocketAddr, ToSocketAddrs}; use std::str::FromStr; use anyhow::Context; -use aquatic_udp_protocol::{Port, Response, TransactionId}; +use aquatic_udp_protocol::{Response, TransactionId}; use clap::{Parser, Subcommand}; +use torrust_tracker_configuration::DEFAULT_TIMEOUT; use torrust_tracker_primitives::info_hash::InfoHash as TorrustInfoHash; use tracing::debug; use tracing::level_filters::LevelFilter; use url::Url; +use super::Error; use crate::console::clients::udp::checker; use crate::console::clients::udp::responses::dto::SerializableResponse; use crate::console::clients::udp::responses::json::ToJson; -const ASSIGNED_BY_OS: u16 = 0; const RANDOM_TRANSACTION_ID: i32 = -888_840_697; #[derive(Parser, Debug)] @@ -109,13 +110,13 @@ pub async fn run() -> anyhow::Result<()> { let response = match args.command { Command::Announce { - tracker_socket_addr, + tracker_socket_addr: remote_addr, info_hash, - } => handle_announce(&tracker_socket_addr, &info_hash).await?, + } => handle_announce(remote_addr, &info_hash).await?, Command::Scrape { - tracker_socket_addr, + tracker_socket_addr: remote_addr, info_hashes, - } => handle_scrape(&tracker_socket_addr, &info_hashes).await?, + } => handle_scrape(remote_addr, &info_hashes).await?, }; let response: SerializableResponse = response.into(); @@ -131,32 +132,24 @@ fn tracing_stdout_init(filter: LevelFilter) { debug!("logging initialized."); } -async fn handle_announce(tracker_socket_addr: &SocketAddr, info_hash: &TorrustInfoHash) -> anyhow::Result { +async fn handle_announce(remote_addr: SocketAddr, info_hash: &TorrustInfoHash) -> Result { let transaction_id = TransactionId::new(RANDOM_TRANSACTION_ID); - let mut client = checker::Client::default(); - - let bound_to = client.bind_and_connect(ASSIGNED_BY_OS, tracker_socket_addr).await?; + let client = checker::Client::new(remote_addr, DEFAULT_TIMEOUT).await?; let connection_id = client.send_connection_request(transaction_id).await?; - client - .send_announce_request(connection_id, transaction_id, *info_hash, Port(bound_to.port().into())) - .await + client.send_announce_request(transaction_id, connection_id, *info_hash).await } -async fn handle_scrape(tracker_socket_addr: &SocketAddr, info_hashes: &[TorrustInfoHash]) -> anyhow::Result { +async fn handle_scrape(remote_addr: SocketAddr, info_hashes: &[TorrustInfoHash]) -> Result { let transaction_id = TransactionId::new(RANDOM_TRANSACTION_ID); - let mut client = checker::Client::default(); - - let _bound_to = client.bind_and_connect(ASSIGNED_BY_OS, tracker_socket_addr).await?; + let client = checker::Client::new(remote_addr, DEFAULT_TIMEOUT).await?; let connection_id = client.send_connection_request(transaction_id).await?; - client - .send_scrape_request(connection_id, transaction_id, info_hashes.to_vec()) - .await + client.send_scrape_request(connection_id, transaction_id, info_hashes).await } fn parse_socket_addr(tracker_socket_addr_str: &str) -> anyhow::Result { diff --git a/src/console/clients/udp/checker.rs b/src/console/clients/udp/checker.rs index afde63d1..49f0ac41 100644 --- a/src/console/clients/udp/checker.rs +++ b/src/console/clients/udp/checker.rs @@ -1,99 +1,46 @@ use std::net::{Ipv4Addr, SocketAddr}; +use std::num::NonZeroU16; +use std::time::Duration; -use anyhow::Context; use aquatic_udp_protocol::common::InfoHash; use aquatic_udp_protocol::{ AnnounceActionPlaceholder, AnnounceEvent, AnnounceRequest, ConnectRequest, ConnectionId, NumberOfBytes, NumberOfPeers, PeerId, PeerKey, Port, Response, ScrapeRequest, TransactionId, }; -use thiserror::Error; use torrust_tracker_primitives::info_hash::InfoHash as TorrustInfoHash; use tracing::debug; -use crate::shared::bit_torrent::tracker::udp::client::{UdpClient, UdpTrackerClient}; - -#[derive(Error, Debug)] -pub enum ClientError { - #[error("Local socket address is not bound yet. Try binding before connecting.")] - NotBound, - #[error("Not connected to remote tracker UDP socket. Try connecting before making requests.")] - NotConnected, - #[error("Unexpected response while connecting the the remote server.")] - UnexpectedConnectionResponse, -} +use super::Error; +use crate::shared::bit_torrent::tracker::udp::client::UdpTrackerClient; /// A UDP Tracker client to make test requests (checks). -#[derive(Debug, Default)] +#[derive(Debug)] pub struct Client { - /// Local UDP socket. It could be 0 to assign a free port. - local_binding_address: Option, - - /// Local UDP socket after binding. It's equals to binding address if a - /// non- zero port was used. - local_bound_address: Option, - - /// Remote UDP tracker socket - remote_socket: Option, - - /// The client used to make UDP requests to the tracker. - udp_tracker_client: Option, + client: UdpTrackerClient, } impl Client { - /// Binds to the local socket and connects to the remote one. + /// Creates a new `[Client]` for checking a UDP Tracker Service /// /// # Errors /// - /// Will return an error if - /// - /// - It can't bound to the local socket address. - /// - It can't make a connection request successfully to the remote UDP server. - pub async fn bind_and_connect(&mut self, local_port: u16, remote_socket_addr: &SocketAddr) -> anyhow::Result { - let bound_to = self.bind(local_port).await?; - self.connect(remote_socket_addr).await?; - Ok(bound_to) - } - - /// Binds local client socket. - /// - /// # Errors + /// It will error if unable to bind and connect to the udp remote address. /// - /// Will return an error if it can't bound to the local address. - async fn bind(&mut self, local_port: u16) -> anyhow::Result { - let local_bind_to = format!("0.0.0.0:{local_port}"); - let binding_address = local_bind_to.parse().context("binding local address")?; - - debug!("Binding to: {local_bind_to}"); - let udp_client = UdpClient::bind(&local_bind_to).await?; - - let bound_to = udp_client.socket.local_addr().context("bound local address")?; - debug!("Bound to: {bound_to}"); - - self.local_binding_address = Some(binding_address); - self.local_bound_address = Some(bound_to); - - self.udp_tracker_client = Some(UdpTrackerClient { udp_client }); + pub async fn new(remote_addr: SocketAddr, timeout: Duration) -> Result { + let client = UdpTrackerClient::new(remote_addr, timeout) + .await + .map_err(|err| Error::UnableToBindAndConnect { remote_addr, err })?; - Ok(bound_to) + Ok(Self { client }) } - /// Connects to the remote server socket. + /// Returns the local addr of this [`Client`]. /// /// # Errors /// - /// Will return and error if it can't make a connection request successfully - /// to the remote UDP server. - async fn connect(&mut self, tracker_socket_addr: &SocketAddr) -> anyhow::Result<()> { - debug!("Connecting to tracker: udp://{tracker_socket_addr}"); - - match &self.udp_tracker_client { - Some(client) => { - client.udp_client.connect(&tracker_socket_addr.to_string()).await?; - self.remote_socket = Some(*tracker_socket_addr); - Ok(()) - } - None => Err(ClientError::NotBound.into()), - } + /// This function will return an error if the socket is somehow not bound. + pub fn local_addr(&self) -> std::io::Result { + self.client.client.socket.local_addr() } /// Sends a connection request to the UDP Tracker server. @@ -109,25 +56,26 @@ impl Client { /// # Panics /// /// Will panic if it receives an unexpected response. - pub async fn send_connection_request(&self, transaction_id: TransactionId) -> anyhow::Result { + pub async fn send_connection_request(&self, transaction_id: TransactionId) -> Result { debug!("Sending connection request with transaction id: {transaction_id:#?}"); let connect_request = ConnectRequest { transaction_id }; - match &self.udp_tracker_client { - Some(client) => { - client.send(connect_request.into()).await?; - - let response = client.receive().await?; - - debug!("connection request response:\n{response:#?}"); - - match response { - Response::Connect(connect_response) => Ok(connect_response.connection_id), - _ => Err(ClientError::UnexpectedConnectionResponse.into()), - } - } - None => Err(ClientError::NotConnected.into()), + let _ = self + .client + .send(connect_request.into()) + .await + .map_err(|err| Error::UnableToSendConnectionRequest { err })?; + + let response = self + .client + .receive() + .await + .map_err(|err| Error::UnableToReceiveConnectResponse { err })?; + + match response { + Response::Connect(connect_response) => Ok(connect_response.connection_id), + _ => Err(Error::UnexpectedConnectionResponse { response }), } } @@ -137,15 +85,28 @@ impl Client { /// /// Will return and error if the client is not connected. You have to connect /// before calling this function. + /// + /// # Panics + /// + /// It will panic if the `local_address` has a zero port. pub async fn send_announce_request( &self, - connection_id: ConnectionId, transaction_id: TransactionId, + connection_id: ConnectionId, info_hash: TorrustInfoHash, - client_port: Port, - ) -> anyhow::Result { + ) -> Result { debug!("Sending announce request with transaction id: {transaction_id:#?}"); + let port = NonZeroU16::new( + self.client + .client + .socket + .local_addr() + .expect("it should get the local address") + .port(), + ) + .expect("it should no be zero"); + let announce_request = AnnounceRequest { connection_id, action_placeholder: AnnounceActionPlaceholder::default(), @@ -159,21 +120,22 @@ impl Client { ip_address: Ipv4Addr::new(0, 0, 0, 0).into(), key: PeerKey::new(0i32), peers_wanted: NumberOfPeers(1i32.into()), - port: client_port, + port: Port::new(port), }; - match &self.udp_tracker_client { - Some(client) => { - client.send(announce_request.into()).await?; - - let response = client.receive().await?; + let _ = self + .client + .send(announce_request.into()) + .await + .map_err(|err| Error::UnableToSendAnnounceRequest { err })?; - debug!("announce request response:\n{response:#?}"); + let response = self + .client + .receive() + .await + .map_err(|err| Error::UnableToReceiveAnnounceResponse { err })?; - Ok(response) - } - None => Err(ClientError::NotConnected.into()), - } + Ok(response) } /// Sends a scrape request to the UDP Tracker server. @@ -186,8 +148,8 @@ impl Client { &self, connection_id: ConnectionId, transaction_id: TransactionId, - info_hashes: Vec, - ) -> anyhow::Result { + info_hashes: &[TorrustInfoHash], + ) -> Result { debug!("Sending scrape request with transaction id: {transaction_id:#?}"); let scrape_request = ScrapeRequest { @@ -199,17 +161,18 @@ impl Client { .collect(), }; - match &self.udp_tracker_client { - Some(client) => { - client.send(scrape_request.into()).await?; - - let response = client.receive().await?; + let _ = self + .client + .send(scrape_request.into()) + .await + .map_err(|err| Error::UnableToSendScrapeRequest { err })?; - debug!("scrape request response:\n{response:#?}"); + let response = self + .client + .receive() + .await + .map_err(|err| Error::UnableToReceiveScrapeResponse { err })?; - Ok(response) - } - None => Err(ClientError::NotConnected.into()), - } + Ok(response) } } diff --git a/src/console/clients/udp/mod.rs b/src/console/clients/udp/mod.rs index 2fcb26ed..b92bed09 100644 --- a/src/console/clients/udp/mod.rs +++ b/src/console/clients/udp/mod.rs @@ -1,3 +1,51 @@ +use std::net::SocketAddr; + +use aquatic_udp_protocol::Response; +use serde::Serialize; +use thiserror::Error; + +use crate::shared::bit_torrent::tracker::udp; + pub mod app; pub mod checker; pub mod responses; + +#[derive(Error, Debug, Clone, Serialize)] +#[serde(into = "String")] +pub enum Error { + #[error("Failed to Connect to: {remote_addr}, with error: {err}")] + UnableToBindAndConnect { remote_addr: SocketAddr, err: udp::Error }, + + #[error("Failed to send a connection request, with error: {err}")] + UnableToSendConnectionRequest { err: udp::Error }, + + #[error("Failed to receive a connect response, with error: {err}")] + UnableToReceiveConnectResponse { err: udp::Error }, + + #[error("Failed to send a announce request, with error: {err}")] + UnableToSendAnnounceRequest { err: udp::Error }, + + #[error("Failed to receive a announce response, with error: {err}")] + UnableToReceiveAnnounceResponse { err: udp::Error }, + + #[error("Failed to send a scrape request, with error: {err}")] + UnableToSendScrapeRequest { err: udp::Error }, + + #[error("Failed to receive a scrape response, with error: {err}")] + UnableToReceiveScrapeResponse { err: udp::Error }, + + #[error("Failed to receive a response, with error: {err}")] + UnableToReceiveResponse { err: udp::Error }, + + #[error("Failed to get local address for connection: {err}")] + UnableToGetLocalAddr { err: udp::Error }, + + #[error("Failed to get a connection response: {response:?}")] + UnexpectedConnectionResponse { response: Response }, +} + +impl From for String { + fn from(value: Error) -> Self { + value.to_string() + } +} diff --git a/src/servers/apis/routes.rs b/src/servers/apis/routes.rs index 2001afc2..4901d760 100644 --- a/src/servers/apis/routes.rs +++ b/src/servers/apis/routes.rs @@ -14,7 +14,7 @@ use axum::response::Response; use axum::routing::get; use axum::{middleware, BoxError, Router}; use hyper::{Request, StatusCode}; -use torrust_tracker_configuration::AccessTokens; +use torrust_tracker_configuration::{AccessTokens, DEFAULT_TIMEOUT}; use tower::timeout::TimeoutLayer; use tower::ServiceBuilder; use tower_http::compression::CompressionLayer; @@ -29,8 +29,6 @@ use super::v1::middlewares::auth::State; use crate::core::Tracker; use crate::servers::apis::API_LOG_TARGET; -const TIMEOUT: Duration = Duration::from_secs(5); - /// Add all API routes to the router. #[allow(clippy::needless_pass_by_value)] pub fn router(tracker: Arc, access_tokens: Arc) -> Router { @@ -84,6 +82,6 @@ pub fn router(tracker: Arc, access_tokens: Arc) -> Router // this middleware goes above `TimeoutLayer` because it will receive // errors returned by `TimeoutLayer` .layer(HandleErrorLayer::new(|_: BoxError| async { StatusCode::REQUEST_TIMEOUT })) - .layer(TimeoutLayer::new(TIMEOUT)), + .layer(TimeoutLayer::new(DEFAULT_TIMEOUT)), ) } diff --git a/src/servers/http/v1/routes.rs b/src/servers/http/v1/routes.rs index b2f37880..c24797c4 100644 --- a/src/servers/http/v1/routes.rs +++ b/src/servers/http/v1/routes.rs @@ -10,6 +10,7 @@ use axum::routing::get; use axum::{BoxError, Router}; use axum_client_ip::SecureClientIpSource; use hyper::{Request, StatusCode}; +use torrust_tracker_configuration::DEFAULT_TIMEOUT; use tower::timeout::TimeoutLayer; use tower::ServiceBuilder; use tower_http::compression::CompressionLayer; @@ -22,8 +23,6 @@ use super::handlers::{announce, health_check, scrape}; use crate::core::Tracker; use crate::servers::http::HTTP_TRACKER_LOG_TARGET; -const TIMEOUT: Duration = Duration::from_secs(5); - /// It adds the routes to the router. /// /// > **NOTICE**: it's added a layer to get the client IP from the connection @@ -80,6 +79,6 @@ pub fn router(tracker: Arc, server_socket_addr: SocketAddr) -> Router { // this middleware goes above `TimeoutLayer` because it will receive // errors returned by `TimeoutLayer` .layer(HandleErrorLayer::new(|_: BoxError| async { StatusCode::REQUEST_TIMEOUT })) - .layer(TimeoutLayer::new(TIMEOUT)), + .layer(TimeoutLayer::new(DEFAULT_TIMEOUT)), ) } diff --git a/src/shared/bit_torrent/tracker/http/client/mod.rs b/src/shared/bit_torrent/tracker/http/client/mod.rs index f5b1b331..4c70cd68 100644 --- a/src/shared/bit_torrent/tracker/http/client/mod.rs +++ b/src/shared/bit_torrent/tracker/http/client/mod.rs @@ -2,18 +2,30 @@ pub mod requests; pub mod responses; use std::net::IpAddr; +use std::sync::Arc; +use std::time::Duration; -use anyhow::{anyhow, Result}; -use requests::announce::{self, Query}; -use requests::scrape; -use reqwest::{Client as ReqwestClient, Response, Url}; +use hyper::StatusCode; +use requests::{announce, scrape}; +use reqwest::{Response, Url}; +use thiserror::Error; use crate::core::auth::Key; +#[derive(Debug, Clone, Error)] +pub enum Error { + #[error("Failed to Build a Http Client: {err:?}")] + ClientBuildingError { err: Arc }, + #[error("Failed to get a response: {err:?}")] + ResponseError { err: Arc }, + #[error("Returned a non-success code: \"{code}\" with the response: \"{response:?}\"")] + UnsuccessfulResponse { code: StatusCode, response: Arc }, +} + /// HTTP Tracker Client pub struct Client { + client: reqwest::Client, base_url: Url, - reqwest: ReqwestClient, key: Option, } @@ -29,11 +41,15 @@ impl Client { /// # Errors /// /// This method fails if the client builder fails. - pub fn new(base_url: Url) -> Result { - let reqwest = reqwest::Client::builder().build()?; + pub fn new(base_url: Url, timeout: Duration) -> Result { + let client = reqwest::Client::builder() + .timeout(timeout) + .build() + .map_err(|e| Error::ClientBuildingError { err: e.into() })?; + Ok(Self { base_url, - reqwest, + client, key: None, }) } @@ -43,11 +59,16 @@ impl Client { /// # Errors /// /// This method fails if the client builder fails. - pub fn bind(base_url: Url, local_address: IpAddr) -> Result { - let reqwest = reqwest::Client::builder().local_address(local_address).build()?; + pub fn bind(base_url: Url, timeout: Duration, local_address: IpAddr) -> Result { + let client = reqwest::Client::builder() + .timeout(timeout) + .local_address(local_address) + .build() + .map_err(|e| Error::ClientBuildingError { err: e.into() })?; + Ok(Self { base_url, - reqwest, + client, key: None, }) } @@ -55,54 +76,106 @@ impl Client { /// # Errors /// /// This method fails if the client builder fails. - pub fn authenticated(base_url: Url, key: Key) -> Result { - let reqwest = reqwest::Client::builder().build()?; + pub fn authenticated(base_url: Url, timeout: Duration, key: Key) -> Result { + let client = reqwest::Client::builder() + .timeout(timeout) + .build() + .map_err(|e| Error::ClientBuildingError { err: e.into() })?; + Ok(Self { base_url, - reqwest, + client, key: Some(key), }) } /// # Errors - pub async fn announce(&self, query: &announce::Query) -> Result { - self.get(&self.build_announce_path_and_query(query)).await + /// + /// This method fails if the returned response was not successful + pub async fn announce(&self, query: &announce::Query) -> Result { + let response = self.get(&self.build_announce_path_and_query(query)).await?; + + if response.status().is_success() { + Ok(response) + } else { + Err(Error::UnsuccessfulResponse { + code: response.status(), + response: response.into(), + }) + } } /// # Errors - pub async fn scrape(&self, query: &scrape::Query) -> Result { - self.get(&self.build_scrape_path_and_query(query)).await + /// + /// This method fails if the returned response was not successful + pub async fn scrape(&self, query: &scrape::Query) -> Result { + let response = self.get(&self.build_scrape_path_and_query(query)).await?; + + if response.status().is_success() { + Ok(response) + } else { + Err(Error::UnsuccessfulResponse { + code: response.status(), + response: response.into(), + }) + } } /// # Errors - pub async fn announce_with_header(&self, query: &Query, key: &str, value: &str) -> Result { - self.get_with_header(&self.build_announce_path_and_query(query), key, value) - .await + /// + /// This method fails if the returned response was not successful + pub async fn announce_with_header(&self, query: &announce::Query, key: &str, value: &str) -> Result { + let response = self + .get_with_header(&self.build_announce_path_and_query(query), key, value) + .await?; + + if response.status().is_success() { + Ok(response) + } else { + Err(Error::UnsuccessfulResponse { + code: response.status(), + response: response.into(), + }) + } } /// # Errors - pub async fn health_check(&self) -> Result { - self.get(&self.build_path("health_check")).await + /// + /// This method fails if the returned response was not successful + pub async fn health_check(&self) -> Result { + let response = self.get(&self.build_path("health_check")).await?; + + if response.status().is_success() { + Ok(response) + } else { + Err(Error::UnsuccessfulResponse { + code: response.status(), + response: response.into(), + }) + } } /// # Errors /// /// This method fails if there was an error while sending request. - pub async fn get(&self, path: &str) -> Result { - match self.reqwest.get(self.build_url(path)).send().await { - Ok(response) => Ok(response), - Err(err) => Err(anyhow!("{err}")), - } + pub async fn get(&self, path: &str) -> Result { + self.client + .get(self.build_url(path)) + .send() + .await + .map_err(|e| Error::ResponseError { err: e.into() }) } /// # Errors /// /// This method fails if there was an error while sending request. - pub async fn get_with_header(&self, path: &str, key: &str, value: &str) -> Result { - match self.reqwest.get(self.build_url(path)).header(key, value).send().await { - Ok(response) => Ok(response), - Err(err) => Err(anyhow!("{err}")), - } + pub async fn get_with_header(&self, path: &str, key: &str, value: &str) -> Result { + self.client + .get(self.build_url(path)) + .header(key, value) + .send() + .await + .map_err(|e| Error::ResponseError { err: e.into() }) } fn build_announce_path_and_query(&self, query: &announce::Query) -> String { diff --git a/src/shared/bit_torrent/tracker/udp/client.rs b/src/shared/bit_torrent/tracker/udp/client.rs index dce596e0..ec11f774 100644 --- a/src/shared/bit_torrent/tracker/udp/client.rs +++ b/src/shared/bit_torrent/tracker/udp/client.rs @@ -1,24 +1,20 @@ use core::result::Result::{Err, Ok}; use std::io::Cursor; -use std::net::SocketAddr; +use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr}; use std::sync::Arc; use std::time::Duration; -use anyhow::{anyhow, Context, Result}; use aquatic_udp_protocol::{ConnectRequest, Request, Response, TransactionId}; use tokio::net::UdpSocket; use tokio::time; -use tracing::debug; +use torrust_tracker_configuration::DEFAULT_TIMEOUT; use zerocopy::network_endian::I32; -use crate::shared::bit_torrent::tracker::udp::{source_address, MAX_PACKET_SIZE}; +use super::Error; +use crate::shared::bit_torrent::tracker::udp::MAX_PACKET_SIZE; pub const UDP_CLIENT_LOG_TARGET: &str = "UDP CLIENT"; -/// Default timeout for sending and receiving packets. And waiting for sockets -/// to be readable and writable. -pub const DEFAULT_TIMEOUT: Duration = Duration::from_secs(5); - #[allow(clippy::module_name_repetitions)] #[derive(Debug)] pub struct UdpClient { @@ -30,51 +26,94 @@ pub struct UdpClient { } impl UdpClient { + /// Creates a new `UdpClient` bound to the default port and ipv6 address + /// /// # Errors /// - /// Will return error if the local address can't be bound. + /// Will return error if unable to bind to any port or ip address. /// - pub async fn bind(local_address: &str) -> Result { - let socket_addr = local_address - .parse::() - .context(format!("{local_address} is not a valid socket address"))?; - - let socket = match time::timeout(DEFAULT_TIMEOUT, UdpSocket::bind(socket_addr)).await { - Ok(bind_result) => match bind_result { - Ok(socket) => { - debug!("Bound to socket: {socket_addr}"); - Ok(socket) - } - Err(e) => Err(anyhow!("Failed to bind to socket: {socket_addr}, error: {e:?}")), - }, - Err(e) => Err(anyhow!("Timeout waiting to bind to socket: {socket_addr}, error: {e:?}")), - }?; + async fn default_ipv4(timeout: Duration) -> Result { + let addr = SocketAddr::new(Ipv4Addr::UNSPECIFIED.into(), 0); + + Self::bound(addr, timeout).await + } + + /// Creates a new `UdpClient` bound to the default port and ipv6 address + /// + /// # Errors + /// + /// Will return error if unable to bind to any port or ip address. + /// + async fn default_ipv6(timeout: Duration) -> Result { + let addr = SocketAddr::new(Ipv6Addr::UNSPECIFIED.into(), 0); + + Self::bound(addr, timeout).await + } + + /// Creates a new `UdpClient` connected to a Udp server + /// + /// # Errors + /// + /// Will return any errors present in the call stack + /// + pub async fn connected(remote_addr: SocketAddr, timeout: Duration) -> Result { + let client = if remote_addr.is_ipv4() { + Self::default_ipv4(timeout).await? + } else { + Self::default_ipv6(timeout).await? + }; + + client.connect(remote_addr).await?; + Ok(client) + } + + /// Creates a `[UdpClient]` bound to a Socket. + /// + /// # Panics + /// + /// Panics if unable to get the `local_addr` of the bound socket. + /// + /// # Errors + /// + /// This function will return an error if the binding takes to long + /// or if there is an underlying OS error. + pub async fn bound(addr: SocketAddr, timeout: Duration) -> Result { + tracing::trace!(target: UDP_CLIENT_LOG_TARGET, "binding to socket: {addr:?} ..."); + + let socket = time::timeout(timeout, UdpSocket::bind(addr)) + .await + .map_err(|_| Error::TimeoutWhileBindingToSocket { addr })? + .map_err(|e| Error::UnableToBindToSocket { err: e.into(), addr })?; + + let addr = socket.local_addr().expect("it should get the local address"); + + tracing::debug!(target: UDP_CLIENT_LOG_TARGET, "bound to socket: {addr:?}."); let udp_client = Self { socket: Arc::new(socket), - timeout: DEFAULT_TIMEOUT, + timeout, }; + Ok(udp_client) } /// # Errors /// /// Will return error if can't connect to the socket. - pub async fn connect(&self, remote_address: &str) -> Result<()> { - let socket_addr = remote_address - .parse::() - .context(format!("{remote_address} is not a valid socket address"))?; - - match time::timeout(self.timeout, self.socket.connect(socket_addr)).await { - Ok(connect_result) => match connect_result { - Ok(()) => { - debug!("Connected to socket {socket_addr}"); - Ok(()) - } - Err(e) => Err(anyhow!("Failed to connect to socket {socket_addr}: {e:?}")), - }, - Err(e) => Err(anyhow!("Timeout waiting to connect to socket {socket_addr}, error: {e:?}")), - } + pub async fn connect(&self, remote_addr: SocketAddr) -> Result<(), Error> { + tracing::trace!(target: UDP_CLIENT_LOG_TARGET, "connecting to remote: {remote_addr:?} ..."); + + let () = time::timeout(self.timeout, self.socket.connect(remote_addr)) + .await + .map_err(|_| Error::TimeoutWhileConnectingToRemote { remote_addr })? + .map_err(|e| Error::UnableToConnectToRemote { + err: e.into(), + remote_addr, + })?; + + tracing::debug!(target: UDP_CLIENT_LOG_TARGET, "connected to remote: {remote_addr:?}."); + + Ok(()) } /// # Errors @@ -83,26 +122,25 @@ impl UdpClient { /// /// - Can't write to the socket. /// - Can't send data. - pub async fn send(&self, bytes: &[u8]) -> Result { - debug!(target: UDP_CLIENT_LOG_TARGET, "sending {bytes:?} ..."); - - match time::timeout(self.timeout, self.socket.writable()).await { - Ok(writable_result) => { - match writable_result { - Ok(()) => (), - Err(e) => return Err(anyhow!("IO error waiting for the socket to become readable: {e:?}")), - }; - } - Err(e) => return Err(anyhow!("Timeout waiting for the socket to become readable: {e:?}")), - }; + pub async fn send(&self, bytes: &[u8]) -> Result { + tracing::trace!(target: UDP_CLIENT_LOG_TARGET, "sending {bytes:?} ..."); - match time::timeout(self.timeout, self.socket.send(bytes)).await { - Ok(send_result) => match send_result { - Ok(size) => Ok(size), - Err(e) => Err(anyhow!("IO error during send: {e:?}")), - }, - Err(e) => Err(anyhow!("Send operation timed out: {e:?}")), - } + let () = time::timeout(self.timeout, self.socket.writable()) + .await + .map_err(|_| Error::TimeoutWaitForWriteableSocket)? + .map_err(|e| Error::UnableToGetWritableSocket { err: e.into() })?; + + let sent_bytes = time::timeout(self.timeout, self.socket.send(bytes)) + .await + .map_err(|_| Error::TimeoutWhileSendingData { data: bytes.to_vec() })? + .map_err(|e| Error::UnableToSendData { + err: e.into(), + data: bytes.to_vec(), + })?; + + tracing::debug!(target: UDP_CLIENT_LOG_TARGET, "sent {sent_bytes} bytes to remote."); + + Ok(sent_bytes) } /// # Errors @@ -114,110 +152,76 @@ impl UdpClient { /// /// # Panics /// - pub async fn receive(&self) -> Result> { - let mut response_buffer = [0u8; MAX_PACKET_SIZE]; + pub async fn receive(&self) -> Result, Error> { + tracing::trace!(target: UDP_CLIENT_LOG_TARGET, "receiving ..."); - debug!(target: UDP_CLIENT_LOG_TARGET, "receiving ..."); + let mut buffer = [0u8; MAX_PACKET_SIZE]; - match time::timeout(self.timeout, self.socket.readable()).await { - Ok(readable_result) => { - match readable_result { - Ok(()) => (), - Err(e) => return Err(anyhow!("IO error waiting for the socket to become readable: {e:?}")), - }; - } - Err(e) => return Err(anyhow!("Timeout waiting for the socket to become readable: {e:?}")), - }; + let () = time::timeout(self.timeout, self.socket.readable()) + .await + .map_err(|_| Error::TimeoutWaitForReadableSocket)? + .map_err(|e| Error::UnableToGetReadableSocket { err: e.into() })?; - let size = match time::timeout(self.timeout, self.socket.recv(&mut response_buffer)).await { - Ok(recv_result) => match recv_result { - Ok(size) => Ok(size), - Err(e) => Err(anyhow!("IO error during send: {e:?}")), - }, - Err(e) => Err(anyhow!("Receive operation timed out: {e:?}")), - }?; + let received_bytes = time::timeout(self.timeout, self.socket.recv(&mut buffer)) + .await + .map_err(|_| Error::TimeoutWhileReceivingData)? + .map_err(|e| Error::UnableToReceivingData { err: e.into() })?; - let mut res: Vec = response_buffer.to_vec(); - Vec::truncate(&mut res, size); + let mut received: Vec = buffer.to_vec(); + Vec::truncate(&mut received, received_bytes); - debug!(target: UDP_CLIENT_LOG_TARGET, "{size} bytes received {res:?}"); + tracing::debug!(target: UDP_CLIENT_LOG_TARGET, "received {received_bytes} bytes: {received:?}"); - Ok(res) + Ok(received) } } -/// Creates a new `UdpClient` connected to a Udp server -/// -/// # Errors -/// -/// Will return any errors present in the call stack -/// -pub async fn new_udp_client_connected(remote_address: &str) -> Result { - let port = 0; // Let OS choose an unused port. - let client = UdpClient::bind(&source_address(port)).await?; - client.connect(remote_address).await?; - Ok(client) -} - #[allow(clippy::module_name_repetitions)] #[derive(Debug)] pub struct UdpTrackerClient { - pub udp_client: UdpClient, + pub client: UdpClient, } impl UdpTrackerClient { + /// Creates a new `UdpTrackerClient` connected to a Udp Tracker server + /// + /// # Errors + /// + /// If unable to connect to the remote address. + /// + pub async fn new(remote_addr: SocketAddr, timeout: Duration) -> Result { + let client = UdpClient::connected(remote_addr, timeout).await?; + Ok(UdpTrackerClient { client }) + } + /// # Errors /// /// Will return error if can't write request to bytes. - pub async fn send(&self, request: Request) -> Result { - debug!(target: UDP_CLIENT_LOG_TARGET, "send request {request:?}"); + pub async fn send(&self, request: Request) -> Result { + tracing::trace!(target: UDP_CLIENT_LOG_TARGET, "sending request {request:?} ..."); // Write request into a buffer - let request_buffer = vec![0u8; MAX_PACKET_SIZE]; - let mut cursor = Cursor::new(request_buffer); - - let request_data_result = match request.write_bytes(&mut cursor) { - Ok(()) => { - #[allow(clippy::cast_possible_truncation)] - let position = cursor.position() as usize; - let inner_request_buffer = cursor.get_ref(); - // Return slice which contains written request data - Ok(&inner_request_buffer[..position]) - } - Err(e) => Err(anyhow!("could not write request to bytes: {e}.")), - }; + // todo: optimize the pre-allocated amount based upon request type. + let mut writer = Cursor::new(Vec::with_capacity(200)); + let () = request + .write_bytes(&mut writer) + .map_err(|e| Error::UnableToWriteDataFromRequest { err: e.into(), request })?; - let request_data = request_data_result?; - - self.udp_client.send(request_data).await + self.client.send(writer.get_ref()).await } /// # Errors /// /// Will return error if can't create response from the received payload (bytes buffer). - pub async fn receive(&self) -> Result { - let payload = self.udp_client.receive().await?; - - debug!(target: UDP_CLIENT_LOG_TARGET, "received {} bytes. Response {payload:?}", payload.len()); + pub async fn receive(&self) -> Result { + let response = self.client.receive().await?; - let response = Response::parse_bytes(&payload, true)?; + tracing::debug!(target: UDP_CLIENT_LOG_TARGET, "received {} bytes: {response:?}", response.len()); - Ok(response) + Response::parse_bytes(&response, true).map_err(|e| Error::UnableToParseResponse { err: e.into(), response }) } } -/// Creates a new `UdpTrackerClient` connected to a Udp Tracker server -/// -/// # Errors -/// -/// Will return any errors present in the call stack -/// -pub async fn new_udp_tracker_client_connected(remote_address: &str) -> Result { - let udp_client = new_udp_client_connected(remote_address).await?; - let udp_tracker_client = UdpTrackerClient { udp_client }; - Ok(udp_tracker_client) -} - /// Helper Function to Check if a UDP Service is Connectable /// /// # Panics @@ -226,10 +230,10 @@ pub async fn new_udp_tracker_client_connected(remote_address: &str) -> Result Result { - debug!("Checking Service (detail): {binding:?}."); +pub async fn check(remote_addr: &SocketAddr) -> Result { + tracing::debug!("Checking Service (detail): {remote_addr:?}."); - match new_udp_tracker_client_connected(binding.to_string().as_str()).await { + match UdpTrackerClient::new(*remote_addr, DEFAULT_TIMEOUT).await { Ok(client) => { let connect_request = ConnectRequest { transaction_id: TransactionId(I32::new(123)), @@ -238,7 +242,7 @@ pub async fn check(binding: &SocketAddr) -> Result { // client.send() return usize, but doesn't use here match client.send(connect_request.into()).await { Ok(_) => (), - Err(e) => debug!("Error: {e:?}."), + Err(e) => tracing::debug!("Error: {e:?}."), }; let process = move |response| { diff --git a/src/shared/bit_torrent/tracker/udp/mod.rs b/src/shared/bit_torrent/tracker/udp/mod.rs index 9322ef04..b9d5f34f 100644 --- a/src/shared/bit_torrent/tracker/udp/mod.rs +++ b/src/shared/bit_torrent/tracker/udp/mod.rs @@ -1,3 +1,10 @@ +use std::net::SocketAddr; +use std::sync::Arc; + +use aquatic_udp_protocol::Request; +use thiserror::Error; +use torrust_tracker_located_error::DynError; + pub mod client; /// The maximum number of bytes in a UDP packet. @@ -6,7 +13,56 @@ pub const MAX_PACKET_SIZE: usize = 1496; /// identify the protocol. pub const PROTOCOL_ID: i64 = 0x0417_2710_1980; -/// Generates the source address for the UDP client -fn source_address(port: u16) -> String { - format!("127.0.0.1:{port}") +#[derive(Debug, Clone, Error)] +pub enum Error { + #[error("Timeout while waiting for socket to bind: {addr:?}")] + TimeoutWhileBindingToSocket { addr: SocketAddr }, + + #[error("Failed to bind to socket: {addr:?}, with error: {err:?}")] + UnableToBindToSocket { err: Arc, addr: SocketAddr }, + + #[error("Timeout while waiting for connection to remote: {remote_addr:?}")] + TimeoutWhileConnectingToRemote { remote_addr: SocketAddr }, + + #[error("Failed to connect to remote: {remote_addr:?}, with error: {err:?}")] + UnableToConnectToRemote { + err: Arc, + remote_addr: SocketAddr, + }, + + #[error("Timeout while waiting for the socket to become writable.")] + TimeoutWaitForWriteableSocket, + + #[error("Failed to get writable socket: {err:?}")] + UnableToGetWritableSocket { err: Arc }, + + #[error("Timeout while trying to send data: {data:?}")] + TimeoutWhileSendingData { data: Vec }, + + #[error("Failed to send data: {data:?}, with error: {err:?}")] + UnableToSendData { err: Arc, data: Vec }, + + #[error("Timeout while waiting for the socket to become readable.")] + TimeoutWaitForReadableSocket, + + #[error("Failed to get readable socket: {err:?}")] + UnableToGetReadableSocket { err: Arc }, + + #[error("Timeout while trying to receive data.")] + TimeoutWhileReceivingData, + + #[error("Failed to receive data: {err:?}")] + UnableToReceivingData { err: Arc }, + + #[error("Failed to get data from request: {request:?}, with error: {err:?}")] + UnableToWriteDataFromRequest { err: Arc, request: Request }, + + #[error("Failed to parse response: {response:?}, with error: {err:?}")] + UnableToParseResponse { err: Arc, response: Vec }, +} + +impl From for DynError { + fn from(e: Error) -> Self { + Arc::new(Box::new(e)) + } } diff --git a/tests/servers/udp/contract.rs b/tests/servers/udp/contract.rs index b23b2090..e37ef7bf 100644 --- a/tests/servers/udp/contract.rs +++ b/tests/servers/udp/contract.rs @@ -6,8 +6,9 @@ use core::panic; use aquatic_udp_protocol::{ConnectRequest, ConnectionId, Response, TransactionId}; -use torrust_tracker::shared::bit_torrent::tracker::udp::client::{new_udp_client_connected, UdpTrackerClient}; +use torrust_tracker::shared::bit_torrent::tracker::udp::client::UdpTrackerClient; use torrust_tracker::shared::bit_torrent::tracker::udp::MAX_PACKET_SIZE; +use torrust_tracker_configuration::DEFAULT_TIMEOUT; use torrust_tracker_test_helpers::configuration; use crate::servers::udp::asserts::is_error_response; @@ -40,17 +41,17 @@ async fn send_connection_request(transaction_id: TransactionId, client: &UdpTrac async fn should_return_a_bad_request_response_when_the_client_sends_an_empty_request() { let env = Started::new(&configuration::ephemeral().into()).await; - let client = match new_udp_client_connected(&env.bind_address().to_string()).await { + let client = match UdpTrackerClient::new(env.bind_address(), DEFAULT_TIMEOUT).await { Ok(udp_client) => udp_client, Err(err) => panic!("{err}"), }; - match client.send(&empty_udp_request()).await { + match client.client.send(&empty_udp_request()).await { Ok(_) => (), Err(err) => panic!("{err}"), }; - let response = match client.receive().await { + let response = match client.client.receive().await { Ok(response) => response, Err(err) => panic!("{err}"), }; @@ -64,7 +65,8 @@ async fn should_return_a_bad_request_response_when_the_client_sends_an_empty_req mod receiving_a_connection_request { use aquatic_udp_protocol::{ConnectRequest, TransactionId}; - use torrust_tracker::shared::bit_torrent::tracker::udp::client::new_udp_tracker_client_connected; + use torrust_tracker::shared::bit_torrent::tracker::udp::client::UdpTrackerClient; + use torrust_tracker_configuration::DEFAULT_TIMEOUT; use torrust_tracker_test_helpers::configuration; use crate::servers::udp::asserts::is_connect_response; @@ -74,7 +76,7 @@ mod receiving_a_connection_request { async fn should_return_a_connect_response() { let env = Started::new(&configuration::ephemeral().into()).await; - let client = match new_udp_tracker_client_connected(&env.bind_address().to_string()).await { + let client = match UdpTrackerClient::new(env.bind_address(), DEFAULT_TIMEOUT).await { Ok(udp_tracker_client) => udp_tracker_client, Err(err) => panic!("{err}"), }; @@ -106,7 +108,8 @@ mod receiving_an_announce_request { AnnounceActionPlaceholder, AnnounceEvent, AnnounceRequest, ConnectionId, InfoHash, NumberOfBytes, NumberOfPeers, PeerId, PeerKey, Port, TransactionId, }; - use torrust_tracker::shared::bit_torrent::tracker::udp::client::{new_udp_tracker_client_connected, UdpTrackerClient}; + use torrust_tracker::shared::bit_torrent::tracker::udp::client::UdpTrackerClient; + use torrust_tracker_configuration::DEFAULT_TIMEOUT; use torrust_tracker_test_helpers::configuration; use crate::servers::udp::asserts::is_ipv4_announce_response; @@ -129,7 +132,7 @@ mod receiving_an_announce_request { ip_address: Ipv4Addr::new(0, 0, 0, 0).into(), key: PeerKey::new(0i32), peers_wanted: NumberOfPeers(1i32.into()), - port: Port(client.udp_client.socket.local_addr().unwrap().port().into()), + port: Port(client.client.socket.local_addr().unwrap().port().into()), }; match client.send(announce_request.into()).await { @@ -151,7 +154,7 @@ mod receiving_an_announce_request { async fn should_return_an_announce_response() { let env = Started::new(&configuration::ephemeral().into()).await; - let client = match new_udp_tracker_client_connected(&env.bind_address().to_string()).await { + let client = match UdpTrackerClient::new(env.bind_address(), DEFAULT_TIMEOUT).await { Ok(udp_tracker_client) => udp_tracker_client, Err(err) => panic!("{err}"), }; @@ -169,7 +172,7 @@ mod receiving_an_announce_request { async fn should_return_many_announce_response() { let env = Started::new(&configuration::ephemeral().into()).await; - let client = match new_udp_tracker_client_connected(&env.bind_address().to_string()).await { + let client = match UdpTrackerClient::new(env.bind_address(), DEFAULT_TIMEOUT).await { Ok(udp_tracker_client) => udp_tracker_client, Err(err) => panic!("{err}"), }; @@ -189,7 +192,8 @@ mod receiving_an_announce_request { mod receiving_an_scrape_request { use aquatic_udp_protocol::{ConnectionId, InfoHash, ScrapeRequest, TransactionId}; - use torrust_tracker::shared::bit_torrent::tracker::udp::client::new_udp_tracker_client_connected; + use torrust_tracker::shared::bit_torrent::tracker::udp::client::UdpTrackerClient; + use torrust_tracker_configuration::DEFAULT_TIMEOUT; use torrust_tracker_test_helpers::configuration; use crate::servers::udp::asserts::is_scrape_response; @@ -200,7 +204,7 @@ mod receiving_an_scrape_request { async fn should_return_a_scrape_response() { let env = Started::new(&configuration::ephemeral().into()).await; - let client = match new_udp_tracker_client_connected(&env.bind_address().to_string()).await { + let client = match UdpTrackerClient::new(env.bind_address(), DEFAULT_TIMEOUT).await { Ok(udp_tracker_client) => udp_tracker_client, Err(err) => panic!("{err}"), }; @@ -211,12 +215,13 @@ mod receiving_an_scrape_request { // Full scrapes are not allowed you need to pass an array of info hashes otherwise // it will return "bad request" error with empty vector - let info_hashes = vec![InfoHash([0u8; 20])]; + + let empty_info_hash = vec![InfoHash([0u8; 20])]; let scrape_request = ScrapeRequest { connection_id: ConnectionId(connection_id.0), transaction_id: TransactionId::new(123i32), - info_hashes, + info_hashes: empty_info_hash, }; match client.send(scrape_request.into()).await { diff --git a/tests/servers/udp/environment.rs b/tests/servers/udp/environment.rs index 2232cb0e..c580c355 100644 --- a/tests/servers/udp/environment.rs +++ b/tests/servers/udp/environment.rs @@ -7,8 +7,7 @@ use torrust_tracker::servers::registar::Registar; use torrust_tracker::servers::udp::server::spawner::Spawner; use torrust_tracker::servers::udp::server::states::{Running, Stopped}; use torrust_tracker::servers::udp::server::Server; -use torrust_tracker::shared::bit_torrent::tracker::udp::client::DEFAULT_TIMEOUT; -use torrust_tracker_configuration::{Configuration, UdpTracker}; +use torrust_tracker_configuration::{Configuration, UdpTracker, DEFAULT_TIMEOUT}; use torrust_tracker_primitives::info_hash::InfoHash; use torrust_tracker_primitives::peer;