From ae88b902e1398ecdcb5285bfb11e779b77b8a683 Mon Sep 17 00:00:00 2001 From: jstuczyn Date: Tue, 14 Apr 2020 10:18:34 +0100 Subject: [PATCH 01/23] Moved auth_token to seperate file --- .../sfw-provider-requests/src/auth_token.rs | 111 +++++++++++++++++ sfw-provider/sfw-provider-requests/src/lib.rs | 112 +----------------- 2 files changed, 112 insertions(+), 111 deletions(-) create mode 100644 sfw-provider/sfw-provider-requests/src/auth_token.rs diff --git a/sfw-provider/sfw-provider-requests/src/auth_token.rs b/sfw-provider/sfw-provider-requests/src/auth_token.rs new file mode 100644 index 00000000000..3fb714bb7e9 --- /dev/null +++ b/sfw-provider/sfw-provider-requests/src/auth_token.rs @@ -0,0 +1,111 @@ +pub const AUTH_TOKEN_SIZE: usize = 32; + +#[derive(Debug, PartialEq, Eq, Hash, Clone, Copy)] +pub struct AuthToken([u8; AUTH_TOKEN_SIZE]); + +#[derive(Debug)] +pub enum AuthTokenConversionError { + InvalidStringError, + StringOfInvalidLengthError, +} + +impl AuthToken { + pub fn from_bytes(bytes: [u8; AUTH_TOKEN_SIZE]) -> Self { + AuthToken(bytes) + } + + pub fn to_bytes(&self) -> [u8; AUTH_TOKEN_SIZE] { + self.0 + } + + pub fn as_bytes(&self) -> &[u8] { + &self.0 + } + + pub fn try_from_base58_string>( + val: S, + ) -> Result { + let decoded = match bs58::decode(val.into()).into_vec() { + Ok(decoded) => decoded, + Err(_) => return Err(AuthTokenConversionError::InvalidStringError), + }; + + if decoded.len() != AUTH_TOKEN_SIZE { + return Err(AuthTokenConversionError::StringOfInvalidLengthError); + } + + let mut token = [0u8; AUTH_TOKEN_SIZE]; + token.copy_from_slice(&decoded[..]); + Ok(AuthToken(token)) + } + + pub fn to_base58_string(&self) -> String { + bs58::encode(self.0).into_string() + } +} + +impl Into for AuthToken { + fn into(self) -> String { + self.to_base58_string() + } +} + +#[cfg(test)] +mod auth_token_conversion { + use super::*; + + #[test] + fn it_is_possible_to_recover_it_from_valid_b58_string() { + let auth_token = AuthToken([42; AUTH_TOKEN_SIZE]); + let auth_token_string = auth_token.to_base58_string(); + assert_eq!( + auth_token, + AuthToken::try_from_base58_string(auth_token_string).unwrap() + ) + } + + #[test] + fn it_is_possible_to_recover_it_from_valid_b58_str_ref() { + let auth_token = AuthToken([42; AUTH_TOKEN_SIZE]); + let auth_token_string = auth_token.to_base58_string(); + let auth_token_str_ref: &str = &auth_token_string; + assert_eq!( + auth_token, + AuthToken::try_from_base58_string(auth_token_str_ref).unwrap() + ) + } + + #[test] + fn it_returns_error_on_b58_with_invalid_characters() { + let auth_token = AuthToken([42; AUTH_TOKEN_SIZE]); + let auth_token_string = auth_token.to_base58_string(); + + let mut chars = auth_token_string.chars(); + let _consumed_first_char = chars.next().unwrap(); + + let invalid_chars_token = "=".to_string() + chars.as_str(); + assert!(AuthToken::try_from_base58_string(invalid_chars_token).is_err()) + } + + #[test] + fn it_returns_error_on_too_long_b58_string() { + let auth_token = AuthToken([42; AUTH_TOKEN_SIZE]); + let mut auth_token_string = auth_token.to_base58_string(); + auth_token_string.push('f'); + + assert!(AuthToken::try_from_base58_string(auth_token_string).is_err()) + } + + #[test] + fn it_returns_error_on_too_short_b58_string() { + let auth_token = AuthToken([42; AUTH_TOKEN_SIZE]); + let auth_token_string = auth_token.to_base58_string(); + + let mut chars = auth_token_string.chars(); + let _consumed_first_char = chars.next().unwrap(); + let _consumed_second_char = chars.next().unwrap(); + let invalid_chars_token = chars.as_str(); + + assert!(AuthToken::try_from_base58_string(invalid_chars_token).is_err()) + } +} diff --git a/sfw-provider/sfw-provider-requests/src/lib.rs b/sfw-provider/sfw-provider-requests/src/lib.rs index 7a4656e4172..e6d4883c0a4 100644 --- a/sfw-provider/sfw-provider-requests/src/lib.rs +++ b/sfw-provider/sfw-provider-requests/src/lib.rs @@ -1,116 +1,6 @@ +pub mod auth_token; pub mod requests; pub mod responses; pub const DUMMY_MESSAGE_CONTENT: &[u8] = b"[DUMMY MESSAGE] Wanting something does not give you the right to have it."; - -// To be renamed to 'AuthToken' once it is safe to replace it -#[derive(Debug, PartialEq, Eq, Hash, Clone)] -pub struct AuthToken([u8; 32]); - -#[derive(Debug)] -pub enum AuthTokenConversionError { - InvalidStringError, - StringOfInvalidLengthError, -} - -impl AuthToken { - pub fn from_bytes(bytes: [u8; 32]) -> Self { - AuthToken(bytes) - } - - pub fn to_bytes(&self) -> [u8; 32] { - self.0 - } - - pub fn as_bytes(&self) -> &[u8] { - &self.0 - } - - pub fn try_from_base58_string>( - val: S, - ) -> Result { - let decoded = match bs58::decode(val.into()).into_vec() { - Ok(decoded) => decoded, - Err(_) => return Err(AuthTokenConversionError::InvalidStringError), - }; - - if decoded.len() != 32 { - return Err(AuthTokenConversionError::StringOfInvalidLengthError); - } - - let mut token = [0u8; 32]; - token.copy_from_slice(&decoded[..]); - Ok(AuthToken(token)) - } - - pub fn to_base58_string(&self) -> String { - bs58::encode(self.0).into_string() - } -} - -impl Into for AuthToken { - fn into(self) -> String { - self.to_base58_string() - } -} - -#[cfg(test)] -mod auth_token_conversion { - use super::*; - - #[test] - fn it_is_possible_to_recover_it_from_valid_b58_string() { - let auth_token = AuthToken([42; 32]); - let auth_token_string = auth_token.to_base58_string(); - assert_eq!( - auth_token, - AuthToken::try_from_base58_string(auth_token_string).unwrap() - ) - } - - #[test] - fn it_is_possible_to_recover_it_from_valid_b58_str_ref() { - let auth_token = AuthToken([42; 32]); - let auth_token_string = auth_token.to_base58_string(); - let auth_token_str_ref: &str = &auth_token_string; - assert_eq!( - auth_token, - AuthToken::try_from_base58_string(auth_token_str_ref).unwrap() - ) - } - - #[test] - fn it_returns_error_on_b58_with_invalid_characters() { - let auth_token = AuthToken([42; 32]); - let auth_token_string = auth_token.to_base58_string(); - - let mut chars = auth_token_string.chars(); - let _consumed_first_char = chars.next().unwrap(); - - let invalid_chars_token = "=".to_string() + chars.as_str(); - assert!(AuthToken::try_from_base58_string(invalid_chars_token).is_err()) - } - - #[test] - fn it_returns_error_on_too_long_b58_string() { - let auth_token = AuthToken([42; 32]); - let mut auth_token_string = auth_token.to_base58_string(); - auth_token_string.push('f'); - - assert!(AuthToken::try_from_base58_string(auth_token_string).is_err()) - } - - #[test] - fn it_returns_error_on_too_short_b58_string() { - let auth_token = AuthToken([42; 32]); - let auth_token_string = auth_token.to_base58_string(); - - let mut chars = auth_token_string.chars(); - let _consumed_first_char = chars.next().unwrap(); - let _consumed_second_char = chars.next().unwrap(); - let invalid_chars_token = chars.as_str(); - - assert!(AuthToken::try_from_base58_string(invalid_chars_token).is_err()) - } -} From 2268c014165285517d447989eb9039fdde0fc738 Mon Sep 17 00:00:00 2001 From: jstuczyn Date: Tue, 14 Apr 2020 10:19:29 +0100 Subject: [PATCH 02/23] Extracted check_id as separate type --- common/healthcheck/src/path_check.rs | 22 ++++++++++++++-------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/common/healthcheck/src/path_check.rs b/common/healthcheck/src/path_check.rs index fb68953ee24..45f4ffb19c2 100644 --- a/common/healthcheck/src/path_check.rs +++ b/common/healthcheck/src/path_check.rs @@ -11,6 +11,8 @@ use std::net::SocketAddr; use std::time::Duration; use topology::provider; +pub(crate) type CheckId = [u8; 16]; + #[derive(Debug, PartialEq, Clone)] pub enum PathStatus { Healthy, @@ -23,7 +25,7 @@ pub(crate) struct PathChecker { mixnet_client: multi_tcp_client::Client, paths_status: HashMap, PathStatus>, our_destination: Destination, - check_id: [u8; 16], + check_id: CheckId, } impl PathChecker { @@ -31,7 +33,7 @@ impl PathChecker { providers: Vec, identity_keys: &MixIdentityKeyPair, connection_timeout: Duration, - check_id: [u8; 16], + check_id: CheckId, ) -> Self { let mut provider_clients = HashMap::new(); @@ -40,6 +42,9 @@ impl PathChecker { for provider in providers { let mut provider_client = ProviderClient::new(provider.client_listener, address.clone(), None); + // TODO: we might be sending unnecessary register requests since after first healthcheck, + // we are registered for any subsequent ones (since our address did not change) + let insertion_result = match provider_client.register().await { Ok(token) => { debug!("[Healthcheck] registered at provider {}", provider.pub_key); @@ -82,7 +87,7 @@ impl PathChecker { // iteration is used to distinguish packets sent through the same path (as the healthcheck // may try to send say 10 packets through given path) - fn unique_path_key(path: &[SphinxNode], check_id: [u8; 16], iteration: u8) -> Vec { + fn unique_path_key(path: &[SphinxNode], check_id: CheckId, iteration: u8) -> Vec { check_id .iter() .cloned() @@ -133,8 +138,8 @@ impl PathChecker { // pull messages from given provider until there are no more 'real' messages async fn resolve_pending_provider_checks( - &self, - provider_client: &ProviderClient, + provider_client: &mut ProviderClient, + check_id: CheckId, ) -> Vec> { // keep getting messages until we encounter the dummy message let mut provider_messages = Vec::new(); @@ -151,7 +156,7 @@ impl PathChecker { if msg == sfw_provider_requests::DUMMY_MESSAGE_CONTENT { // finish iterating the loop as the messages might not be ordered should_stop = true; - } else if msg[..16] != self.check_id { + } else if msg[..16] != check_id { warn!("received response from previous healthcheck") } else { provider_messages.push(msg); @@ -169,14 +174,15 @@ impl PathChecker { pub(crate) async fn resolve_pending_checks(&mut self) { // not sure how to nicely put it into an iterator due to it being async calls let mut provider_messages = Vec::new(); - for provider_client in self.provider_clients.values() { + for provider_client in self.provider_clients.values_mut() { // if it was none all associated paths were already marked as unhealthy let pc = match provider_client { Some(pc) => pc, None => continue, }; - provider_messages.extend(self.resolve_pending_provider_checks(pc).await); + provider_messages + .extend(Self::resolve_pending_provider_checks(pc, self.check_id).await); } self.update_path_statuses(provider_messages); From 7a9238fd7d66d07fac6310e8979bc7153321fe1e Mon Sep 17 00:00:00 2001 From: jstuczyn Date: Tue, 14 Apr 2020 10:20:31 +0100 Subject: [PATCH 03/23] Changes due to move of auth_token and making provider client mutable --- nym-client/src/client/mod.rs | 2 +- nym-client/src/client/provider_poller.rs | 6 +++--- nym-client/src/commands/init.rs | 4 ++-- sfw-provider/src/provider/client_handling/ledger.rs | 2 +- 4 files changed, 7 insertions(+), 7 deletions(-) diff --git a/nym-client/src/client/mod.rs b/nym-client/src/client/mod.rs index 4a6af32df84..1fae01883b8 100644 --- a/nym-client/src/client/mod.rs +++ b/nym-client/src/client/mod.rs @@ -16,7 +16,7 @@ use futures::channel::{mpsc, oneshot}; use log::*; use nymsphinx::chunking::split_and_prepare_payloads; use pemstore::pemstore::PemStore; -use sfw_provider_requests::AuthToken; +use sfw_provider_requests::auth_token::AuthToken; use sphinx::route::Destination; use tokio::runtime::Runtime; use topology::NymTopology; diff --git a/nym-client/src/client/provider_poller.rs b/nym-client/src/client/provider_poller.rs index 57f7cb55ca8..9af6cc6a455 100644 --- a/nym-client/src/client/provider_poller.rs +++ b/nym-client/src/client/provider_poller.rs @@ -1,7 +1,7 @@ use futures::channel::mpsc; use log::*; use provider_client::ProviderClientError; -use sfw_provider_requests::AuthToken; +use sfw_provider_requests::auth_token::AuthToken; use sphinx::route::DestinationAddressBytes; use std::net::SocketAddr; use std::time; @@ -60,7 +60,7 @@ impl ProviderPoller { Ok(()) } - pub(crate) async fn start_provider_polling(self) { + pub(crate) async fn start_provider_polling(&mut self) { let loop_message = &mix_client::packet::LOOP_COVER_MESSAGE_PAYLOAD.to_vec(); let dummy_message = &sfw_provider_requests::DUMMY_MESSAGE_CONTENT.to_vec(); @@ -100,7 +100,7 @@ impl ProviderPoller { } } - pub(crate) fn start(self, handle: &Handle) -> JoinHandle<()> { + pub(crate) fn start(mut self, handle: &Handle) -> JoinHandle<()> { handle.spawn(async move { self.start_provider_polling().await }) } } diff --git a/nym-client/src/commands/init.rs b/nym-client/src/commands/init.rs index e6285c3999b..71b632b8f04 100644 --- a/nym-client/src/commands/init.rs +++ b/nym-client/src/commands/init.rs @@ -6,7 +6,7 @@ use config::NymConfig; use crypto::identity::MixIdentityKeyPair; use directory_client::presence::Topology; use pemstore::pemstore::PemStore; -use sfw_provider_requests::AuthToken; +use sfw_provider_requests::auth_token::AuthToken; use sphinx::route::DestinationAddressBytes; use topology::provider::Node; use topology::NymTopology; @@ -49,7 +49,7 @@ async fn try_provider_registrations( ) -> Option<(String, AuthToken)> { // since the order of providers is non-deterministic we can just try to get a first 'working' provider for provider in providers { - let provider_client = provider_client::ProviderClient::new( + let mut provider_client = provider_client::ProviderClient::new( provider.client_listener, our_address.clone(), None, diff --git a/sfw-provider/src/provider/client_handling/ledger.rs b/sfw-provider/src/provider/client_handling/ledger.rs index a5dda44f365..90f5efe59d6 100644 --- a/sfw-provider/src/provider/client_handling/ledger.rs +++ b/sfw-provider/src/provider/client_handling/ledger.rs @@ -1,6 +1,6 @@ use directory_client::presence::providers::MixProviderClient; use futures::lock::Mutex; -use sfw_provider_requests::AuthToken; +use sfw_provider_requests::auth_token::AuthToken; use sphinx::route::DestinationAddressBytes; use std::collections::HashMap; use std::io; From d7d8be234bc944f3cbb981b7c120cf78f274627e Mon Sep 17 00:00:00 2001 From: jstuczyn Date: Tue, 14 Apr 2020 10:41:13 +0100 Subject: [PATCH 04/23] New way of serialization provider requests/responses --- sfw-provider/sfw-provider-requests/Cargo.toml | 4 +- .../sfw-provider-requests/src/requests.rs | 338 +++++++++++------- .../sfw-provider-requests/src/responses.rs | 313 +++++++++++++--- 3 files changed, 474 insertions(+), 181 deletions(-) diff --git a/sfw-provider/sfw-provider-requests/Cargo.toml b/sfw-provider/sfw-provider-requests/Cargo.toml index 4909c709a2b..dd5b21f4566 100644 --- a/sfw-provider/sfw-provider-requests/Cargo.toml +++ b/sfw-provider/sfw-provider-requests/Cargo.toml @@ -7,5 +7,7 @@ edition = "2018" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] -bs58 = "0.3.0" +bs58 = "0.3" +bytes = "0.5" +tokio = "0.2" sphinx = { git = "https://github.com/nymtech/sphinx", rev="44d8f2aece5049eaa4fe84b7948758ce82b4b80d" } diff --git a/sfw-provider/sfw-provider-requests/src/requests.rs b/sfw-provider/sfw-provider-requests/src/requests.rs index 06ab0a87414..1262677ad8d 100644 --- a/sfw-provider/sfw-provider-requests/src/requests.rs +++ b/sfw-provider/sfw-provider-requests/src/requests.rs @@ -1,59 +1,204 @@ -use crate::AuthToken; +use crate::auth_token::{AuthToken, AUTH_TOKEN_SIZE}; +use sphinx::constants::DESTINATION_ADDRESS_LENGTH; use sphinx::route::DestinationAddressBytes; +use std::convert::TryFrom; +use std::io; +use std::io::Error; +use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; -const PULL_REQUEST_MESSAGE_PREFIX: [u8; 2] = [1, 0]; -const REGISTER_MESSAGE_PREFIX: [u8; 2] = [0, 1]; - -// TODO: how to do it more nicely, considering all sfw-provider-requests implement same trait that is exercised here? #[derive(Debug)] -pub enum ProviderRequests { - PullMessages(PullRequest), +pub enum ProviderRequestError { + MarshalError, + UnmarshalError, + UnmarshalErrorInvalidKind, + UnmarshalErrorInvalidLength, + TooLongRequestError, + TooShortRequestError, + IOError(io::Error), + RemoteConnectionClosed, +} + +impl From for ProviderRequestError { + fn from(e: Error) -> Self { + ProviderRequestError::IOError(e) + } +} + +impl<'a, R: AsyncRead + Unpin> Drop for TokioAsyncRequestReader<'a, R> { + fn drop(&mut self) { + println!("request reader drop"); + } +} + +impl<'a, R: AsyncWrite + Unpin> Drop for TokioAsyncRequestWriter<'a, R> { + fn drop(&mut self) { + println!("request writer drop"); + } +} + +#[derive(Debug, Clone, Copy, PartialEq)] +#[repr(u8)] +pub enum RequestKind { + Pull = 1, + Register = 2, +} + +impl TryFrom for RequestKind { + type Error = ProviderRequestError; + + fn try_from(value: u8) -> Result { + match value { + _ if value == (RequestKind::Pull as u8) => Ok(Self::Pull), + _ if value == (RequestKind::Register as u8) => Ok(Self::Register), + _ => Err(Self::Error::UnmarshalErrorInvalidKind), + } + } +} + +#[derive(Debug, Clone, PartialEq)] +pub enum ProviderRequest { + Pull(PullRequest), Register(RegisterRequest), } -impl ProviderRequests { - pub fn to_bytes(&self) -> Vec { - use ProviderRequests::*; - match self { - PullMessages(pr) => pr.to_bytes(), - Register(pr) => pr.to_bytes(), +// Ideally I would have used futures::AsyncRead for even more generic approach, but unfortunately +// tokio::io::AsyncRead differs from futures::AsyncRead +pub struct TokioAsyncRequestReader<'a, R: AsyncRead + Unpin> { + max_allowed_len: usize, + reader: &'a mut R, +} + +impl<'a, R: AsyncRead + Unpin> TokioAsyncRequestReader<'a, R> { + pub fn new(reader: &'a mut R, max_allowed_len: usize) -> Self { + TokioAsyncRequestReader { + reader, + max_allowed_len, } } - pub fn from_bytes(bytes: &[u8]) -> Result { - use ProviderRequests::*; - if bytes.len() < 2 { - return Err(ProviderRequestError::UnmarshalError); + pub async fn try_read_request(&mut self) -> Result { + let req_len = self.reader.read_u32().await?; + if req_len == 0 { + return Err(ProviderRequestError::RemoteConnectionClosed); } - let mut received_prefix = [0; 2]; - received_prefix.copy_from_slice(&bytes[..2]); - match received_prefix { - PULL_REQUEST_MESSAGE_PREFIX => Ok(PullMessages(PullRequest::from_bytes(bytes)?)), - REGISTER_MESSAGE_PREFIX => Ok(Register(RegisterRequest::from_bytes(bytes)?)), - _ => Err(ProviderRequestError::UnmarshalErrorIncorrectPrefix), + if req_len as usize > self.max_allowed_len { + // TODO: should reader be drained? + return Err(ProviderRequestError::TooLongRequestError); } + + let mut req_buf = Vec::with_capacity(req_len as usize); + let mut chunk = self.reader.take(req_len as u64); + + if let Err(_) = chunk.read_to_end(&mut req_buf).await { + return Err(ProviderRequestError::TooShortRequestError); + }; + + let parse_res = RequestDeserializer::new_with_len(req_len, &req_buf)?.try_to_parse(); + + parse_res } } -#[derive(Debug)] -pub enum ProviderRequestError { - MarshalError, - UnmarshalError, - UnmarshalErrorIncorrectPrefix, +pub struct RequestDeserializer<'a> { + kind: RequestKind, + data: &'a [u8], } -pub trait ProviderRequest -where - Self: Sized, -{ - fn get_prefix() -> [u8; 2]; - fn to_bytes(&self) -> Vec; - fn from_bytes(bytes: &[u8]) -> Result; +impl<'a> RequestDeserializer<'a> { + // perform initial parsing + pub fn new(raw_bytes: &'a [u8]) -> Result { + if raw_bytes.len() < 1 + 4 { + Err(ProviderRequestError::UnmarshalErrorInvalidLength) + } else { + let data_len = + u32::from_be_bytes([raw_bytes[0], raw_bytes[1], raw_bytes[2], raw_bytes[3]]); + let kind = RequestKind::try_from(raw_bytes[4])?; + let data = &raw_bytes[4..]; + + if data.len() != data_len as usize { + Err(ProviderRequestError::UnmarshalErrorInvalidLength) + } else { + Ok(RequestDeserializer { kind, data }) + } + } + } + + pub fn new_with_len(len: u32, raw_bytes: &'a [u8]) -> Result { + if raw_bytes.len() != len as usize { + Err(ProviderRequestError::UnmarshalErrorInvalidLength) + } else { + let kind = RequestKind::try_from(raw_bytes[0])?; + let data = &raw_bytes[1..]; + Ok(RequestDeserializer { kind, data }) + } + } + + pub fn get_kind(&self) -> RequestKind { + self.kind + } + + pub fn get_data(&self) -> &'a [u8] { + self.data + } + + pub fn try_to_parse(self) -> Result { + match self.get_kind() { + RequestKind::Pull => Ok(ProviderRequest::Pull(PullRequest::try_from_bytes( + self.data, + )?)), + RequestKind::Register => Ok(ProviderRequest::Register( + RegisterRequest::try_from_bytes(self.data)?, + )), + } + } } -#[derive(Debug)] +// Ideally I would have used futures::AsyncWrite for even more generic approach, but unfortunately +// tokio::io::AsyncWrite differs from futures::AsyncWrite +pub struct TokioAsyncRequestWriter<'a, W: AsyncWrite + Unpin> { + writer: &'a mut W, +} + +impl<'a, W: AsyncWrite + Unpin> TokioAsyncRequestWriter<'a, W> { + pub fn new(writer: &'a mut W) -> Self { + TokioAsyncRequestWriter { writer } + } + + pub async fn try_write_request(&mut self, res: ProviderRequest) -> io::Result<()> { + let res_bytes = RequestSerializer::new(res).into_bytes(); + self.writer.write_all(&res_bytes).await + } +} + +pub struct RequestSerializer { + req: ProviderRequest, +} + +impl RequestSerializer { + pub fn new(req: ProviderRequest) -> Self { + RequestSerializer { req } + } + + /// Serialized requests in general have the following structure: + /// follows: 4 byte len (be u32) || 1-byte kind prefix || request-specific data + pub fn into_bytes(self) -> Vec { + let (kind, req_bytes) = match self.req { + ProviderRequest::Pull(req) => (req.get_kind(), req.to_bytes()), + ProviderRequest::Register(req) => (req.get_kind(), req.to_bytes()), + }; + let req_len = req_bytes.len() as u32 + 1; // 1 is to accommodate for 'kind' + let req_len_bytes = req_len.to_be_bytes(); + req_len_bytes + .iter() + .cloned() + .chain(std::iter::once(kind as u8)) + .chain(req_bytes.into_iter()) + .collect() + } +} + +#[derive(Debug, Clone, PartialEq)] pub struct PullRequest { - // TODO: public keys, signatures, tokens, etc. basically some kind of authentication bs pub auth_token: AuthToken, pub destination_address: sphinx::route::DestinationAddressBytes, } @@ -68,38 +213,30 @@ impl PullRequest { destination_address, } } -} -impl ProviderRequest for PullRequest { - fn get_prefix() -> [u8; 2] { - PULL_REQUEST_MESSAGE_PREFIX + pub fn get_kind(&self) -> RequestKind { + RequestKind::Pull } fn to_bytes(&self) -> Vec { - Self::get_prefix() - .to_vec() - .into_iter() - .chain(self.destination_address.to_bytes().iter().cloned()) - .chain(self.auth_token.0.iter().cloned()) + self.destination_address + .to_bytes() + .iter() + .cloned() + .chain(self.auth_token.as_bytes().iter().cloned()) .collect() } - fn from_bytes(bytes: &[u8]) -> Result { - if bytes.len() != 2 + 32 + 32 { - return Err(ProviderRequestError::UnmarshalError); - } - - let mut received_prefix = [0u8; 2]; - received_prefix.copy_from_slice(&bytes[..2]); - if received_prefix != Self::get_prefix() { - return Err(ProviderRequestError::UnmarshalErrorIncorrectPrefix); + fn try_from_bytes(bytes: &[u8]) -> Result { + if bytes.len() != DESTINATION_ADDRESS_LENGTH + AUTH_TOKEN_SIZE { + return Err(ProviderRequestError::UnmarshalErrorInvalidLength); } - let mut destination_address = [0u8; 32]; - destination_address.copy_from_slice(&bytes[2..34]); + let mut destination_address = [0u8; DESTINATION_ADDRESS_LENGTH]; + destination_address.copy_from_slice(&bytes[..DESTINATION_ADDRESS_LENGTH]); - let mut auth_token = [0u8; 32]; - auth_token.copy_from_slice(&bytes[34..]); + let mut auth_token = [0u8; AUTH_TOKEN_SIZE]; + auth_token.copy_from_slice(&bytes[DESTINATION_ADDRESS_LENGTH..]); Ok(PullRequest { auth_token: AuthToken::from_bytes(auth_token), @@ -108,7 +245,7 @@ impl ProviderRequest for PullRequest { } } -#[derive(Debug)] +#[derive(Debug, Clone, PartialEq)] pub struct RegisterRequest { pub destination_address: DestinationAddressBytes, } @@ -119,34 +256,26 @@ impl RegisterRequest { destination_address, } } -} -impl ProviderRequest for RegisterRequest { - fn get_prefix() -> [u8; 2] { - REGISTER_MESSAGE_PREFIX + pub fn get_kind(&self) -> RequestKind { + RequestKind::Register } fn to_bytes(&self) -> Vec { - Self::get_prefix() - .to_vec() - .into_iter() - .chain(self.destination_address.to_bytes().iter().cloned()) + self.destination_address + .to_bytes() + .iter() + .cloned() .collect() } - fn from_bytes(bytes: &[u8]) -> Result { - if bytes.len() != 2 + 32 { - return Err(ProviderRequestError::UnmarshalError); + fn try_from_bytes(bytes: &[u8]) -> Result { + if bytes.len() != DESTINATION_ADDRESS_LENGTH { + return Err(ProviderRequestError::UnmarshalErrorInvalidLength); } - let mut received_prefix = [0u8; 2]; - received_prefix.copy_from_slice(&bytes[..2]); - if received_prefix != Self::get_prefix() { - return Err(ProviderRequestError::UnmarshalErrorIncorrectPrefix); - } - - let mut destination_address = [0u8; 32]; - destination_address.copy_from_slice(&bytes[2..]); + let mut destination_address = [0u8; DESTINATION_ADDRESS_LENGTH]; + destination_address.copy_from_slice(&bytes[..DESTINATION_ADDRESS_LENGTH]); Ok(RegisterRequest { destination_address: DestinationAddressBytes::from_bytes(destination_address), @@ -164,33 +293,12 @@ mod creating_pull_request { 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, ]); - let auth_token = [1u8; 32]; - let pull_request = PullRequest::new(address.clone(), AuthToken(auth_token)); + let auth_token = [1u8; AUTH_TOKEN_SIZE]; + let pull_request = PullRequest::new(address.clone(), AuthToken::from_bytes(auth_token)); let bytes = pull_request.to_bytes(); - let recovered = PullRequest::from_bytes(&bytes).unwrap(); - assert_eq!(address, recovered.destination_address); - assert_eq!(AuthToken(auth_token), recovered.auth_token); - } - - #[test] - fn it_is_possible_to_recover_it_from_bytes_with_enum_wrapper() { - let address = DestinationAddressBytes::from_bytes([ - 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, - 0, 1, 2, - ]); - let auth_token = [1u8; 32]; - let pull_request = PullRequest::new(address.clone(), AuthToken(auth_token)); - let bytes = pull_request.to_bytes(); - - let recovered = ProviderRequests::from_bytes(&bytes).unwrap(); - match recovered { - ProviderRequests::PullMessages(req) => { - assert_eq!(address, req.destination_address); - assert_eq!(AuthToken(auth_token), req.auth_token); - } - _ => panic!("expected to recover pull request!"), - } + let recovered = PullRequest::try_from_bytes(&bytes).unwrap(); + assert_eq!(recovered, pull_request); } } @@ -207,25 +315,7 @@ mod creating_register_request { let register_request = RegisterRequest::new(address.clone()); let bytes = register_request.to_bytes(); - let recovered = RegisterRequest::from_bytes(&bytes).unwrap(); - assert_eq!(address, recovered.destination_address); - } - - #[test] - fn it_is_possible_to_recover_it_from_bytes_with_enum_wrapper() { - let address = DestinationAddressBytes::from_bytes([ - 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, - 0, 1, 2, - ]); - let register_request = RegisterRequest::new(address.clone()); - let bytes = register_request.to_bytes(); - - let recovered = ProviderRequests::from_bytes(&bytes).unwrap(); - match recovered { - ProviderRequests::Register(req) => { - assert_eq!(address, req.destination_address); - } - _ => panic!("expected to recover pull request!"), - } + let recovered = RegisterRequest::try_from_bytes(&bytes).unwrap(); + assert_eq!(recovered, register_request); } } diff --git a/sfw-provider/sfw-provider-requests/src/responses.rs b/sfw-provider/sfw-provider-requests/src/responses.rs index 47163ef0225..ef7b74ce14d 100644 --- a/sfw-provider/sfw-provider-requests/src/responses.rs +++ b/sfw-provider/sfw-provider-requests/src/responses.rs @@ -1,74 +1,234 @@ -use crate::AuthToken; -use std::convert::TryInto; +use crate::auth_token::{AuthToken, AUTH_TOKEN_SIZE}; +use std::convert::{TryFrom, TryInto}; +use std::io; +use std::io::Error; +use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; #[derive(Debug)] pub enum ProviderResponseError { MarshalError, UnmarshalError, + UnmarshalErrorInvalidKind, UnmarshalErrorInvalidLength, + TooShortResponseError, + TooLongResponseError, + IOError(io::Error), + RemoteConnectionClosed, } -pub trait ProviderResponse -where - Self: Sized, -{ - fn to_bytes(&self) -> Vec; - fn from_bytes(bytes: &[u8]) -> Result; +impl From for ProviderResponseError { + fn from(e: Error) -> Self { + ProviderResponseError::IOError(e) + } } -#[derive(Debug)] -pub struct PullResponse { - pub messages: Vec>, +impl<'a, R: AsyncRead + Unpin> Drop for TokioAsyncResponseReader<'a, R> { + fn drop(&mut self) { + println!("response reader drop"); + } } -#[derive(Debug)] -pub struct RegisterResponse { - pub auth_token: AuthToken, +impl<'a, R: AsyncWrite + Unpin> Drop for TokioAsyncResponseWriter<'a, R> { + fn drop(&mut self) { + println!("response writer drop"); + } } -pub struct ErrorResponse { - pub message: String, +#[derive(Debug, Clone, Copy, PartialEq)] +#[repr(u8)] +pub enum ResponseKind { + Failure = 0, // perhaps Error would have been a better name, but it'd clash with Self::TryFrom::Error + Pull = 1, + Register = 2, } -impl PullResponse { - pub fn new(messages: Vec>) -> Self { - PullResponse { messages } +impl TryFrom for ResponseKind { + type Error = ProviderResponseError; + + fn try_from(value: u8) -> Result { + match value { + _ if value == (ResponseKind::Failure as u8) => Ok(Self::Register), + _ if value == (ResponseKind::Pull as u8) => Ok(Self::Pull), + _ if value == (ResponseKind::Register as u8) => Ok(Self::Register), + _ => Err(Self::Error::UnmarshalErrorInvalidKind), + } } } -impl RegisterResponse { - pub fn new(auth_token: AuthToken) -> Self { - RegisterResponse { auth_token } +#[derive(Debug, Clone, PartialEq)] +pub enum ProviderResponse { + Failure(FailureResponse), + Pull(PullResponse), + Register(RegisterResponse), +} + +// Ideally I would have used futures::AsyncRead for even more generic approach, but unfortunately +// tokio::io::AsyncRead differs from futures::AsyncRead +pub struct TokioAsyncResponseReader<'a, R: AsyncRead + Unpin> { + max_allowed_len: usize, + reader: &'a mut R, +} + +impl<'a, R: AsyncRead + Unpin> TokioAsyncResponseReader<'a, R> { + pub fn new(reader: &'a mut R, max_allowed_len: usize) -> Self { + TokioAsyncResponseReader { + reader, + max_allowed_len, + } + } + + pub async fn try_read_response(&mut self) -> Result { + let res_len = self.reader.read_u32().await?; + if res_len == 0 { + return Err(ProviderResponseError::RemoteConnectionClosed); + } + if res_len as usize > self.max_allowed_len { + // TODO: should reader be drained? + return Err(ProviderResponseError::TooLongResponseError); + } + + let mut res_buf = Vec::with_capacity(res_len as usize); + let mut chunk = self.reader.take(res_len as u64); + + if let Err(_) = chunk.read_to_end(&mut res_buf).await { + return Err(ProviderResponseError::TooShortResponseError); + }; + + ResponseDeserializer::new_with_len(res_len, &res_buf)?.try_to_parse() } } -impl ErrorResponse { - pub fn new>(message: S) -> Self { - ErrorResponse { - message: message.into(), +pub struct ResponseDeserializer<'a> { + kind: ResponseKind, + data: &'a [u8], +} + +impl<'a> ResponseDeserializer<'a> { + // perform initial parsing + pub fn new(raw_bytes: &'a [u8]) -> Result { + if raw_bytes.len() < 1 + 4 { + Err(ProviderResponseError::UnmarshalErrorInvalidLength) + } else { + let data_len = + u32::from_be_bytes([raw_bytes[0], raw_bytes[1], raw_bytes[2], raw_bytes[3]]); + let kind = ResponseKind::try_from(raw_bytes[4])?; + let data = &raw_bytes[4..]; + + if data.len() != data_len as usize { + Err(ProviderResponseError::UnmarshalErrorInvalidLength) + } else { + Ok(ResponseDeserializer { kind, data }) + } + } + } + + pub fn new_with_len(len: u32, raw_bytes: &'a [u8]) -> Result { + if raw_bytes.len() != len as usize { + Err(ProviderResponseError::UnmarshalErrorInvalidLength) + } else { + let kind = ResponseKind::try_from(raw_bytes[0])?; + let data = &raw_bytes[1..]; + Ok(ResponseDeserializer { kind, data }) + } + } + + pub fn get_kind(&self) -> ResponseKind { + self.kind + } + + pub fn get_data(&self) -> &'a [u8] { + self.data + } + + pub fn try_to_parse(self) -> Result { + match self.get_kind() { + ResponseKind::Failure => Ok(ProviderResponse::Failure( + FailureResponse::try_from_bytes(self.data)?, + )), + ResponseKind::Pull => Ok(ProviderResponse::Pull(PullResponse::try_from_bytes( + self.data, + )?)), + ResponseKind::Register => Ok(ProviderResponse::Register( + RegisterResponse::try_from_bytes(self.data)?, + )), } } } -// TODO: This should go into some kind of utils module/crate -fn read_be_u16(input: &mut &[u8]) -> u16 { - let (int_bytes, rest) = input.split_at(std::mem::size_of::()); - *input = rest; - u16::from_be_bytes(int_bytes.try_into().unwrap()) +// Ideally I would have used futures::AsyncWrite for even more generic approach, but unfortunately +// tokio::io::AsyncWrite differs from futures::AsyncWrite +pub struct TokioAsyncResponseWriter<'a, W: AsyncWrite + Unpin> { + writer: &'a mut W, +} + +impl<'a, W: AsyncWrite + Unpin> TokioAsyncResponseWriter<'a, W> { + pub fn new(writer: &'a mut W) -> Self { + TokioAsyncResponseWriter { writer } + } + + pub async fn try_write_response(&mut self, res: ProviderResponse) -> io::Result<()> { + let res_bytes = ResponseSerializer::new(res).into_bytes(); + self.writer.write_all(&res_bytes).await + } +} + +pub struct ResponseSerializer { + res: ProviderResponse, +} + +impl ResponseSerializer { + pub fn new(res: ProviderResponse) -> Self { + ResponseSerializer { res } + } + + /// Serialized responses in general have the following structure: + /// follows: 4 byte len (be u32) || 1-byte kind prefix || response-specific data + pub fn into_bytes(self) -> Vec { + let (kind, res_bytes) = match self.res { + ProviderResponse::Failure(res) => (res.get_kind(), res.to_bytes()), + ProviderResponse::Pull(res) => (res.get_kind(), res.to_bytes()), + ProviderResponse::Register(res) => (res.get_kind(), res.to_bytes()), + }; + let res_len = res_bytes.len() as u32 + 1; // 1 is to accommodate for 'kind' + let res_len_bytes = res_len.to_be_bytes(); + res_len_bytes + .iter() + .cloned() + .chain(std::iter::once(kind as u8)) + .chain(res_bytes.into_iter()) + .collect() + } +} + +#[derive(Debug, Clone, PartialEq)] +pub struct PullResponse { + messages: Vec>, } -// TODO: currently this allows for maximum 64kB payload - -// if we go over that in sphinx we need to update this code. -impl ProviderResponse for PullResponse { +impl PullResponse { + pub fn new(messages: Vec>) -> Self { + PullResponse { messages } + } + + pub fn extract_messages(self) -> Vec> { + self.messages + } + + pub fn get_kind(&self) -> ResponseKind { + ResponseKind::Pull + } + + // TODO: currently this allows for maximum 64kB payload (max value of u16) - + // if we go over that in sphinx we need to update this code. // num_msgs || len1 || len2 || ... || msg1 || msg2 || ... - fn to_bytes(&self) -> Vec { + pub fn to_bytes(&self) -> Vec { let num_msgs = self.messages.len() as u16; let msgs_lens: Vec = self.messages.iter().map(|msg| msg.len() as u16).collect(); num_msgs .to_be_bytes() - .to_vec() - .into_iter() + .iter() + .cloned() .chain( msgs_lens .into_iter() @@ -78,7 +238,7 @@ impl ProviderResponse for PullResponse { .collect() } - fn from_bytes(bytes: &[u8]) -> Result { + pub fn try_from_bytes(bytes: &[u8]) -> Result { // can we read number of messages? if bytes.len() < 2 { return Err(ProviderResponseError::UnmarshalErrorInvalidLength); @@ -119,38 +279,79 @@ impl ProviderResponse for PullResponse { } } -impl ProviderResponse for RegisterResponse { - fn to_bytes(&self) -> Vec { - self.auth_token.0.to_vec() +#[derive(Debug, Clone, PartialEq)] +pub struct RegisterResponse { + auth_token: AuthToken, +} + +impl RegisterResponse { + pub fn new(auth_token: AuthToken) -> Self { + RegisterResponse { auth_token } } - fn from_bytes(bytes: &[u8]) -> Result { - match bytes.len() { - 32 => { - let mut auth_token = [0u8; 32]; - auth_token.copy_from_slice(&bytes[..32]); - Ok(RegisterResponse { - auth_token: AuthToken(auth_token), - }) - } - _ => Err(ProviderResponseError::UnmarshalErrorInvalidLength), + pub fn get_token(&self) -> AuthToken { + self.auth_token + } + + pub fn get_kind(&self) -> ResponseKind { + ResponseKind::Register + } + + pub fn to_bytes(&self) -> Vec { + self.auth_token.to_bytes().to_vec() + } + + pub fn try_from_bytes(bytes: &[u8]) -> Result { + if bytes.len() != AUTH_TOKEN_SIZE { + return Err(ProviderResponseError::UnmarshalErrorInvalidLength); } + + let mut auth_token = [0u8; AUTH_TOKEN_SIZE]; + auth_token.copy_from_slice(&bytes[..AUTH_TOKEN_SIZE]); + Ok(RegisterResponse { + auth_token: AuthToken::from_bytes(auth_token), + }) } } -impl ProviderResponse for ErrorResponse { - fn to_bytes(&self) -> Vec { +#[derive(Debug, Clone, PartialEq)] +pub struct FailureResponse { + message: String, +} + +impl FailureResponse { + pub fn new>(message: S) -> Self { + FailureResponse { + message: message.into(), + } + } + + pub fn get_message(&self) -> &str { + &self.message + } + + pub fn get_kind(&self) -> ResponseKind { + ResponseKind::Failure + } + + pub fn to_bytes(&self) -> Vec { self.message.clone().into_bytes() } - fn from_bytes(bytes: &[u8]) -> Result { + pub fn try_from_bytes(bytes: &[u8]) -> Result { match String::from_utf8(bytes.to_vec()) { Err(_) => Err(ProviderResponseError::UnmarshalError), - Ok(message) => Ok(ErrorResponse { message }), + Ok(message) => Ok(FailureResponse { message }), } } } +fn read_be_u16(input: &mut &[u8]) -> u16 { + let (int_bytes, rest) = input.split_at(std::mem::size_of::()); + *input = rest; + u16::from_be_bytes(int_bytes.try_into().unwrap()) +} + #[cfg(test)] mod creating_pull_response { use super::*; @@ -172,7 +373,7 @@ mod creating_pull_response { let pull_response = PullResponse::new(msgs.clone()); let bytes = pull_response.to_bytes(); - let recovered = PullResponse::from_bytes(&bytes).unwrap(); + let recovered = PullResponse::try_from_bytes(&bytes).unwrap(); assert_eq!(msgs, recovered.messages); } } From 8dde353609abbd5a707caa6856f01dd775a41561 Mon Sep 17 00:00:00 2001 From: jstuczyn Date: Tue, 14 Apr 2020 10:41:54 +0100 Subject: [PATCH 05/23] Initial attempt of using new provider client --- Cargo.lock | 2 + common/clients/provider-client/src/lib.rs | 116 +++++++++++++----- .../src/provider/client_handling/listener.rs | 64 ++++++---- .../client_handling/request_processing.rs | 21 ++-- 4 files changed, 134 insertions(+), 69 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 618c1d260fc..1d39c2d371b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2421,7 +2421,9 @@ name = "sfw-provider-requests" version = "0.1.0" dependencies = [ "bs58", + "bytes 0.5.4", "sphinx", + "tokio 0.2.12", ] [[package]] diff --git a/common/clients/provider-client/src/lib.rs b/common/clients/provider-client/src/lib.rs index 7191941a93e..1f642e33e68 100644 --- a/common/clients/provider-client/src/lib.rs +++ b/common/clients/provider-client/src/lib.rs @@ -1,13 +1,16 @@ use futures::io::Error; +use futures::AsyncReadExt; use log::*; -use sfw_provider_requests::requests::{ProviderRequest, PullRequest, RegisterRequest}; +use sfw_provider_requests::auth_token::AuthToken; +use sfw_provider_requests::requests::{ + ProviderRequest, PullRequest, RegisterRequest, TokioAsyncRequestWriter, +}; use sfw_provider_requests::responses::{ ProviderResponse, ProviderResponseError, PullResponse, RegisterResponse, + TokioAsyncResponseReader, }; -use sfw_provider_requests::AuthToken; use sphinx::route::DestinationAddressBytes; -use std::net::{Shutdown, SocketAddr}; -use std::time::Duration; +use std::net::SocketAddr; use tokio::prelude::*; #[derive(Debug)] @@ -36,6 +39,12 @@ impl From for ProviderClientError { ProviderResponseError::MarshalError => InvalidRequestError, ProviderResponseError::UnmarshalError => InvalidResponseError, ProviderResponseError::UnmarshalErrorInvalidLength => InvalidResponseLengthError, + ProviderResponseError::UnmarshalErrorInvalidKind => InvalidResponseLengthError, + + ProviderResponseError::TooLongResponseError => InvalidResponseError, + ProviderResponseError::TooShortResponseError => InvalidResponseError, + ProviderResponseError::IOError(_) => NetworkError, + ProviderResponseError::RemoteConnectionClosed => NetworkError, } } } @@ -44,8 +53,11 @@ pub struct ProviderClient { provider_network_address: SocketAddr, our_address: DestinationAddressBytes, auth_token: Option, + connection: Option, } +const MAX_RESPONSE_SIZE: usize = 1_000_000_000; + impl ProviderClient { pub fn new( provider_network_address: SocketAddr, @@ -56,6 +68,23 @@ impl ProviderClient { provider_network_address, our_address, auth_token, + // establish connection when it's necessary (mainly to not break current code + // as then 'new' would need to be called within async context) + connection: None, + } + } + + async fn check_connection(&mut self) -> bool { + if self.connection.is_some() { + true + } else { + // TODO: possibly also introduce timeouts here? + // However, at this point it's slightly less important as we are in full control + // of providers. + self.connection = tokio::net::TcpStream::connect(self.provider_network_address) + .await + .ok(); + self.connection.is_some() } } @@ -63,25 +92,30 @@ impl ProviderClient { self.auth_token = Some(auth_token) } - pub async fn send_request(&self, bytes: Vec) -> Result, ProviderClientError> { - let mut socket = tokio::net::TcpStream::connect(self.provider_network_address).await?; - - socket.set_keepalive(Some(Duration::from_secs(2)))?; - socket.write_all(&bytes[..]).await?; - if let Err(e) = socket.shutdown(Shutdown::Write) { - warn!("failed to close write part of the socket; err = {:?}", e) + pub async fn send_request( + &mut self, + request: ProviderRequest, + ) -> Result { + if !self.check_connection().await { + return Err(ProviderClientError::NetworkError); } - let mut response = Vec::new(); - socket.read_to_end(&mut response).await?; - if let Err(e) = socket.shutdown(Shutdown::Read) { - debug!("failed to close read part of the socket; err = {:?}. It was probably already closed by the provider", e) + let socket = self.connection.as_mut().unwrap(); + let (mut socket_reader, mut socket_writer) = socket.split(); + + let mut request_writer = TokioAsyncRequestWriter::new(&mut socket_writer); + let mut response_reader = + TokioAsyncResponseReader::new(&mut socket_reader, MAX_RESPONSE_SIZE); + + if let Err(e) = request_writer.try_write_request(request).await { + debug!("Failed to write the request - {:?}", e); + return Err(e.into()); } - Ok(response) + Ok(response_reader.try_read_response().await?) } - pub async fn retrieve_messages(&self) -> Result>, ProviderClientError> { + pub async fn retrieve_messages(&mut self) -> Result>, ProviderClientError> { let auth_token = match self.auth_token.as_ref() { Some(token) => token.clone(), None => { @@ -89,27 +123,45 @@ impl ProviderClient { } }; - let pull_request = PullRequest::new(self.our_address.clone(), auth_token); - let bytes = pull_request.to_bytes(); - - let response = self.send_request(bytes).await?; - - let parsed_response = PullResponse::from_bytes(&response)?; - Ok(parsed_response.messages) + let pull_request = + ProviderRequest::Pull(PullRequest::new(self.our_address.clone(), auth_token)); + match self.send_request(pull_request).await? { + ProviderResponse::Pull(res) => Ok(res.extract_messages()), + ProviderResponse::Failure(res) => { + error!( + "We failed to get our request processed - {:?}", + res.get_message() + ); + Err(ProviderClientError::InvalidResponseError) + } + _ => { + error!("Received response of unexpected type!"); + Err(ProviderClientError::InvalidResponseError) + } + } } - pub async fn register(&self) -> Result { + pub async fn register(&mut self) -> Result { if self.auth_token.is_some() { return Err(ProviderClientError::ClientAlreadyRegisteredError); } - let register_request = RegisterRequest::new(self.our_address.clone()); - let bytes = register_request.to_bytes(); - - let response = self.send_request(bytes).await?; - let parsed_response = RegisterResponse::from_bytes(&response)?; - - Ok(parsed_response.auth_token) + let register_request = + ProviderRequest::Register(RegisterRequest::new(self.our_address.clone())); + match self.send_request(register_request).await? { + ProviderResponse::Register(res) => Ok(res.get_token()), + ProviderResponse::Failure(res) => { + error!( + "We failed to get our request processed - {:?}", + res.get_message() + ); + Err(ProviderClientError::InvalidResponseError) + } + _ => { + error!("Received response of unexpected type!"); + Err(ProviderClientError::InvalidResponseError) + } + } } pub fn is_registered(&self) -> bool { diff --git a/sfw-provider/src/provider/client_handling/listener.rs b/sfw-provider/src/provider/client_handling/listener.rs index 1508efa2321..e7d3d5e29a4 100644 --- a/sfw-provider/src/provider/client_handling/listener.rs +++ b/sfw-provider/src/provider/client_handling/listener.rs @@ -2,32 +2,33 @@ use crate::provider::client_handling::request_processing::{ ClientProcessingResult, RequestProcessor, }; use log::*; -use sfw_provider_requests::responses::{ - ErrorResponse, ProviderResponse, PullResponse, RegisterResponse, +use sfw_provider_requests::requests::{ + ProviderRequest, ProviderRequestError, TokioAsyncRequestReader, }; +use sfw_provider_requests::responses::*; use std::io; use std::net::SocketAddr; -use tokio::prelude::*; use tokio::runtime::Handle; use tokio::task::JoinHandle; -async fn process_request( - socket: &mut tokio::net::TcpStream, - packet_data: &[u8], +async fn process_request<'a>( + response_writer: &mut TokioAsyncResponseWriter<'a, tokio::net::tcp::WriteHalf<'a>>, + request: ProviderRequest, request_processor: &mut RequestProcessor, ) { - match request_processor.process_client_request(packet_data).await { + match request_processor.process_client_request(request).await { Err(e) => { warn!("We failed to process client request - {:?}", e); - let response_bytes = ErrorResponse::new(format!("{:?}", e)).to_bytes(); - if let Err(e) = socket.write_all(&response_bytes).await { + let failure_response = + ProviderResponse::Failure(FailureResponse::new(format!("{:?}", e))); + if let Err(e) = response_writer.try_write_response(failure_response).await { debug!("Failed to write response to the client - {:?}", e); } } Ok(res) => match res { ClientProcessingResult::RegisterResponse(auth_token) => { - let response_bytes = RegisterResponse::new(auth_token).to_bytes(); - if let Err(e) = socket.write_all(&response_bytes).await { + let response = ProviderResponse::Register(RegisterResponse::new(auth_token)); + if let Err(e) = response_writer.try_write_response(response).await { debug!("Failed to write response to the client - {:?}", e); } } @@ -36,8 +37,8 @@ async fn process_request( .into_iter() .map(|c| c.into_tuple()) .unzip(); - let response_bytes = PullResponse::new(messages).to_bytes(); - if socket.write_all(&response_bytes).await.is_ok() { + let response = ProviderResponse::Pull(PullResponse::new(messages)); + if response_writer.try_write_response(response).await.is_ok() { // only delete stored messages if we managed to actually send the response if let Err(e) = request_processor.delete_sent_messages(paths).await { error!("Somehow failed to delete stored messages! - {:?}", e); @@ -48,36 +49,49 @@ async fn process_request( } } +// TODO: temporary proof of concept. will later be moved into config +const MAX_REQUEST_SIZE: usize = 4_096; + async fn process_socket_connection( mut socket: tokio::net::TcpStream, mut request_processor: RequestProcessor, ) { - let mut buf = [0u8; 1024]; + let peer_addr = socket.peer_addr(); + let (mut socket_reader, mut socket_writer) = socket.split(); + let mut request_reader = TokioAsyncRequestReader::new(&mut socket_reader, MAX_REQUEST_SIZE); + let mut response_writer = TokioAsyncResponseWriter::new(&mut socket_writer); + loop { - match socket.read(&mut buf).await { - // socket closed - Ok(n) if n == 0 => { + match request_reader.try_read_request().await { + Err(ProviderRequestError::RemoteConnectionClosed) => { trace!("Remote connection closed."); return; } - // in here we do not really want to process multiple requests from the same - // client concurrently as then we might end up with really weird race conditions - // plus realistically it wouldn't really introduce any speed up - Ok(n) => process_request(&mut socket, &buf[0..n], &mut request_processor).await, - Err(e) => { + Err(ProviderRequestError::IOError(e)) => { if e.kind() == io::ErrorKind::UnexpectedEof { debug!("Read buffer was not fully filled. Most likely the client ({:?}) closed the connection.\ - Also closing the connection on this end.", socket.peer_addr()) + Also closing the connection on this end.", peer_addr) } else { warn!( "failed to read from socket (source: {:?}). Closing the connection; err = {:?}", - socket.peer_addr(), + peer_addr, e ); } return; } - }; + Err(e) => { + // let's leave it like this for time being and see if we need to decrease + // logging level and / or close the connection + warn!("the received request was invalid - {:?}", e); + } + // in here we do not really want to process multiple requests from the same + // client concurrently as then we might end up with really weird race conditions + // plus realistically it wouldn't really introduce any speed up + Ok(request) => { + process_request(&mut response_writer, request, &mut request_processor).await + } + } } } diff --git a/sfw-provider/src/provider/client_handling/request_processing.rs b/sfw-provider/src/provider/client_handling/request_processing.rs index 57707c1b0dd..6e9cb2fa5e9 100644 --- a/sfw-provider/src/provider/client_handling/request_processing.rs +++ b/sfw-provider/src/provider/client_handling/request_processing.rs @@ -3,10 +3,8 @@ use crate::provider::storage::{ClientFile, ClientStorage}; use crypto::encryption; use hmac::{Hmac, Mac}; use log::*; -use sfw_provider_requests::requests::{ - ProviderRequestError, ProviderRequests, PullRequest, RegisterRequest, -}; -use sfw_provider_requests::AuthToken; +use sfw_provider_requests::auth_token::AuthToken; +use sfw_provider_requests::requests::*; use sha2::Sha256; use sphinx::route::DestinationAddressBytes; use std::io; @@ -65,13 +63,11 @@ impl RequestProcessor { pub(crate) async fn process_client_request( &mut self, - request_bytes: &[u8], + client_request: ProviderRequest, ) -> Result { - let client_request = ProviderRequests::from_bytes(request_bytes)?; - debug!("Received the following request: {:?}", client_request); match client_request { - ProviderRequests::Register(req) => self.process_register_request(req).await, - ProviderRequests::PullMessages(req) => self.process_pull_request(req).await, + ProviderRequest::Register(req) => self.process_register_request(req).await, + ProviderRequest::Pull(req) => self.process_pull_request(req).await, } } @@ -126,6 +122,7 @@ impl RequestProcessor { &self, req: PullRequest, ) -> Result { + println!("pull request for {:?}", req.destination_address); if self .client_ledger .verify_token(&req.auth_token, &req.destination_address) @@ -135,10 +132,10 @@ impl RequestProcessor { .client_storage .retrieve_client_files(req.destination_address) .await?; - return Ok(ClientProcessingResult::PullResponse(retrieved_messages)); + Ok(ClientProcessingResult::PullResponse(retrieved_messages)) + } else { + Err(ClientProcessingError::InvalidToken) } - - Err(ClientProcessingError::InvalidToken) } pub(crate) async fn delete_sent_messages(&self, file_paths: Vec) -> io::Result<()> { From 18d1307508859ee5f5bdabd3d68528a52b8403c5 Mon Sep 17 00:00:00 2001 From: jstuczyn Date: Tue, 14 Apr 2020 11:01:39 +0100 Subject: [PATCH 06/23] Moved requests and responses to separate modules --- .../sfw-provider-requests/src/{requests.rs => requests/mod.rs} | 0 .../sfw-provider-requests/src/{responses.rs => responses/mod.rs} | 0 2 files changed, 0 insertions(+), 0 deletions(-) rename sfw-provider/sfw-provider-requests/src/{requests.rs => requests/mod.rs} (100%) rename sfw-provider/sfw-provider-requests/src/{responses.rs => responses/mod.rs} (100%) diff --git a/sfw-provider/sfw-provider-requests/src/requests.rs b/sfw-provider/sfw-provider-requests/src/requests/mod.rs similarity index 100% rename from sfw-provider/sfw-provider-requests/src/requests.rs rename to sfw-provider/sfw-provider-requests/src/requests/mod.rs diff --git a/sfw-provider/sfw-provider-requests/src/responses.rs b/sfw-provider/sfw-provider-requests/src/responses/mod.rs similarity index 100% rename from sfw-provider/sfw-provider-requests/src/responses.rs rename to sfw-provider/sfw-provider-requests/src/responses/mod.rs From 0e11e2dacd124ab5db303ae5b2bee586c0699f46 Mon Sep 17 00:00:00 2001 From: jstuczyn Date: Tue, 14 Apr 2020 11:10:08 +0100 Subject: [PATCH 07/23] Moved serialization to separate files --- .../sfw-provider-requests/src/requests/mod.rs | 84 +---------------- .../src/requests/serialization.rs | 88 ++++++++++++++++++ .../src/responses/mod.rs | 88 +----------------- .../src/responses/serialization.rs | 93 +++++++++++++++++++ 4 files changed, 187 insertions(+), 166 deletions(-) create mode 100644 sfw-provider/sfw-provider-requests/src/requests/serialization.rs create mode 100644 sfw-provider/sfw-provider-requests/src/responses/serialization.rs diff --git a/sfw-provider/sfw-provider-requests/src/requests/mod.rs b/sfw-provider/sfw-provider-requests/src/requests/mod.rs index 1262677ad8d..c10f8e15540 100644 --- a/sfw-provider/sfw-provider-requests/src/requests/mod.rs +++ b/sfw-provider/sfw-provider-requests/src/requests/mod.rs @@ -1,4 +1,5 @@ use crate::auth_token::{AuthToken, AUTH_TOKEN_SIZE}; +use crate::requests::serialization::{RequestDeserializer, RequestSerializer}; use sphinx::constants::DESTINATION_ADDRESS_LENGTH; use sphinx::route::DestinationAddressBytes; use std::convert::TryFrom; @@ -6,6 +7,8 @@ use std::io; use std::io::Error; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; +pub(crate) mod serialization; + #[derive(Debug)] pub enum ProviderRequestError { MarshalError, @@ -99,60 +102,6 @@ impl<'a, R: AsyncRead + Unpin> TokioAsyncRequestReader<'a, R> { } } -pub struct RequestDeserializer<'a> { - kind: RequestKind, - data: &'a [u8], -} - -impl<'a> RequestDeserializer<'a> { - // perform initial parsing - pub fn new(raw_bytes: &'a [u8]) -> Result { - if raw_bytes.len() < 1 + 4 { - Err(ProviderRequestError::UnmarshalErrorInvalidLength) - } else { - let data_len = - u32::from_be_bytes([raw_bytes[0], raw_bytes[1], raw_bytes[2], raw_bytes[3]]); - let kind = RequestKind::try_from(raw_bytes[4])?; - let data = &raw_bytes[4..]; - - if data.len() != data_len as usize { - Err(ProviderRequestError::UnmarshalErrorInvalidLength) - } else { - Ok(RequestDeserializer { kind, data }) - } - } - } - - pub fn new_with_len(len: u32, raw_bytes: &'a [u8]) -> Result { - if raw_bytes.len() != len as usize { - Err(ProviderRequestError::UnmarshalErrorInvalidLength) - } else { - let kind = RequestKind::try_from(raw_bytes[0])?; - let data = &raw_bytes[1..]; - Ok(RequestDeserializer { kind, data }) - } - } - - pub fn get_kind(&self) -> RequestKind { - self.kind - } - - pub fn get_data(&self) -> &'a [u8] { - self.data - } - - pub fn try_to_parse(self) -> Result { - match self.get_kind() { - RequestKind::Pull => Ok(ProviderRequest::Pull(PullRequest::try_from_bytes( - self.data, - )?)), - RequestKind::Register => Ok(ProviderRequest::Register( - RegisterRequest::try_from_bytes(self.data)?, - )), - } - } -} - // Ideally I would have used futures::AsyncWrite for even more generic approach, but unfortunately // tokio::io::AsyncWrite differs from futures::AsyncWrite pub struct TokioAsyncRequestWriter<'a, W: AsyncWrite + Unpin> { @@ -170,33 +119,6 @@ impl<'a, W: AsyncWrite + Unpin> TokioAsyncRequestWriter<'a, W> { } } -pub struct RequestSerializer { - req: ProviderRequest, -} - -impl RequestSerializer { - pub fn new(req: ProviderRequest) -> Self { - RequestSerializer { req } - } - - /// Serialized requests in general have the following structure: - /// follows: 4 byte len (be u32) || 1-byte kind prefix || request-specific data - pub fn into_bytes(self) -> Vec { - let (kind, req_bytes) = match self.req { - ProviderRequest::Pull(req) => (req.get_kind(), req.to_bytes()), - ProviderRequest::Register(req) => (req.get_kind(), req.to_bytes()), - }; - let req_len = req_bytes.len() as u32 + 1; // 1 is to accommodate for 'kind' - let req_len_bytes = req_len.to_be_bytes(); - req_len_bytes - .iter() - .cloned() - .chain(std::iter::once(kind as u8)) - .chain(req_bytes.into_iter()) - .collect() - } -} - #[derive(Debug, Clone, PartialEq)] pub struct PullRequest { pub auth_token: AuthToken, diff --git a/sfw-provider/sfw-provider-requests/src/requests/serialization.rs b/sfw-provider/sfw-provider-requests/src/requests/serialization.rs new file mode 100644 index 00000000000..e62a6bb144d --- /dev/null +++ b/sfw-provider/sfw-provider-requests/src/requests/serialization.rs @@ -0,0 +1,88 @@ +use crate::requests::{ + ProviderRequest, ProviderRequestError, PullRequest, RegisterRequest, RequestKind, +}; +use std::convert::TryFrom; + +// TODO: way down the line, mostly for learning purposes, combine this with responses::serialization +// via procedural macros + +pub struct RequestSerializer { + req: ProviderRequest, +} + +impl RequestSerializer { + pub fn new(req: ProviderRequest) -> Self { + RequestSerializer { req } + } + + /// Serialized requests in general have the following structure: + /// follows: 4 byte len (be u32) || 1-byte kind prefix || request-specific data + pub fn into_bytes(self) -> Vec { + let (kind, req_bytes) = match self.req { + ProviderRequest::Pull(req) => (req.get_kind(), req.to_bytes()), + ProviderRequest::Register(req) => (req.get_kind(), req.to_bytes()), + }; + let req_len = req_bytes.len() as u32 + 1; // 1 is to accommodate for 'kind' + let req_len_bytes = req_len.to_be_bytes(); + req_len_bytes + .iter() + .cloned() + .chain(std::iter::once(kind as u8)) + .chain(req_bytes.into_iter()) + .collect() + } +} + +pub struct RequestDeserializer<'a> { + kind: RequestKind, + data: &'a [u8], +} + +impl<'a> RequestDeserializer<'a> { + // perform initial parsing + pub fn new(raw_bytes: &'a [u8]) -> Result { + if raw_bytes.len() < 1 + 4 { + Err(ProviderRequestError::UnmarshalErrorInvalidLength) + } else { + let data_len = + u32::from_be_bytes([raw_bytes[0], raw_bytes[1], raw_bytes[2], raw_bytes[3]]); + let kind = RequestKind::try_from(raw_bytes[4])?; + let data = &raw_bytes[4..]; + + if data.len() != data_len as usize { + Err(ProviderRequestError::UnmarshalErrorInvalidLength) + } else { + Ok(RequestDeserializer { kind, data }) + } + } + } + + pub fn new_with_len(len: u32, raw_bytes: &'a [u8]) -> Result { + if raw_bytes.len() != len as usize { + Err(ProviderRequestError::UnmarshalErrorInvalidLength) + } else { + let kind = RequestKind::try_from(raw_bytes[0])?; + let data = &raw_bytes[1..]; + Ok(RequestDeserializer { kind, data }) + } + } + + pub fn get_kind(&self) -> RequestKind { + self.kind + } + + pub fn get_data(&self) -> &'a [u8] { + self.data + } + + pub fn try_to_parse(self) -> Result { + match self.get_kind() { + RequestKind::Pull => Ok(ProviderRequest::Pull(PullRequest::try_from_bytes( + self.data, + )?)), + RequestKind::Register => Ok(ProviderRequest::Register( + RegisterRequest::try_from_bytes(self.data)?, + )), + } + } +} diff --git a/sfw-provider/sfw-provider-requests/src/responses/mod.rs b/sfw-provider/sfw-provider-requests/src/responses/mod.rs index ef7b74ce14d..09e4f30d8b0 100644 --- a/sfw-provider/sfw-provider-requests/src/responses/mod.rs +++ b/sfw-provider/sfw-provider-requests/src/responses/mod.rs @@ -1,9 +1,12 @@ use crate::auth_token::{AuthToken, AUTH_TOKEN_SIZE}; +use crate::responses::serialization::{ResponseDeserializer, ResponseSerializer}; use std::convert::{TryFrom, TryInto}; use std::io; use std::io::Error; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; +pub(crate) mod serialization; + #[derive(Debug)] pub enum ProviderResponseError { MarshalError, @@ -98,63 +101,6 @@ impl<'a, R: AsyncRead + Unpin> TokioAsyncResponseReader<'a, R> { } } -pub struct ResponseDeserializer<'a> { - kind: ResponseKind, - data: &'a [u8], -} - -impl<'a> ResponseDeserializer<'a> { - // perform initial parsing - pub fn new(raw_bytes: &'a [u8]) -> Result { - if raw_bytes.len() < 1 + 4 { - Err(ProviderResponseError::UnmarshalErrorInvalidLength) - } else { - let data_len = - u32::from_be_bytes([raw_bytes[0], raw_bytes[1], raw_bytes[2], raw_bytes[3]]); - let kind = ResponseKind::try_from(raw_bytes[4])?; - let data = &raw_bytes[4..]; - - if data.len() != data_len as usize { - Err(ProviderResponseError::UnmarshalErrorInvalidLength) - } else { - Ok(ResponseDeserializer { kind, data }) - } - } - } - - pub fn new_with_len(len: u32, raw_bytes: &'a [u8]) -> Result { - if raw_bytes.len() != len as usize { - Err(ProviderResponseError::UnmarshalErrorInvalidLength) - } else { - let kind = ResponseKind::try_from(raw_bytes[0])?; - let data = &raw_bytes[1..]; - Ok(ResponseDeserializer { kind, data }) - } - } - - pub fn get_kind(&self) -> ResponseKind { - self.kind - } - - pub fn get_data(&self) -> &'a [u8] { - self.data - } - - pub fn try_to_parse(self) -> Result { - match self.get_kind() { - ResponseKind::Failure => Ok(ProviderResponse::Failure( - FailureResponse::try_from_bytes(self.data)?, - )), - ResponseKind::Pull => Ok(ProviderResponse::Pull(PullResponse::try_from_bytes( - self.data, - )?)), - ResponseKind::Register => Ok(ProviderResponse::Register( - RegisterResponse::try_from_bytes(self.data)?, - )), - } - } -} - // Ideally I would have used futures::AsyncWrite for even more generic approach, but unfortunately // tokio::io::AsyncWrite differs from futures::AsyncWrite pub struct TokioAsyncResponseWriter<'a, W: AsyncWrite + Unpin> { @@ -172,34 +118,6 @@ impl<'a, W: AsyncWrite + Unpin> TokioAsyncResponseWriter<'a, W> { } } -pub struct ResponseSerializer { - res: ProviderResponse, -} - -impl ResponseSerializer { - pub fn new(res: ProviderResponse) -> Self { - ResponseSerializer { res } - } - - /// Serialized responses in general have the following structure: - /// follows: 4 byte len (be u32) || 1-byte kind prefix || response-specific data - pub fn into_bytes(self) -> Vec { - let (kind, res_bytes) = match self.res { - ProviderResponse::Failure(res) => (res.get_kind(), res.to_bytes()), - ProviderResponse::Pull(res) => (res.get_kind(), res.to_bytes()), - ProviderResponse::Register(res) => (res.get_kind(), res.to_bytes()), - }; - let res_len = res_bytes.len() as u32 + 1; // 1 is to accommodate for 'kind' - let res_len_bytes = res_len.to_be_bytes(); - res_len_bytes - .iter() - .cloned() - .chain(std::iter::once(kind as u8)) - .chain(res_bytes.into_iter()) - .collect() - } -} - #[derive(Debug, Clone, PartialEq)] pub struct PullResponse { messages: Vec>, diff --git a/sfw-provider/sfw-provider-requests/src/responses/serialization.rs b/sfw-provider/sfw-provider-requests/src/responses/serialization.rs new file mode 100644 index 00000000000..9b9b995b797 --- /dev/null +++ b/sfw-provider/sfw-provider-requests/src/responses/serialization.rs @@ -0,0 +1,93 @@ +use crate::responses::{ + FailureResponse, ProviderResponse, ProviderResponseError, PullResponse, RegisterResponse, + ResponseKind, +}; +use std::convert::TryFrom; + +// TODO: way down the line, mostly for learning purposes, combine this with requests::serialization +// via procedural macros + +pub struct ResponseDeserializer<'a> { + kind: ResponseKind, + data: &'a [u8], +} + +impl<'a> ResponseDeserializer<'a> { + // perform initial parsing + pub fn new(raw_bytes: &'a [u8]) -> Result { + if raw_bytes.len() < 1 + 4 { + Err(ProviderResponseError::UnmarshalErrorInvalidLength) + } else { + let data_len = + u32::from_be_bytes([raw_bytes[0], raw_bytes[1], raw_bytes[2], raw_bytes[3]]); + let kind = ResponseKind::try_from(raw_bytes[4])?; + let data = &raw_bytes[4..]; + + if data.len() != data_len as usize { + Err(ProviderResponseError::UnmarshalErrorInvalidLength) + } else { + Ok(ResponseDeserializer { kind, data }) + } + } + } + + pub fn new_with_len(len: u32, raw_bytes: &'a [u8]) -> Result { + if raw_bytes.len() != len as usize { + Err(ProviderResponseError::UnmarshalErrorInvalidLength) + } else { + let kind = ResponseKind::try_from(raw_bytes[0])?; + let data = &raw_bytes[1..]; + Ok(ResponseDeserializer { kind, data }) + } + } + + pub fn get_kind(&self) -> ResponseKind { + self.kind + } + + pub fn get_data(&self) -> &'a [u8] { + self.data + } + + pub fn try_to_parse(self) -> Result { + match self.get_kind() { + ResponseKind::Failure => Ok(ProviderResponse::Failure( + FailureResponse::try_from_bytes(self.data)?, + )), + ResponseKind::Pull => Ok(ProviderResponse::Pull(PullResponse::try_from_bytes( + self.data, + )?)), + ResponseKind::Register => Ok(ProviderResponse::Register( + RegisterResponse::try_from_bytes(self.data)?, + )), + } + } +} + +pub struct ResponseSerializer { + res: ProviderResponse, +} + +impl ResponseSerializer { + pub fn new(res: ProviderResponse) -> Self { + ResponseSerializer { res } + } + + /// Serialized responses in general have the following structure: + /// follows: 4 byte len (be u32) || 1-byte kind prefix || response-specific data + pub fn into_bytes(self) -> Vec { + let (kind, res_bytes) = match self.res { + ProviderResponse::Failure(res) => (res.get_kind(), res.to_bytes()), + ProviderResponse::Pull(res) => (res.get_kind(), res.to_bytes()), + ProviderResponse::Register(res) => (res.get_kind(), res.to_bytes()), + }; + let res_len = res_bytes.len() as u32 + 1; // 1 is to accommodate for 'kind' + let res_len_bytes = res_len.to_be_bytes(); + res_len_bytes + .iter() + .cloned() + .chain(std::iter::once(kind as u8)) + .chain(res_bytes.into_iter()) + .collect() + } +} From f37b6362e7c40ec01f2e1e3142e580fffb42480f Mon Sep 17 00:00:00 2001 From: jstuczyn Date: Tue, 14 Apr 2020 11:28:14 +0100 Subject: [PATCH 08/23] Extracted readers and writers to io related modules --- common/clients/provider-client/src/lib.rs | 6 +- .../src/requests/async_io.rs | 74 +++++++++++++++++++ .../sfw-provider-requests/src/requests/mod.rs | 72 +----------------- .../src/responses/async_io.rs | 72 ++++++++++++++++++ .../src/responses/mod.rs | 70 +----------------- .../src/provider/client_handling/listener.rs | 7 +- 6 files changed, 157 insertions(+), 144 deletions(-) create mode 100644 sfw-provider/sfw-provider-requests/src/requests/async_io.rs create mode 100644 sfw-provider/sfw-provider-requests/src/responses/async_io.rs diff --git a/common/clients/provider-client/src/lib.rs b/common/clients/provider-client/src/lib.rs index 1f642e33e68..1f2db8cfed5 100644 --- a/common/clients/provider-client/src/lib.rs +++ b/common/clients/provider-client/src/lib.rs @@ -1,13 +1,11 @@ use futures::io::Error; -use futures::AsyncReadExt; use log::*; use sfw_provider_requests::auth_token::AuthToken; use sfw_provider_requests::requests::{ - ProviderRequest, PullRequest, RegisterRequest, TokioAsyncRequestWriter, + async_io::TokioAsyncRequestWriter, ProviderRequest, PullRequest, RegisterRequest, }; use sfw_provider_requests::responses::{ - ProviderResponse, ProviderResponseError, PullResponse, RegisterResponse, - TokioAsyncResponseReader, + async_io::TokioAsyncResponseReader, ProviderResponse, ProviderResponseError, }; use sphinx::route::DestinationAddressBytes; use std::net::SocketAddr; diff --git a/sfw-provider/sfw-provider-requests/src/requests/async_io.rs b/sfw-provider/sfw-provider-requests/src/requests/async_io.rs new file mode 100644 index 00000000000..2e70511521c --- /dev/null +++ b/sfw-provider/sfw-provider-requests/src/requests/async_io.rs @@ -0,0 +1,74 @@ +use crate::requests::serialization::{RequestDeserializer, RequestSerializer}; +use crate::requests::{ProviderRequest, ProviderRequestError}; +use std::io; +use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; + +// TODO: way down the line, mostly for learning purposes, combine this with responses::async_io +// via procedural macros + +impl<'a, R: AsyncRead + Unpin> Drop for TokioAsyncRequestReader<'a, R> { + fn drop(&mut self) { + println!("request reader drop"); + } +} + +impl<'a, R: AsyncWrite + Unpin> Drop for TokioAsyncRequestWriter<'a, R> { + fn drop(&mut self) { + println!("request writer drop"); + } +} + +// Ideally I would have used futures::AsyncRead for even more generic approach, but unfortunately +// tokio::io::AsyncRead differs from futures::AsyncRead +pub struct TokioAsyncRequestReader<'a, R: AsyncRead + Unpin> { + max_allowed_len: usize, + reader: &'a mut R, +} + +impl<'a, R: AsyncRead + Unpin> TokioAsyncRequestReader<'a, R> { + pub fn new(reader: &'a mut R, max_allowed_len: usize) -> Self { + TokioAsyncRequestReader { + reader, + max_allowed_len, + } + } + + pub async fn try_read_request(&mut self) -> Result { + let req_len = self.reader.read_u32().await?; + if req_len == 0 { + return Err(ProviderRequestError::RemoteConnectionClosed); + } + if req_len as usize > self.max_allowed_len { + // TODO: should reader be drained? + return Err(ProviderRequestError::TooLongRequestError); + } + + let mut req_buf = Vec::with_capacity(req_len as usize); + let mut chunk = self.reader.take(req_len as u64); + + if let Err(_) = chunk.read_to_end(&mut req_buf).await { + return Err(ProviderRequestError::TooShortRequestError); + }; + + let parse_res = RequestDeserializer::new_with_len(req_len, &req_buf)?.try_to_parse(); + + parse_res + } +} + +// Ideally I would have used futures::AsyncWrite for even more generic approach, but unfortunately +// tokio::io::AsyncWrite differs from futures::AsyncWrite +pub struct TokioAsyncRequestWriter<'a, W: AsyncWrite + Unpin> { + writer: &'a mut W, +} + +impl<'a, W: AsyncWrite + Unpin> TokioAsyncRequestWriter<'a, W> { + pub fn new(writer: &'a mut W) -> Self { + TokioAsyncRequestWriter { writer } + } + + pub async fn try_write_request(&mut self, res: ProviderRequest) -> io::Result<()> { + let res_bytes = RequestSerializer::new(res).into_bytes(); + self.writer.write_all(&res_bytes).await + } +} diff --git a/sfw-provider/sfw-provider-requests/src/requests/mod.rs b/sfw-provider/sfw-provider-requests/src/requests/mod.rs index c10f8e15540..904c521988e 100644 --- a/sfw-provider/sfw-provider-requests/src/requests/mod.rs +++ b/sfw-provider/sfw-provider-requests/src/requests/mod.rs @@ -1,13 +1,12 @@ use crate::auth_token::{AuthToken, AUTH_TOKEN_SIZE}; -use crate::requests::serialization::{RequestDeserializer, RequestSerializer}; use sphinx::constants::DESTINATION_ADDRESS_LENGTH; use sphinx::route::DestinationAddressBytes; use std::convert::TryFrom; use std::io; use std::io::Error; -use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; -pub(crate) mod serialization; +pub mod async_io; +pub mod serialization; #[derive(Debug)] pub enum ProviderRequestError { @@ -27,18 +26,6 @@ impl From for ProviderRequestError { } } -impl<'a, R: AsyncRead + Unpin> Drop for TokioAsyncRequestReader<'a, R> { - fn drop(&mut self) { - println!("request reader drop"); - } -} - -impl<'a, R: AsyncWrite + Unpin> Drop for TokioAsyncRequestWriter<'a, R> { - fn drop(&mut self) { - println!("request writer drop"); - } -} - #[derive(Debug, Clone, Copy, PartialEq)] #[repr(u8)] pub enum RequestKind { @@ -64,61 +51,6 @@ pub enum ProviderRequest { Register(RegisterRequest), } -// Ideally I would have used futures::AsyncRead for even more generic approach, but unfortunately -// tokio::io::AsyncRead differs from futures::AsyncRead -pub struct TokioAsyncRequestReader<'a, R: AsyncRead + Unpin> { - max_allowed_len: usize, - reader: &'a mut R, -} - -impl<'a, R: AsyncRead + Unpin> TokioAsyncRequestReader<'a, R> { - pub fn new(reader: &'a mut R, max_allowed_len: usize) -> Self { - TokioAsyncRequestReader { - reader, - max_allowed_len, - } - } - - pub async fn try_read_request(&mut self) -> Result { - let req_len = self.reader.read_u32().await?; - if req_len == 0 { - return Err(ProviderRequestError::RemoteConnectionClosed); - } - if req_len as usize > self.max_allowed_len { - // TODO: should reader be drained? - return Err(ProviderRequestError::TooLongRequestError); - } - - let mut req_buf = Vec::with_capacity(req_len as usize); - let mut chunk = self.reader.take(req_len as u64); - - if let Err(_) = chunk.read_to_end(&mut req_buf).await { - return Err(ProviderRequestError::TooShortRequestError); - }; - - let parse_res = RequestDeserializer::new_with_len(req_len, &req_buf)?.try_to_parse(); - - parse_res - } -} - -// Ideally I would have used futures::AsyncWrite for even more generic approach, but unfortunately -// tokio::io::AsyncWrite differs from futures::AsyncWrite -pub struct TokioAsyncRequestWriter<'a, W: AsyncWrite + Unpin> { - writer: &'a mut W, -} - -impl<'a, W: AsyncWrite + Unpin> TokioAsyncRequestWriter<'a, W> { - pub fn new(writer: &'a mut W) -> Self { - TokioAsyncRequestWriter { writer } - } - - pub async fn try_write_request(&mut self, res: ProviderRequest) -> io::Result<()> { - let res_bytes = RequestSerializer::new(res).into_bytes(); - self.writer.write_all(&res_bytes).await - } -} - #[derive(Debug, Clone, PartialEq)] pub struct PullRequest { pub auth_token: AuthToken, diff --git a/sfw-provider/sfw-provider-requests/src/responses/async_io.rs b/sfw-provider/sfw-provider-requests/src/responses/async_io.rs new file mode 100644 index 00000000000..4ad11818a52 --- /dev/null +++ b/sfw-provider/sfw-provider-requests/src/responses/async_io.rs @@ -0,0 +1,72 @@ +use crate::responses::serialization::{ResponseDeserializer, ResponseSerializer}; +use crate::responses::{ProviderResponse, ProviderResponseError}; +use std::io; +use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; + +// TODO: way down the line, mostly for learning purposes, combine this with requests::async_io +// via procedural macros + +impl<'a, R: AsyncRead + Unpin> Drop for TokioAsyncResponseReader<'a, R> { + fn drop(&mut self) { + println!("response reader drop"); + } +} + +impl<'a, R: AsyncWrite + Unpin> Drop for TokioAsyncResponseWriter<'a, R> { + fn drop(&mut self) { + println!("response writer drop"); + } +} + +// Ideally I would have used futures::AsyncRead for even more generic approach, but unfortunately +// tokio::io::AsyncRead differs from futures::AsyncRead +pub struct TokioAsyncResponseReader<'a, R: AsyncRead + Unpin> { + max_allowed_len: usize, + reader: &'a mut R, +} + +impl<'a, R: AsyncRead + Unpin> TokioAsyncResponseReader<'a, R> { + pub fn new(reader: &'a mut R, max_allowed_len: usize) -> Self { + TokioAsyncResponseReader { + reader, + max_allowed_len, + } + } + + pub async fn try_read_response(&mut self) -> Result { + let res_len = self.reader.read_u32().await?; + if res_len == 0 { + return Err(ProviderResponseError::RemoteConnectionClosed); + } + if res_len as usize > self.max_allowed_len { + // TODO: should reader be drained? + return Err(ProviderResponseError::TooLongResponseError); + } + + let mut res_buf = Vec::with_capacity(res_len as usize); + let mut chunk = self.reader.take(res_len as u64); + + if let Err(_) = chunk.read_to_end(&mut res_buf).await { + return Err(ProviderResponseError::TooShortResponseError); + }; + + ResponseDeserializer::new_with_len(res_len, &res_buf)?.try_to_parse() + } +} + +// Ideally I would have used futures::AsyncWrite for even more generic approach, but unfortunately +// tokio::io::AsyncWrite differs from futures::AsyncWrite +pub struct TokioAsyncResponseWriter<'a, W: AsyncWrite + Unpin> { + writer: &'a mut W, +} + +impl<'a, W: AsyncWrite + Unpin> TokioAsyncResponseWriter<'a, W> { + pub fn new(writer: &'a mut W) -> Self { + TokioAsyncResponseWriter { writer } + } + + pub async fn try_write_response(&mut self, res: ProviderResponse) -> io::Result<()> { + let res_bytes = ResponseSerializer::new(res).into_bytes(); + self.writer.write_all(&res_bytes).await + } +} diff --git a/sfw-provider/sfw-provider-requests/src/responses/mod.rs b/sfw-provider/sfw-provider-requests/src/responses/mod.rs index 09e4f30d8b0..6d4d2733366 100644 --- a/sfw-provider/sfw-provider-requests/src/responses/mod.rs +++ b/sfw-provider/sfw-provider-requests/src/responses/mod.rs @@ -1,11 +1,10 @@ use crate::auth_token::{AuthToken, AUTH_TOKEN_SIZE}; -use crate::responses::serialization::{ResponseDeserializer, ResponseSerializer}; use std::convert::{TryFrom, TryInto}; use std::io; use std::io::Error; -use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; -pub(crate) mod serialization; +pub mod async_io; +pub mod serialization; #[derive(Debug)] pub enum ProviderResponseError { @@ -25,18 +24,6 @@ impl From for ProviderResponseError { } } -impl<'a, R: AsyncRead + Unpin> Drop for TokioAsyncResponseReader<'a, R> { - fn drop(&mut self) { - println!("response reader drop"); - } -} - -impl<'a, R: AsyncWrite + Unpin> Drop for TokioAsyncResponseWriter<'a, R> { - fn drop(&mut self) { - println!("response writer drop"); - } -} - #[derive(Debug, Clone, Copy, PartialEq)] #[repr(u8)] pub enum ResponseKind { @@ -65,59 +52,6 @@ pub enum ProviderResponse { Register(RegisterResponse), } -// Ideally I would have used futures::AsyncRead for even more generic approach, but unfortunately -// tokio::io::AsyncRead differs from futures::AsyncRead -pub struct TokioAsyncResponseReader<'a, R: AsyncRead + Unpin> { - max_allowed_len: usize, - reader: &'a mut R, -} - -impl<'a, R: AsyncRead + Unpin> TokioAsyncResponseReader<'a, R> { - pub fn new(reader: &'a mut R, max_allowed_len: usize) -> Self { - TokioAsyncResponseReader { - reader, - max_allowed_len, - } - } - - pub async fn try_read_response(&mut self) -> Result { - let res_len = self.reader.read_u32().await?; - if res_len == 0 { - return Err(ProviderResponseError::RemoteConnectionClosed); - } - if res_len as usize > self.max_allowed_len { - // TODO: should reader be drained? - return Err(ProviderResponseError::TooLongResponseError); - } - - let mut res_buf = Vec::with_capacity(res_len as usize); - let mut chunk = self.reader.take(res_len as u64); - - if let Err(_) = chunk.read_to_end(&mut res_buf).await { - return Err(ProviderResponseError::TooShortResponseError); - }; - - ResponseDeserializer::new_with_len(res_len, &res_buf)?.try_to_parse() - } -} - -// Ideally I would have used futures::AsyncWrite for even more generic approach, but unfortunately -// tokio::io::AsyncWrite differs from futures::AsyncWrite -pub struct TokioAsyncResponseWriter<'a, W: AsyncWrite + Unpin> { - writer: &'a mut W, -} - -impl<'a, W: AsyncWrite + Unpin> TokioAsyncResponseWriter<'a, W> { - pub fn new(writer: &'a mut W) -> Self { - TokioAsyncResponseWriter { writer } - } - - pub async fn try_write_response(&mut self, res: ProviderResponse) -> io::Result<()> { - let res_bytes = ResponseSerializer::new(res).into_bytes(); - self.writer.write_all(&res_bytes).await - } -} - #[derive(Debug, Clone, PartialEq)] pub struct PullResponse { messages: Vec>, diff --git a/sfw-provider/src/provider/client_handling/listener.rs b/sfw-provider/src/provider/client_handling/listener.rs index e7d3d5e29a4..eb3decddc36 100644 --- a/sfw-provider/src/provider/client_handling/listener.rs +++ b/sfw-provider/src/provider/client_handling/listener.rs @@ -3,9 +3,12 @@ use crate::provider::client_handling::request_processing::{ }; use log::*; use sfw_provider_requests::requests::{ - ProviderRequest, ProviderRequestError, TokioAsyncRequestReader, + async_io::TokioAsyncRequestReader, ProviderRequest, ProviderRequestError, +}; +use sfw_provider_requests::responses::{ + async_io::TokioAsyncResponseWriter, FailureResponse, ProviderResponse, PullResponse, + RegisterResponse, }; -use sfw_provider_requests::responses::*; use std::io; use std::net::SocketAddr; use tokio::runtime::Handle; From 6f6eb554dd908dd0e72e077083fbce5b3c459892 Mon Sep 17 00:00:00 2001 From: jstuczyn Date: Tue, 14 Apr 2020 12:49:17 +0100 Subject: [PATCH 09/23] Extra tests + bug fixes --- .../sfw-provider-requests/src/requests/mod.rs | 55 +++++++-- .../src/responses/mod.rs | 107 +++++++++++++++++- 2 files changed, 149 insertions(+), 13 deletions(-) diff --git a/sfw-provider/sfw-provider-requests/src/requests/mod.rs b/sfw-provider/sfw-provider-requests/src/requests/mod.rs index 904c521988e..bceb072f73e 100644 --- a/sfw-provider/sfw-provider-requests/src/requests/mod.rs +++ b/sfw-provider/sfw-provider-requests/src/requests/mod.rs @@ -138,17 +138,51 @@ impl RegisterRequest { } #[cfg(test)] -mod creating_pull_request { +mod request_kind { use super::*; #[test] - fn it_is_possible_to_recover_it_from_bytes() { + fn try_from_u8_is_defined_for_all_variants() { + // it is crucial this match statement is never removed as it ensures at compilation + // time that new variants of RequestKind weren't added; the actual code is not as + // important + let dummy_kind = RequestKind::Register; + match dummy_kind { + RequestKind::Register | RequestKind::Pull => (), + }; + + assert_eq!( + RequestKind::try_from(RequestKind::Register as u8).unwrap(), + RequestKind::Register + ); + assert_eq!( + RequestKind::try_from(RequestKind::Pull as u8).unwrap(), + RequestKind::Pull + ); + } +} + +#[cfg(test)] +mod pull_request { + use super::*; + + #[test] + fn returns_correct_kind() { + let pull_request = PullRequest::new( + DestinationAddressBytes::from_bytes(Default::default()), + AuthToken::from_bytes(Default::default()), + ); + assert_eq!(pull_request.get_kind(), RequestKind::Pull) + } + + #[test] + fn can_be_converted_to_and_from_bytes() { let address = DestinationAddressBytes::from_bytes([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, ]); - let auth_token = [1u8; AUTH_TOKEN_SIZE]; - let pull_request = PullRequest::new(address.clone(), AuthToken::from_bytes(auth_token)); + let auth_token = AuthToken::from_bytes([1u8; AUTH_TOKEN_SIZE]); + let pull_request = PullRequest::new(address, auth_token); let bytes = pull_request.to_bytes(); let recovered = PullRequest::try_from_bytes(&bytes).unwrap(); @@ -157,16 +191,23 @@ mod creating_pull_request { } #[cfg(test)] -mod creating_register_request { +mod register_request { use super::*; #[test] - fn it_is_possible_to_recover_it_from_bytes() { + fn returns_correct_kind() { + let register_request = + RegisterRequest::new(DestinationAddressBytes::from_bytes(Default::default())); + assert_eq!(register_request.get_kind(), RequestKind::Register) + } + + #[test] + fn can_be_converted_to_and_from_bytes() { let address = DestinationAddressBytes::from_bytes([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, ]); - let register_request = RegisterRequest::new(address.clone()); + let register_request = RegisterRequest::new(address); let bytes = register_request.to_bytes(); let recovered = RegisterRequest::try_from_bytes(&bytes).unwrap(); diff --git a/sfw-provider/sfw-provider-requests/src/responses/mod.rs b/sfw-provider/sfw-provider-requests/src/responses/mod.rs index 6d4d2733366..ba51e5f2526 100644 --- a/sfw-provider/sfw-provider-requests/src/responses/mod.rs +++ b/sfw-provider/sfw-provider-requests/src/responses/mod.rs @@ -37,7 +37,7 @@ impl TryFrom for ResponseKind { fn try_from(value: u8) -> Result { match value { - _ if value == (ResponseKind::Failure as u8) => Ok(Self::Register), + _ if value == (ResponseKind::Failure as u8) => Ok(Self::Failure), _ if value == (ResponseKind::Pull as u8) => Ok(Self::Pull), _ if value == (ResponseKind::Register as u8) => Ok(Self::Register), _ => Err(Self::Error::UnmarshalErrorInvalidKind), @@ -75,7 +75,14 @@ impl PullResponse { // num_msgs || len1 || len2 || ... || msg1 || msg2 || ... pub fn to_bytes(&self) -> Vec { let num_msgs = self.messages.len() as u16; - let msgs_lens: Vec = self.messages.iter().map(|msg| msg.len() as u16).collect(); + let msgs_lens: Vec = self + .messages + .iter() + .map(|msg| { + assert!(msg.len() <= u16::max_value() as usize); + msg.len() as u16 + }) + .collect(); num_msgs .to_be_bytes() @@ -205,11 +212,46 @@ fn read_be_u16(input: &mut &[u8]) -> u16 { } #[cfg(test)] -mod creating_pull_response { +mod response_kind { + use super::*; + + #[test] + fn try_from_u8_is_defined_for_all_variants() { + // it is crucial this match statement is never removed as it ensures at compilation + // time that new variants of ResponseKind weren't added; the actual code is not as + // important + let dummy_kind = ResponseKind::Register; + match dummy_kind { + ResponseKind::Register | ResponseKind::Pull | ResponseKind::Failure => (), + }; + + assert_eq!( + ResponseKind::try_from(ResponseKind::Register as u8).unwrap(), + ResponseKind::Register + ); + assert_eq!( + ResponseKind::try_from(ResponseKind::Pull as u8).unwrap(), + ResponseKind::Pull + ); + assert_eq!( + ResponseKind::try_from(ResponseKind::Failure as u8).unwrap(), + ResponseKind::Failure + ); + } +} + +#[cfg(test)] +mod pull_response { use super::*; #[test] - fn it_is_possible_to_recover_it_from_bytes() { + fn returns_correct_kind() { + let pull_response = PullResponse::new(Default::default()); + assert_eq!(pull_response.get_kind(), ResponseKind::Pull) + } + + #[test] + fn can_be_converted_to_and_from_bytes() { let msg1 = vec![1, 2, 3, 4, 5]; let msg2 = vec![]; let msg3 = vec![ @@ -222,10 +264,63 @@ mod creating_pull_response { let msg4 = vec![1, 2, 3, 4, 5, 6, 7]; let msgs = vec![msg1, msg2, msg3, msg4]; - let pull_response = PullResponse::new(msgs.clone()); + let pull_response = PullResponse::new(msgs); let bytes = pull_response.to_bytes(); let recovered = PullResponse::try_from_bytes(&bytes).unwrap(); - assert_eq!(msgs, recovered.messages); + assert_eq!(recovered, pull_response); + } + + #[test] + #[should_panic] + fn panics_if_message_is_longer_than_u16_max_when_converted_to_bytes() { + let msg = [1u8; u16::max_value() as usize + 1].to_vec(); + + let pull_response = PullResponse::new(vec![msg]); + pull_response.to_bytes(); + } +} + +#[cfg(test)] +mod register_response { + use super::*; + + #[test] + fn returns_correct_kind() { + let register_response = RegisterResponse::new(AuthToken::from_bytes(Default::default())); + assert_eq!(register_response.get_kind(), ResponseKind::Register) + } + + #[test] + fn can_be_converted_to_and_from_bytes() { + let address = AuthToken::from_bytes([ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, + 0, 1, 2, + ]); + let register_response = RegisterResponse::new(address); + let bytes = register_response.to_bytes(); + + let recovered = RegisterResponse::try_from_bytes(&bytes).unwrap(); + assert_eq!(recovered, register_response); + } +} + +#[cfg(test)] +mod failure_response { + use super::*; + + #[test] + fn returns_correct_kind() { + let failure_response = FailureResponse::new("hello nym"); + assert_eq!(failure_response.get_kind(), ResponseKind::Failure) + } + + #[test] + fn can_be_converted_to_and_from_bytes() { + let failure_response = FailureResponse::new("hello nym"); + let bytes = failure_response.to_bytes(); + + let recovered = FailureResponse::try_from_bytes(&bytes).unwrap(); + assert_eq!(recovered, failure_response); } } From 3fb50a1b861a0aa7acee9d014b5df5cd494fb0a7 Mon Sep 17 00:00:00 2001 From: jstuczyn Date: Tue, 14 Apr 2020 12:58:46 +0100 Subject: [PATCH 10/23] Updated tokio dependency to require correct features --- sfw-provider/sfw-provider-requests/Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sfw-provider/sfw-provider-requests/Cargo.toml b/sfw-provider/sfw-provider-requests/Cargo.toml index dd5b21f4566..cf4ada74b76 100644 --- a/sfw-provider/sfw-provider-requests/Cargo.toml +++ b/sfw-provider/sfw-provider-requests/Cargo.toml @@ -9,5 +9,5 @@ edition = "2018" [dependencies] bs58 = "0.3" bytes = "0.5" -tokio = "0.2" +tokio = { version = "0.2", features = ["io-util"] } sphinx = { git = "https://github.com/nymtech/sphinx", rev="44d8f2aece5049eaa4fe84b7948758ce82b4b80d" } From 20c4753523e9becc1661151179381e670f3b7e16 Mon Sep 17 00:00:00 2001 From: jstuczyn Date: Tue, 14 Apr 2020 13:00:09 +0100 Subject: [PATCH 11/23] typo --- .../sfw-provider-requests/src/requests/serialization.rs | 2 +- .../sfw-provider-requests/src/responses/serialization.rs | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/sfw-provider/sfw-provider-requests/src/requests/serialization.rs b/sfw-provider/sfw-provider-requests/src/requests/serialization.rs index e62a6bb144d..7de0aca7309 100644 --- a/sfw-provider/sfw-provider-requests/src/requests/serialization.rs +++ b/sfw-provider/sfw-provider-requests/src/requests/serialization.rs @@ -16,7 +16,7 @@ impl RequestSerializer { } /// Serialized requests in general have the following structure: - /// follows: 4 byte len (be u32) || 1-byte kind prefix || request-specific data + /// 4 byte len (be u32) || 1-byte kind prefix || request-specific data pub fn into_bytes(self) -> Vec { let (kind, req_bytes) = match self.req { ProviderRequest::Pull(req) => (req.get_kind(), req.to_bytes()), diff --git a/sfw-provider/sfw-provider-requests/src/responses/serialization.rs b/sfw-provider/sfw-provider-requests/src/responses/serialization.rs index 9b9b995b797..5a764cbc394 100644 --- a/sfw-provider/sfw-provider-requests/src/responses/serialization.rs +++ b/sfw-provider/sfw-provider-requests/src/responses/serialization.rs @@ -74,7 +74,7 @@ impl ResponseSerializer { } /// Serialized responses in general have the following structure: - /// follows: 4 byte len (be u32) || 1-byte kind prefix || response-specific data + /// 4 byte len (be u32) || 1-byte kind prefix || response-specific data pub fn into_bytes(self) -> Vec { let (kind, res_bytes) = match self.res { ProviderResponse::Failure(res) => (res.get_kind(), res.to_bytes()), From 9a2068939199cc399be4cd91912286d0b4d0d14e Mon Sep 17 00:00:00 2001 From: jstuczyn Date: Tue, 14 Apr 2020 15:02:42 +0100 Subject: [PATCH 12/23] Easier conversion of requests/responses into enum variants --- .../sfw-provider-requests/src/requests/mod.rs | 12 ++++++++++++ .../sfw-provider-requests/src/responses/mod.rs | 18 ++++++++++++++++++ 2 files changed, 30 insertions(+) diff --git a/sfw-provider/sfw-provider-requests/src/requests/mod.rs b/sfw-provider/sfw-provider-requests/src/requests/mod.rs index bceb072f73e..85cfca85198 100644 --- a/sfw-provider/sfw-provider-requests/src/requests/mod.rs +++ b/sfw-provider/sfw-provider-requests/src/requests/mod.rs @@ -57,6 +57,12 @@ pub struct PullRequest { pub destination_address: sphinx::route::DestinationAddressBytes, } +impl Into for PullRequest { + fn into(self) -> ProviderRequest { + ProviderRequest::Pull(self) + } +} + impl PullRequest { pub fn new( destination_address: sphinx::route::DestinationAddressBytes, @@ -104,6 +110,12 @@ pub struct RegisterRequest { pub destination_address: DestinationAddressBytes, } +impl Into for RegisterRequest { + fn into(self) -> ProviderRequest { + ProviderRequest::Register(self) + } +} + impl RegisterRequest { pub fn new(destination_address: DestinationAddressBytes) -> Self { RegisterRequest { diff --git a/sfw-provider/sfw-provider-requests/src/responses/mod.rs b/sfw-provider/sfw-provider-requests/src/responses/mod.rs index ba51e5f2526..1a5f5c318b3 100644 --- a/sfw-provider/sfw-provider-requests/src/responses/mod.rs +++ b/sfw-provider/sfw-provider-requests/src/responses/mod.rs @@ -57,6 +57,12 @@ pub struct PullResponse { messages: Vec>, } +impl Into for PullResponse { + fn into(self) -> ProviderResponse { + ProviderResponse::Pull(self) + } +} + impl PullResponse { pub fn new(messages: Vec>) -> Self { PullResponse { messages } @@ -143,6 +149,12 @@ pub struct RegisterResponse { auth_token: AuthToken, } +impl Into for RegisterResponse { + fn into(self) -> ProviderResponse { + ProviderResponse::Register(self) + } +} + impl RegisterResponse { pub fn new(auth_token: AuthToken) -> Self { RegisterResponse { auth_token } @@ -178,6 +190,12 @@ pub struct FailureResponse { message: String, } +impl Into for FailureResponse { + fn into(self) -> ProviderResponse { + ProviderResponse::Failure(self) + } +} + impl FailureResponse { pub fn new>(message: S) -> Self { FailureResponse { From cf3e82f9efd9e6e18f4fd3fac07a215f8a79dd2e Mon Sep 17 00:00:00 2001 From: jstuczyn Date: Wed, 15 Apr 2020 10:41:04 +0100 Subject: [PATCH 13/23] Renamed 'read_be_u16' to better show its purpose --- sfw-provider/sfw-provider-requests/src/responses/mod.rs | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/sfw-provider/sfw-provider-requests/src/responses/mod.rs b/sfw-provider/sfw-provider-requests/src/responses/mod.rs index 1a5f5c318b3..a17517cb84a 100644 --- a/sfw-provider/sfw-provider-requests/src/responses/mod.rs +++ b/sfw-provider/sfw-provider-requests/src/responses/mod.rs @@ -110,7 +110,7 @@ impl PullResponse { } let mut bytes_copy = bytes; - let num_msgs = read_be_u16(&mut bytes_copy); + let num_msgs = extract_be_u16(&mut bytes_copy); // can we read all lengths of messages? if bytes_copy.len() < (num_msgs * 2) as usize { @@ -118,7 +118,7 @@ impl PullResponse { } let msgs_lens: Vec<_> = (0..num_msgs) - .map(|_| read_be_u16(&mut bytes_copy)) + .map(|_| extract_be_u16(&mut bytes_copy)) .collect(); let required_remaining_len = msgs_lens @@ -223,7 +223,8 @@ impl FailureResponse { } } -fn read_be_u16(input: &mut &[u8]) -> u16 { +/// Takes first 2 bytes of slice, REMOVES THEM and reads it as u16 +fn extract_be_u16(input: &mut &[u8]) -> u16 { let (int_bytes, rest) = input.split_at(std::mem::size_of::()); *input = rest; u16::from_be_bytes(int_bytes.try_into().unwrap()) From dbbdb1777e49c9cb5970e16f52672a9029969c28 Mon Sep 17 00:00:00 2001 From: jstuczyn Date: Wed, 15 Apr 2020 10:41:20 +0100 Subject: [PATCH 14/23] Serialization related tests and fixes --- sfw-provider/sfw-provider-requests/Cargo.toml | 1 + .../src/requests/serialization.rs | 222 ++++++++++++++- .../src/responses/serialization.rs | 256 +++++++++++++++--- 3 files changed, 433 insertions(+), 46 deletions(-) diff --git a/sfw-provider/sfw-provider-requests/Cargo.toml b/sfw-provider/sfw-provider-requests/Cargo.toml index cf4ada74b76..afe59ce5ed0 100644 --- a/sfw-provider/sfw-provider-requests/Cargo.toml +++ b/sfw-provider/sfw-provider-requests/Cargo.toml @@ -9,5 +9,6 @@ edition = "2018" [dependencies] bs58 = "0.3" bytes = "0.5" +byteorder = "1" tokio = { version = "0.2", features = ["io-util"] } sphinx = { git = "https://github.com/nymtech/sphinx", rev="44d8f2aece5049eaa4fe84b7948758ce82b4b80d" } diff --git a/sfw-provider/sfw-provider-requests/src/requests/serialization.rs b/sfw-provider/sfw-provider-requests/src/requests/serialization.rs index 7de0aca7309..4769dd2e3b5 100644 --- a/sfw-provider/sfw-provider-requests/src/requests/serialization.rs +++ b/sfw-provider/sfw-provider-requests/src/requests/serialization.rs @@ -1,11 +1,11 @@ -use crate::requests::{ - ProviderRequest, ProviderRequestError, PullRequest, RegisterRequest, RequestKind, -}; +use crate::requests::*; use std::convert::TryFrom; // TODO: way down the line, mostly for learning purposes, combine this with responses::serialization // via procedural macros +/// Responsible for taking a request and converting it into bytes that can be sent +/// over the wire, such that a `RequestDeserializer` can recover it. pub struct RequestSerializer { req: ProviderRequest, } @@ -33,27 +33,23 @@ impl RequestSerializer { } } +/// Responsible for taking raw bytes extracted from a stream that have been serialized +/// with `RequestSerializer` and eventually return original Request written. pub struct RequestDeserializer<'a> { kind: RequestKind, data: &'a [u8], } impl<'a> RequestDeserializer<'a> { - // perform initial parsing + // perform initial parsing and validation pub fn new(raw_bytes: &'a [u8]) -> Result { if raw_bytes.len() < 1 + 4 { Err(ProviderRequestError::UnmarshalErrorInvalidLength) } else { let data_len = u32::from_be_bytes([raw_bytes[0], raw_bytes[1], raw_bytes[2], raw_bytes[3]]); - let kind = RequestKind::try_from(raw_bytes[4])?; - let data = &raw_bytes[4..]; - - if data.len() != data_len as usize { - Err(ProviderRequestError::UnmarshalErrorInvalidLength) - } else { - Ok(RequestDeserializer { kind, data }) - } + + Self::new_with_len(data_len, &raw_bytes[4..]) } } @@ -86,3 +82,205 @@ impl<'a> RequestDeserializer<'a> { } } } + +#[cfg(test)] +mod request_serialization { + use super::*; + use crate::auth_token::{AuthToken, AUTH_TOKEN_SIZE}; + use byteorder::{BigEndian, ByteOrder}; + use sphinx::route::DestinationAddressBytes; + use std::convert::TryInto; + + #[test] + fn correctly_serializes_pull_request() { + let address = DestinationAddressBytes::from_bytes([ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, + 0, 1, 2, + ]); + let auth_token = AuthToken::from_bytes([1u8; AUTH_TOKEN_SIZE]); + let pull_request = PullRequest::new(address, auth_token); + + let raw_request_bytes = pull_request.to_bytes(); + let serializer = RequestSerializer::new(pull_request.clone().into()); + let bytes = serializer.into_bytes(); + + // we expect first four bytes to represent length then kind and finally raw data + let len = BigEndian::read_u32(&bytes); + let kind: RequestKind = bytes[4].try_into().unwrap(); + let data = &bytes[5..]; + assert_eq!(len as usize, data.len() + 1); + assert_eq!(data.to_vec(), raw_request_bytes); + + let recovered_request = PullRequest::try_from_bytes(data).unwrap(); + assert_eq!(pull_request, recovered_request); + assert_eq!(kind, pull_request.get_kind()); + } + + #[test] + fn correctly_serializes_register_request() { + let address = DestinationAddressBytes::from_bytes([ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, + 0, 1, 2, + ]); + let register_request = RegisterRequest::new(address); + + let raw_request_bytes = register_request.to_bytes(); + let serializer = RequestSerializer::new(register_request.clone().into()); + let bytes = serializer.into_bytes(); + + // we expect first four bytes to represent length then kind and finally raw data + let len = BigEndian::read_u32(&bytes); + let kind: RequestKind = bytes[4].try_into().unwrap(); + let data = &bytes[5..]; + assert_eq!(len as usize, data.len() + 1); + assert_eq!(data.to_vec(), raw_request_bytes); + + let recovered_request = RegisterRequest::try_from_bytes(data).unwrap(); + assert_eq!(register_request, recovered_request); + assert_eq!(kind, register_request.get_kind()); + } +} + +#[cfg(test)] +mod request_deserialization { + use super::*; + use crate::auth_token::{AuthToken, AUTH_TOKEN_SIZE}; + use byteorder::{BigEndian, ByteOrder}; + use sphinx::route::DestinationAddressBytes; + + #[test] + fn correctly_deserializes_pull_request() { + let address = DestinationAddressBytes::from_bytes([ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, + 0, 1, 2, + ]); + let auth_token = AuthToken::from_bytes([1u8; AUTH_TOKEN_SIZE]); + let pull_request = PullRequest::new(address, auth_token); + + let raw_request_bytes = pull_request.to_bytes(); + let bytes = RequestSerializer::new(pull_request.clone().into()).into_bytes(); + + let deserializer_new = RequestDeserializer::new(&bytes).unwrap(); + assert_eq!(deserializer_new.get_kind(), pull_request.get_kind()); + assert_eq!(deserializer_new.get_data().to_vec(), raw_request_bytes); + + assert_eq!( + ProviderRequest::Pull(pull_request.clone()), + deserializer_new.try_to_parse().unwrap() + ); + + // simulate consuming first 4 bytes to read len + let len = BigEndian::read_u32(&bytes); + let bytes_without_len = &bytes[4..]; + let deserializer_new_with_len = + RequestDeserializer::new_with_len(len, bytes_without_len).unwrap(); + assert_eq!( + deserializer_new_with_len.get_kind(), + pull_request.get_kind() + ); + assert_eq!( + deserializer_new_with_len.get_data().to_vec(), + raw_request_bytes + ); + + assert_eq!( + ProviderRequest::Pull(pull_request), + deserializer_new_with_len.try_to_parse().unwrap() + ); + } + + #[test] + fn correctly_deserializes_register_request() { + let address = DestinationAddressBytes::from_bytes([ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, + 0, 1, 2, + ]); + let register_request = RegisterRequest::new(address); + + let raw_request_bytes = register_request.to_bytes(); + let bytes = RequestSerializer::new(register_request.clone().into()).into_bytes(); + + let deserializer_new = RequestDeserializer::new(&bytes).unwrap(); + assert_eq!(deserializer_new.get_kind(), register_request.get_kind()); + assert_eq!(deserializer_new.get_data().to_vec(), raw_request_bytes); + + assert_eq!( + ProviderRequest::Register(register_request.clone()), + deserializer_new.try_to_parse().unwrap() + ); + + // simulate consuming first 4 bytes to read len + let len = BigEndian::read_u32(&bytes); + let bytes_without_len = &bytes[4..]; + let deserializer_new_with_len = + RequestDeserializer::new_with_len(len, bytes_without_len).unwrap(); + assert_eq!( + deserializer_new_with_len.get_kind(), + register_request.get_kind() + ); + assert_eq!( + deserializer_new_with_len.get_data().to_vec(), + raw_request_bytes + ); + + assert_eq!( + ProviderRequest::Register(register_request), + deserializer_new_with_len.try_to_parse().unwrap() + ); + } + + #[test] + fn returns_error_on_too_short_messages() { + // no matter the request, it must be AT LEAST 5 byte long (for length and 'kind') + let mut len_bytes = 1u32.to_be_bytes().to_vec(); + len_bytes.push(RequestKind::Register as u8); // to have a 'valid' kind + + // bare minimum should return no error + assert!(RequestDeserializer::new(&len_bytes).is_ok()); + + // but shorter should + assert!(RequestDeserializer::new(&0u32.to_be_bytes().to_vec()).is_err()); + } + + #[test] + fn returns_error_on_messages_of_contradictory_length() { + let data = vec![RequestKind::Register as u8, 1, 2, 3]; + + // it shouldn't fail if it matches up + assert!(RequestDeserializer::new_with_len(4, &data).is_ok()); + + assert!(RequestDeserializer::new_with_len(3, &data).is_err()); + } + + #[test] + fn returns_error_on_messages_of_unknown_kind() { + // perform proper serialization but change 'kind' byte to some invalid value + let address = DestinationAddressBytes::from_bytes([ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, + 0, 1, 2, + ]); + let register_request = RegisterRequest::new(address); + let mut bytes = RequestSerializer::new(register_request.clone().into()).into_bytes(); + + let invalid_kind = 42u8; + // sanity check to ensure it IS invalid + assert!(RequestKind::try_from(invalid_kind).is_err()); + bytes[4] = invalid_kind; + assert!(RequestDeserializer::new(&bytes).is_err()); + } + + #[test] + fn returns_error_on_parsing_invalid_data() { + // kind exists, length is correct, but data is unparsable + // no matter the request, it must be AT LEAST 5 byte long (for length and 'kind') + let mut len_bytes = 5u32.to_be_bytes().to_vec(); + len_bytes.push(RequestKind::Register as u8); // to have a 'valid' kind + len_bytes.push(1); + len_bytes.push(2); + len_bytes.push(3); + len_bytes.push(4); + + let deserializer = RequestDeserializer::new(&len_bytes).unwrap(); + assert!(deserializer.try_to_parse().is_err()); + } +} diff --git a/sfw-provider/sfw-provider-requests/src/responses/serialization.rs b/sfw-provider/sfw-provider-requests/src/responses/serialization.rs index 5a764cbc394..af2e541488b 100644 --- a/sfw-provider/sfw-provider-requests/src/responses/serialization.rs +++ b/sfw-provider/sfw-provider-requests/src/responses/serialization.rs @@ -1,33 +1,57 @@ -use crate::responses::{ - FailureResponse, ProviderResponse, ProviderResponseError, PullResponse, RegisterResponse, - ResponseKind, -}; +use crate::responses::*; use std::convert::TryFrom; // TODO: way down the line, mostly for learning purposes, combine this with requests::serialization // via procedural macros +/// Responsible for taking a response and converting it into bytes that can be sent +/// over the wire, such that a `ResponseDeserializer` can recover it. +pub struct ResponseSerializer { + res: ProviderResponse, +} + +impl ResponseSerializer { + pub fn new(res: ProviderResponse) -> Self { + ResponseSerializer { res } + } + + /// Serialized responses in general have the following structure: + /// 4 byte len (be u32) || 1-byte kind prefix || response-specific data + pub fn into_bytes(self) -> Vec { + let (kind, res_bytes) = match self.res { + // again, perhaps some extra macros/generics here? + ProviderResponse::Failure(res) => (res.get_kind(), res.to_bytes()), + ProviderResponse::Pull(res) => (res.get_kind(), res.to_bytes()), + ProviderResponse::Register(res) => (res.get_kind(), res.to_bytes()), + }; + let res_len = res_bytes.len() as u32 + 1; // 1 is to accommodate for 'kind' + let res_len_bytes = res_len.to_be_bytes(); + res_len_bytes + .iter() + .cloned() + .chain(std::iter::once(kind as u8)) + .chain(res_bytes.into_iter()) + .collect() + } +} + +/// Responsible for taking raw bytes extracted from a stream that have been serialized +/// with `ResponseSerializer` and eventually return original Response written. pub struct ResponseDeserializer<'a> { kind: ResponseKind, data: &'a [u8], } impl<'a> ResponseDeserializer<'a> { - // perform initial parsing + // perform initial parsing and validation pub fn new(raw_bytes: &'a [u8]) -> Result { if raw_bytes.len() < 1 + 4 { Err(ProviderResponseError::UnmarshalErrorInvalidLength) } else { let data_len = u32::from_be_bytes([raw_bytes[0], raw_bytes[1], raw_bytes[2], raw_bytes[3]]); - let kind = ResponseKind::try_from(raw_bytes[4])?; - let data = &raw_bytes[4..]; - - if data.len() != data_len as usize { - Err(ProviderResponseError::UnmarshalErrorInvalidLength) - } else { - Ok(ResponseDeserializer { kind, data }) - } + + Self::new_with_len(data_len, &raw_bytes[4..]) } } @@ -64,30 +88,194 @@ impl<'a> ResponseDeserializer<'a> { } } -pub struct ResponseSerializer { - res: ProviderResponse, +#[cfg(test)] +mod response_serialization { + use super::*; + use crate::auth_token::{AuthToken, AUTH_TOKEN_SIZE}; + use byteorder::{BigEndian, ByteOrder}; + use std::convert::TryInto; + + #[test] + fn correctly_serializes_pull_response() { + let msg1 = vec![ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, + 0, 1, 2, + ]; + let msg2 = vec![1, 2, 3, 4, 5, 6, 7]; + let pull_response = PullResponse::new(vec![msg1, msg2]); + + let raw_response_bytes = pull_response.to_bytes(); + let serializer = ResponseSerializer::new(pull_response.clone().into()); + let bytes = serializer.into_bytes(); + + // we expect first four bytes to represent length then kind and finally raw data + let len = BigEndian::read_u32(&bytes); + let kind: ResponseKind = bytes[4].try_into().unwrap(); + let data = &bytes[5..]; + assert_eq!(len as usize, data.len() + 1); + assert_eq!(data.to_vec(), raw_response_bytes); + + let recovered_response = PullResponse::try_from_bytes(data).unwrap(); + assert_eq!(pull_response, recovered_response); + assert_eq!(kind, pull_response.get_kind()); + } + + #[test] + fn correctly_serializes_register_response() { + let auth_token = AuthToken::from_bytes([1u8; AUTH_TOKEN_SIZE]); + let register_response = RegisterResponse::new(auth_token); + + let raw_response_bytes = register_response.to_bytes(); + let serializer = ResponseSerializer::new(register_response.clone().into()); + let bytes = serializer.into_bytes(); + + // we expect first four bytes to represent length then kind and finally raw data + let len = BigEndian::read_u32(&bytes); + let kind: ResponseKind = bytes[4].try_into().unwrap(); + let data = &bytes[5..]; + assert_eq!(len as usize, data.len() + 1); + assert_eq!(data.to_vec(), raw_response_bytes); + + let recovered_response = RegisterResponse::try_from_bytes(data).unwrap(); + assert_eq!(register_response, recovered_response); + assert_eq!(kind, register_response.get_kind()); + } } -impl ResponseSerializer { - pub fn new(res: ProviderResponse) -> Self { - ResponseSerializer { res } +#[cfg(test)] +mod response_deserialization { + use super::*; + use crate::auth_token::{AuthToken, AUTH_TOKEN_SIZE}; + use byteorder::{BigEndian, ByteOrder}; + + #[test] + fn correctly_deserializes_pull_response() { + let msg1 = vec![ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, + 0, 1, 2, + ]; + let msg2 = vec![1, 2, 3, 4, 5, 6, 7]; + let pull_response = PullResponse::new(vec![msg1, msg2]); + + let raw_response_bytes = pull_response.to_bytes(); + let bytes = ResponseSerializer::new(pull_response.clone().into()).into_bytes(); + + let deserializer_new = ResponseDeserializer::new(&bytes).unwrap(); + assert_eq!(deserializer_new.get_kind(), pull_response.get_kind()); + assert_eq!(deserializer_new.get_data().to_vec(), raw_response_bytes); + + assert_eq!( + ProviderResponse::Pull(pull_response.clone()), + deserializer_new.try_to_parse().unwrap() + ); + + // simulate consuming first 4 bytes to read len + let len = BigEndian::read_u32(&bytes); + let bytes_without_len = &bytes[4..]; + let deserializer_new_with_len = + ResponseDeserializer::new_with_len(len, bytes_without_len).unwrap(); + assert_eq!( + deserializer_new_with_len.get_kind(), + pull_response.get_kind() + ); + assert_eq!( + deserializer_new_with_len.get_data().to_vec(), + raw_response_bytes + ); + + assert_eq!( + ProviderResponse::Pull(pull_response), + deserializer_new_with_len.try_to_parse().unwrap() + ); } - /// Serialized responses in general have the following structure: - /// 4 byte len (be u32) || 1-byte kind prefix || response-specific data - pub fn into_bytes(self) -> Vec { - let (kind, res_bytes) = match self.res { - ProviderResponse::Failure(res) => (res.get_kind(), res.to_bytes()), - ProviderResponse::Pull(res) => (res.get_kind(), res.to_bytes()), - ProviderResponse::Register(res) => (res.get_kind(), res.to_bytes()), - }; - let res_len = res_bytes.len() as u32 + 1; // 1 is to accommodate for 'kind' - let res_len_bytes = res_len.to_be_bytes(); - res_len_bytes - .iter() - .cloned() - .chain(std::iter::once(kind as u8)) - .chain(res_bytes.into_iter()) - .collect() + #[test] + fn correctly_deserializes_register_response() { + let auth_token = AuthToken::from_bytes([1u8; AUTH_TOKEN_SIZE]); + let register_response = RegisterResponse::new(auth_token); + + let raw_response_bytes = register_response.to_bytes(); + let bytes = ResponseSerializer::new(register_response.clone().into()).into_bytes(); + + let deserializer_new = ResponseDeserializer::new(&bytes).unwrap(); + assert_eq!(deserializer_new.get_kind(), register_response.get_kind()); + assert_eq!(deserializer_new.get_data().to_vec(), raw_response_bytes); + + assert_eq!( + ProviderResponse::Register(register_response.clone()), + deserializer_new.try_to_parse().unwrap() + ); + + // simulate consuming first 4 bytes to read len + let len = BigEndian::read_u32(&bytes); + let bytes_without_len = &bytes[4..]; + let deserializer_new_with_len = + ResponseDeserializer::new_with_len(len, bytes_without_len).unwrap(); + assert_eq!( + deserializer_new_with_len.get_kind(), + register_response.get_kind() + ); + assert_eq!( + deserializer_new_with_len.get_data().to_vec(), + raw_response_bytes + ); + + assert_eq!( + ProviderResponse::Register(register_response), + deserializer_new_with_len.try_to_parse().unwrap() + ); + } + + #[test] + fn returns_error_on_too_short_messages() { + // no matter the response, it must be AT LEAST 5 byte long (for length and 'kind') + let mut len_bytes = 1u32.to_be_bytes().to_vec(); + len_bytes.push(ResponseKind::Register as u8); // to have a 'valid' kind + + // bare minimum should return no error + assert!(ResponseDeserializer::new(&len_bytes).is_ok()); + + // but shorter should + assert!(ResponseDeserializer::new(&0u32.to_be_bytes().to_vec()).is_err()); + } + + #[test] + fn returns_error_on_messages_of_contradictory_length() { + let data = vec![ResponseKind::Register as u8, 1, 2, 3]; + + // it shouldn't fail if it matches up + assert!(ResponseDeserializer::new_with_len(4, &data).is_ok()); + + assert!(ResponseDeserializer::new_with_len(3, &data).is_err()); + } + + #[test] + fn returns_error_on_messages_of_unknown_kind() { + // perform proper serialization but change 'kind' byte to some invalid value + let auth_token = AuthToken::from_bytes([1u8; AUTH_TOKEN_SIZE]); + let register_response = RegisterResponse::new(auth_token); + + let mut bytes = ResponseSerializer::new(register_response.clone().into()).into_bytes(); + + let invalid_kind = 42u8; + // sanity check to ensure it IS invalid + assert!(ResponseKind::try_from(invalid_kind).is_err()); + bytes[4] = invalid_kind; + assert!(ResponseDeserializer::new(&bytes).is_err()); + } + + #[test] + fn returns_error_on_parsing_invalid_data() { + // kind exists, length is correct, but data is unparsable + // no matter the response, it must be AT LEAST 5 byte long (for length and 'kind') + let mut len_bytes = 5u32.to_be_bytes().to_vec(); + len_bytes.push(ResponseKind::Register as u8); // to have a 'valid' kind + len_bytes.push(1); + len_bytes.push(2); + len_bytes.push(3); + len_bytes.push(4); + + let deserializer = ResponseDeserializer::new(&len_bytes).unwrap(); + assert!(deserializer.try_to_parse().is_err()); } } From fcf3481e7cb933322ec127bfc68f38f02322437c Mon Sep 17 00:00:00 2001 From: jstuczyn Date: Wed, 15 Apr 2020 16:01:00 +0100 Subject: [PATCH 15/23] Tests for async_io + fixes --- sfw-provider/sfw-provider-requests/Cargo.toml | 4 + .../src/requests/async_io.rs | 268 ++++++++++++++++- .../src/responses/async_io.rs | 275 +++++++++++++++++- 3 files changed, 530 insertions(+), 17 deletions(-) diff --git a/sfw-provider/sfw-provider-requests/Cargo.toml b/sfw-provider/sfw-provider-requests/Cargo.toml index afe59ce5ed0..1807e117b0e 100644 --- a/sfw-provider/sfw-provider-requests/Cargo.toml +++ b/sfw-provider/sfw-provider-requests/Cargo.toml @@ -12,3 +12,7 @@ bytes = "0.5" byteorder = "1" tokio = { version = "0.2", features = ["io-util"] } sphinx = { git = "https://github.com/nymtech/sphinx", rev="44d8f2aece5049eaa4fe84b7948758ce82b4b80d" } + +[dev-dependencies] +tokio = { version = "0.2", features = ["rt-threaded", "time"] } +tokio-test = "0.2" diff --git a/sfw-provider/sfw-provider-requests/src/requests/async_io.rs b/sfw-provider/sfw-provider-requests/src/requests/async_io.rs index 2e70511521c..596fdbce566 100644 --- a/sfw-provider/sfw-provider-requests/src/requests/async_io.rs +++ b/sfw-provider/sfw-provider-requests/src/requests/async_io.rs @@ -39,20 +39,17 @@ impl<'a, R: AsyncRead + Unpin> TokioAsyncRequestReader<'a, R> { return Err(ProviderRequestError::RemoteConnectionClosed); } if req_len as usize > self.max_allowed_len { - // TODO: should reader be drained? + // TODO: should reader be drained or just assume caller will close the + // underlying reader and/or deal with the issue itself? return Err(ProviderRequestError::TooLongRequestError); } - let mut req_buf = Vec::with_capacity(req_len as usize); - let mut chunk = self.reader.take(req_len as u64); - - if let Err(_) = chunk.read_to_end(&mut req_buf).await { - return Err(ProviderRequestError::TooShortRequestError); - }; - - let parse_res = RequestDeserializer::new_with_len(req_len, &req_buf)?.try_to_parse(); + let mut req_buf = vec![0; req_len as usize]; + if let Err(e) = self.reader.read_exact(&mut req_buf).await { + return Err(ProviderRequestError::IOError(e)); + } - parse_res + RequestDeserializer::new_with_len(req_len, &req_buf)?.try_to_parse() } } @@ -72,3 +69,254 @@ impl<'a, W: AsyncWrite + Unpin> TokioAsyncRequestWriter<'a, W> { self.writer.write_all(&res_bytes).await } } + +#[cfg(test)] +mod request_writer { + use super::*; + use crate::auth_token::{AuthToken, AUTH_TOKEN_SIZE}; + use crate::requests::{PullRequest, RegisterRequest}; + use sphinx::route::DestinationAddressBytes; + + // TODO: what else to test here? + + #[test] + fn writes_all_bytes_to_underlying_writer_for_register_request() { + let mut rt = tokio::runtime::Runtime::new().unwrap(); + + let address = DestinationAddressBytes::from_bytes([ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, + 0, 1, 2, + ]); + let register_request = RegisterRequest::new(address); + let expected_bytes = RequestSerializer::new(register_request.clone().into()).into_bytes(); + + let mut writer = Vec::new(); + + let mut request_writer = TokioAsyncRequestWriter::new(&mut writer); + rt.block_on(request_writer.try_write_request(register_request.into())) + .unwrap(); + + // to finish the mutable borrow since we don't need request_writer anymore anyway + drop(request_writer); + + assert_eq!(writer, expected_bytes); + } + + #[test] + fn writes_all_bytes_to_underlying_writer_for_pull_request() { + let mut rt = tokio::runtime::Runtime::new().unwrap(); + + let address = DestinationAddressBytes::from_bytes([ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, + 0, 1, 2, + ]); + let auth_token = AuthToken::from_bytes([1u8; AUTH_TOKEN_SIZE]); + let pull_request = PullRequest::new(address, auth_token); + let expected_bytes = RequestSerializer::new(pull_request.clone().into()).into_bytes(); + + let mut writer = Vec::new(); + + let mut request_writer = TokioAsyncRequestWriter::new(&mut writer); + rt.block_on(request_writer.try_write_request(pull_request.into())) + .unwrap(); + + // to finish the mutable borrow since we don't need request_writer anymore anyway + drop(request_writer); + + assert_eq!(writer, expected_bytes); + } +} + +#[cfg(test)] +mod request_reader { + use super::*; + use crate::auth_token::{AuthToken, AUTH_TOKEN_SIZE}; + use crate::requests::{PullRequest, RegisterRequest, RequestKind}; + use sphinx::route::DestinationAddressBytes; + use std::io::Cursor; + use std::time; + + #[test] + fn correctly_reads_valid_register_request() { + let mut rt = tokio::runtime::Runtime::new().unwrap(); + + let address = DestinationAddressBytes::from_bytes([ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, + 0, 1, 2, + ]); + let register_request = RegisterRequest::new(address); + let serialized_bytes = RequestSerializer::new(register_request.clone().into()).into_bytes(); + + let mut reader = Cursor::new(serialized_bytes); + let mut request_reader = + TokioAsyncRequestReader::new(&mut reader, u32::max_value() as usize); + + let read_request = rt.block_on(request_reader.try_read_request()).unwrap(); + match read_request { + ProviderRequest::Register(req) => assert_eq!(register_request, req), + _ => panic!("read incorrect request!"), + } + } + + #[test] + fn correctly_reads_valid_pull_request() { + let mut rt = tokio::runtime::Runtime::new().unwrap(); + + let address = DestinationAddressBytes::from_bytes([ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, + 0, 1, 2, + ]); + let auth_token = AuthToken::from_bytes([1u8; AUTH_TOKEN_SIZE]); + let pull_request = PullRequest::new(address, auth_token); + let serialized_bytes = RequestSerializer::new(pull_request.clone().into()).into_bytes(); + + let mut reader = Cursor::new(serialized_bytes); + let mut request_reader = + TokioAsyncRequestReader::new(&mut reader, u32::max_value() as usize); + + let read_request = rt.block_on(request_reader.try_read_request()).unwrap(); + match read_request { + ProviderRequest::Pull(req) => assert_eq!(pull_request, req), + _ => panic!("read incorrect request!"), + } + } + + #[test] + fn correctly_reads_valid_register_request_even_if_more_random_bytes_follow() { + // note that if read was called again, it would have failed + + let mut rt = tokio::runtime::Runtime::new().unwrap(); + + let address = DestinationAddressBytes::from_bytes([ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, + 0, 1, 2, + ]); + let register_request = RegisterRequest::new(address); + let serialized_bytes = RequestSerializer::new(register_request.clone().into()).into_bytes(); + + let serialized_bytes_with_garbage: Vec<_> = serialized_bytes + .into_iter() + .chain(vec![1, 2, 3, 4, 5, 6, 7, 8, 9].into_iter()) + .collect(); + + let mut reader = Cursor::new(serialized_bytes_with_garbage); + let mut request_reader = + TokioAsyncRequestReader::new(&mut reader, u32::max_value() as usize); + + let read_request = rt.block_on(request_reader.try_read_request()).unwrap(); + match read_request { + ProviderRequest::Register(req) => assert_eq!(register_request, req), + _ => panic!("read incorrect request!"), + } + } + + #[test] + fn correctly_reads_two_consecutive_requests() { + let mut rt = tokio::runtime::Runtime::new().unwrap(); + + let address = DestinationAddressBytes::from_bytes([ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, + 0, 1, 2, + ]); + let auth_token = AuthToken::from_bytes([1u8; AUTH_TOKEN_SIZE]); + + let pull_request = PullRequest::new(address.clone(), auth_token); + let register_request = RegisterRequest::new(address); + + let register_serialized_bytes = + RequestSerializer::new(register_request.clone().into()).into_bytes(); + let pull_serialized_bytes = + RequestSerializer::new(pull_request.clone().into()).into_bytes(); + + let combined_requests: Vec<_> = register_serialized_bytes + .into_iter() + .chain(pull_serialized_bytes.into_iter()) + .collect(); + + let mut reader = Cursor::new(combined_requests); + let mut request_reader = + TokioAsyncRequestReader::new(&mut reader, u32::max_value() as usize); + + let first_read_request = rt.block_on(request_reader.try_read_request()).unwrap(); + match first_read_request { + ProviderRequest::Register(req) => assert_eq!(register_request, req), + _ => panic!("read incorrect request!"), + } + + let second_read_request = rt.block_on(request_reader.try_read_request()).unwrap(); + match second_read_request { + ProviderRequest::Pull(req) => assert_eq!(pull_request, req), + _ => panic!("read incorrect request!"), + } + } + + #[test] + fn correctly_reads_valid_request_even_if_written_with_delay() { + let mut rt = tokio::runtime::Runtime::new().unwrap(); + + let address = DestinationAddressBytes::from_bytes([ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, + 0, 1, 2, + ]); + let register_request = RegisterRequest::new(address); + let serialized_bytes = RequestSerializer::new(register_request.clone().into()).into_bytes(); + + let (first_half, second_half) = serialized_bytes.split_at(30); // 30 is an arbitrary value + + let mut mock = tokio_test::io::Builder::new() + .read(&first_half) + .wait(time::Duration::from_millis(300)) + .read(&second_half) + .build(); + + let mut request_reader = TokioAsyncRequestReader::new(&mut mock, u32::max_value() as usize); + + let read_request = rt.block_on(request_reader.try_read_request()).unwrap(); + match read_request { + ProviderRequest::Register(req) => assert_eq!(register_request, req), + _ => panic!("read incorrect request!"), + } + } + + #[test] + fn fails_to_read_invalid_request() { + let mut rt = tokio::runtime::Runtime::new().unwrap(); + + let mut invalid_request = 9u32.to_be_bytes().to_vec(); + invalid_request.push(RequestKind::Register as u8); // to have a 'valid' kind + invalid_request.push(0); + invalid_request.push(1); + invalid_request.push(2); + invalid_request.push(3); + invalid_request.push(4); + invalid_request.push(5); + invalid_request.push(6); + invalid_request.push(7); + + let mut reader = Cursor::new(invalid_request); + let mut request_reader = + TokioAsyncRequestReader::new(&mut reader, u32::max_value() as usize); + + assert!(rt.block_on(request_reader.try_read_request()).is_err()); + } + + #[test] + fn fails_to_read_too_long_request() { + let mut rt = tokio::runtime::Runtime::new().unwrap(); + + let address = DestinationAddressBytes::from_bytes([ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, + 0, 1, 2, + ]); + let register_request = RegisterRequest::new(address); + let serialized_bytes = RequestSerializer::new(register_request.clone().into()).into_bytes(); + let serialized_bytes_len = serialized_bytes.len(); + + let mut reader = Cursor::new(serialized_bytes); + // note our reader accepts fewer bytes than what we have + let mut request_reader = + TokioAsyncRequestReader::new(&mut reader, serialized_bytes_len - 10); + + assert!(rt.block_on(request_reader.try_read_request()).is_err()); + } +} \ No newline at end of file diff --git a/sfw-provider/sfw-provider-requests/src/responses/async_io.rs b/sfw-provider/sfw-provider-requests/src/responses/async_io.rs index 4ad11818a52..bbefa31f786 100644 --- a/sfw-provider/sfw-provider-requests/src/responses/async_io.rs +++ b/sfw-provider/sfw-provider-requests/src/responses/async_io.rs @@ -39,16 +39,15 @@ impl<'a, R: AsyncRead + Unpin> TokioAsyncResponseReader<'a, R> { return Err(ProviderResponseError::RemoteConnectionClosed); } if res_len as usize > self.max_allowed_len { - // TODO: should reader be drained? + // TODO: should reader be drained or just assume caller will close the + // underlying reader and/or deal with the issue itself? return Err(ProviderResponseError::TooLongResponseError); } - let mut res_buf = Vec::with_capacity(res_len as usize); - let mut chunk = self.reader.take(res_len as u64); - - if let Err(_) = chunk.read_to_end(&mut res_buf).await { - return Err(ProviderResponseError::TooShortResponseError); - }; + let mut res_buf = vec![0; res_len as usize]; + if let Err(e) = self.reader.read_exact(&mut res_buf).await { + return Err(ProviderResponseError::IOError(e)); + } ResponseDeserializer::new_with_len(res_len, &res_buf)?.try_to_parse() } @@ -70,3 +69,265 @@ impl<'a, W: AsyncWrite + Unpin> TokioAsyncResponseWriter<'a, W> { self.writer.write_all(&res_bytes).await } } + +#[cfg(test)] +mod response_writer { + use super::*; + use crate::auth_token::{AuthToken, AUTH_TOKEN_SIZE}; + use crate::responses::{FailureResponse, PullResponse, RegisterResponse}; + + // TODO: what else to test here? + + #[test] + fn writes_all_bytes_to_underlying_writer_for_register_response() { + let mut rt = tokio::runtime::Runtime::new().unwrap(); + + let auth_token = AuthToken::from_bytes([1u8; AUTH_TOKEN_SIZE]); + let register_response = RegisterResponse::new(auth_token); + let expected_bytes = ResponseSerializer::new(register_response.clone().into()).into_bytes(); + + let mut writer = Vec::new(); + + let mut response_writer = TokioAsyncResponseWriter::new(&mut writer); + rt.block_on(response_writer.try_write_response(register_response.into())) + .unwrap(); + + // to finish the mutable borrow since we don't need response_writer anymore anyway + drop(response_writer); + + assert_eq!(writer, expected_bytes); + } + + #[test] + fn writes_all_bytes_to_underlying_writer_for_pull_response() { + let mut rt = tokio::runtime::Runtime::new().unwrap(); + + let msg1 = vec![ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, + 0, 1, 2, + ]; + let msg2 = vec![1, 2, 3, 4, 5, 6, 7]; + let pull_response = PullResponse::new(vec![msg1, msg2]); + + let expected_bytes = ResponseSerializer::new(pull_response.clone().into()).into_bytes(); + + let mut writer = Vec::new(); + + let mut response_writer = TokioAsyncResponseWriter::new(&mut writer); + rt.block_on(response_writer.try_write_response(pull_response.into())) + .unwrap(); + + // to finish the mutable borrow since we don't need response_writer anymore anyway + drop(response_writer); + + assert_eq!(writer, expected_bytes); + } + + #[test] + fn writes_all_bytes_to_underlying_writer_for_failure_response() { + let mut rt = tokio::runtime::Runtime::new().unwrap(); + + let msg1 = "hello nym"; + let failure_response = FailureResponse::new(msg1); + + let expected_bytes = ResponseSerializer::new(failure_response.clone().into()).into_bytes(); + + let mut writer = Vec::new(); + + let mut response_writer = TokioAsyncResponseWriter::new(&mut writer); + rt.block_on(response_writer.try_write_response(failure_response.into())) + .unwrap(); + + // to finish the mutable borrow since we don't need response_writer anymore anyway + drop(response_writer); + + assert_eq!(writer, expected_bytes); + } +} + +#[cfg(test)] +mod response_reader { + use super::*; + use crate::auth_token::{AuthToken, AUTH_TOKEN_SIZE}; + use crate::responses::{PullResponse, RegisterResponse, ResponseKind}; + use std::io::Cursor; + use std::time; + + #[test] + fn correctly_reads_valid_register_response() { + let mut rt = tokio::runtime::Runtime::new().unwrap(); + + let auth_token = AuthToken::from_bytes([1u8; AUTH_TOKEN_SIZE]); + let register_response = RegisterResponse::new(auth_token); + let serialized_bytes = + ResponseSerializer::new(register_response.clone().into()).into_bytes(); + + let mut reader = Cursor::new(serialized_bytes); + let mut response_reader = + TokioAsyncResponseReader::new(&mut reader, u32::max_value() as usize); + + let read_response = rt.block_on(response_reader.try_read_response()).unwrap(); + match read_response { + ProviderResponse::Register(req) => assert_eq!(register_response, req), + _ => panic!("read incorrect response!"), + } + } + + #[test] + fn correctly_reads_valid_pull_response() { + let mut rt = tokio::runtime::Runtime::new().unwrap(); + + let msg1 = vec![ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, + 0, 1, 2, + ]; + let msg2 = vec![1, 2, 3, 4, 5, 6, 7]; + let pull_response = PullResponse::new(vec![msg1, msg2]); + let serialized_bytes = ResponseSerializer::new(pull_response.clone().into()).into_bytes(); + + let mut reader = Cursor::new(serialized_bytes); + let mut response_reader = + TokioAsyncResponseReader::new(&mut reader, u32::max_value() as usize); + + let read_response = rt.block_on(response_reader.try_read_response()).unwrap(); + match read_response { + ProviderResponse::Pull(req) => assert_eq!(pull_response, req), + _ => panic!("read incorrect response!"), + } + } + + #[test] + fn correctly_reads_valid_register_response_even_if_more_random_bytes_follow() { + // note that if read was called again, it would have failed + + let mut rt = tokio::runtime::Runtime::new().unwrap(); + + let auth_token = AuthToken::from_bytes([1u8; AUTH_TOKEN_SIZE]); + let register_response = RegisterResponse::new(auth_token); + let serialized_bytes = + ResponseSerializer::new(register_response.clone().into()).into_bytes(); + + let serialized_bytes_with_garbage: Vec<_> = serialized_bytes + .into_iter() + .chain(vec![1, 2, 3, 4, 5, 6, 7, 8, 9].into_iter()) + .collect(); + + let mut reader = Cursor::new(serialized_bytes_with_garbage); + let mut response_reader = + TokioAsyncResponseReader::new(&mut reader, u32::max_value() as usize); + + let read_response = rt.block_on(response_reader.try_read_response()).unwrap(); + match read_response { + ProviderResponse::Register(req) => assert_eq!(register_response, req), + _ => panic!("read incorrect response!"), + } + } + + #[test] + fn correctly_reads_two_consecutive_responses() { + let mut rt = tokio::runtime::Runtime::new().unwrap(); + + let msg1 = vec![ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, + 0, 1, 2, + ]; + let msg2 = vec![1, 2, 3, 4, 5, 6, 7]; + let auth_token = AuthToken::from_bytes([1u8; AUTH_TOKEN_SIZE]); + + let pull_response = PullResponse::new(vec![msg1, msg2]); + let register_response = RegisterResponse::new(auth_token); + + let register_serialized_bytes = + ResponseSerializer::new(register_response.clone().into()).into_bytes(); + let pull_serialized_bytes = + ResponseSerializer::new(pull_response.clone().into()).into_bytes(); + + let combined_responses: Vec<_> = register_serialized_bytes + .into_iter() + .chain(pull_serialized_bytes.into_iter()) + .collect(); + + let mut reader = Cursor::new(combined_responses); + let mut response_reader = + TokioAsyncResponseReader::new(&mut reader, u32::max_value() as usize); + + let first_read_response = rt.block_on(response_reader.try_read_response()).unwrap(); + match first_read_response { + ProviderResponse::Register(req) => assert_eq!(register_response, req), + _ => panic!("read incorrect response!"), + } + + let second_read_response = rt.block_on(response_reader.try_read_response()).unwrap(); + match second_read_response { + ProviderResponse::Pull(req) => assert_eq!(pull_response, req), + _ => panic!("read incorrect response!"), + } + } + + #[test] + fn correctly_reads_valid_response_even_if_written_with_delay() { + let mut rt = tokio::runtime::Runtime::new().unwrap(); + + let auth_token = AuthToken::from_bytes([1u8; AUTH_TOKEN_SIZE]); + let register_response = RegisterResponse::new(auth_token); + let serialized_bytes = + ResponseSerializer::new(register_response.clone().into()).into_bytes(); + + let (first_half, second_half) = serialized_bytes.split_at(30); // 30 is an arbitrary value + + let mut mock = tokio_test::io::Builder::new() + .read(&first_half) + .wait(time::Duration::from_millis(300)) + .read(&second_half) + .build(); + + let mut response_reader = + TokioAsyncResponseReader::new(&mut mock, u32::max_value() as usize); + + let read_response = rt.block_on(response_reader.try_read_response()).unwrap(); + match read_response { + ProviderResponse::Register(req) => assert_eq!(register_response, req), + _ => panic!("read incorrect response!"), + } + } + + #[test] + fn fails_to_read_invalid_response() { + let mut rt = tokio::runtime::Runtime::new().unwrap(); + + let mut invalid_response = 9u32.to_be_bytes().to_vec(); + invalid_response.push(ResponseKind::Register as u8); // to have a 'valid' kind + invalid_response.push(0); + invalid_response.push(1); + invalid_response.push(2); + invalid_response.push(3); + invalid_response.push(4); + invalid_response.push(5); + invalid_response.push(6); + invalid_response.push(7); + + let mut reader = Cursor::new(invalid_response); + let mut response_reader = + TokioAsyncResponseReader::new(&mut reader, u32::max_value() as usize); + + assert!(rt.block_on(response_reader.try_read_response()).is_err()); + } + + #[test] + fn fails_to_read_too_long_response() { + let mut rt = tokio::runtime::Runtime::new().unwrap(); + + let auth_token = AuthToken::from_bytes([1u8; AUTH_TOKEN_SIZE]); + let register_response = RegisterResponse::new(auth_token); + let serialized_bytes = + ResponseSerializer::new(register_response.clone().into()).into_bytes(); + let serialized_bytes_len = serialized_bytes.len(); + + let mut reader = Cursor::new(serialized_bytes); + // note our reader accepts fewer bytes than what we have + let mut response_reader = + TokioAsyncResponseReader::new(&mut reader, serialized_bytes_len - 10); + + assert!(rt.block_on(response_reader.try_read_response()).is_err()); + } +} From 85e54c1b190c40229c577708595b8c841a07154e Mon Sep 17 00:00:00 2001 From: jstuczyn Date: Wed, 15 Apr 2020 16:01:12 +0100 Subject: [PATCH 16/23] Future considerations --- common/clients/provider-client/src/lib.rs | 4 ++++ sfw-provider/sfw-provider-requests/src/lib.rs | 4 ++++ sfw-provider/src/provider/client_handling/listener.rs | 4 ++++ 3 files changed, 12 insertions(+) diff --git a/common/clients/provider-client/src/lib.rs b/common/clients/provider-client/src/lib.rs index 1f2db8cfed5..660e2a52e27 100644 --- a/common/clients/provider-client/src/lib.rs +++ b/common/clients/provider-client/src/lib.rs @@ -101,6 +101,10 @@ impl ProviderClient { let socket = self.connection.as_mut().unwrap(); let (mut socket_reader, mut socket_writer) = socket.split(); + // TODO: benchmark and determine if below should be done: + // let mut socket_writer = tokio::io::BufWriter::new(socket_writer); + // let mut socket_reader = tokio::io::BufReader::new(socket_reader); + let mut request_writer = TokioAsyncRequestWriter::new(&mut socket_writer); let mut response_reader = TokioAsyncResponseReader::new(&mut socket_reader, MAX_RESPONSE_SIZE); diff --git a/sfw-provider/sfw-provider-requests/src/lib.rs b/sfw-provider/sfw-provider-requests/src/lib.rs index e6d4883c0a4..848954a81a5 100644 --- a/sfw-provider/sfw-provider-requests/src/lib.rs +++ b/sfw-provider/sfw-provider-requests/src/lib.rs @@ -4,3 +4,7 @@ pub mod responses; pub const DUMMY_MESSAGE_CONTENT: &[u8] = b"[DUMMY MESSAGE] Wanting something does not give you the right to have it."; + +// TODO: consideration for the future: should all request/responses have associated IDs +// for "async" API? However, TCP should ensure packets are received in order, so maybe +// it's not really required? diff --git a/sfw-provider/src/provider/client_handling/listener.rs b/sfw-provider/src/provider/client_handling/listener.rs index eb3decddc36..5324608bf18 100644 --- a/sfw-provider/src/provider/client_handling/listener.rs +++ b/sfw-provider/src/provider/client_handling/listener.rs @@ -61,6 +61,10 @@ async fn process_socket_connection( ) { let peer_addr = socket.peer_addr(); let (mut socket_reader, mut socket_writer) = socket.split(); + // TODO: benchmark and determine if below should be done: + // let mut socket_writer = tokio::io::BufWriter::new(socket_writer); + // let mut socket_reader = tokio::io::BufReader::new(socket_reader); + let mut request_reader = TokioAsyncRequestReader::new(&mut socket_reader, MAX_REQUEST_SIZE); let mut response_writer = TokioAsyncResponseWriter::new(&mut socket_writer); From 1ac7ecf6bf28d5efbaf7e69ae7b929fe0ee2a5c4 Mon Sep 17 00:00:00 2001 From: jstuczyn Date: Wed, 15 Apr 2020 16:26:55 +0100 Subject: [PATCH 17/23] Configurable max request size --- sfw-provider/src/config/mod.rs | 13 ++++++++++++- .../src/provider/client_handling/listener.rs | 10 ++++++---- .../provider/client_handling/request_processing.rs | 10 ++++++++++ sfw-provider/src/provider/mod.rs | 1 + 4 files changed, 29 insertions(+), 5 deletions(-) diff --git a/sfw-provider/src/config/mod.rs b/sfw-provider/src/config/mod.rs index 9b49692a4d4..b4476b403a3 100644 --- a/sfw-provider/src/config/mod.rs +++ b/sfw-provider/src/config/mod.rs @@ -19,6 +19,7 @@ const DEFAULT_PRESENCE_SENDING_DELAY: u64 = 1500; const DEFAULT_STORED_MESSAGE_FILENAME_LENGTH: u16 = 16; const DEFAULT_MESSAGE_RETRIEVAL_LIMIT: u16 = 5; +const DEFAULT_MAX_REQUEST_SIZE: u32 = 16 * 1024; #[derive(Debug, Default, Deserialize, PartialEq, Serialize)] #[serde(deny_unknown_fields)] @@ -333,6 +334,10 @@ impl Config { pub fn get_stored_messages_filename_length(&self) -> u16 { self.debug.stored_messages_filename_length } + + pub fn get_max_request_size(&self) -> usize { + self.debug.max_request_size as usize + } } #[derive(Debug, Deserialize, PartialEq, Serialize)] @@ -479,10 +484,15 @@ pub struct Debug { /// Length of filenames for new client messages. stored_messages_filename_length: u16, - /// number of messages client gets on each request + /// Number of messages client gets on each request /// if there are no real messages, dummy ones are create to always return /// `message_retrieval_limit` total messages message_retrieval_limit: u16, + + /// Maximum allowed length for requests received. + /// Anything declaring bigger size than that will be regarded as an error and + /// is going to be rejected. + max_request_size: u32, } impl Debug { @@ -503,6 +513,7 @@ impl Default for Debug { presence_sending_delay: DEFAULT_PRESENCE_SENDING_DELAY, stored_messages_filename_length: DEFAULT_STORED_MESSAGE_FILENAME_LENGTH, message_retrieval_limit: DEFAULT_MESSAGE_RETRIEVAL_LIMIT, + max_request_size: DEFAULT_MAX_REQUEST_SIZE, } } } diff --git a/sfw-provider/src/provider/client_handling/listener.rs b/sfw-provider/src/provider/client_handling/listener.rs index 5324608bf18..1763df95808 100644 --- a/sfw-provider/src/provider/client_handling/listener.rs +++ b/sfw-provider/src/provider/client_handling/listener.rs @@ -52,9 +52,6 @@ async fn process_request<'a>( } } -// TODO: temporary proof of concept. will later be moved into config -const MAX_REQUEST_SIZE: usize = 4_096; - async fn process_socket_connection( mut socket: tokio::net::TcpStream, mut request_processor: RequestProcessor, @@ -65,7 +62,8 @@ async fn process_socket_connection( // let mut socket_writer = tokio::io::BufWriter::new(socket_writer); // let mut socket_reader = tokio::io::BufReader::new(socket_reader); - let mut request_reader = TokioAsyncRequestReader::new(&mut socket_reader, MAX_REQUEST_SIZE); + let mut request_reader = + TokioAsyncRequestReader::new(&mut socket_reader, request_processor.max_request_size()); let mut response_writer = TokioAsyncResponseWriter::new(&mut socket_writer); loop { @@ -91,6 +89,10 @@ async fn process_socket_connection( // let's leave it like this for time being and see if we need to decrease // logging level and / or close the connection warn!("the received request was invalid - {:?}", e); + // should the connection be closed here? invalid request might imply + // the subsequent requests in the reader buffer might not be aligned anymore + // however, that might not necessarily be the case + return; } // in here we do not really want to process multiple requests from the same // client concurrently as then we might end up with really weird race conditions diff --git a/sfw-provider/src/provider/client_handling/request_processing.rs b/sfw-provider/src/provider/client_handling/request_processing.rs index 6e9cb2fa5e9..3a81608380e 100644 --- a/sfw-provider/src/provider/client_handling/request_processing.rs +++ b/sfw-provider/src/provider/client_handling/request_processing.rs @@ -46,6 +46,7 @@ pub struct RequestProcessor { secret_key: Arc, client_storage: ClientStorage, client_ledger: ClientLedger, + max_request_size: usize, } impl RequestProcessor { @@ -53,14 +54,20 @@ impl RequestProcessor { secret_key: encryption::PrivateKey, client_storage: ClientStorage, client_ledger: ClientLedger, + max_request_size: usize, ) -> Self { RequestProcessor { secret_key: Arc::new(secret_key), client_storage, client_ledger, + max_request_size, } } + pub(crate) fn max_request_size(&self) -> usize { + self.max_request_size + } + pub(crate) async fn process_client_request( &mut self, client_request: ProviderRequest, @@ -157,6 +164,7 @@ mod generating_new_auth_token { secret_key: Arc::new(key), client_storage: ClientStorage::new(3, 16, Default::default()), client_ledger: ClientLedger::new(), + max_request_size: 42, }; let token1 = request_processor.generate_new_auth_token(client_address1); @@ -175,12 +183,14 @@ mod generating_new_auth_token { secret_key: Arc::new(key1), client_storage: ClientStorage::new(3, 16, Default::default()), client_ledger: ClientLedger::new(), + max_request_size: 42, }; let request_processor2 = RequestProcessor { secret_key: Arc::new(key2), client_storage: ClientStorage::new(3, 16, Default::default()), client_ledger: ClientLedger::new(), + max_request_size: 42, }; let token1 = request_processor1.generate_new_auth_token(client_address1.clone()); diff --git a/sfw-provider/src/provider/mod.rs b/sfw-provider/src/provider/mod.rs index fc2aba3e200..939c0a68668 100644 --- a/sfw-provider/src/provider/mod.rs +++ b/sfw-provider/src/provider/mod.rs @@ -78,6 +78,7 @@ impl ServiceProvider { self.sphinx_keypair.private_key().clone(), client_storage, self.registered_clients_ledger.clone(), + self.config.get_max_request_size(), ); client_handling::listener::run_client_socket_listener( From b46a73f8182c1890d1f2c95b1fe9cff13a46f652 Mon Sep 17 00:00:00 2001 From: jstuczyn Date: Wed, 15 Apr 2020 16:50:36 +0100 Subject: [PATCH 18/23] Configurable max response size for client --- common/clients/provider-client/src/lib.rs | 7 ++++--- common/healthcheck/src/lib.rs | 3 +++ common/healthcheck/src/path_check.rs | 9 +++++++-- nym-client/src/client/mod.rs | 1 + nym-client/src/client/provider_poller.rs | 2 ++ nym-client/src/commands/init.rs | 4 ++++ nym-client/src/config/mod.rs | 13 +++++++++++++ 7 files changed, 34 insertions(+), 5 deletions(-) diff --git a/common/clients/provider-client/src/lib.rs b/common/clients/provider-client/src/lib.rs index 660e2a52e27..2ab7712730c 100644 --- a/common/clients/provider-client/src/lib.rs +++ b/common/clients/provider-client/src/lib.rs @@ -52,20 +52,21 @@ pub struct ProviderClient { our_address: DestinationAddressBytes, auth_token: Option, connection: Option, + max_response_size: usize, } -const MAX_RESPONSE_SIZE: usize = 1_000_000_000; - impl ProviderClient { pub fn new( provider_network_address: SocketAddr, our_address: DestinationAddressBytes, auth_token: Option, + max_response_size: usize, ) -> Self { ProviderClient { provider_network_address, our_address, auth_token, + max_response_size, // establish connection when it's necessary (mainly to not break current code // as then 'new' would need to be called within async context) connection: None, @@ -107,7 +108,7 @@ impl ProviderClient { let mut request_writer = TokioAsyncRequestWriter::new(&mut socket_writer); let mut response_reader = - TokioAsyncResponseReader::new(&mut socket_reader, MAX_RESPONSE_SIZE); + TokioAsyncResponseReader::new(&mut socket_reader, self.max_response_size); if let Err(e) = request_writer.try_write_request(request).await { debug!("Failed to write the request - {:?}", e); diff --git a/common/healthcheck/src/lib.rs b/common/healthcheck/src/lib.rs index 0aee269d139..c751372c7f0 100644 --- a/common/healthcheck/src/lib.rs +++ b/common/healthcheck/src/lib.rs @@ -5,6 +5,9 @@ use std::fmt::{Error, Formatter}; use std::time::Duration; use topology::{NymTopology, NymTopologyError}; +// basically no limit +pub(crate) const MAX_PROVIDER_RESPONSE_SIZE: usize = 1024 * 1024; + pub mod config; mod path_check; mod result; diff --git a/common/healthcheck/src/path_check.rs b/common/healthcheck/src/path_check.rs index 45f4ffb19c2..4cc30e17874 100644 --- a/common/healthcheck/src/path_check.rs +++ b/common/healthcheck/src/path_check.rs @@ -1,3 +1,4 @@ +use crate::MAX_PROVIDER_RESPONSE_SIZE; use crypto::identity::MixIdentityKeyPair; use itertools::Itertools; use log::{debug, error, info, trace, warn}; @@ -40,8 +41,12 @@ impl PathChecker { let address = identity_keys.public_key().derive_address(); for provider in providers { - let mut provider_client = - ProviderClient::new(provider.client_listener, address.clone(), None); + let mut provider_client = ProviderClient::new( + provider.client_listener, + address.clone(), + None, + MAX_PROVIDER_RESPONSE_SIZE, + ); // TODO: we might be sending unnecessary register requests since after first healthcheck, // we are registered for any subsequent ones (since our address did not change) diff --git a/nym-client/src/client/mod.rs b/nym-client/src/client/mod.rs index 1fae01883b8..333480f3d3b 100644 --- a/nym-client/src/client/mod.rs +++ b/nym-client/src/client/mod.rs @@ -171,6 +171,7 @@ impl NymClient { .map(|str_token| AuthToken::try_from_base58_string(str_token).ok()) .unwrap_or(None), self.config.get_fetch_message_delay(), + self.config.get_max_response_size(), ); if !provider_poller.is_registered() { diff --git a/nym-client/src/client/provider_poller.rs b/nym-client/src/client/provider_poller.rs index 9af6cc6a455..c98cf824135 100644 --- a/nym-client/src/client/provider_poller.rs +++ b/nym-client/src/client/provider_poller.rs @@ -24,12 +24,14 @@ impl ProviderPoller { client_address: DestinationAddressBytes, auth_token: Option, polling_rate: time::Duration, + max_response_size: usize, ) -> Self { ProviderPoller { provider_client: provider_client::ProviderClient::new( provider_client_listener_address, client_address, auth_token, + max_response_size, ), poller_tx, polling_rate, diff --git a/nym-client/src/commands/init.rs b/nym-client/src/commands/init.rs index 71b632b8f04..35337bbc7f8 100644 --- a/nym-client/src/commands/init.rs +++ b/nym-client/src/commands/init.rs @@ -47,12 +47,16 @@ async fn try_provider_registrations( providers: Vec, our_address: DestinationAddressBytes, ) -> Option<(String, AuthToken)> { + // we don't expect the response to be longer than AUTH_TOKEN_SIZE, but allow for more bytes + // in case there was an error message + let max_response_len = 16 * 1024; // since the order of providers is non-deterministic we can just try to get a first 'working' provider for provider in providers { let mut provider_client = provider_client::ProviderClient::new( provider.client_listener, our_address.clone(), None, + max_response_len, ); let auth_token = provider_client.register().await; if let Ok(token) = auth_token { diff --git a/nym-client/src/config/mod.rs b/nym-client/src/config/mod.rs index d747a22eb45..0897872d42e 100644 --- a/nym-client/src/config/mod.rs +++ b/nym-client/src/config/mod.rs @@ -26,6 +26,9 @@ const DEFAULT_HEALTHCHECK_CONNECTION_TIMEOUT: u64 = DEFAULT_INITIAL_CONNECTION_T const DEFAULT_NUMBER_OF_HEALTHCHECK_TEST_PACKETS: u64 = 2; const DEFAULT_NODE_SCORE_THRESHOLD: f64 = 0.0; +// for time being treat it as if there was no limit +const DEFAULT_MAX_RESPONSE_SIZE: u32 = u32::max_value(); + #[derive(Debug, Deserialize, PartialEq, Serialize, Clone, Copy)] #[serde(deny_unknown_fields)] pub enum SocketType { @@ -223,6 +226,10 @@ impl Config { pub fn get_healthcheck_connection_timeout(&self) -> time::Duration { time::Duration::from_millis(self.debug.healthcheck_connection_timeout) } + + pub fn get_max_response_size(&self) -> usize { + self.debug.max_response_size as usize + } } fn de_option_string<'de, D>(deserializer: D) -> Result, D::Error> @@ -395,6 +402,11 @@ pub struct Debug { /// during healthcheck. /// The provider value is interpreted as milliseconds. healthcheck_connection_timeout: u64, + + /// Maximum allowed length for sfw-provider responses received. + /// Anything declaring bigger size than that will be regarded as an error and + /// is going to be rejected. + max_response_size: u32, } impl Default for Debug { @@ -413,6 +425,7 @@ impl Default for Debug { packet_forwarding_maximum_backoff: DEFAULT_PACKET_FORWARDING_MAXIMUM_BACKOFF, initial_connection_timeout: DEFAULT_INITIAL_CONNECTION_TIMEOUT, healthcheck_connection_timeout: DEFAULT_HEALTHCHECK_CONNECTION_TIMEOUT, + max_response_size: DEFAULT_MAX_RESPONSE_SIZE } } } From 2f4b6fdc29d51e2c495c8ffef04f4e5849cc9974 Mon Sep 17 00:00:00 2001 From: jstuczyn Date: Wed, 15 Apr 2020 16:51:01 +0100 Subject: [PATCH 19/23] Removed debug drop implementations --- .../sfw-provider-requests/src/requests/async_io.rs | 14 +------------- .../src/responses/async_io.rs | 12 ------------ 2 files changed, 1 insertion(+), 25 deletions(-) diff --git a/sfw-provider/sfw-provider-requests/src/requests/async_io.rs b/sfw-provider/sfw-provider-requests/src/requests/async_io.rs index 596fdbce566..e68af235788 100644 --- a/sfw-provider/sfw-provider-requests/src/requests/async_io.rs +++ b/sfw-provider/sfw-provider-requests/src/requests/async_io.rs @@ -6,18 +6,6 @@ use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; // TODO: way down the line, mostly for learning purposes, combine this with responses::async_io // via procedural macros -impl<'a, R: AsyncRead + Unpin> Drop for TokioAsyncRequestReader<'a, R> { - fn drop(&mut self) { - println!("request reader drop"); - } -} - -impl<'a, R: AsyncWrite + Unpin> Drop for TokioAsyncRequestWriter<'a, R> { - fn drop(&mut self) { - println!("request writer drop"); - } -} - // Ideally I would have used futures::AsyncRead for even more generic approach, but unfortunately // tokio::io::AsyncRead differs from futures::AsyncRead pub struct TokioAsyncRequestReader<'a, R: AsyncRead + Unpin> { @@ -319,4 +307,4 @@ mod request_reader { assert!(rt.block_on(request_reader.try_read_request()).is_err()); } -} \ No newline at end of file +} diff --git a/sfw-provider/sfw-provider-requests/src/responses/async_io.rs b/sfw-provider/sfw-provider-requests/src/responses/async_io.rs index bbefa31f786..f5ccdf95e4d 100644 --- a/sfw-provider/sfw-provider-requests/src/responses/async_io.rs +++ b/sfw-provider/sfw-provider-requests/src/responses/async_io.rs @@ -6,18 +6,6 @@ use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; // TODO: way down the line, mostly for learning purposes, combine this with requests::async_io // via procedural macros -impl<'a, R: AsyncRead + Unpin> Drop for TokioAsyncResponseReader<'a, R> { - fn drop(&mut self) { - println!("response reader drop"); - } -} - -impl<'a, R: AsyncWrite + Unpin> Drop for TokioAsyncResponseWriter<'a, R> { - fn drop(&mut self) { - println!("response writer drop"); - } -} - // Ideally I would have used futures::AsyncRead for even more generic approach, but unfortunately // tokio::io::AsyncRead differs from futures::AsyncRead pub struct TokioAsyncResponseReader<'a, R: AsyncRead + Unpin> { From 7daa2d99061a5dd0b2e8ae4cbcde36c9a27e6a90 Mon Sep 17 00:00:00 2001 From: jstuczyn Date: Wed, 15 Apr 2020 16:51:11 +0100 Subject: [PATCH 20/23] Removed debug print statement --- sfw-provider/src/provider/client_handling/request_processing.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/sfw-provider/src/provider/client_handling/request_processing.rs b/sfw-provider/src/provider/client_handling/request_processing.rs index 3a81608380e..5bdbff64c0f 100644 --- a/sfw-provider/src/provider/client_handling/request_processing.rs +++ b/sfw-provider/src/provider/client_handling/request_processing.rs @@ -129,7 +129,6 @@ impl RequestProcessor { &self, req: PullRequest, ) -> Result { - println!("pull request for {:?}", req.destination_address); if self .client_ledger .verify_token(&req.auth_token, &req.destination_address) From b9fed2d00c8d2e39840dce57a74966226aeb7a8d Mon Sep 17 00:00:00 2001 From: jstuczyn Date: Wed, 15 Apr 2020 16:51:20 +0100 Subject: [PATCH 21/23] Changes to lock file --- Cargo.lock | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/Cargo.lock b/Cargo.lock index 1d39c2d371b..0b7a7789dea 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2421,9 +2421,11 @@ name = "sfw-provider-requests" version = "0.1.0" dependencies = [ "bs58", + "byteorder", "bytes 0.5.4", "sphinx", "tokio 0.2.12", + "tokio-test", ] [[package]] @@ -2794,6 +2796,17 @@ dependencies = [ "tokio-reactor", ] +[[package]] +name = "tokio-test" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09cf9705471976fa5fc6817d3fbc9c4ff9696a6647af0e5c1870c81ca7445b05" +dependencies = [ + "bytes 0.5.4", + "futures-core", + "tokio 0.2.12", +] + [[package]] name = "tokio-threadpool" version = "0.1.18" From f110c09c303d024b30248b605f27ec41dfb55973 Mon Sep 17 00:00:00 2001 From: jstuczyn Date: Wed, 15 Apr 2020 16:59:05 +0100 Subject: [PATCH 22/23] Added license notifications --- .../sfw-provider-requests/src/auth_token.rs | 14 ++++++++++++++ .../sfw-provider-requests/src/requests/async_io.rs | 14 ++++++++++++++ .../sfw-provider-requests/src/requests/mod.rs | 14 ++++++++++++++ .../src/requests/serialization.rs | 14 ++++++++++++++ .../src/responses/async_io.rs | 14 ++++++++++++++ .../sfw-provider-requests/src/responses/mod.rs | 14 ++++++++++++++ .../src/responses/serialization.rs | 14 ++++++++++++++ 7 files changed, 98 insertions(+) diff --git a/sfw-provider/sfw-provider-requests/src/auth_token.rs b/sfw-provider/sfw-provider-requests/src/auth_token.rs index 3fb714bb7e9..92768c1dbe1 100644 --- a/sfw-provider/sfw-provider-requests/src/auth_token.rs +++ b/sfw-provider/sfw-provider-requests/src/auth_token.rs @@ -1,3 +1,17 @@ +// Copyright 2020 Nym Technologies SA +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + pub const AUTH_TOKEN_SIZE: usize = 32; #[derive(Debug, PartialEq, Eq, Hash, Clone, Copy)] diff --git a/sfw-provider/sfw-provider-requests/src/requests/async_io.rs b/sfw-provider/sfw-provider-requests/src/requests/async_io.rs index e68af235788..b5a6a3ae6d4 100644 --- a/sfw-provider/sfw-provider-requests/src/requests/async_io.rs +++ b/sfw-provider/sfw-provider-requests/src/requests/async_io.rs @@ -1,3 +1,17 @@ +// Copyright 2020 Nym Technologies SA +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + use crate::requests::serialization::{RequestDeserializer, RequestSerializer}; use crate::requests::{ProviderRequest, ProviderRequestError}; use std::io; diff --git a/sfw-provider/sfw-provider-requests/src/requests/mod.rs b/sfw-provider/sfw-provider-requests/src/requests/mod.rs index 85cfca85198..2e7a1087492 100644 --- a/sfw-provider/sfw-provider-requests/src/requests/mod.rs +++ b/sfw-provider/sfw-provider-requests/src/requests/mod.rs @@ -1,3 +1,17 @@ +// Copyright 2020 Nym Technologies SA +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + use crate::auth_token::{AuthToken, AUTH_TOKEN_SIZE}; use sphinx::constants::DESTINATION_ADDRESS_LENGTH; use sphinx::route::DestinationAddressBytes; diff --git a/sfw-provider/sfw-provider-requests/src/requests/serialization.rs b/sfw-provider/sfw-provider-requests/src/requests/serialization.rs index 4769dd2e3b5..53827981203 100644 --- a/sfw-provider/sfw-provider-requests/src/requests/serialization.rs +++ b/sfw-provider/sfw-provider-requests/src/requests/serialization.rs @@ -1,3 +1,17 @@ +// Copyright 2020 Nym Technologies SA +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + use crate::requests::*; use std::convert::TryFrom; diff --git a/sfw-provider/sfw-provider-requests/src/responses/async_io.rs b/sfw-provider/sfw-provider-requests/src/responses/async_io.rs index f5ccdf95e4d..4c68589a584 100644 --- a/sfw-provider/sfw-provider-requests/src/responses/async_io.rs +++ b/sfw-provider/sfw-provider-requests/src/responses/async_io.rs @@ -1,3 +1,17 @@ +// Copyright 2020 Nym Technologies SA +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + use crate::responses::serialization::{ResponseDeserializer, ResponseSerializer}; use crate::responses::{ProviderResponse, ProviderResponseError}; use std::io; diff --git a/sfw-provider/sfw-provider-requests/src/responses/mod.rs b/sfw-provider/sfw-provider-requests/src/responses/mod.rs index a17517cb84a..80bc0d5085c 100644 --- a/sfw-provider/sfw-provider-requests/src/responses/mod.rs +++ b/sfw-provider/sfw-provider-requests/src/responses/mod.rs @@ -1,3 +1,17 @@ +// Copyright 2020 Nym Technologies SA +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + use crate::auth_token::{AuthToken, AUTH_TOKEN_SIZE}; use std::convert::{TryFrom, TryInto}; use std::io; diff --git a/sfw-provider/sfw-provider-requests/src/responses/serialization.rs b/sfw-provider/sfw-provider-requests/src/responses/serialization.rs index af2e541488b..d059beb58ad 100644 --- a/sfw-provider/sfw-provider-requests/src/responses/serialization.rs +++ b/sfw-provider/sfw-provider-requests/src/responses/serialization.rs @@ -1,3 +1,17 @@ +// Copyright 2020 Nym Technologies SA +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + use crate::responses::*; use std::convert::TryFrom; From 030a3e302baf12def8bcca5d6b8d414b9afbc316 Mon Sep 17 00:00:00 2001 From: jstuczyn Date: Wed, 15 Apr 2020 17:02:45 +0100 Subject: [PATCH 23/23] Cargo fmt --- nym-client/src/config/mod.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nym-client/src/config/mod.rs b/nym-client/src/config/mod.rs index 9c58a0b6e04..5591410bdcf 100644 --- a/nym-client/src/config/mod.rs +++ b/nym-client/src/config/mod.rs @@ -439,7 +439,7 @@ impl Default for Debug { packet_forwarding_maximum_backoff: DEFAULT_PACKET_FORWARDING_MAXIMUM_BACKOFF, initial_connection_timeout: DEFAULT_INITIAL_CONNECTION_TIMEOUT, healthcheck_connection_timeout: DEFAULT_HEALTHCHECK_CONNECTION_TIMEOUT, - max_response_size: DEFAULT_MAX_RESPONSE_SIZE + max_response_size: DEFAULT_MAX_RESPONSE_SIZE, } } }