Skip to content

Commit dfefd23

Browse files
authored
Improvements to sfw-provider - client communcation (#180)
* Moved auth_token to seperate file * Extracted check_id as separate type * Changes due to move of auth_token and making provider client mutable * New way of serialization provider requests/responses * Initial attempt of using new provider client * Moved requests and responses to separate modules * Moved serialization to separate files * Extracted readers and writers to io related modules * Extra tests + bug fixes * Updated tokio dependency to require correct features * typo * Easier conversion of requests/responses into enum variants * Renamed 'read_be_u16' to better show its purpose * Serialization related tests and fixes * Tests for async_io + fixes * Future considerations * Configurable max request size * Configurable max response size for client * Removed debug drop implementations * Removed debug print statement * Changes to lock file * Added license notifications * Cargo fmt
1 parent 3f06ccc commit dfefd23

File tree

24 files changed

+2224
-635
lines changed

24 files changed

+2224
-635
lines changed

Cargo.lock

+15
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

common/clients/provider-client/src/lib.rs

+88-33
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,15 @@
1414

1515
use futures::io::Error;
1616
use log::*;
17-
use sfw_provider_requests::requests::{ProviderRequest, PullRequest, RegisterRequest};
17+
use sfw_provider_requests::auth_token::AuthToken;
18+
use sfw_provider_requests::requests::{
19+
async_io::TokioAsyncRequestWriter, ProviderRequest, PullRequest, RegisterRequest,
20+
};
1821
use sfw_provider_requests::responses::{
19-
ProviderResponse, ProviderResponseError, PullResponse, RegisterResponse,
22+
async_io::TokioAsyncResponseReader, ProviderResponse, ProviderResponseError,
2023
};
21-
use sfw_provider_requests::AuthToken;
2224
use sphinx::route::DestinationAddressBytes;
23-
use std::net::{Shutdown, SocketAddr};
24-
use std::time::Duration;
25+
use std::net::SocketAddr;
2526
use tokio::prelude::*;
2627

2728
#[derive(Debug)]
@@ -50,6 +51,12 @@ impl From<ProviderResponseError> for ProviderClientError {
5051
ProviderResponseError::MarshalError => InvalidRequestError,
5152
ProviderResponseError::UnmarshalError => InvalidResponseError,
5253
ProviderResponseError::UnmarshalErrorInvalidLength => InvalidResponseLengthError,
54+
ProviderResponseError::UnmarshalErrorInvalidKind => InvalidResponseLengthError,
55+
56+
ProviderResponseError::TooLongResponseError => InvalidResponseError,
57+
ProviderResponseError::TooShortResponseError => InvalidResponseError,
58+
ProviderResponseError::IOError(_) => NetworkError,
59+
ProviderResponseError::RemoteConnectionClosed => NetworkError,
5360
}
5461
}
5562
}
@@ -58,72 +65,120 @@ pub struct ProviderClient {
5865
provider_network_address: SocketAddr,
5966
our_address: DestinationAddressBytes,
6067
auth_token: Option<AuthToken>,
68+
connection: Option<tokio::net::TcpStream>,
69+
max_response_size: usize,
6170
}
6271

6372
impl ProviderClient {
6473
pub fn new(
6574
provider_network_address: SocketAddr,
6675
our_address: DestinationAddressBytes,
6776
auth_token: Option<AuthToken>,
77+
max_response_size: usize,
6878
) -> Self {
6979
ProviderClient {
7080
provider_network_address,
7181
our_address,
7282
auth_token,
83+
max_response_size,
84+
// establish connection when it's necessary (mainly to not break current code
85+
// as then 'new' would need to be called within async context)
86+
connection: None,
87+
}
88+
}
89+
90+
async fn check_connection(&mut self) -> bool {
91+
if self.connection.is_some() {
92+
true
93+
} else {
94+
// TODO: possibly also introduce timeouts here?
95+
// However, at this point it's slightly less important as we are in full control
96+
// of providers.
97+
self.connection = tokio::net::TcpStream::connect(self.provider_network_address)
98+
.await
99+
.ok();
100+
self.connection.is_some()
73101
}
74102
}
75103

76104
pub fn update_token(&mut self, auth_token: AuthToken) {
77105
self.auth_token = Some(auth_token)
78106
}
79107

80-
pub async fn send_request(&self, bytes: Vec<u8>) -> Result<Vec<u8>, ProviderClientError> {
81-
let mut socket = tokio::net::TcpStream::connect(self.provider_network_address).await?;
82-
83-
socket.set_keepalive(Some(Duration::from_secs(2)))?;
84-
socket.write_all(&bytes[..]).await?;
85-
if let Err(e) = socket.shutdown(Shutdown::Write) {
86-
warn!("failed to close write part of the socket; err = {:?}", e)
108+
pub async fn send_request(
109+
&mut self,
110+
request: ProviderRequest,
111+
) -> Result<ProviderResponse, ProviderClientError> {
112+
if !self.check_connection().await {
113+
return Err(ProviderClientError::NetworkError);
87114
}
88115

89-
let mut response = Vec::new();
90-
socket.read_to_end(&mut response).await?;
91-
if let Err(e) = socket.shutdown(Shutdown::Read) {
92-
debug!("failed to close read part of the socket; err = {:?}. It was probably already closed by the provider", e)
116+
let socket = self.connection.as_mut().unwrap();
117+
let (mut socket_reader, mut socket_writer) = socket.split();
118+
119+
// TODO: benchmark and determine if below should be done:
120+
// let mut socket_writer = tokio::io::BufWriter::new(socket_writer);
121+
// let mut socket_reader = tokio::io::BufReader::new(socket_reader);
122+
123+
let mut request_writer = TokioAsyncRequestWriter::new(&mut socket_writer);
124+
let mut response_reader =
125+
TokioAsyncResponseReader::new(&mut socket_reader, self.max_response_size);
126+
127+
if let Err(e) = request_writer.try_write_request(request).await {
128+
debug!("Failed to write the request - {:?}", e);
129+
return Err(e.into());
93130
}
94131

95-
Ok(response)
132+
Ok(response_reader.try_read_response().await?)
96133
}
97134

98-
pub async fn retrieve_messages(&self) -> Result<Vec<Vec<u8>>, ProviderClientError> {
135+
pub async fn retrieve_messages(&mut self) -> Result<Vec<Vec<u8>>, ProviderClientError> {
99136
let auth_token = match self.auth_token.as_ref() {
100137
Some(token) => token.clone(),
101138
None => {
102139
return Err(ProviderClientError::EmptyAuthTokenError);
103140
}
104141
};
105142

106-
let pull_request = PullRequest::new(self.our_address.clone(), auth_token);
107-
let bytes = pull_request.to_bytes();
108-
109-
let response = self.send_request(bytes).await?;
110-
111-
let parsed_response = PullResponse::from_bytes(&response)?;
112-
Ok(parsed_response.messages)
143+
let pull_request =
144+
ProviderRequest::Pull(PullRequest::new(self.our_address.clone(), auth_token));
145+
match self.send_request(pull_request).await? {
146+
ProviderResponse::Pull(res) => Ok(res.extract_messages()),
147+
ProviderResponse::Failure(res) => {
148+
error!(
149+
"We failed to get our request processed - {:?}",
150+
res.get_message()
151+
);
152+
Err(ProviderClientError::InvalidResponseError)
153+
}
154+
_ => {
155+
error!("Received response of unexpected type!");
156+
Err(ProviderClientError::InvalidResponseError)
157+
}
158+
}
113159
}
114160

115-
pub async fn register(&self) -> Result<AuthToken, ProviderClientError> {
161+
pub async fn register(&mut self) -> Result<AuthToken, ProviderClientError> {
116162
if self.auth_token.is_some() {
117163
return Err(ProviderClientError::ClientAlreadyRegisteredError);
118164
}
119165

120-
let register_request = RegisterRequest::new(self.our_address.clone());
121-
let bytes = register_request.to_bytes();
122-
123-
let response = self.send_request(bytes).await?;
124-
let parsed_response = RegisterResponse::from_bytes(&response)?;
125-
126-
Ok(parsed_response.auth_token)
166+
let register_request =
167+
ProviderRequest::Register(RegisterRequest::new(self.our_address.clone()));
168+
match self.send_request(register_request).await? {
169+
ProviderResponse::Register(res) => Ok(res.get_token()),
170+
ProviderResponse::Failure(res) => {
171+
error!(
172+
"We failed to get our request processed - {:?}",
173+
res.get_message()
174+
);
175+
Err(ProviderClientError::InvalidResponseError)
176+
}
177+
_ => {
178+
error!("Received response of unexpected type!");
179+
Err(ProviderClientError::InvalidResponseError)
180+
}
181+
}
127182
}
128183

129184
pub fn is_registered(&self) -> bool {

common/healthcheck/src/lib.rs

+3
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,9 @@ use std::fmt::{Error, Formatter};
1919
use std::time::Duration;
2020
use topology::{NymTopology, NymTopologyError};
2121

22+
// basically no limit
23+
pub(crate) const MAX_PROVIDER_RESPONSE_SIZE: usize = 1024 * 1024;
24+
2225
pub mod config;
2326
mod path_check;
2427
mod result;

common/healthcheck/src/path_check.rs

+21-10
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
// See the License for the specific language governing permissions and
1313
// limitations under the License.
1414

15+
use crate::MAX_PROVIDER_RESPONSE_SIZE;
1516
use crypto::identity::MixIdentityKeyPair;
1617
use itertools::Itertools;
1718
use log::{debug, error, info, trace, warn};
@@ -25,6 +26,8 @@ use std::net::SocketAddr;
2526
use std::time::Duration;
2627
use topology::provider;
2728

29+
pub(crate) type CheckId = [u8; 16];
30+
2831
#[derive(Debug, PartialEq, Clone)]
2932
pub enum PathStatus {
3033
Healthy,
@@ -37,23 +40,30 @@ pub(crate) struct PathChecker {
3740
mixnet_client: multi_tcp_client::Client,
3841
paths_status: HashMap<Vec<u8>, PathStatus>,
3942
our_destination: Destination,
40-
check_id: [u8; 16],
43+
check_id: CheckId,
4144
}
4245

4346
impl PathChecker {
4447
pub(crate) async fn new(
4548
providers: Vec<provider::Node>,
4649
identity_keys: &MixIdentityKeyPair,
4750
connection_timeout: Duration,
48-
check_id: [u8; 16],
51+
check_id: CheckId,
4952
) -> Self {
5053
let mut provider_clients = HashMap::new();
5154

5255
let address = identity_keys.public_key().derive_address();
5356

5457
for provider in providers {
55-
let mut provider_client =
56-
ProviderClient::new(provider.client_listener, address.clone(), None);
58+
let mut provider_client = ProviderClient::new(
59+
provider.client_listener,
60+
address.clone(),
61+
None,
62+
MAX_PROVIDER_RESPONSE_SIZE,
63+
);
64+
// TODO: we might be sending unnecessary register requests since after first healthcheck,
65+
// we are registered for any subsequent ones (since our address did not change)
66+
5767
let insertion_result = match provider_client.register().await {
5868
Ok(token) => {
5969
debug!("[Healthcheck] registered at provider {}", provider.pub_key);
@@ -96,7 +106,7 @@ impl PathChecker {
96106

97107
// iteration is used to distinguish packets sent through the same path (as the healthcheck
98108
// may try to send say 10 packets through given path)
99-
fn unique_path_key(path: &[SphinxNode], check_id: [u8; 16], iteration: u8) -> Vec<u8> {
109+
fn unique_path_key(path: &[SphinxNode], check_id: CheckId, iteration: u8) -> Vec<u8> {
100110
check_id
101111
.iter()
102112
.cloned()
@@ -147,8 +157,8 @@ impl PathChecker {
147157

148158
// pull messages from given provider until there are no more 'real' messages
149159
async fn resolve_pending_provider_checks(
150-
&self,
151-
provider_client: &ProviderClient,
160+
provider_client: &mut ProviderClient,
161+
check_id: CheckId,
152162
) -> Vec<Vec<u8>> {
153163
// keep getting messages until we encounter the dummy message
154164
let mut provider_messages = Vec::new();
@@ -165,7 +175,7 @@ impl PathChecker {
165175
if msg == sfw_provider_requests::DUMMY_MESSAGE_CONTENT {
166176
// finish iterating the loop as the messages might not be ordered
167177
should_stop = true;
168-
} else if msg[..16] != self.check_id {
178+
} else if msg[..16] != check_id {
169179
warn!("received response from previous healthcheck")
170180
} else {
171181
provider_messages.push(msg);
@@ -183,14 +193,15 @@ impl PathChecker {
183193
pub(crate) async fn resolve_pending_checks(&mut self) {
184194
// not sure how to nicely put it into an iterator due to it being async calls
185195
let mut provider_messages = Vec::new();
186-
for provider_client in self.provider_clients.values() {
196+
for provider_client in self.provider_clients.values_mut() {
187197
// if it was none all associated paths were already marked as unhealthy
188198
let pc = match provider_client {
189199
Some(pc) => pc,
190200
None => continue,
191201
};
192202

193-
provider_messages.extend(self.resolve_pending_provider_checks(pc).await);
203+
provider_messages
204+
.extend(Self::resolve_pending_provider_checks(pc, self.check_id).await);
194205
}
195206

196207
self.update_path_statuses(provider_messages);

nym-client/src/client/mod.rs

+2-1
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ use futures::channel::{mpsc, oneshot};
3030
use log::*;
3131
use nymsphinx::chunking::split_and_prepare_payloads;
3232
use pemstore::pemstore::PemStore;
33-
use sfw_provider_requests::AuthToken;
33+
use sfw_provider_requests::auth_token::AuthToken;
3434
use sphinx::route::Destination;
3535
use tokio::runtime::Runtime;
3636
use topology::NymTopology;
@@ -185,6 +185,7 @@ impl NymClient {
185185
.map(|str_token| AuthToken::try_from_base58_string(str_token).ok())
186186
.unwrap_or(None),
187187
self.config.get_fetch_message_delay(),
188+
self.config.get_max_response_size(),
188189
);
189190

190191
if !provider_poller.is_registered() {

nym-client/src/client/provider_poller.rs

+5-3
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
use futures::channel::mpsc;
1616
use log::*;
1717
use provider_client::ProviderClientError;
18-
use sfw_provider_requests::AuthToken;
18+
use sfw_provider_requests::auth_token::AuthToken;
1919
use sphinx::route::DestinationAddressBytes;
2020
use std::net::SocketAddr;
2121
use std::time;
@@ -38,12 +38,14 @@ impl ProviderPoller {
3838
client_address: DestinationAddressBytes,
3939
auth_token: Option<AuthToken>,
4040
polling_rate: time::Duration,
41+
max_response_size: usize,
4142
) -> Self {
4243
ProviderPoller {
4344
provider_client: provider_client::ProviderClient::new(
4445
provider_client_listener_address,
4546
client_address,
4647
auth_token,
48+
max_response_size,
4749
),
4850
poller_tx,
4951
polling_rate,
@@ -74,7 +76,7 @@ impl ProviderPoller {
7476
Ok(())
7577
}
7678

77-
pub(crate) async fn start_provider_polling(self) {
79+
pub(crate) async fn start_provider_polling(&mut self) {
7880
let loop_message = &mix_client::packet::LOOP_COVER_MESSAGE_PAYLOAD.to_vec();
7981
let dummy_message = &sfw_provider_requests::DUMMY_MESSAGE_CONTENT.to_vec();
8082

@@ -114,7 +116,7 @@ impl ProviderPoller {
114116
}
115117
}
116118

117-
pub(crate) fn start(self, handle: &Handle) -> JoinHandle<()> {
119+
pub(crate) fn start(mut self, handle: &Handle) -> JoinHandle<()> {
118120
handle.spawn(async move { self.start_provider_polling().await })
119121
}
120122
}

0 commit comments

Comments
 (0)