Skip to content

Commit

Permalink
[#184] Supports platforms that requires connect with the first packets
Browse files Browse the repository at this point in the history
  • Loading branch information
zonyitoo committed Jan 4, 2020
1 parent e107c81 commit b35758c
Show file tree
Hide file tree
Showing 6 changed files with 254 additions and 109 deletions.
4 changes: 2 additions & 2 deletions build/build-release
Original file line number Diff line number Diff line change
Expand Up @@ -57,5 +57,5 @@ function build() {
echo "* Done build package ${PKG_NAME}"
}

#build "x86_64-unknown-linux-musl"
build "x86_64-pc-windows-gnu"
build "x86_64-unknown-linux-musl"
#build "x86_64-pc-windows-gnu"
114 changes: 68 additions & 46 deletions src/relay/tcprelay/utils/tcp.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,19 @@
//! TCP API wrappers

use std::{io, mem::MaybeUninit, net::SocketAddr, pin::Pin, task, time::Duration};
use std::{
io,
mem::MaybeUninit,
net::SocketAddr,
pin::Pin,
task::{self, Poll},
time::Duration,
};

use bytes::{Buf, BufMut};
use log::{error, trace};
use tokio::{
io::{AsyncRead, AsyncWrite},
net,
net::{TcpListener as TokioTcpListener, TcpStream as TokioTcpStream},
};

use crate::{
Expand All @@ -22,7 +29,7 @@ use super::{

/// A TCP socket server, listening for connections.
pub struct TcpListener {
inner: net::TcpListener,
inner: TokioTcpListener,
}

impl TcpListener {
Expand All @@ -33,14 +40,17 @@ impl TcpListener {
if fast_open {
tfo::bind_listener(addr).await
} else {
net::TcpListener::bind(addr).await
TokioTcpListener::bind(addr).await
}
.map(|inner| TcpListener { inner })
}

/// Accept a new incoming connection from this listener.
pub async fn accept(&mut self) -> io::Result<(TcpStream, SocketAddr)> {
self.inner.accept().await.map(|(s, a)| (TcpStream { inner: s }, a))
self.inner
.accept()
.await
.map(|(s, a)| (TcpStream::from_tokio_stream(s), a))
}

/// Returns the local address that this listener is bound to.
Expand All @@ -51,17 +61,36 @@ impl TcpListener {

/// A TCP stream between a local and a remote socket.
pub struct TcpStream {
inner: net::TcpStream,
inner: TokioTcpStream,

// For TFO connect
// Some operating systems require calling specific APIs to perform actual connect
// with payload data. So we should keep the remote address here and use it in the
// first call of poll_write
connect_context: Option<tfo::ConnectContext>,
}

impl TcpStream {
fn new(inner: TokioTcpStream, connect_context: tfo::ConnectContext) -> TcpStream {
TcpStream {
inner,
connect_context: Some(connect_context),
}
}

fn from_tokio_stream(inner: TokioTcpStream) -> TcpStream {
TcpStream {
inner,
connect_context: None,
}
}

async fn connect(addr: &SocketAddr, fast_open: bool) -> io::Result<TcpStream> {
if fast_open {
tfo::connect_stream(addr).await
tfo::connect_stream(addr).await.map(|(s, c)| TcpStream::new(s, c))
} else {
net::TcpStream::connect(addr).await
TokioTcpStream::connect(addr).await.map(TcpStream::from_tokio_stream)
}
.map(|inner| TcpStream { inner })
}

/// Opens a TCP connection to a remote host.
Expand Down Expand Up @@ -197,67 +226,60 @@ impl AsyncRead for TcpStream {
self.inner.prepare_uninitialized_buffer(buf)
}

fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut task::Context<'_>,
buf: &mut [u8],
) -> task::Poll<io::Result<usize>> {
// Pin::new(&mut self.inner).poll_read(cx, buf)

match Pin::new(&mut self.inner).poll_read(cx, buf) {
task::Poll::Pending => task::Poll::Pending,
task::Poll::Ready(Ok(n)) => {
println!("READ {}", n);
task::Poll::Ready(Ok(n))
}
task::Poll::Ready(Err(err)) => {
println!("READ ERR {:?}", err);
task::Poll::Ready(Err(err))
}
}
fn poll_read(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>, buf: &mut [u8]) -> Poll<io::Result<usize>> {
Pin::new(&mut self.inner).poll_read(cx, buf)
}

fn poll_read_buf<B: BufMut>(
mut self: Pin<&mut Self>,
cx: &mut task::Context<'_>,
buf: &mut B,
) -> task::Poll<io::Result<usize>> {
) -> Poll<io::Result<usize>> {
Pin::new(&mut self.inner).poll_read_buf(cx, buf)
}
}

impl AsyncWrite for TcpStream {
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut task::Context<'_>,
buf: &[u8],
) -> task::Poll<Result<usize, io::Error>> {
match Pin::new(&mut self.inner).poll_write(cx, buf) {
task::Poll::Pending => task::Poll::Pending,
task::Poll::Ready(Ok(n)) => {
println!("WRITE {}", n);
task::Poll::Ready(Ok(n))
}
task::Poll::Ready(Err(err)) => {
println!("WRITE ERR {:?}", err);
task::Poll::Ready(Err(err))
}
fn poll_write(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>, buf: &[u8]) -> Poll<io::Result<usize>> {
if let Some(cc) = self.connect_context.take() {
// For TFO, first send has something different between operating systems
Poll::Ready(cc.connect_with_data(buf))
} else {
Pin::new(&mut self.inner).poll_write(cx, buf)
}
}

fn poll_flush(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> task::Poll<Result<(), io::Error>> {
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<io::Result<()>> {
Pin::new(&mut self.inner).poll_flush(cx)
}

fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> task::Poll<Result<(), io::Error>> {
fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<io::Result<()>> {
Pin::new(&mut self.inner).poll_shutdown(cx)
}

fn poll_write_buf<B: Buf>(
mut self: Pin<&mut Self>,
cx: &mut task::Context<'_>,
buf: &mut B,
) -> task::Poll<Result<usize, io::Error>> {
) -> Poll<io::Result<usize>> {
Pin::new(&mut self.inner).poll_write_buf(cx, buf)
}
}

#[cfg(unix)]
mod sys {
use super::{TcpListener, TcpStream};
use std::os::unix::prelude::*;

impl AsRawFd for TcpStream {
fn as_raw_fd(&self) -> RawFd {
self.inner.as_raw_fd()
}
}

impl AsRawFd for TcpListener {
fn as_raw_fd(&self) -> RawFd {
self.inner.as_raw_fd()
}
}
}
33 changes: 33 additions & 0 deletions src/relay/tcprelay/utils/tfo/bsd.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,39 @@ pub async fn bind_listener(addr: &SocketAddr) -> io::Result<TcpListener> {
}
}

pub struct ConnectContext {
// Reference to the partial connected socket fd
// This struct doesn't own the fd, so do not close it while dropping
socket: RawFd,

// Target address for calling `sendto`
remote_addr: SocketAddr,
}

impl ConnectContext {
/// Performing actual connect operation
pub fn connect_with_data(self, buf: &[u8]) -> io::Result<usize> {
unsafe {
let (saddr, saddr_len) = addr2raw(&self.remote_addr);

let ret = libc::sendto(
self.socket,
buf.as_ptr() as *const _ as *const libc::c_void,
buf.len(),
0,
saddr,
saddr_len,
);

if ret < 0 {
Err(Error::last_os_error())
} else {
Ok(ret as usize)
}
}
}
}

pub async fn connect_stream(addr: &SocketAddr) -> io::Result<TcpStream> {
let domain = match addr {
SocketAddr::V4(..) => libc::AF_INET,
Expand Down
Loading

0 comments on commit b35758c

Please sign in to comment.