diff --git a/chain/src/chain.rs b/chain/src/chain.rs index 1ff3450333..91fcffbc82 100644 --- a/chain/src/chain.rs +++ b/chain/src/chain.rs @@ -646,8 +646,7 @@ impl Chain { /// TODO - Write this data to disk and validate the rebuilt kernel MMR. pub fn kernel_data_write(&self, reader: &mut Read) -> Result<(), Error> { let mut count = 0; - let mut stream = - StreamingReader::new(reader, ProtocolVersion::local(), Duration::from_secs(1)); + let mut stream = StreamingReader::new(reader, ProtocolVersion::local()); while let Ok(_kernel) = TxKernelEntry::read(&mut stream) { count += 1; } diff --git a/core/src/core/block.rs b/core/src/core/block.rs index 1270d9bc7c..7ba27be020 100644 --- a/core/src/core/block.rs +++ b/core/src/core/block.rs @@ -180,9 +180,9 @@ impl Default for HeaderVersion { // self-conscious increment function courtesy of Jasper impl HeaderVersion { - fn next(&self) -> Self { - Self(self.0+1) - } + fn next(&self) -> Self { + Self(self.0 + 1) + } } impl HeaderVersion { diff --git a/core/src/global.rs b/core/src/global.rs index 719f0a7b6b..1bf265cc66 100644 --- a/core/src/global.rs +++ b/core/src/global.rs @@ -17,13 +17,15 @@ //! should be used sparingly. use crate::consensus::{ - HeaderInfo, valid_header_version, graph_weight, BASE_EDGE_BITS, BLOCK_TIME_SEC, + graph_weight, valid_header_version, HeaderInfo, BASE_EDGE_BITS, BLOCK_TIME_SEC, COINBASE_MATURITY, CUT_THROUGH_HORIZON, DAY_HEIGHT, DEFAULT_MIN_EDGE_BITS, DIFFICULTY_ADJUST_WINDOW, INITIAL_DIFFICULTY, MAX_BLOCK_WEIGHT, PROOFSIZE, SECOND_POW_EDGE_BITS, STATE_SYNC_THRESHOLD, }; use crate::core::block::HeaderVersion; -use crate::pow::{self, new_cuckatoo_ctx, new_cuckaroo_ctx, new_cuckarood_ctx, EdgeType, PoWContext}; +use crate::pow::{ + self, new_cuckaroo_ctx, new_cuckarood_ctx, new_cuckatoo_ctx, EdgeType, PoWContext, +}; /// An enum collecting sets of parameters used throughout the /// code wherever mining is needed. This should allow for /// different sets of parameters for different purposes, @@ -164,14 +166,16 @@ where match chain_type { // Mainnet has Cuckaroo(d)29 for AR and Cuckatoo31+ for AF ChainTypes::Mainnet if edge_bits > 29 => new_cuckatoo_ctx(edge_bits, proof_size, max_sols), - ChainTypes::Mainnet if valid_header_version(height, HeaderVersion::new(2)) - => new_cuckarood_ctx(edge_bits, proof_size), + ChainTypes::Mainnet if valid_header_version(height, HeaderVersion::new(2)) => { + new_cuckarood_ctx(edge_bits, proof_size) + } ChainTypes::Mainnet => new_cuckaroo_ctx(edge_bits, proof_size), // Same for Floonet ChainTypes::Floonet if edge_bits > 29 => new_cuckatoo_ctx(edge_bits, proof_size, max_sols), - ChainTypes::Floonet if valid_header_version(height, HeaderVersion::new(2)) - => new_cuckarood_ctx(edge_bits, proof_size), + ChainTypes::Floonet if valid_header_version(height, HeaderVersion::new(2)) => { + new_cuckarood_ctx(edge_bits, proof_size) + } ChainTypes::Floonet => new_cuckaroo_ctx(edge_bits, proof_size), // Everything else is Cuckatoo only diff --git a/core/src/pow.rs b/core/src/pow.rs index 8d97effe2f..e744fe4750 100644 --- a/core/src/pow.rs +++ b/core/src/pow.rs @@ -33,9 +33,9 @@ use num; #[macro_use] mod common; -pub mod cuckatoo; pub mod cuckaroo; pub mod cuckarood; +pub mod cuckatoo; mod error; #[allow(dead_code)] pub mod lean; @@ -49,9 +49,9 @@ use chrono::prelude::{DateTime, NaiveDateTime, Utc}; pub use self::common::EdgeType; pub use self::types::*; -pub use crate::pow::cuckatoo::{new_cuckatoo_ctx, CuckatooContext}; pub use crate::pow::cuckaroo::{new_cuckaroo_ctx, CuckarooContext}; pub use crate::pow::cuckarood::{new_cuckarood_ctx, CuckaroodContext}; +pub use crate::pow::cuckatoo::{new_cuckatoo_ctx, CuckatooContext}; pub use crate::pow::error::Error; const MAX_SOLS: u32 = 10; diff --git a/core/src/pow/cuckarood.rs b/core/src/pow/cuckarood.rs index e6274a2381..be46712fd4 100644 --- a/core/src/pow/cuckarood.rs +++ b/core/src/pow/cuckarood.rs @@ -22,11 +22,11 @@ //! a rotation by 25, halves the number of graph nodes in each partition, //! and requires cycles to alternate between even- and odd-indexed edges. +use crate::global; use crate::pow::common::{CuckooParams, EdgeType}; use crate::pow::error::{Error, ErrorKind}; use crate::pow::siphash::siphash_block; use crate::pow::{PoWContext, Proof}; -use crate::global; /// Instantiate a new CuckaroodContext as a PowContext. Note that this can't /// be moved in the PoWContext trait as this particular trait needs to be @@ -69,8 +69,7 @@ where fn verify(&self, proof: &Proof) -> Result<(), Error> { if proof.proof_size() != global::proofsize() { - return Err(ErrorKind::Verification( - "wrong cycle length".to_owned(),))?; + return Err(ErrorKind::Verification("wrong cycle length".to_owned()))?; } let nonces = &proof.nonces; let mut uvs = vec![0u64; 2 * proof.proof_size()]; @@ -92,10 +91,10 @@ where } let edge = to_edge!(T, siphash_block(&self.params.siphash_keys, nonces[n], 25)); let idx = 4 * ndir[dir] + 2 * dir; - uvs[idx ] = to_u64!( edge & nodemask); - uvs[idx+1] = to_u64!((edge >> 32) & nodemask); - xor0 ^= uvs[idx ]; - xor1 ^= uvs[idx+1]; + uvs[idx] = to_u64!(edge & nodemask); + uvs[idx + 1] = to_u64!((edge >> 32) & nodemask); + xor0 ^= uvs[idx]; + xor1 ^= uvs[idx + 1]; ndir[dir] += 1; } if xor0 | xor1 != 0 { @@ -110,7 +109,8 @@ where // follow cycle j = i; for k in (((i % 4) ^ 2)..(2 * self.params.proof_size)).step_by(4) { - if uvs[k] == uvs[i] { // find reverse edge endpoint identical to one at i + if uvs[k] == uvs[i] { + // find reverse edge endpoint identical to one at i if j != i { return Err(ErrorKind::Verification("branch in cycle".to_owned()))?; } @@ -173,11 +173,15 @@ mod test { fn cuckarood19_29_vectors() { let mut ctx19 = new_impl::(19, 42); ctx19.params.siphash_keys = V1_19_HASH.clone(); - assert!(ctx19.verify(&Proof::new(V1_19_SOL.to_vec().clone())).is_ok()); + assert!(ctx19 + .verify(&Proof::new(V1_19_SOL.to_vec().clone())) + .is_ok()); assert!(ctx19.verify(&Proof::zero(42)).is_err()); let mut ctx29 = new_impl::(29, 42); ctx29.params.siphash_keys = V2_29_HASH.clone(); - assert!(ctx29.verify(&Proof::new(V2_29_SOL.to_vec().clone())).is_ok()); + assert!(ctx29 + .verify(&Proof::new(V2_29_SOL.to_vec().clone())) + .is_ok()); assert!(ctx29.verify(&Proof::zero(42)).is_err()); } diff --git a/core/src/ser.rs b/core/src/ser.rs index 57a0898167..e7de7e0a0f 100644 --- a/core/src/ser.rs +++ b/core/src/ser.rs @@ -22,7 +22,6 @@ use crate::core::hash::{DefaultHashable, Hash, Hashed}; use crate::global::PROTOCOL_VERSION; use crate::keychain::{BlindingFactor, Identifier, IDENTIFIER_SIZE}; -use crate::util::read_write::read_exact; use crate::util::secp::constants::{ AGG_SIGNATURE_SIZE, COMPRESSED_PUBLIC_KEY_SIZE, MAX_PROOF_SIZE, PEDERSEN_COMMITMENT_SIZE, SECRET_KEY_SIZE, @@ -35,7 +34,6 @@ use byteorder::{BigEndian, ByteOrder, ReadBytesExt}; use std::fmt::{self, Debug}; use std::io::{self, Read, Write}; use std::marker; -use std::time::Duration; use std::{cmp, error}; /// Possible errors deriving from serializing or deserializing. @@ -459,22 +457,16 @@ pub struct StreamingReader<'a> { total_bytes_read: u64, version: ProtocolVersion, stream: &'a mut dyn Read, - timeout: Duration, } impl<'a> StreamingReader<'a> { /// Create a new streaming reader with the provided underlying stream. /// Also takes a duration to be used for each individual read_exact call. - pub fn new( - stream: &'a mut dyn Read, - version: ProtocolVersion, - timeout: Duration, - ) -> StreamingReader<'a> { + pub fn new(stream: &'a mut dyn Read, version: ProtocolVersion) -> StreamingReader<'a> { StreamingReader { total_bytes_read: 0, version, stream, - timeout, } } @@ -521,7 +513,7 @@ impl<'a> Reader for StreamingReader<'a> { /// Read a fixed number of bytes. fn read_fixed_bytes(&mut self, len: usize) -> Result, Error> { let mut buf = vec![0u8; len]; - read_exact(&mut self.stream, &mut buf, self.timeout, true)?; + self.stream.read_exact(&mut buf)?; self.total_bytes_read += len as u64; Ok(buf) } diff --git a/p2p/src/conn.rs b/p2p/src/conn.rs index e916669b8a..8c098fce16 100644 --- a/p2p/src/conn.rs +++ b/p2p/src/conn.rs @@ -20,24 +20,26 @@ //! forces us to go through some additional gymnastic to loop over the async //! stream and make sure we get the right number of bytes out. +use crate::core::ser; +use crate::core::ser::{FixedLength, ProtocolVersion}; +use crate::msg::{ + read_body, read_discard, read_header, read_item, write_to_buf, MsgHeader, MsgHeaderWrapper, + Type, +}; +use crate::types::Error; +use crate::util::{RateCounter, RwLock}; use std::fs::File; use std::io::{self, Read, Write}; use std::net::{Shutdown, TcpStream}; +use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::{mpsc, Arc}; +use std::time::Duration; use std::{ cmp, thread::{self, JoinHandle}, - time, }; -use crate::core::ser::{self, FixedLength, ProtocolVersion}; -use crate::msg::{ - read_body, read_discard, read_header, read_item, write_to_buf, MsgHeader, MsgHeaderWrapper, - Type, -}; -use crate::types::Error; -use crate::util::read_write::{read_exact, write_all}; -use crate::util::{RateCounter, RwLock}; +const IO_TIMEOUT: Duration = Duration::from_millis(1000); /// A trait to be implemented in order to receive messages from the /// connection. Allows providing an optional response. @@ -56,7 +58,11 @@ macro_rules! try_break { ($inner:expr) => { match $inner { Ok(v) => Some(v), - Err(Error::Connection(ref e)) if e.kind() == io::ErrorKind::WouldBlock => None, + Err(Error::Connection(ref e)) + if e.kind() == io::ErrorKind::WouldBlock || e.kind() == io::ErrorKind::TimedOut => + { + None + } Err(Error::Store(_)) | Err(Error::Chain(_)) | Err(Error::Internal) @@ -106,12 +112,7 @@ impl<'a> Message<'a> { while written < len { let read_len = cmp::min(8000, len - written); let mut buf = vec![0u8; read_len]; - read_exact( - &mut self.stream, - &mut buf[..], - time::Duration::from_secs(10), - true, - )?; + self.stream.read_exact(&mut buf[..])?; writer.write_all(&mut buf)?; written += read_len; } @@ -151,7 +152,7 @@ impl<'a> Response<'a> { self.version, )?; msg.append(&mut self.body); - write_all(&mut self.stream, &msg[..], time::Duration::from_secs(10))?; + self.stream.write_all(&msg[..])?; tracker.inc_sent(msg.len() as u64); if let Some(mut file) = self.attachment { @@ -160,7 +161,7 @@ impl<'a> Response<'a> { match file.read(&mut buf[..]) { Ok(0) => break, Ok(n) => { - write_all(&mut self.stream, &buf[..n], time::Duration::from_secs(10))?; + self.stream.write_all(&buf[..n])?; // Increase sent bytes "quietly" without incrementing the counter. // (In a loop here for the single attachment). tracker.inc_quiet_sent(n as u64); @@ -181,34 +182,39 @@ pub const SEND_CHANNEL_CAP: usize = 10; pub struct StopHandle { /// Channel to close the connection - pub close_channel: mpsc::Sender<()>, + stopped: Arc, // we need Option to take ownhership of the handle in stop() - peer_thread: Option>, + reader_thread: Option>, + writer_thread: Option>, } impl StopHandle { /// Schedule this connection to safely close via the async close_channel. pub fn stop(&self) { - if self.close_channel.send(()).is_err() { - debug!("peer's close_channel is disconnected, must be stopped already"); - return; - } + self.stopped.store(true, Ordering::Relaxed); } pub fn wait(&mut self) { - if let Some(peer_thread) = self.peer_thread.take() { - // wait only if other thread is calling us, eg shutdown - if thread::current().id() != peer_thread.thread().id() { - debug!("waiting for thread {:?} exit", peer_thread.thread().id()); - if let Err(e) = peer_thread.join() { - error!("failed to wait for peer thread to stop: {:?}", e); - } - } else { - debug!( - "attempt to wait for thread {:?} from itself", - peer_thread.thread().id() - ); + if let Some(reader_thread) = self.reader_thread.take() { + self.join_thread(reader_thread); + } + if let Some(writer_thread) = self.writer_thread.take() { + self.join_thread(writer_thread); + } + } + + fn join_thread(&self, peer_thread: JoinHandle<()>) { + // wait only if other thread is calling us, eg shutdown + if thread::current().id() != peer_thread.thread().id() { + debug!("waiting for thread {:?} exit", peer_thread.thread().id()); + if let Err(e) = peer_thread.join() { + error!("failed to stop peer thread: {:?}", e); } + } else { + debug!( + "attempt to stop thread {:?} from itself", + peer_thread.thread().id() + ); } } } @@ -277,20 +283,27 @@ where H: MessageHandler, { let (send_tx, send_rx) = mpsc::sync_channel(SEND_CHANNEL_CAP); - let (close_tx, close_rx) = mpsc::channel(); stream - .set_nonblocking(true) - .expect("Non-blocking IO not available."); - let peer_thread = poll(stream, version, handler, send_rx, close_rx, tracker)?; + .set_read_timeout(Some(IO_TIMEOUT)) + .expect("can't set read timeout"); + stream + .set_write_timeout(Some(IO_TIMEOUT)) + .expect("can't set read timeout"); + + let stopped = Arc::new(AtomicBool::new(false)); + + let (reader_thread, writer_thread) = + poll(stream, version, handler, send_rx, stopped.clone(), tracker)?; Ok(( ConnHandle { send_channel: send_tx, }, StopHandle { - close_channel: close_tx, - peer_thread: Some(peer_thread), + stopped, + reader_thread: Some(reader_thread), + writer_thread: Some(writer_thread), }, )) } @@ -300,24 +313,24 @@ fn poll( version: ProtocolVersion, handler: H, send_rx: mpsc::Receiver>, - close_rx: mpsc::Receiver<()>, + stopped: Arc, tracker: Arc, -) -> io::Result> +) -> io::Result<(JoinHandle<()>, JoinHandle<()>)> where H: MessageHandler, { // Split out tcp stream out into separate reader/writer halves. let mut reader = conn.try_clone().expect("clone conn for reader failed"); let mut writer = conn.try_clone().expect("clone conn for writer failed"); + let mut responder = conn.try_clone().expect("clone conn for writer failed"); + let reader_stopped = stopped.clone(); - thread::Builder::new() - .name("peer".to_string()) + let reader_thread = thread::Builder::new() + .name("peer_read".to_string()) .spawn(move || { - let sleep_time = time::Duration::from_millis(5); - let mut retry_send = Err(()); loop { // check the read end - match try_break!(read_header(&mut reader, version, None)) { + match try_break!(read_header(&mut reader, version)) { Some(MsgHeaderWrapper::Known(header)) => { let msg = Message::from_header(header, &mut reader, version); @@ -331,7 +344,7 @@ where tracker.inc_received(MsgHeader::LEN as u64 + msg.header.msg_len); if let Some(Some(resp)) = - try_break!(handler.consume(msg, &mut writer, tracker.clone())) + try_break!(handler.consume(msg, &mut responder, tracker.clone())) { try_break!(resp.write(tracker.clone())); } @@ -345,35 +358,48 @@ where None => {} } - // check the write end, use or_else so try_recv is lazily eval'd - let maybe_data = retry_send.or_else(|_| send_rx.try_recv()); + // check the close channel + if reader_stopped.load(Ordering::Relaxed) { + break; + } + } + + debug!( + "Shutting down reader connection with {}", + reader + .peer_addr() + .map(|a| a.to_string()) + .unwrap_or("?".to_owned()) + ); + let _ = reader.shutdown(Shutdown::Both); + })?; + + let writer_thread = thread::Builder::new() + .name("peer_read".to_string()) + .spawn(move || { + let mut retry_send = Err(()); + loop { + let maybe_data = retry_send.or_else(|_| send_rx.recv_timeout(IO_TIMEOUT)); retry_send = Err(()); if let Ok(data) = maybe_data { - let written = try_break!(write_all( - &mut writer, - &data[..], - std::time::Duration::from_secs(10) - ) - .map_err(&From::from)); + let written = try_break!(writer.write_all(&data[..]).map_err(&From::from)); if written.is_none() { retry_send = Ok(data); } } - // check the close channel - if let Ok(_) = close_rx.try_recv() { + if stopped.load(Ordering::Relaxed) { break; } - - thread::sleep(sleep_time); } debug!( - "Shutting down connection with {}", - conn.peer_addr() + "Shutting down reader connection with {}", + writer + .peer_addr() .map(|a| a.to_string()) .unwrap_or("?".to_owned()) ); - let _ = conn.shutdown(Shutdown::Both); - }) + })?; + Ok((reader_thread, writer_thread)) } diff --git a/p2p/src/msg.rs b/p2p/src/msg.rs index 67338d3c04..e79c2b827e 100644 --- a/p2p/src/msg.rs +++ b/p2p/src/msg.rs @@ -14,10 +14,6 @@ //! Message types that transit over the network and related serialization code. -use num::FromPrimitive; -use std::io::{Read, Write}; -use std::time; - use crate::core::core::hash::Hash; use crate::core::core::BlockHeader; use crate::core::pow::Difficulty; @@ -28,7 +24,8 @@ use crate::core::{consensus, global}; use crate::types::{ Capabilities, Error, PeerAddr, ReasonForBan, MAX_BLOCK_HEADERS, MAX_LOCATORS, MAX_PEER_ADDRS, }; -use crate::util::read_write::read_exact; +use num::FromPrimitive; +use std::io::{Read, Write}; /// Grin's user agent with current version pub const USER_AGENT: &'static str = concat!("MW/Grin ", env!("CARGO_PKG_VERSION")); @@ -126,14 +123,9 @@ fn magic() -> [u8; 2] { pub fn read_header( stream: &mut dyn Read, version: ProtocolVersion, - msg_type: Option, ) -> Result { let mut head = vec![0u8; MsgHeader::LEN]; - if Some(Type::Hand) == msg_type { - read_exact(stream, &mut head, time::Duration::from_millis(10), true)?; - } else { - read_exact(stream, &mut head, time::Duration::from_secs(10), false)?; - } + stream.read_exact(&mut head)?; let header = ser::deserialize::(&mut &head[..], version)?; Ok(header) } @@ -145,8 +137,7 @@ pub fn read_item( stream: &mut dyn Read, version: ProtocolVersion, ) -> Result<(T, u64), Error> { - let timeout = time::Duration::from_secs(20); - let mut reader = StreamingReader::new(stream, version, timeout); + let mut reader = StreamingReader::new(stream, version); let res = T::read(&mut reader)?; Ok((res, reader.total_bytes_read())) } @@ -159,14 +150,14 @@ pub fn read_body( version: ProtocolVersion, ) -> Result { let mut body = vec![0u8; h.msg_len as usize]; - read_exact(stream, &mut body, time::Duration::from_secs(20), true)?; + stream.read_exact(&mut body)?; ser::deserialize(&mut &body[..], version).map_err(From::from) } /// Read (an unknown) message from the provided stream and discard it. pub fn read_discard(msg_len: u64, stream: &mut dyn Read) -> Result<(), Error> { let mut buffer = vec![0u8; msg_len as usize]; - read_exact(stream, &mut buffer, time::Duration::from_secs(20), true)?; + stream.read_exact(&mut buffer)?; Ok(()) } @@ -176,7 +167,7 @@ pub fn read_message( version: ProtocolVersion, msg_type: Type, ) -> Result { - match read_header(stream, version, Some(msg_type))? { + match read_header(stream, version)? { MsgHeaderWrapper::Known(header) => { if header.msg_type == msg_type { read_body(&header, stream, version) diff --git a/store/src/types.rs b/store/src/types.rs index 2d4de2870f..5a8161f580 100644 --- a/store/src/types.rs +++ b/store/src/types.rs @@ -16,14 +16,14 @@ use memmap; use tempfile::tempfile; use crate::core::ser::{ - self, BinWriter, FixedLength, ProtocolVersion, Readable, Reader, StreamingReader, Writeable, Writer, + self, BinWriter, FixedLength, ProtocolVersion, Readable, Reader, StreamingReader, Writeable, + Writer, }; use std::fmt::Debug; use std::fs::{self, File, OpenOptions}; use std::io::{self, BufReader, BufWriter, Seek, SeekFrom, Write}; use std::marker; use std::path::{Path, PathBuf}; -use std::time; /// Represents a single entry in the size_file. /// Offset (in bytes) and size (in bytes) of a variable sized entry @@ -482,8 +482,7 @@ where { let reader = File::open(&self.path)?; let mut buf_reader = BufReader::new(reader); - let mut streaming_reader = - StreamingReader::new(&mut buf_reader, self.version, time::Duration::from_secs(1)); + let mut streaming_reader = StreamingReader::new(&mut buf_reader, self.version); let mut buf_writer = BufWriter::new(File::create(&tmp_path)?); let mut bin_writer = BinWriter::new(&mut buf_writer, self.version); @@ -529,11 +528,7 @@ where { let reader = File::open(&self.path)?; let mut buf_reader = BufReader::new(reader); - let mut streaming_reader = StreamingReader::new( - &mut buf_reader, - self.version, - time::Duration::from_secs(1), - ); + let mut streaming_reader = StreamingReader::new(&mut buf_reader, self.version); let mut buf_writer = BufWriter::new(File::create(&tmp_path)?); let mut bin_writer = BinWriter::new(&mut buf_writer, self.version); diff --git a/util/src/lib.rs b/util/src/lib.rs index d4577ed43b..d48a363f7b 100644 --- a/util/src/lib.rs +++ b/util/src/lib.rs @@ -47,9 +47,6 @@ pub use crate::types::{LogLevel, LoggingConfig, ZeroingString}; pub mod macros; -// read_exact and write_all impls -pub mod read_write; - // other utils #[allow(unused_imports)] use std::ops::Deref; diff --git a/util/src/read_write.rs b/util/src/read_write.rs deleted file mode 100644 index 15e3f3f72a..0000000000 --- a/util/src/read_write.rs +++ /dev/null @@ -1,110 +0,0 @@ -// Copyright 2018 The Grin Developers -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -//! Custom impls of read_exact and write_all to work around async stream restrictions. - -use std::io; -use std::io::prelude::*; -use std::thread; -use std::time::Duration; - -/// The default implementation of read_exact is useless with an async stream (TcpStream) as -/// it will return as soon as something has been read, regardless of -/// whether the buffer has been filled (and then errors). This implementation -/// will block until it has read exactly `len` bytes and returns them as a -/// `vec`. Except for a timeout, this implementation will never return a -/// partially filled buffer. -/// -/// The timeout in milliseconds aborts the read when it's met. Note that the -/// time is not guaranteed to be exact. To support cases where we want to poll -/// instead of blocking, a `block_on_empty` boolean, when false, ensures -/// `read_exact` returns early with a `io::ErrorKind::WouldBlock` if nothing -/// has been read from the socket. -pub fn read_exact( - stream: &mut dyn Read, - mut buf: &mut [u8], - timeout: Duration, - block_on_empty: bool, -) -> io::Result<()> { - let sleep_time = Duration::from_micros(10); - let mut count = Duration::new(0, 0); - - let mut read = 0; - loop { - match stream.read(buf) { - Ok(0) => { - return Err(io::Error::new( - io::ErrorKind::ConnectionAborted, - "read_exact", - )); - } - Ok(n) => { - let tmp = buf; - buf = &mut tmp[n..]; - read += n; - } - Err(ref e) if e.kind() == io::ErrorKind::Interrupted => {} - Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { - if read == 0 && !block_on_empty { - return Err(io::Error::new(io::ErrorKind::WouldBlock, "read_exact")); - } - } - Err(e) => return Err(e), - } - if !buf.is_empty() { - thread::sleep(sleep_time); - count += sleep_time; - } else { - break; - } - if count > timeout { - return Err(io::Error::new( - io::ErrorKind::TimedOut, - "reading from stream", - )); - } - } - Ok(()) -} - -/// Same as `read_exact` but for writing. -pub fn write_all(stream: &mut dyn Write, mut buf: &[u8], timeout: Duration) -> io::Result<()> { - let sleep_time = Duration::from_micros(10); - let mut count = Duration::new(0, 0); - - while !buf.is_empty() { - match stream.write(buf) { - Ok(0) => { - return Err(io::Error::new( - io::ErrorKind::WriteZero, - "failed to write whole buffer", - )); - } - Ok(n) => buf = &buf[n..], - Err(ref e) if e.kind() == io::ErrorKind::Interrupted => {} - Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {} - Err(e) => return Err(e), - } - if !buf.is_empty() { - thread::sleep(sleep_time); - count += sleep_time; - } else { - break; - } - if count > timeout { - return Err(io::Error::new(io::ErrorKind::TimedOut, "writing to stream")); - } - } - Ok(()) -}