Skip to content

Commit

Permalink
feat: support set TCP_NODELAY; code refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
ihciah committed Dec 22, 2022
1 parent 91da9f2 commit 109c654
Show file tree
Hide file tree
Showing 7 changed files with 160 additions and 95 deletions.
2 changes: 1 addition & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ license = "MIT/Apache-2.0"
name = "shadow-tls"
readme = "README.md"
repository = "https://github.com/ihciah/shadow-tls"
version = "0.2.3"
version = "0.2.5"

[dependencies]
monoio = {version = "0.0.9"}
Expand Down
14 changes: 11 additions & 3 deletions src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@ use rustls::{OwnedTrustAnchor, RootCertStore, ServerName};

use crate::{
stream::HashedReadStream,
util::{copy_with_application_data, copy_without_application_data, set_tcp_keepalive},
util::{copy_with_application_data, copy_without_application_data, mod_tcp_conn},
Opts,
};

/// ShadowTlsClient.
Expand All @@ -15,11 +16,17 @@ pub struct ShadowTlsClient<A> {
server_name: ServerName,
address: A,
password: String,
opts: Opts,
}

impl<A> ShadowTlsClient<A> {
/// Create new ShadowTlsClient.
pub fn new(server_name: &str, address: A, password: String) -> anyhow::Result<Self> {
pub fn new(
server_name: &str,
address: A,
password: String,
opts: Opts,
) -> anyhow::Result<Self> {
let mut root_store = RootCertStore::empty();
root_store.add_server_trust_anchors(webpki_roots::TLS_SERVER_ROOTS.0.iter().map(|ta| {
OwnedTrustAnchor::from_subject_spki_name_constraints(
Expand All @@ -40,6 +47,7 @@ impl<A> ShadowTlsClient<A> {
server_name,
address,
password,
opts,
})
}

Expand Down Expand Up @@ -72,7 +80,7 @@ impl<A> ShadowTlsClient<A> {
A: std::net::ToSocketAddrs,
{
let mut stream = TcpStream::connect(&self.address).await?;
set_tcp_keepalive(&mut stream);
mod_tcp_conn(&mut stream, true, self.opts.nodelay);
tracing::debug!("tcp connected, start handshaking");
let stream = HashedReadStream::new(stream, self.password.as_bytes())?;
let tls_stream = self
Expand Down
131 changes: 83 additions & 48 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,20 @@
#![feature(generic_associated_types)]
#![feature(type_alias_impl_trait)]

mod arg;
mod client;
mod server;
mod sip003;
mod stream;
mod util;

use std::{rc::Rc, sync::Arc};
use std::{fmt::Display, rc::Rc, sync::Arc};

use clap::{Parser, Subcommand};
use monoio::net::TcpListener;
use tracing::{error, info};
use tracing_subscriber::{filter::LevelFilter, fmt, prelude::*, EnvFilter};

use crate::{client::ShadowTlsClient, server::ShadowTlsServer, util::set_tcp_keepalive};
use crate::{client::ShadowTlsClient, server::ShadowTlsServer, util::mod_tcp_conn};

#[derive(Parser, Debug)]
#[clap(
Expand All @@ -27,8 +27,30 @@ use crate::{client::ShadowTlsClient, server::ShadowTlsServer, util::set_tcp_keep
struct Args {
#[clap(subcommand)]
cmd: Commands,
#[clap(short, long)]
#[clap(flatten)]
opts: Opts,
}

#[derive(Parser, Debug, Default, Clone)]
pub struct Opts {
#[clap(short, long, help = "Set parallelism manually")]
threads: Option<u8>,
#[clap(short, long, help = "Set TCP_NODELAY")]
nodelay: bool,
}

impl Display for Opts {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self.threads {
Some(t) => {
write!(f, "fixed {t} threads")
}
None => {
write!(f, "auto adjusted threads")
}
}?;
write!(f, "; nodelay: {}", self.nodelay)
}
}

#[derive(Subcommand, Debug)]
Expand Down Expand Up @@ -74,6 +96,45 @@ enum Commands {
},
}

impl Args {
async fn start(&self) {
match &self.cmd {
Commands::Client {
listen,
server_addr,
tls_name,
password,
} => {
run_client(
listen.clone(),
server_addr.clone(),
tls_name.clone(),
password.clone(),
self.opts.clone(),
)
.await
.expect("client exited");
}
Commands::Server {
listen,
server_addr,
tls_addr,
password,
} => {
run_server(
listen.clone(),
server_addr.clone(),
tls_addr.clone(),
password.clone(),
self.opts.clone(),
)
.await
.expect("server exited");
}
}
}
}

fn main() {
tracing_subscriber::registry()
.with(fmt::layer())
Expand All @@ -83,7 +144,7 @@ fn main() {
.from_env_lossy(),
)
.init();
let args = match arg::get_sip003_arg() {
let args = match sip003::get_sip003_arg() {
Some(a) => Arc::new(a),
None => Arc::new(Args::parse()),
};
Expand All @@ -97,7 +158,7 @@ fn main() {
.enable_timer()
.build()
.expect("unable to build monoio runtime");
rt.block_on(run(args_clone));
rt.block_on(args_clone.start());
});
threads.push(t);
}
Expand All @@ -107,64 +168,36 @@ fn main() {
}

fn get_parallelism(args: &Args) -> usize {
if let Some(n) = args.threads {
if let Some(n) = args.opts.threads {
return n as usize;
}
std::thread::available_parallelism()
.map(|n| n.get())
.unwrap_or(1)
}

async fn run(cli: Arc<Args>) {
match &cli.cmd {
Commands::Client {
listen,
server_addr,
tls_name,
password,
} => {
run_client(
listen.clone(),
server_addr.clone(),
tls_name.clone(),
password.clone(),
)
.await
.expect("client exited");
}
Commands::Server {
listen,
server_addr,
tls_addr,
password,
} => {
run_server(
listen.clone(),
server_addr.clone(),
tls_addr.clone(),
password.clone(),
)
.await
.expect("server exited");
}
}
}

async fn run_client(
listen: String,
server_addr: String,
tls_name: String,
password: String,
opts: Opts,
) -> anyhow::Result<()> {
info!("Client is running!\nListen address: {listen}\nRemote address: {server_addr}\nTLS server name: {tls_name}");
let shadow_client = Rc::new(ShadowTlsClient::new(&tls_name, server_addr, password)?);
info!("Client is running!\nListen address: {listen}\nRemote address: {server_addr}\nTLS server name: {tls_name}\nOpts: {opts}");
let nodelay = opts.nodelay;
let shadow_client = Rc::new(ShadowTlsClient::new(
&tls_name,
server_addr,
password,
opts,
)?);
let listener = TcpListener::bind(&listen)?;
loop {
match listener.accept().await {
Ok((mut conn, addr)) => {
info!("Accepted a connection from {addr}");
let client = shadow_client.clone();
set_tcp_keepalive(&mut conn);
mod_tcp_conn(&mut conn, true, nodelay);
monoio::spawn(async move { client.relay(conn, addr).await });
}
Err(e) => {
Expand All @@ -179,15 +212,17 @@ async fn run_server(
server_addr: String,
tls_addr: String,
password: String,
opts: Opts,
) -> anyhow::Result<()> {
info!("Server is running!\nListen address: {listen}\nRemote address: {server_addr}\nTLS server address: {tls_addr}");
let shadow_server = Rc::new(ShadowTlsServer::new(tls_addr, server_addr, password));
info!("Server is running!\nListen address: {listen}\nRemote address: {server_addr}\nTLS server address: {tls_addr}\nOpts: {opts}");
let nodelay = opts.nodelay;
let shadow_server = Rc::new(ShadowTlsServer::new(tls_addr, server_addr, password, opts));
let listener = TcpListener::bind(&listen)?;
loop {
match listener.accept().await {
Ok((mut conn, addr)) => {
info!("Accepted a connection from {addr}");
set_tcp_keepalive(&mut conn);
mod_tcp_conn(&mut conn, true, nodelay);
let server = shadow_server.clone();
monoio::spawn(async move { server.relay(conn).await });
}
Expand Down
13 changes: 8 additions & 5 deletions src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,24 +9,27 @@ use monoio::{
use crate::{
stream::{HashedWriteStream, HmacHandler},
util::{
copy_until_eof, copy_with_application_data, copy_without_application_data,
set_tcp_keepalive, ErrGroup, FirstRetGroup, APPLICATION_DATA,
copy_until_eof, copy_with_application_data, copy_without_application_data, mod_tcp_conn,
ErrGroup, FirstRetGroup, APPLICATION_DATA,
},
Opts,
};

/// ShadowTlsServer.
pub struct ShadowTlsServer<RA, RB> {
handshake_address: RA,
data_address: RB,
password: String,
opts: Opts,
}

impl<HA, DA> ShadowTlsServer<HA, DA> {
pub fn new(handshake_address: HA, data_address: DA, password: String) -> Self {
pub fn new(handshake_address: HA, data_address: DA, password: String, opts: Opts) -> Self {
Self {
handshake_address,
data_address,
password,
opts,
}
}
}
Expand All @@ -38,7 +41,7 @@ where
{
pub async fn relay(&self, in_stream: TcpStream) -> anyhow::Result<()> {
let mut out_stream = TcpStream::connect(&self.handshake_address).await?;
set_tcp_keepalive(&mut out_stream);
mod_tcp_conn(&mut out_stream, true, self.opts.nodelay);
tracing::debug!("handshake server connected");
let mut in_stream = HashedWriteStream::new(in_stream, self.password.as_bytes())?;
let mut hmac = in_stream.hmac_handler();
Expand All @@ -62,7 +65,7 @@ where
let _ = out_stream.shutdown().await;
drop(out_stream);
let mut data_stream = TcpStream::connect(&self.data_address).await?;
set_tcp_keepalive(&mut data_stream);
mod_tcp_conn(&mut data_stream, true, self.opts.nodelay);
tracing::debug!("data server connected, start relay");
let (mut data_r, mut data_w) = data_stream.split();
let (result, _) = data_w.write(data_left).await;
Expand Down
Loading

0 comments on commit 109c654

Please sign in to comment.