From ad063c1f66f400ecd92659b7392b64f0d7452775 Mon Sep 17 00:00:00 2001 From: Florin Lipan Date: Mon, 11 Nov 2024 21:26:04 +0100 Subject: [PATCH] Allow matching the Request object based on a closure --- src/lib.rs | 24 +++++++++++++++++++ src/matcher.rs | 32 +++++++++++++++++++++++++ src/mock.rs | 34 ++++++++++++++++++++++++++- src/request.rs | 9 +++++++- src/server.rs | 5 ++++ tests/lib.rs | 63 ++++++++++++++++++++++++++++++++++++++++++++++++++ 6 files changed, 165 insertions(+), 2 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 4acd54b..6057d81 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -472,6 +472,30 @@ //! .create(); //! ``` //! +//! # Custom matchers +//! +//! If you need a more custom matcher, you can use the [`Mock::match_request`] function, which +//! takes a closure and exposes the [`Request`] object as an argument. The closure should return +//! a boolean value. +//! +//! ## Example +//! +//! ``` +//! use mockito::Matcher; +//! +//! let mut s = mockito::Server::new(); +//! +//! // This will match requests that have the x-test header set +//! // and contain the word "hello" inside the body +//! s.mock("GET", "/") +//! .match_request(|request| { +//! request.has_header("x-test") && +//! request.utf8_lossy_body().unwrap().contains("hello") +//! }) +//! .create(); +//! +//! ``` +//! //! # Asserts //! //! You can use the [`Mock::assert`] method to **assert that a mock was called**. In other words, diff --git a/src/matcher.rs b/src/matcher.rs index 57808bc..9c1a463 100644 --- a/src/matcher.rs +++ b/src/matcher.rs @@ -1,3 +1,4 @@ +use crate::request::Request; use assert_json_diff::{assert_json_matches_no_panic, CompareMode}; use http::header::HeaderValue; use regex::Regex; @@ -9,6 +10,7 @@ use std::io; use std::io::Read; use std::path::Path; use std::string::ToString; +use std::sync::Arc; /// /// Allows matching the request path, headers or body in multiple ways: by the exact value, by any value (as @@ -281,3 +283,33 @@ impl fmt::Display for BinaryBody { } } } + +#[derive(Clone)] +pub(crate) struct RequestMatcher(Arc bool + Send + Sync>); + +impl RequestMatcher { + pub(crate) fn matches(&self, value: &Request) -> bool { + self.0(value) + } +} + +impl From for RequestMatcher +where + F: Fn(&Request) -> bool + Send + Sync + 'static, +{ + fn from(value: F) -> Self { + Self(Arc::new(value)) + } +} + +impl Default for RequestMatcher { + fn default() -> Self { + RequestMatcher(Arc::new(|_| true)) + } +} + +impl fmt::Debug for RequestMatcher { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "(RequestMatcher)") + } +} diff --git a/src/mock.rs b/src/mock.rs index ca3e21f..c87999c 100644 --- a/src/mock.rs +++ b/src/mock.rs @@ -1,5 +1,5 @@ use crate::diff; -use crate::matcher::{Matcher, PathAndQueryMatcher}; +use crate::matcher::{Matcher, PathAndQueryMatcher, RequestMatcher}; use crate::response::{Body, Response}; use crate::server::RemoteMock; use crate::server::State; @@ -67,6 +67,7 @@ pub struct InnerMock { pub(crate) path: PathAndQueryMatcher, pub(crate) headers: HeaderMap, pub(crate) body: Matcher, + pub(crate) request_matcher: RequestMatcher, pub(crate) response: Response, pub(crate) hits: usize, pub(crate) expected_hits_at_least: Option, @@ -161,6 +162,7 @@ impl Mock { path: PathAndQueryMatcher::Unified(path.into()), headers: HeaderMap::::default(), body: Matcher::Any, + request_matcher: RequestMatcher::default(), response: Response::default(), hits: 0, expected_hits_at_least: None, @@ -303,6 +305,36 @@ impl Mock { self } + /// + /// Allows matching the entire request based on a closure that takes + /// the [`Request`] object as an argument and returns a boolean value. + /// + /// ## Example + /// + /// ``` + /// use mockito::Matcher; + /// + /// let mut s = mockito::Server::new(); + /// + /// // This will match requests that have the x-test header set + /// // and contain the word "hello" inside the body + /// s.mock("GET", "/") + /// .match_request(|request| { + /// request.has_header("x-test") && + /// request.utf8_lossy_body().unwrap().contains("hello") + /// }) + /// .create(); + /// ``` + /// + pub fn match_request(mut self, request_matcher: F) -> Self + where + F: Fn(&Request) -> bool + Send + Sync + 'static, + { + self.inner.request_matcher = request_matcher.into(); + + self + } + /// /// Sets the status code of the mock response. The default status code is 200. /// diff --git a/src/request.rs b/src/request.rs index 4b1fd48..462bbe8 100644 --- a/src/request.rs +++ b/src/request.rs @@ -3,6 +3,7 @@ use http::header::{AsHeaderName, HeaderValue}; use http::Request as HttpRequest; use http_body_util::BodyExt; use hyper::body::Incoming; +use std::borrow::Cow; /// /// Stores a HTTP request @@ -51,13 +52,19 @@ impl Request { } /// Returns the request body or an error, if the body hasn't been read - /// up to this moment. + /// yet. pub fn body(&self) -> Result<&Vec, Error> { self.body .as_ref() .ok_or_else(|| Error::new(ErrorKind::RequestBodyFailure)) } + /// Returns the request body as UTF8 or an error, if the body hasn't + /// been read yet. + pub fn utf8_lossy_body(&self) -> Result, Error> { + self.body().map(|body| String::from_utf8_lossy(body)) + } + /// Reads the body (if it hasn't been read already) and returns it pub(crate) async fn read_body(&mut self) -> &Vec { if self.body.is_none() { diff --git a/src/server.rs b/src/server.rs index 4b04409..03530fd 100644 --- a/src/server.rs +++ b/src/server.rs @@ -41,6 +41,7 @@ impl RemoteMock { && self.path_matches(other) && self.headers_match(other) && self.body_matches(other) + && self.request_matches(other) } fn method_matches(&self, request: &Request) -> bool { @@ -65,6 +66,10 @@ impl RemoteMock { self.inner.body.matches_value(safe_body) || self.inner.body.matches_binary_value(body) } + fn request_matches(&self, request: &Request) -> bool { + self.inner.request_matcher.matches(request) + } + #[allow(clippy::missing_const_for_fn)] fn is_missing_hits(&self) -> bool { match ( diff --git a/tests/lib.rs b/tests/lib.rs index 0fc1aef..2a96723 100644 --- a/tests/lib.rs +++ b/tests/lib.rs @@ -1879,6 +1879,69 @@ fn test_anyof_exact_path_and_query_matcher() { mock.assert(); } +#[test] +fn test_request_matcher_path() { + let mut s = Server::new(); + let host = s.host_with_port(); + let m = s + .mock("GET", Matcher::Any) + .match_request(|req| req.path().contains("hello")) + .with_body("world") + .create(); + + let (status_line, _, _) = request(&host, "GET /", ""); + assert_eq!("HTTP/1.1 501 Not Implemented\r\n", status_line); + + let (status_line, _, body) = request(host, "GET /hello", ""); + assert_eq!("HTTP/1.1 200 OK\r\n", status_line); + assert_eq!("world", body); + + m.assert(); +} + +#[test] +fn test_request_matcher_headers() { + let mut s = Server::new(); + let host = s.host_with_port(); + let m = s + .mock("GET", "/") + .match_request(|req| req.has_header("x-test")) + .with_body("world") + .create(); + + let (status_line, _, _) = request(&host, "GET /", ""); + assert_eq!("HTTP/1.1 501 Not Implemented\r\n", status_line); + + let (status_line, _, body) = request(host, "GET /", "x-test: 1\r\n"); + assert_eq!("HTTP/1.1 200 OK\r\n", status_line); + assert_eq!("world", body); + + m.assert(); +} + +#[test] +fn test_request_matcher_body() { + let mut s = Server::new(); + let host = s.host_with_port(); + let m = s + .mock("GET", "/") + .match_request(|req| { + let body = req.utf8_lossy_body().unwrap(); + body.contains("hello") + }) + .with_body("world") + .create(); + + let (status_line, _, _) = request_with_body(&host, "GET /", "", "bye"); + assert_eq!("HTTP/1.1 501 Not Implemented\r\n", status_line); + + let (status_line, _, body) = request_with_body(host, "GET /", "", "hello"); + assert_eq!("HTTP/1.1 200 OK\r\n", status_line); + assert_eq!("world", body); + + m.assert(); +} + #[test] fn test_default_headers() { let mut s = Server::new();