From f38b937327d4b4daeb337b26365b88e1c6a2d78e Mon Sep 17 00:00:00 2001 From: Stanimal Date: Fri, 11 Oct 2019 13:13:25 +0200 Subject: [PATCH] Filter Store and forward messages if feature is not enabled - Store and forward messages are discarded when they are not supported by the node - To implement this, tower-filter was used. However, it is not released yet for futures 0.3 so I've included it directly in this PR --- Cargo.toml | 2 + comms/dht/Cargo.toml | 1 + comms/dht/src/dht.rs | 77 ++++++++++++++-- comms/dht/src/macros.rs | 6 -- comms/dht/src/test_utils/makers.rs | 11 +++ comms/middleware/tower-filter/CHANGELOG.md | 12 +++ comms/middleware/tower-filter/Cargo.toml | 35 +++++++ comms/middleware/tower-filter/LICENSE | 25 +++++ comms/middleware/tower-filter/README.md | 14 +++ comms/middleware/tower-filter/src/error.rs | 46 ++++++++++ comms/middleware/tower-filter/src/future.rs | 91 +++++++++++++++++++ comms/middleware/tower-filter/src/layer.rs | 24 +++++ comms/middleware/tower-filter/src/lib.rs | 59 ++++++++++++ .../middleware/tower-filter/src/predicate.rs | 25 +++++ comms/middleware/tower-filter/tests/filter.rs | 61 +++++++++++++ 15 files changed, 477 insertions(+), 12 deletions(-) create mode 100644 comms/middleware/tower-filter/CHANGELOG.md create mode 100644 comms/middleware/tower-filter/Cargo.toml create mode 100644 comms/middleware/tower-filter/LICENSE create mode 100644 comms/middleware/tower-filter/README.md create mode 100644 comms/middleware/tower-filter/src/error.rs create mode 100644 comms/middleware/tower-filter/src/future.rs create mode 100644 comms/middleware/tower-filter/src/layer.rs create mode 100644 comms/middleware/tower-filter/src/lib.rs create mode 100644 comms/middleware/tower-filter/src/predicate.rs create mode 100644 comms/middleware/tower-filter/tests/filter.rs diff --git a/Cargo.toml b/Cargo.toml index 234353f528..d7418da9c6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -11,6 +11,8 @@ members = [ "comms", "comms/dht", "comms/middleware", + # TODO: Remove this once tower filter (0.3.0-alpha.3) is released + "comms/middleware/tower-filter", "digital_assets_layer/core", "infrastructure/broadcast_channel", "infrastructure/crypto", diff --git a/comms/dht/Cargo.toml b/comms/dht/Cargo.toml index fda410d121..a72854abaf 100644 --- a/comms/dht/Cargo.toml +++ b/comms/dht/Cargo.toml @@ -30,6 +30,7 @@ serde_repr = "0.1.5" tokio = "0.2.0-alpha.6" tokio-executor = "0.2.0-alpha.6" tower= "0.3.0-alpha.2" +tower-filter= {path="../middleware/tower-filter"}#version="=0.3.0-alpha.2", path="../" ttl_cache = "0.5.1" [dev-dependencies] diff --git a/comms/dht/src/dht.rs b/comms/dht/src/dht.rs index 4cf34b6719..b1c532138c 100644 --- a/comms/dht/src/dht.rs +++ b/comms/dht/src/dht.rs @@ -24,23 +24,24 @@ use self::outbound::OutboundMessageRequester; use crate::{ envelope::{DhtMessageType, NodeDestination}, inbound, - inbound::{DecryptedDhtMessage, DiscoverMessage, JoinMessage}, + inbound::{DecryptedDhtMessage, DhtInboundMessage, DiscoverMessage, JoinMessage}, outbound, outbound::{BroadcastClosestRequest, BroadcastStrategy, DhtOutboundError, DhtOutboundRequest, OutboundEncryption}, store_forward, DhtConfig, }; -use futures::{channel::mpsc, Future}; -use log::debug; +use futures::{channel::mpsc, future, Future}; +use log::*; use std::sync::Arc; use tari_comms::{ message::InboundMessage, outbound_message_service::OutboundMessage, - peer_manager::{NodeId, NodeIdentity, PeerManager}, + peer_manager::{NodeId, NodeIdentity, PeerFeature, PeerManager}, types::CommsPublicKey, }; use tari_comms_middleware::MiddlewareError; use tower::{layer::Layer, Service, ServiceBuilder}; +use tower_filter::error::Error as FilterError; const LOG_TARGET: &'static str = "comms::dht"; @@ -99,6 +100,7 @@ impl Dht { ServiceBuilder::new() .layer(inbound::DeserializeLayer::new()) + .layer(tower_filter::FilterLayer::new(self.unsupported_saf_messages_filter())) .layer(inbound::DecryptionLayer::new(Arc::clone(&self.node_identity))) .layer(store_forward::ForwardLayer::new( Arc::clone(&self.peer_manager), @@ -185,6 +187,33 @@ impl Dht { Ok(()) } + /// Produces a filter predicate which disallows store and forward messages if that feature is not + /// supported by the node. + fn unsupported_saf_messages_filter( + &self, + ) -> impl tower_filter::Predicate>> + Clone + Send + { + let node_identity = Arc::clone(&self.node_identity); + move |msg: &DhtInboundMessage| { + if node_identity.has_peer_feature(&PeerFeature::DhtStoreForward) { + return future::ready(Ok(())); + } + + match msg.dht_header.message_type { + DhtMessageType::SAFRequestMessages | DhtMessageType::SAFStoredMessages => { + // TODO: This is an indication of node misbehaviour + warn!( + "Received store and forward message from PublicKey={}. Store and forward feature is not \ + supported by this node. Discarding message.", + msg.dht_header.origin_public_key + ); + future::ready(Err(FilterError::rejected())) + }, + _ => future::ready(Ok(())), + } + } + } + pub async fn send_discover( &self, dest_public_key: CommsPublicKey, @@ -233,9 +262,15 @@ impl Dht { #[cfg(test)] mod test { use crate::{ - envelope::DhtMessageFlags, + envelope::{DhtMessageFlags, DhtMessageType}, outbound::DhtOutboundRequest, - test_utils::{make_comms_inbound_message, make_dht_envelope, make_node_identity, make_peer_manager}, + test_utils::{ + make_client_identity, + make_comms_inbound_message, + make_dht_envelope, + make_node_identity, + make_peer_manager, + }, DhtBuilder, }; use futures::{channel::mpsc, StreamExt}; @@ -349,4 +384,34 @@ mod test { // Check the next service was not called assert!(rt.block_on(next_service_rx.next()).is_none()); } + + #[test] + fn stack_filter_saf_message() { + let node_identity = make_client_identity(); + let peer_manager = make_peer_manager(); + + let dht = DhtBuilder::new(Arc::clone(&node_identity), peer_manager).finish(); + + let rt = Runtime::new().unwrap(); + + let (next_service_tx, mut next_service_rx) = mpsc::channel(10); + + let mut service = dht + .inbound_middleware_layer() + .layer(SinkMiddleware::new(next_service_tx)); + + let msg = Message::from_message_format((), "secret".to_string()).unwrap(); + let mut dht_envelope = make_dht_envelope(&node_identity, msg.to_binary().unwrap(), DhtMessageFlags::empty()); + dht_envelope.header.message_type = DhtMessageType::SAFStoredMessages; + let inbound_message = + make_comms_inbound_message(&node_identity, dht_envelope.to_binary().unwrap(), MessageFlags::empty()); + + let err = rt.block_on(service.call(inbound_message)); + assert!(err.is_err()); + // This seems like the best way to tell that an open channel is empty without the test blocking indefinitely + assert_eq!( + format!("{}", next_service_rx.try_next().unwrap_err()), + "receiver channel is empty" + ); + } } diff --git a/comms/dht/src/macros.rs b/comms/dht/src/macros.rs index b99ee39585..5efb85d67f 100644 --- a/comms/dht/src/macros.rs +++ b/comms/dht/src/macros.rs @@ -40,9 +40,3 @@ macro_rules! acquire_write_lock { acquire_lock!($e, write) }; } - -macro_rules! acquire_read_lock { - ($e:expr) => { - acquire_lock!($e, read) - }; -} diff --git a/comms/dht/src/test_utils/makers.rs b/comms/dht/src/test_utils/makers.rs index 77d5494e06..f437415456 100644 --- a/comms/dht/src/test_utils/makers.rs +++ b/comms/dht/src/test_utils/makers.rs @@ -48,6 +48,17 @@ pub fn make_node_identity() -> Arc { ) } +pub fn make_client_identity() -> Arc { + Arc::new( + NodeIdentity::random( + &mut OsRng::new().unwrap(), + "127.0.0.1:9000".parse().unwrap(), + PeerFeatures::communication_client_default(), + ) + .unwrap(), + ) +} + pub fn make_comms_inbound_message( node_identity: &NodeIdentity, message: Vec, diff --git a/comms/middleware/tower-filter/CHANGELOG.md b/comms/middleware/tower-filter/CHANGELOG.md new file mode 100644 index 0000000000..4eb60cc391 --- /dev/null +++ b/comms/middleware/tower-filter/CHANGELOG.md @@ -0,0 +1,12 @@ +# 0.3.0-alpha.2 (September 30, 2019) + +- Move to `futures-*-preview 0.3.0-alpha.19` +- Move to `pin-project 0.4` + +# 0.3.0-alpha.1 + +- Move to `std::future` + +# 0.1.0 (unreleased) + +- Initial release diff --git a/comms/middleware/tower-filter/Cargo.toml b/comms/middleware/tower-filter/Cargo.toml new file mode 100644 index 0000000000..0a7204f44a --- /dev/null +++ b/comms/middleware/tower-filter/Cargo.toml @@ -0,0 +1,35 @@ +[package] +name = "tower-filter" +# When releasing to crates.io: +# - Remove path dependencies +# - Update html_root_url. +# - Update doc url +# - Cargo.toml +# - README.md +# - Update CHANGELOG.md. +# - Create "v0.1.x" git tag. +version = "0.3.0-alpha.2" +authors = ["Tower Maintainers "] +license = "MIT" +readme = "README.md" +repository = "https://github.com/tower-rs/tower" +homepage = "https://github.com/tower-rs/tower" +documentation = "https://docs.rs/tower-filter/0.3.0-alpha.2" +description = """ +Conditionally allow requests to be dispatched to a service based on the result +of a predicate. +""" +categories = ["asynchronous", "network-programming"] +edition = "2018" +publish = false + +[dependencies] +tower= { version = "=0.3.0-alpha.2"} +pin-project = "0.4" +futures-core-preview = "=0.3.0-alpha.19" + +[dev-dependencies] +tower-test = { version = "=0.3.0-alpha.2" } +tokio-test = "=0.2.0-alpha.6" +tokio = "=0.2.0-alpha.6" +futures-util-preview = "=0.3.0-alpha.19" diff --git a/comms/middleware/tower-filter/LICENSE b/comms/middleware/tower-filter/LICENSE new file mode 100644 index 0000000000..b980cacc77 --- /dev/null +++ b/comms/middleware/tower-filter/LICENSE @@ -0,0 +1,25 @@ +Copyright (c) 2019 Tower Contributors + +Permission is hereby granted, free of charge, to any +person obtaining a copy of this software and associated +documentation files (the "Software"), to deal in the +Software without restriction, including without +limitation the rights to use, copy, modify, merge, +publish, distribute, sublicense, and/or sell copies of +the Software, and to permit persons to whom the Software +is furnished to do so, subject to the following +conditions: + +The above copyright notice and this permission notice +shall be included in all copies or substantial portions +of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF +ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED +TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A +PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT +SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY +CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION +OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR +IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. diff --git a/comms/middleware/tower-filter/README.md b/comms/middleware/tower-filter/README.md new file mode 100644 index 0000000000..c862a192e2 --- /dev/null +++ b/comms/middleware/tower-filter/README.md @@ -0,0 +1,14 @@ +# Tower Filter + +Conditionally allow requests to be dispatched to a service based on the result +of a predicate. + +## License + +This project is licensed under the [MIT license](LICENSE). + +### Contribution + +Unless you explicitly state otherwise, any contribution intentionally submitted +for inclusion in Tower by you, shall be licensed as MIT, without any additional +terms or conditions. diff --git a/comms/middleware/tower-filter/src/error.rs b/comms/middleware/tower-filter/src/error.rs new file mode 100644 index 0000000000..b2643f77f4 --- /dev/null +++ b/comms/middleware/tower-filter/src/error.rs @@ -0,0 +1,46 @@ +//! Error types + +use std::{error, fmt}; + +/// Error produced by `Filter` +#[derive(Debug)] +pub struct Error { + source: Option, +} + +pub(crate) type Source = Box; + +impl Error { + /// Create a new `Error` representing a rejected request. + pub fn rejected() -> Error { + Error { source: None } + } + + /// Create a new `Error` representing an inner service error. + pub fn inner(source: E) -> Error + where E: Into { + Error { + source: Some(source.into()), + } + } +} + +impl fmt::Display for Error { + fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { + if self.source.is_some() { + write!(fmt, "inner service errored") + } else { + write!(fmt, "rejected") + } + } +} + +impl error::Error for Error { + fn source(&self) -> Option<&(dyn error::Error + 'static)> { + if let Some(ref err) = self.source { + Some(&**err) + } else { + None + } + } +} diff --git a/comms/middleware/tower-filter/src/future.rs b/comms/middleware/tower-filter/src/future.rs new file mode 100644 index 0000000000..267ae3500c --- /dev/null +++ b/comms/middleware/tower-filter/src/future.rs @@ -0,0 +1,91 @@ +//! Future types + +use crate::error::{self, Error}; +use futures_core::ready; +use pin_project::{pin_project, project}; +use std::{ + future::Future, + pin::Pin, + task::{Context, Poll}, +}; +use tower::Service; + +/// Filtered response future +#[pin_project] +#[derive(Debug)] +pub struct ResponseFuture +where S: Service +{ + #[pin] + /// Response future state + state: State, + + #[pin] + /// Predicate future + check: T, + + /// Inner service + service: S, +} + +#[pin_project] +#[derive(Debug)] +enum State { + Check(Option), + WaitResponse(#[pin] U), +} + +impl ResponseFuture +where + F: Future>, + S: Service, + S::Error: Into, +{ + pub(crate) fn new(request: Request, check: F, service: S) -> Self { + ResponseFuture { + state: State::Check(Some(request)), + check, + service, + } + } +} + +impl Future for ResponseFuture +where + F: Future>, + S: Service, + S::Error: Into, +{ + type Output = Result; + + #[project] + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let mut this = self.project(); + + loop { + #[project] + match this.state.as_mut().project() { + State::Check(request) => { + let request = request + .take() + .expect("we either give it back or leave State::Check once we take"); + + // Poll predicate + match this.check.as_mut().poll(cx)? { + Poll::Ready(_) => { + let response = this.service.call(request); + this.state.set(State::WaitResponse(response)); + }, + Poll::Pending => { + this.state.set(State::Check(Some(request))); + return Poll::Pending; + }, + } + }, + State::WaitResponse(response) => { + return Poll::Ready(ready!(response.poll(cx)).map_err(Error::inner)); + }, + } + } + } +} diff --git a/comms/middleware/tower-filter/src/layer.rs b/comms/middleware/tower-filter/src/layer.rs new file mode 100644 index 0000000000..af6d6a0879 --- /dev/null +++ b/comms/middleware/tower-filter/src/layer.rs @@ -0,0 +1,24 @@ +use crate::Filter; +use tower::layer::Layer; + +/// Conditionally dispatch requests to the inner service based on a predicate. +#[derive(Debug)] +pub struct FilterLayer { + predicate: U, +} + +impl FilterLayer { + #[allow(missing_docs)] + pub fn new(predicate: U) -> Self { + FilterLayer { predicate } + } +} + +impl Layer for FilterLayer { + type Service = Filter; + + fn layer(&self, service: S) -> Self::Service { + let predicate = self.predicate.clone(); + Filter::new(service, predicate) + } +} diff --git a/comms/middleware/tower-filter/src/lib.rs b/comms/middleware/tower-filter/src/lib.rs new file mode 100644 index 0000000000..ba7c07567a --- /dev/null +++ b/comms/middleware/tower-filter/src/lib.rs @@ -0,0 +1,59 @@ +#![doc(html_root_url = "https://docs.rs/tower-filter/0.3.0-alpha.2")] +#![warn(missing_debug_implementations, missing_docs, rust_2018_idioms, unreachable_pub)] +#![allow(elided_lifetimes_in_paths)] + +//! Conditionally dispatch requests to the inner service based on the result of +//! a predicate. + +pub mod error; +pub mod future; +mod layer; +mod predicate; + +pub use crate::{layer::FilterLayer, predicate::Predicate}; + +use crate::{error::Error, future::ResponseFuture}; +use futures_core::ready; +use std::task::{Context, Poll}; +use tower::Service; + +/// Conditionally dispatch requests to the inner service based on a predicate. +#[derive(Clone, Debug)] +pub struct Filter { + inner: T, + predicate: U, +} + +impl Filter { + #[allow(missing_docs)] + pub fn new(inner: T, predicate: U) -> Self { + Filter { inner, predicate } + } +} + +impl Service for Filter +where + T: Service + Clone, + T::Error: Into, + U: Predicate, +{ + type Error = Error; + type Future = ResponseFuture; + type Response = T::Response; + + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + Poll::Ready(ready!(self.inner.poll_ready(cx)).map_err(error::Error::inner)) + } + + fn call(&mut self, request: Request) -> Self::Future { + use std::mem; + + let inner = self.inner.clone(); + let inner = mem::replace(&mut self.inner, inner); + + // Check the request + let check = self.predicate.check(&request); + + ResponseFuture::new(request, check, inner) + } +} diff --git a/comms/middleware/tower-filter/src/predicate.rs b/comms/middleware/tower-filter/src/predicate.rs new file mode 100644 index 0000000000..34d2e585f7 --- /dev/null +++ b/comms/middleware/tower-filter/src/predicate.rs @@ -0,0 +1,25 @@ +use crate::error::Error; +use std::future::Future; + +/// Checks a request +pub trait Predicate { + /// The future returned by `check`. + type Future: Future>; + + /// Check whether the given request should be forwarded. + /// + /// If the future resolves with `Ok`, the request is forwarded to the inner service. + fn check(&mut self, request: &Request) -> Self::Future; +} + +impl Predicate for F +where + F: Fn(&T) -> U, + U: Future>, +{ + type Future = U; + + fn check(&mut self, request: &T) -> Self::Future { + self(request) + } +} diff --git a/comms/middleware/tower-filter/tests/filter.rs b/comms/middleware/tower-filter/tests/filter.rs new file mode 100644 index 0000000000..ff869b894e --- /dev/null +++ b/comms/middleware/tower-filter/tests/filter.rs @@ -0,0 +1,61 @@ +use futures_util::{future::poll_fn, pin_mut}; +use std::{future::Future, thread}; +use tokio_test::{assert_ready, assert_ready_err, task}; +use tower::Service; +use tower_filter::{error::Error, Filter}; +use tower_test::{assert_request_eq, mock}; + +#[tokio::test] +async fn passthrough_sync() { + let (mut service, handle) = new_service(|_| async { Ok(()) }); + + let th = thread::spawn(move || { + // Receive the requests and respond + pin_mut!(handle); + for i in 0..10 { + assert_request_eq!(handle, format!("ping-{}", i)).send_response(format!("pong-{}", i)); + } + }); + + let mut responses = vec![]; + + for i in 0usize..10 { + let request = format!("ping-{}", i); + poll_fn(|cx| service.poll_ready(cx)).await.unwrap(); + let exchange = service.call(request); + let exchange = async move { + let response = exchange.await.unwrap(); + let expect = format!("pong-{}", i); + assert_eq!(response.as_str(), expect.as_str()); + }; + + responses.push(exchange); + } + + futures_util::future::join_all(responses).await; + th.join().unwrap(); +} + +#[test] +fn rejected_sync() { + task::mock(|cx| { + let (mut service, _handle) = new_service(|_| async { Err(Error::rejected()) }); + + let fut = service.call("hello".into()); + pin_mut!(fut); + assert_ready_err!(fut.poll(cx)); + }); +} + +type Mock = mock::Mock; +type Handle = mock::Handle; + +fn new_service(f: F) -> (Filter, Handle) +where + F: Fn(&String) -> U, + U: Future>, +{ + let (service, handle) = mock::pair(); + let service = Filter::new(service, f); + (service, handle) +}