diff --git a/Cargo.toml b/Cargo.toml index f0f9a25..4296218 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,15 +1,17 @@ [dependencies] async-trait = {version = "0.1", optional = true} +axum = {version = "0.6", optional = true} bytes = {version = "1.5", optional = true} chrono = {version = "0.4", optional = true} mockall = {version = "0.11", optional = true} open = {version = "5.0", optional = true} -tokio = {version = "1.32", features = ["process"], optional = true} +tokio = {version = "1.32", features = ["process", "rt"], optional = true} tracing = "0.1" uuid = {version = "0.8", features = ["v4"], optional = true} [dev-dependencies] mockall = "0.11" +reqwest = "0.11" tokio = {version = "1.32", features = ["full"]} tokio-test = "0.4" @@ -18,6 +20,7 @@ browser = ["dep:open"] clock = ["dep:chrono"] cmd = ["dep:async-trait", "dep:tokio"] full = ["browser", "clock", "cmd", "uuid"] +http = ["dep:axum", "dep:tokio"] mock = ["dep:mockall"] uuid = ["dep:uuid"] diff --git a/examples/http.rs b/examples/http.rs new file mode 100644 index 0000000..e1740b9 --- /dev/null +++ b/examples/http.rs @@ -0,0 +1,38 @@ +use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4}; + +use mockable::{DefaultHttpServer, HttpServer}; + +#[tokio::main] +async fn main() { + let addr = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::LOCALHOST, 8000)); + let mut server = DefaultHttpServer::start(&addr) + .await + .expect("failed to start server"); + let req = server.next().await.expect("failed to get request"); + println!("{:?}", req); + server.stop().await; +} + +#[cfg(test)] +mod test { + use mockable::{HttpRequest, MockHttpServer}; + + use super::*; + + #[tokio::test] + async fn test() { + let expected = HttpRequest { + body: vec![], + headers: Default::default(), + method: "GET".into(), + path: "/".into(), + query: Default::default(), + }; + let mut server = MockHttpServer::new(); + server.expect_next().return_const(expected.clone()); + server.expect_stop().return_const(()); + let req = server.next().await.expect("failed to get request"); + server.stop().await; + assert_eq!(req, expected); + } +} diff --git a/src/http.rs b/src/http.rs new file mode 100644 index 0000000..9099dd3 --- /dev/null +++ b/src/http.rs @@ -0,0 +1,220 @@ +use std::{collections::HashMap, io, net::SocketAddr}; + +use async_trait::async_trait; +use axum::{ + body::Bytes, + extract::Query, + http::{HeaderMap, Method, StatusCode, Uri}, + Router, Server, +}; +use tokio::{ + spawn, + sync::{mpsc, oneshot}, + task::JoinHandle, +}; +use tracing::{error, warn}; + +// HttpRequest + +#[derive(Clone, Debug, Eq, PartialEq)] +pub struct HttpRequest { + pub body: Vec, + pub headers: HashMap, + pub method: String, + pub path: String, + pub query: HashMap>, +} + +// HttpServer + +/// Simple HTTP server that listen all requests. +/// +/// **This is supported on `feature=http` only.** +/// +/// [Example](https://github.com/leroyguillaume/mockable/tree/main/examples/http.rs). +#[async_trait] +pub trait HttpServer { + /// Returns the next request received by the server. + /// + /// `None` is returned if the server is stopped. + async fn next(&mut self) -> Option; + + /// Stops the server. + async fn stop(self); +} + +// DefaultHttpServer + +/// Default implementation of [`HttpServer`](trait.HttpServer.html). +/// +/// **This is supported on `feature=http` only.** +/// +/// [Example](https://github.com/leroyguillaume/mockable/tree/main/examples/http.rs). +pub struct DefaultHttpServer { + req_rx: mpsc::Receiver, + server: JoinHandle<()>, + stop_tx: oneshot::Sender<()>, +} + +impl DefaultHttpServer { + /// Starts a new server listening on the given address. + pub async fn start(addr: &SocketAddr) -> io::Result { + let (stop_tx, stop_rx) = oneshot::channel(); + let (req_tx, req_rx) = mpsc::channel(1); + let app = Router::new().fallback( + move |method: Method, + uri: Uri, + Query(query): Query>, + headers: HeaderMap, + body: Bytes| async move { + let mut req_headers = HashMap::new(); + for (name, val) in headers { + let name = if let Some(name) = &name { + name.as_str() + } else { + warn!("request contains header with no name"); + continue; + }; + let val = match val.to_str() { + Ok(val) => val, + Err(err) => { + warn!(details = %err, header = name, "failed to decode header value"); + continue; + } + }; + req_headers.insert(name.into(), val.into()); + } + let query = query.into_iter().fold( + HashMap::>::new(), + |mut query, (key, val)| { + query.entry(key).or_default().push(val); + query + }, + ); + let req = HttpRequest { + body: body.to_vec(), + headers: req_headers, + method: method.to_string(), + path: uri.path().into(), + query, + }; + req_tx.send(req).await.ok(); + StatusCode::OK + }, + ); + let server = Server::bind(addr) + .serve(app.into_make_service()) + .with_graceful_shutdown(async { + stop_rx.await.ok(); + }); + let server = spawn(async { + if let Err(err) = server.await { + error!(details = %err, "failed to start server"); + } + }); + Ok(Self { + req_rx, + server, + stop_tx, + }) + } +} + +#[async_trait] +impl HttpServer for DefaultHttpServer { + async fn next(&mut self) -> Option { + self.req_rx.recv().await + } + + async fn stop(self) { + self.stop_tx.send(()).ok(); + if let Err(err) = self.server.await { + error!(details = %err, "failed to stop server"); + } + } +} + +// MockHttpServer + +#[cfg(feature = "mock")] +mockall::mock! { + /// `mockall` implementation of [`HttpServer`](trait.HttpServer.html). + /// + /// **This is supported on `feature=http,mock` only.** + /// + /// [Example](https://github.com/leroyguillaume/mockable/tree/main/examples/http.rs). + pub HttpServer {} + + #[async_trait] + impl HttpServer for HttpServer { + async fn next(&mut self) -> Option; + async fn stop(self); + } +} + +// Tests + +#[cfg(test)] +mod test { + use std::{ + net::{Ipv4Addr, SocketAddrV4}, + time::Duration, + }; + + use reqwest::Client; + use tokio::time::sleep; + + use super::*; + + mod default_http_server { + use super::*; + + #[tokio::test] + async fn test() { + let port = 8000; + let expected = HttpRequest { + body: "abc".to_string().into_bytes(), + headers: HashMap::from_iter([ + ("accept".into(), "*/*".into()), + ("content-length".into(), "3".into()), + ("host".into(), format!("localhost:{port}")), + ]), + method: "GET".into(), + path: "/a/b".into(), + query: HashMap::from_iter([("foo".into(), vec!["bar1".into(), "bar2".into()])]), + }; + let addr = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::LOCALHOST, port)); + let mut server = DefaultHttpServer::start(&addr) + .await + .expect("failed to start server"); + sleep(Duration::from_secs(1)).await; + let client = Client::new(); + let query: Vec<(String, String)> = expected + .query + .clone() + .into_iter() + .flat_map(|(key, val)| val.into_iter().map(move |val| (key.clone(), val))) + .collect(); + let resp = client + .get(format!("http://localhost:{port}{}", expected.path)) + .query(&query) + .body(expected.body.clone()) + .send() + .await + .expect("failed to send request"); + let status = resp.status(); + if status != reqwest::StatusCode::OK { + let body = resp.text().await.expect("failed to read response body"); + panic!("request failed with status {status}: {body}"); + } + let req = server.next().await.expect("failed to receive request"); + assert_eq!(req, expected); + server.stop().await; + client + .get(format!("http://localhost:{port}")) + .send() + .await + .expect_err("request should fail after server is stopped"); + } + } +} diff --git a/src/lib.rs b/src/lib.rs index 35422d7..97bae0a 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -8,6 +8,10 @@ pub use self::clock::{Clock, DefaultClock}; pub use self::cmd::MockCommandRunner; #[cfg(feature = "cmd")] pub use self::cmd::{Command, CommandOutput, CommandRunner, DefaultCommandRunner}; +#[cfg(all(feature = "http", feature = "mock"))] +pub use self::http::MockHttpServer; +#[cfg(feature = "http")] +pub use self::http::{DefaultHttpServer, HttpRequest, HttpServer}; #[cfg(any(feature = "mock", test))] pub use self::mock::Mock; #[cfg(all(feature = "uuid", feature = "mock"))] @@ -28,6 +32,8 @@ mod clock; #[cfg(feature = "cmd")] mod cmd; mod env; +#[cfg(feature = "http")] +mod http; #[cfg(any(feature = "mock", test))] mod mock; mod sys;