diff --git a/Cargo.toml b/Cargo.toml index 4d3dd44..84ff985 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -27,7 +27,7 @@ regex = "1" futures-timer = "3.0.2" futures = "0.3.5" hyper = { version = "0.14", features = ["full"] } -tokio = { version = "1.5.0", features = ["rt", "io-util", "time"] } +tokio = { version = "1.5.0", features = ["rt"] } deadpool = "0.9.2" async-trait = "0.1" once_cell = "1" diff --git a/src/mock_server/bare_server.rs b/src/mock_server/bare_server.rs index e2f2211..8db499b 100644 --- a/src/mock_server/bare_server.rs +++ b/src/mock_server/bare_server.rs @@ -3,7 +3,10 @@ use crate::mock_set::MockId; use crate::mock_set::MountedMockSet; use crate::{mock::Mock, verification::VerificationOutcome, Request}; use std::net::{SocketAddr, TcpListener, TcpStream}; +use std::pin::pin; +use std::sync::atomic::AtomicBool; use std::sync::Arc; +use tokio::sync::Notify; use tokio::sync::RwLock; use tokio::task::LocalSet; @@ -101,8 +104,9 @@ impl BareMockServer { /// When the returned `MockGuard` is dropped, `MockServer` will verify that the expectations set on the scoped `Mock` were /// verified - if not, it will panic. pub async fn register_as_scoped(&self, mock: Mock) -> MockGuard { - let mock_id = self.state.write().await.mock_set.register(mock); + let (notify, mock_id) = self.state.write().await.mock_set.register(mock); MockGuard { + notify, mock_id, server_state: self.state.clone(), } @@ -182,6 +186,7 @@ Check `wiremock`'s documentation on scoped mocks for more details."] pub struct MockGuard { mock_id: MockId, server_state: Arc>, + notify: Arc<(Notify, AtomicBool)>, } impl MockGuard { @@ -190,6 +195,22 @@ impl MockGuard { let (mounted_mock, _) = &state.mock_set[self.mock_id]; mounted_mock.received_requests() } + + pub async fn wait_until_satisfied(&self) { + let (notify, flag) = &*self.notify; + let mut notification = pin!(notify.notified()); + + // listen for events of satisfaction. + notification.as_mut().enable(); + + // check if satisfaction has previously been recorded + if flag.load(std::sync::atomic::Ordering::Acquire) { + return; + } + + // await event + notification.await + } } impl Drop for MockGuard { @@ -198,6 +219,7 @@ impl Drop for MockGuard { let MockGuard { mock_id, server_state, + .. } = self; let mut state = server_state.write().await; let report = state.mock_set.verify(*mock_id); diff --git a/src/mock_set.rs b/src/mock_set.rs index f507b4a..3b0c164 100644 --- a/src/mock_set.rs +++ b/src/mock_set.rs @@ -6,7 +6,11 @@ use crate::{Mock, Request, ResponseTemplate}; use futures_timer::Delay; use http_types::{Response, StatusCode}; use log::debug; -use std::ops::{Index, IndexMut}; +use std::{ + ops::{Index, IndexMut}, + sync::{atomic::AtomicBool, Arc}, +}; +use tokio::sync::Notify; /// The collection of mocks used by a `MockServer` instance to match against /// incoming requests. @@ -67,15 +71,18 @@ impl MountedMockSet { } } - pub(crate) fn register(&mut self, mock: Mock) -> MockId { + pub(crate) fn register(&mut self, mock: Mock) -> (Arc<(Notify, AtomicBool)>, MockId) { let n_registered_mocks = self.mocks.len(); let active_mock = MountedMock::new(mock, n_registered_mocks); + let notify = active_mock.notify(); self.mocks.push((active_mock, MountedMockState::InScope)); - - MockId { - index: self.mocks.len() - 1, - generation: self.generation, - } + ( + notify, + MockId { + index: self.mocks.len() - 1, + generation: self.generation, + }, + ) } pub(crate) fn reset(&mut self) { @@ -179,7 +186,7 @@ mod tests { // Assert let mut set = MountedMockSet::new(); let mock = Mock::given(path("/")).respond_with(ResponseTemplate::new(200)); - let mock_id = set.register(mock); + let (_, mock_id) = set.register(mock); // Act set.reset(); @@ -194,8 +201,8 @@ mod tests { let mut set = MountedMockSet::new(); let first_mock = Mock::given(path("/")).respond_with(ResponseTemplate::new(200)); let second_mock = Mock::given(path("/hello")).respond_with(ResponseTemplate::new(500)); - let first_mock_id = set.register(first_mock); - let second_mock_id = set.register(second_mock); + let (_, first_mock_id) = set.register(first_mock); + let (_, second_mock_id) = set.register(second_mock); // Act set.deactivate(first_mock_id); diff --git a/src/mounted_mock.rs b/src/mounted_mock.rs index ce6907b..1bbe8a9 100644 --- a/src/mounted_mock.rs +++ b/src/mounted_mock.rs @@ -1,3 +1,7 @@ +use std::sync::{atomic::AtomicBool, Arc}; + +use tokio::sync::Notify; + use crate::{verification::VerificationReport, Match, Mock, Request, ResponseTemplate}; /// Given the behaviour specification as a [`Mock`](crate::Mock), keep track of runtime information @@ -14,6 +18,8 @@ pub(crate) struct MountedMock { // matched requests: matched_requests: Vec, + + notify: Arc<(Notify, AtomicBool)>, } impl MountedMock { @@ -23,6 +29,7 @@ impl MountedMock { n_matched_requests: 0, position_in_set, matched_requests: Vec::new(), + notify: Arc::new((Notify::new(), AtomicBool::new(false))), } } @@ -46,7 +53,16 @@ impl MountedMock { // Increase match count self.n_matched_requests += 1; // Keep track of request - self.matched_requests.push(request.clone()) + self.matched_requests.push(request.clone()); + + // notification of satisfaction + if self.verify().is_satisfied() { + // always set the satisfaction flag **before** raising the event + self.notify + .1 + .store(true, std::sync::atomic::Ordering::Release); + self.notify.0.notify_waiters(); + } } matched @@ -71,4 +87,8 @@ impl MountedMock { pub(crate) fn received_requests(&self) -> Vec { self.matched_requests.clone() } + + pub(crate) fn notify(&self) -> Arc<(Notify, AtomicBool)> { + self.notify.clone() + } } diff --git a/tests/mocks.rs b/tests/mocks.rs index 58b7241..51fe0a2 100644 --- a/tests/mocks.rs +++ b/tests/mocks.rs @@ -1,7 +1,9 @@ +use futures::FutureExt; use http_types::StatusCode; use serde::Serialize; use serde_json::json; use std::net::TcpStream; +use std::time::Duration; use wiremock::matchers::{body_json, body_partial_json, method, path, PathExactMatcher}; use wiremock::{Mock, MockServer, ResponseTemplate}; @@ -273,3 +275,67 @@ async fn use_mock_guard_to_verify_requests_from_mock() { assert_eq!(value, json!({"attempt": 99})); } + +#[async_std::test] +async fn use_mock_guard_to_await_satisfaction_readiness() { + // Arrange + let mock_server = MockServer::start().await; + + let satisfy = mock_server + .register_as_scoped( + Mock::given(method("POST")) + .and(PathExactMatcher::new("satisfy")) + .respond_with(ResponseTemplate::new(200)) + .expect(1), + ) + .await; + + let eventually_satisfy = mock_server + .register_as_scoped( + Mock::given(method("POST")) + .and(PathExactMatcher::new("eventually_satisfy")) + .respond_with(ResponseTemplate::new(200)) + .expect(1), + ) + .await; + + // Act one + let uri = mock_server.uri(); + let response = surf::post(format!("{uri}/satisfy")).await.unwrap(); + assert_eq!(response.status(), StatusCode::Ok); + + // Assert + satisfy + .wait_until_satisfied() + .now_or_never() + .expect("should be satisfied immediately"); + + eventually_satisfy + .wait_until_satisfied() + .now_or_never() + .ok_or(()) + .expect_err("should not be satisfied yet"); + + // Act two + async_std::task::spawn(async move { + async_std::task::sleep(Duration::from_millis(100)).await; + let response = surf::post(format!("{uri}/eventually_satisfy")) + .await + .unwrap(); + assert_eq!(response.status(), StatusCode::Ok); + }); + + // Assert + eventually_satisfy + .wait_until_satisfied() + .now_or_never() + .ok_or(()) + .expect_err("should not be satisfied yet"); + + async_std::io::timeout( + Duration::from_millis(1000), + eventually_satisfy.wait_until_satisfied().map(Ok), + ) + .await + .expect("should be satisfied"); +}