Skip to content

Commit

Permalink
Bootstrap/send timeout traited (#4079)
Browse files Browse the repository at this point in the history

---------

Co-authored-by: modship <yeskinokay@gmail.com>
  • Loading branch information
Ben-PH and modship authored Jun 15, 2023
1 parent 33528b3 commit 8392ef0
Show file tree
Hide file tree
Showing 4 changed files with 186 additions and 12 deletions.
49 changes: 48 additions & 1 deletion massa-bootstrap/src/bindings.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,16 @@ trait BindingReadExact: io::Read {
count += n;
}
Err(ref e) if e.kind() == std::io::ErrorKind::Interrupted => {}
Err(e) => return Err((e, count)),
Err(e) => {
if e.kind() == ErrorKind::TimedOut || e.kind() == ErrorKind::WouldBlock {
return Err((
std::io::Error::new(ErrorKind::TimedOut, "deadline has elapsed"),
count,
));
} else {
return Err((e, count));
}
}
}
}
if count != buf.len() {
Expand All @@ -54,3 +63,41 @@ trait BindingReadExact: io::Read {
/// Internal helper
fn set_read_timeout(&mut self, duration: Option<Duration>) -> Result<(), std::io::Error>;
}

trait BindingWriteExact: io::Write {
fn write_all_timeout(
&mut self,
write_buf: &[u8],
deadline: Option<Instant>,
) -> Result<(), (std::io::Error, usize)> {
self.set_write_timeout(None).map_err(|e| (e, 0))?;
let mut total_bytes_written = 0;

while total_bytes_written < write_buf.len() {
if let Some(deadline) = deadline {
let dur = deadline.saturating_duration_since(Instant::now());
if dur.is_zero() {
return Err((
std::io::Error::new(ErrorKind::TimedOut, "deadline has elapsed"),
total_bytes_written,
));
}
self.set_write_timeout(Some(dur))
.map_err(|e| (e, total_bytes_written))?;
}

match self.write(&write_buf[total_bytes_written..]) {
Ok(bytes_written) => {
total_bytes_written += bytes_written;
}
Err(err) => {
return Err((err, total_bytes_written));
}
}
}

Ok(())
}
/// Internal helper
fn set_write_timeout(&mut self, duration: Option<Duration>) -> Result<(), std::io::Error>;
}
28 changes: 23 additions & 5 deletions massa-bootstrap/src/bindings/client.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
// Copyright (c) 2022 MASSA LABS <info@massa.net>

use crate::bindings::BindingReadExact;
use crate::bindings::{BindingReadExact, BindingWriteExact};
use crate::error::BootstrapError;
use crate::messages::{
BootstrapClientMessage, BootstrapClientMessageSerializer, BootstrapServerMessage,
Expand All @@ -17,7 +17,7 @@ use massa_serialization::{DeserializeError, Deserializer, Serializer};
use massa_signature::{PublicKey, Signature};
use rand::{rngs::StdRng, RngCore, SeedableRng};
use std::time::Instant;
use std::{io::Write, net::TcpStream, time::Duration};
use std::{net::TcpStream, time::Duration};

/// Bootstrap client binder
pub struct BootstrapClientBinder {
Expand Down Expand Up @@ -64,7 +64,8 @@ impl BootstrapClientBinder {
vec![0u8; version_ser.len() + self.cfg.randomness_size_bytes];
version_random_bytes[..version_ser.len()].clone_from_slice(&version_ser);
StdRng::from_entropy().fill_bytes(&mut version_random_bytes[version_ser.len()..]);
self.duplex.write_all(&version_random_bytes)?;
self.write_all_timeout(&version_random_bytes, None)
.map_err(|(e, _)| e)?;
Hash::compute_from(&version_random_bytes)
};

Expand Down Expand Up @@ -145,6 +146,7 @@ impl BootstrapClientBinder {
msg: &BootstrapClientMessage,
duration: Option<Duration>,
) -> Result<(), BootstrapError> {
let deadline = duration.map(|d| Instant::now() + d);
let mut msg_bytes = Vec::new();
let message_serializer = BootstrapClientMessageSerializer::new();
message_serializer.serialize(msg, &mut msg_bytes)?;
Expand Down Expand Up @@ -172,15 +174,15 @@ impl BootstrapClientBinder {
}

// Provide the message length
self.duplex.set_write_timeout(duration)?;
let msg_len_bytes = msg_len.to_be_bytes_min(MAX_BOOTSTRAP_MESSAGE_SIZE)?;
write_buf.extend(&msg_len_bytes);

// Provide the message
write_buf.extend(&msg_bytes);

// And send it off
self.duplex.write_all(&write_buf)?;
self.write_all_timeout(&write_buf, deadline)
.map_err(|(e, _)| e)?;
Ok(())
}

Expand Down Expand Up @@ -213,3 +215,19 @@ impl std::io::Read for BootstrapClientBinder {
self.duplex.read(buf)
}
}

impl crate::bindings::BindingWriteExact for BootstrapClientBinder {
fn set_write_timeout(&mut self, duration: Option<Duration>) -> Result<(), std::io::Error> {
self.duplex.set_write_timeout(duration)
}
}

impl std::io::Write for BootstrapClientBinder {
fn write(&mut self, buf: &[u8]) -> Result<usize, std::io::Error> {
self.duplex.write(buf)
}

fn flush(&mut self) -> Result<(), std::io::Error> {
self.duplex.flush()
}
}
30 changes: 25 additions & 5 deletions massa-bootstrap/src/bindings/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,15 @@ use std::io;
use std::time::Instant;
use std::{
convert::TryInto,
io::{ErrorKind, Read, Write},
io::ErrorKind,
net::{SocketAddr, TcpStream},
thread,
time::Duration,
};
use tracing::error;

use super::BindingWriteExact;

const KNOWN_PREFIX_LEN: usize = HASH_SIZE_BYTES + MAX_BOOTSTRAP_MESSAGE_SIZE_BYTES;
/// The known-length component of a message to be received.
struct ClientMessageLeader {
Expand Down Expand Up @@ -86,14 +88,15 @@ impl BootstrapServerBinder {
version: Version,
duration: Option<Duration>,
) -> Result<(), BootstrapError> {
let deadline = duration.map(|d| Instant::now() + d);
// read version and random bytes, send signature
let msg_hash = {
let mut version_bytes = Vec::new();
self.version_serializer
.serialize(&version, &mut version_bytes)?;
let mut msg_bytes = vec![0u8; version_bytes.len() + self.randomness_size_bytes];
self.duplex.set_read_timeout(duration)?;
self.duplex.read_exact(&mut msg_bytes)?;
self.read_exact_timeout(&mut msg_bytes, deadline)
.map_err(|(e, _)| e)?;
let (_, received_version) = self
.version_deserializer
.deserialize::<DeserializeError>(&msg_bytes[..version_bytes.len()])
Expand Down Expand Up @@ -182,6 +185,7 @@ impl BootstrapServerBinder {
msg: BootstrapServerMessage,
duration: Option<Duration>,
) -> Result<(), BootstrapError> {
let deadline = duration.map(|d| Instant::now() + d);
// serialize the message to bytes
let mut msg_bytes = Vec::new();
BootstrapServerMessageSerializer::new().serialize(&msg, &mut msg_bytes)?;
Expand Down Expand Up @@ -211,8 +215,8 @@ impl BootstrapServerBinder {
let stream_data = [sig.to_bytes().as_slice(), &msg_len_bytes, &msg_bytes].concat();

// send the data
self.duplex.set_write_timeout(duration)?;
self.duplex.write_all(&stream_data)?;
self.write_all_timeout(&stream_data, deadline)
.map_err(|(e, _)| e)?;

// update prev sig
self.prev_message = Some(Hash::compute_from(&sig.to_bytes()));
Expand Down Expand Up @@ -316,3 +320,19 @@ impl crate::bindings::BindingReadExact for BootstrapServerBinder {
self.duplex.set_read_timeout(duration)
}
}

impl io::Write for BootstrapServerBinder {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
self.duplex.write(buf)
}

fn flush(&mut self) -> io::Result<()> {
self.duplex.flush()
}
}

impl crate::bindings::BindingWriteExact for BootstrapServerBinder {
fn set_write_timeout(&mut self, duration: Option<Duration>) -> Result<(), std::io::Error> {
self.duplex.set_write_timeout(duration)
}
}
91 changes: 90 additions & 1 deletion massa-bootstrap/src/tests/binders.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
use crate::messages::{BootstrapClientMessage, BootstrapServerMessage};
use crate::settings::{BootstrapClientConfig, BootstrapSrvBindCfg};
use crate::BootstrapConfig;
use crate::{
bindings::{BootstrapClientBinder, BootstrapServerBinder},
tests::tools::get_bootstrap_config,
BootstrapPeers,
};
use crate::{BootstrapConfig, BootstrapError};
use massa_models::config::{
BOOTSTRAP_RANDOMNESS_SIZE_BYTES, CONSENSUS_BOOTSTRAP_PART_SIZE, ENDORSEMENT_COUNT,
MAX_ADVERTISE_LENGTH, MAX_ASYNC_MESSAGE_DATA, MAX_ASYNC_POOL_LENGTH,
Expand All @@ -23,8 +23,10 @@ use massa_protocol_exports::{PeerId, TransportType};
use massa_signature::{KeyPair, PublicKey};
use massa_time::MassaTime;
use std::collections::HashMap;
use std::io::Write;
use std::net::TcpStream;
use std::str::FromStr;
use std::time::Duration;

lazy_static::lazy_static! {
pub static ref BOOTSTRAP_CONFIG_KEYPAIR: (BootstrapConfig, KeyPair) = {
Expand Down Expand Up @@ -484,3 +486,90 @@ fn test_binders_try_double_send_client_works() {
server_thread.join().unwrap();
client_thread.join().unwrap();
}

#[test]
fn test_client_drip_feed() {
let (bootstrap_config, server_keypair): &(BootstrapConfig, KeyPair) = &BOOTSTRAP_CONFIG_KEYPAIR;
let server = std::net::TcpListener::bind("localhost:0").unwrap();
let addr = server.local_addr().unwrap();
let client = std::net::TcpStream::connect(addr).unwrap();
let mut client_clone = client.try_clone().unwrap();
let server = server.accept().unwrap();
let version = || Version::from_str("TEST.1.10").unwrap();

let mut server = BootstrapServerBinder::new(
server.0,
server_keypair.clone(),
BootstrapSrvBindCfg {
max_bytes_read_write: f64::INFINITY,
thread_count: THREAD_COUNT,
max_datastore_key_length: MAX_DATASTORE_KEY_LENGTH,
randomness_size_bytes: BOOTSTRAP_RANDOMNESS_SIZE_BYTES,
consensus_bootstrap_part_size: CONSENSUS_BOOTSTRAP_PART_SIZE,
write_error_timeout: MassaTime::from_millis(1000),
},
);
let mut client = BootstrapClientBinder::test_default(
client,
bootstrap_config.bootstrap_list[0].1.get_public_key(),
);

let start = std::time::Instant::now();
let server_thread = std::thread::Builder::new()
.name("test_binders::server_thread".to_string())
.spawn({
move || {
server.handshake_timeout(version(), None).unwrap();

let message = server
.next_timeout(Some(Duration::from_secs(1)))
.unwrap_err();
match message {
BootstrapError::TimedOut(message) => {
assert_eq!(message.to_string(), "deadline has elapsed");
assert_eq!(message.kind(), std::io::ErrorKind::TimedOut);
}
message => panic!("expected timeout error, got {:?}", message),
}
std::mem::forget(server);
}
})
.unwrap();

let client_thread = std::thread::Builder::new()
.name("test_binders::server_thread".to_string())
.spawn({
move || {
client.handshake(version()).unwrap();

// write the signature.
// This test assumes that the the signature is not checked until the message is read in
// its entirety. The signature here would cause the message exchange to fail on that basis
// if this assumption is broken.
client_clone
.write_all(b"AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA")
.unwrap();
// give a message size that we can drip-feed
client_clone.write_all(&[0, 0, 0, 120]).unwrap();
for i in 0..120 {
client_clone.write(&[i]).unwrap();
client_clone.flush().unwrap();
std::thread::sleep(Duration::from_millis(10));
}
}
})
.unwrap();

server_thread.join().unwrap();
assert!(
start.elapsed() > Duration::from_millis(1000),
"elapsed {:?}",
start.elapsed()
);
assert!(
start.elapsed() < Duration::from_millis(1100),
"elapsed {:?}",
start.elapsed()
);
client_thread.join().unwrap();
}

0 comments on commit 8392ef0

Please sign in to comment.