Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(node): Implement sessions #130

Merged
merged 14 commits into from
Nov 14, 2023
7 changes: 5 additions & 2 deletions celestia/src/native.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,14 +52,17 @@ pub(crate) async fn run(args: Params) -> Result<()> {
let network_id = network_id(network).to_owned();
let genesis_hash = network_genesis(network);

info!("Initializing store");

let store = if let Some(db_path) = args.store {
SledStore::new_in_path(db_path).await?
} else {
SledStore::new(network_id.clone()).await?
};

if let Ok(store_height) = store.head_height().await {
info!("Initialised store with head height: {store_height}");
match store.head_height().await {
Ok(height) => info!("Initialised store with head height: {height}"),
Err(_) => info!("Initialised new store"),
}

let node = Node::new(NodeConfig {
Expand Down
5 changes: 2 additions & 3 deletions node/src/executor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -107,13 +107,12 @@ mod imp {
#[derive(Debug)]
pub(crate) struct Elapsed;

#[allow(dead_code)]
pub(crate) fn timeout<F>(duration: Duration, future: F) -> Timeout<F>
where
F: Future,
{
let millis = u32::try_from(duration.as_millis().max(1)).unwrap_or(u32::MAX);
let delay = TimeoutFuture::new(millis);
let delay = SendWrapper::new(TimeoutFuture::new(millis));

Timeout {
value: future,
Expand All @@ -134,7 +133,7 @@ mod imp {
#[pin]
value: T,
#[pin]
delay: TimeoutFuture,
delay: SendWrapper<TimeoutFuture>,
}
impl<T> Future for Timeout<T>
where
Expand Down
160 changes: 122 additions & 38 deletions node/src/header_ex.rs
Original file line number Diff line number Diff line change
@@ -1,46 +1,54 @@
use std::io;
use std::sync::Arc;
use std::task::{Context, Poll};
use tracing::warn;

use async_trait::async_trait;
use celestia_proto::p2p::pb::{HeaderRequest, HeaderResponse};
use celestia_types::ExtendedHeader;
use futures::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
use instant::{Duration, Instant};
use libp2p::{
core::Endpoint,
request_response::{self, Codec, InboundFailure, OutboundFailure, ProtocolSupport},
swarm::{
ConnectionDenied, ConnectionId, FromSwarm, NetworkBehaviour, THandlerInEvent,
handler::ConnectionEvent, ConnectionDenied, ConnectionHandler, ConnectionHandlerEvent,
ConnectionId, FromSwarm, NetworkBehaviour, SubstreamProtocol, THandlerInEvent,
THandlerOutEvent, ToSwarm,
},
Multiaddr, PeerId, StreamProtocol,
};
use prost::Message;
use tracing::debug;
use tracing::instrument;
use tracing::{debug, instrument, warn};

mod client;
mod server;
mod utils;
pub(crate) mod utils;

use crate::executor::timeout;
use crate::header_ex::client::HeaderExClientHandler;
use crate::header_ex::server::HeaderExServerHandler;
use crate::p2p::P2pError;
use crate::peer_tracker::PeerTracker;
use crate::store::Store;
use crate::utils::{protocol_id, OneshotResultSender};

/// Max request size in bytes
const REQUEST_SIZE_MAXIMUM: usize = 1024;
/// Max response size in bytes
const RESPONSE_SIZE_MAXIMUM: usize = 10 * 1024 * 1024;
/// Size limit of a request in bytes
const REQUEST_SIZE_LIMIT: usize = 1024;
/// Time limit on reading/writing a request
const REQUEST_TIME_LIMIT: Duration = Duration::from_secs(1);
/// Size limit of a response in bytes
const RESPONSE_SIZE_LIMIT: usize = 10 * 1024 * 1024;
/// Time limit on reading/writing a response
const RESPONSE_TIME_LIMIT: Duration = Duration::from_secs(5);
/// Substream negotiation timeout
const NEGOTIATION_TIMEOUT: Duration = Duration::from_secs(1);

type RequestType = HeaderRequest;
type ResponseType = Vec<HeaderResponse>;
type ReqRespBehaviour = request_response::Behaviour<HeaderCodec>;
type ReqRespEvent = request_response::Event<RequestType, ResponseType>;
type ReqRespMessage = request_response::Message<RequestType, ResponseType>;
type ReqRespConnectionHandler = <ReqRespBehaviour as NetworkBehaviour>::ConnectionHandler;

pub(crate) struct HeaderExBehaviour<S>
where
Expand Down Expand Up @@ -176,7 +184,7 @@ impl<S> NetworkBehaviour for HeaderExBehaviour<S>
where
S: Store + 'static,
{
type ConnectionHandler = <ReqRespBehaviour as NetworkBehaviour>::ConnectionHandler;
type ConnectionHandler = ConnHandler;
type ToSwarm = ();

fn handle_established_inbound_connection(
Expand All @@ -186,12 +194,9 @@ where
local_addr: &Multiaddr,
remote_addr: &Multiaddr,
) -> Result<Self::ConnectionHandler, ConnectionDenied> {
self.req_resp.handle_established_inbound_connection(
connection_id,
peer,
local_addr,
remote_addr,
)
self.req_resp
.handle_established_inbound_connection(connection_id, peer, local_addr, remote_addr)
.map(ConnHandler)
}

fn handle_established_outbound_connection(
Expand All @@ -201,12 +206,9 @@ where
addr: &Multiaddr,
role_override: Endpoint,
) -> Result<Self::ConnectionHandler, ConnectionDenied> {
self.req_resp.handle_established_outbound_connection(
connection_id,
peer,
addr,
role_override,
)
self.req_resp
.handle_established_outbound_connection(connection_id, peer, addr, role_override)
.map(ConnHandler)
}

fn handle_pending_inbound_connection(
Expand Down Expand Up @@ -274,6 +276,62 @@ where
}
}

pub(crate) struct ConnHandler(ReqRespConnectionHandler);

impl ConnectionHandler for ConnHandler {
type ToBehaviour = <ReqRespConnectionHandler as ConnectionHandler>::ToBehaviour;
type FromBehaviour = <ReqRespConnectionHandler as ConnectionHandler>::FromBehaviour;
type InboundProtocol = <ReqRespConnectionHandler as ConnectionHandler>::InboundProtocol;
type InboundOpenInfo = <ReqRespConnectionHandler as ConnectionHandler>::InboundOpenInfo;
type OutboundProtocol = <ReqRespConnectionHandler as ConnectionHandler>::OutboundProtocol;
type OutboundOpenInfo = <ReqRespConnectionHandler as ConnectionHandler>::OutboundOpenInfo;

fn listen_protocol(&self) -> SubstreamProtocol<Self::InboundProtocol, Self::InboundOpenInfo> {
self.0.listen_protocol().with_timeout(NEGOTIATION_TIMEOUT)
}

fn poll(
&mut self,
cx: &mut Context<'_>,
) -> Poll<
ConnectionHandlerEvent<Self::OutboundProtocol, Self::OutboundOpenInfo, Self::ToBehaviour>,
> {
match self.0.poll(cx) {
Poll::Ready(ConnectionHandlerEvent::OutboundSubstreamRequest { protocol }) => {
Poll::Ready(ConnectionHandlerEvent::OutboundSubstreamRequest {
protocol: protocol.with_timeout(NEGOTIATION_TIMEOUT),
})
}
ev => ev,
}
}

fn on_behaviour_event(&mut self, event: Self::FromBehaviour) {
self.0.on_behaviour_event(event)
}

fn on_connection_event(
&mut self,
event: ConnectionEvent<
'_,
Self::InboundProtocol,
Self::OutboundProtocol,
Self::InboundOpenInfo,
Self::OutboundOpenInfo,
>,
) {
self.0.on_connection_event(event)
}

fn connection_keep_alive(&self) -> bool {
self.0.connection_keep_alive()
}

fn poll_close(&mut self, cx: &mut Context<'_>) -> Poll<Option<Self::ToBehaviour>> {
self.0.poll_close(cx)
}
}

#[derive(Clone, Copy, Debug, Default)]
pub(crate) struct HeaderCodec;

Expand All @@ -287,14 +345,19 @@ impl Codec for HeaderCodec {
where
T: AsyncRead + Unpin + Send,
{
let data = read_up_to(io, REQUEST_SIZE_MAXIMUM).await?;
let data = read_up_to(io, REQUEST_SIZE_LIMIT, REQUEST_TIME_LIMIT).await?;

if data.len() >= REQUEST_SIZE_MAXIMUM {
if data.len() >= REQUEST_SIZE_LIMIT {
debug!("Message filled the whole buffer (len: {})", data.len());
}

parse_header_request(&data)
.ok_or_else(|| io::Error::new(io::ErrorKind::Other, "invalid request"))
parse_header_request(&data).ok_or_else(|| {
// There are two cases that can reach here:
//
// 1. The request is invalid
// 2. The request is incomplete because of the size limit or time limit
io::Error::new(io::ErrorKind::Other, "invalid or incomplete request")
})
}

async fn read_response<T>(
Expand All @@ -305,9 +368,9 @@ impl Codec for HeaderCodec {
where
T: AsyncRead + Unpin + Send,
{
let data = read_up_to(io, RESPONSE_SIZE_MAXIMUM).await?;
let data = read_up_to(io, RESPONSE_SIZE_LIMIT, RESPONSE_TIME_LIMIT).await?;

if data.len() >= RESPONSE_SIZE_MAXIMUM {
if data.len() >= RESPONSE_SIZE_LIMIT {
debug!("Message filled the whole buffer (len: {})", data.len());
}

Expand All @@ -320,7 +383,14 @@ impl Codec for HeaderCodec {
}

if msgs.is_empty() {
return Err(io::Error::new(io::ErrorKind::Other, "invalid response"));
// There are two cases that can reach here:
//
// 1. The response is invalid
// 2. The response is incomplete because of the size limit or time limit
return Err(io::Error::new(
io::ErrorKind::Other,
"invalid or incomplete response",
));
}

Ok(msgs)
Expand All @@ -335,11 +405,13 @@ impl Codec for HeaderCodec {
where
T: AsyncWrite + Unpin + Send,
{
let mut buf = Vec::with_capacity(REQUEST_SIZE_MAXIMUM);
let mut buf = Vec::with_capacity(REQUEST_SIZE_LIMIT);

let _ = req.encode_length_delimited(&mut buf);

io.write_all(&buf).await?;
timeout(REQUEST_TIME_LIMIT, io.write_all(&buf))
.await
.map_err(|_| io::Error::new(io::ErrorKind::Other, "writing request timed out"))??;

Ok(())
}
Expand All @@ -353,7 +425,7 @@ impl Codec for HeaderCodec {
where
T: AsyncWrite + Unpin + Send,
{
let mut buf = Vec::with_capacity(RESPONSE_SIZE_MAXIMUM);
let mut buf = Vec::with_capacity(RESPONSE_SIZE_LIMIT);

for resp in resps {
if resp.encode_length_delimited(&mut buf).is_err() {
Expand All @@ -364,26 +436,38 @@ impl Codec for HeaderCodec {
}
}

io.write_all(&buf).await?;
timeout(RESPONSE_TIME_LIMIT, io.write_all(&buf))
.await
.map_err(|_| io::Error::new(io::ErrorKind::Other, "writing response timed out"))??;

Ok(())
}
}

async fn read_up_to<T>(io: &mut T, limit: usize) -> io::Result<Vec<u8>>
/// Reads up to `size_limit` within `time_limit`.
async fn read_up_to<T>(io: &mut T, size_limit: usize, time_limit: Duration) -> io::Result<Vec<u8>>
where
T: AsyncRead + Unpin + Send,
{
let mut buf = vec![0u8; limit];
let mut buf = vec![0u8; size_limit];
let mut read_len = 0;
let now = Instant::now();

loop {
if read_len == buf.len() {
// No empty space. Buffer is full.
break;
}

let len = io.read(&mut buf[read_len..]).await?;
let Some(time_limit) = time_limit.checked_sub(now.elapsed()) else {
break;
};

let len = match timeout(time_limit, io.read(&mut buf[read_len..])).await {
Ok(Ok(len)) => len,
Ok(Err(e)) => return Err(e),
Err(_) => break,
};

if len == 0 {
// EOF
Expand Down Expand Up @@ -511,7 +595,7 @@ mod tests {

#[async_test]
async fn test_decode_header_request_too_large() {
let too_long_message_len = REQUEST_SIZE_MAXIMUM + 1;
let too_long_message_len = REQUEST_SIZE_LIMIT + 1;
let mut length_delimiter_buffer = BytesMut::new();
prost::encode_length_delimiter(too_long_message_len, &mut length_delimiter_buffer).unwrap();
let mut reader = Cursor::new(length_delimiter_buffer);
Expand All @@ -529,7 +613,7 @@ mod tests {

#[async_test]
async fn test_decode_header_response_too_large() {
let too_long_message_len = RESPONSE_SIZE_MAXIMUM + 1;
let too_long_message_len = RESPONSE_SIZE_LIMIT + 1;
let mut length_delimiter_buffer = BytesMut::new();
encode_length_delimiter(too_long_message_len, &mut length_delimiter_buffer).unwrap();
let mut reader = Cursor::new(length_delimiter_buffer);
Expand Down
Loading
Loading