diff --git a/src/client.rs b/src/client.rs index 0fc2f2c..5c37f10 100644 --- a/src/client.rs +++ b/src/client.rs @@ -17,9 +17,24 @@ pub struct ShadowTlsClient { password: String, } +pub struct TlsExtConfig { + alpn: Vec>, +} + +impl TlsExtConfig { + pub fn new(alpn: Vec>) -> TlsExtConfig { + TlsExtConfig { alpn } + } +} + impl ShadowTlsClient { /// Create new ShadowTlsClient. - pub fn new(server_name: &str, address: A, password: String) -> anyhow::Result { + pub fn new( + server_name: &str, + address: A, + password: String, + tls_ext_config: TlsExtConfig, + ) -> anyhow::Result { let mut root_store = RootCertStore::empty(); root_store.add_server_trust_anchors(webpki_roots::TLS_SERVER_ROOTS.0.iter().map(|ta| { OwnedTrustAnchor::from_subject_spki_name_constraints( @@ -29,10 +44,14 @@ impl ShadowTlsClient { ) })); // TLS 1.2 and TLS 1.3 is enabled. - let tls_config = rustls::ClientConfig::builder() + let mut tls_config = rustls::ClientConfig::builder() .with_safe_defaults() .with_root_certificates(root_store) .with_no_client_auth(); + + // Set tls config + tls_config.alpn_protocols = tls_ext_config.alpn; + let tls_connector = TlsConnector::from(tls_config); let server_name = ServerName::try_from(server_name)?; Ok(Self { diff --git a/src/main.rs b/src/main.rs index 3cb9eb7..63329e0 100644 --- a/src/main.rs +++ b/src/main.rs @@ -14,7 +14,9 @@ use monoio::net::TcpListener; use tracing::{error, info}; use tracing_subscriber::{filter::LevelFilter, fmt, prelude::*, EnvFilter}; -use crate::{client::ShadowTlsClient, server::ShadowTlsServer, util::set_tcp_keepalive}; +use crate::{ + client::ShadowTlsClient, client::TlsExtConfig, server::ShadowTlsServer, util::set_tcp_keepalive, +}; #[derive(Parser, Debug)] #[clap( @@ -49,6 +51,12 @@ enum Commands { tls_name: String, #[clap(long = "password", help = "Password")] password: String, + #[clap( + long = "alpn", + default_value = "", + help = "Application-Layer Protocol Negotiation(like \"http/1.1\")" + )] + alpn: String, }, #[clap(about = "Run server side")] Server { @@ -118,12 +126,14 @@ async fn run(cli: Arc) { server_addr, tls_name, password, + alpn, } => { run_client( listen.clone(), server_addr.clone(), tls_name.clone(), password.clone(), + alpn.clone(), ) .await .expect("client exited"); @@ -151,9 +161,15 @@ async fn run_client( server_addr: String, tls_name: String, password: String, + alpn: String, ) -> anyhow::Result<()> { info!("Client is running!\nListen address: {listen}\nRemote address: {server_addr}\nTLS server name: {tls_name}"); - let shadow_client = Rc::new(ShadowTlsClient::new(&tls_name, server_addr, password)?); + let shadow_client = Rc::new(ShadowTlsClient::new( + &tls_name, + server_addr, + password, + TlsExtConfig::new(vec![alpn.into()]), + )?); let listener = TcpListener::bind(&listen)?; loop { match listener.accept().await {