From 6bb1af4becf5146f77cb0b9f5cac35d09e4f5935 Mon Sep 17 00:00:00 2001 From: ihciah Date: Sat, 4 Feb 2023 00:04:10 +0800 Subject: [PATCH] fix: client args parse error --- Cargo.lock | 2 +- Cargo.toml | 2 +- src/client.rs | 60 +++++++++++++++++++++++++++++++++------------------ src/main.rs | 13 ++++++----- src/sip003.rs | 2 +- 5 files changed, 48 insertions(+), 31 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 49ac54b..4d720d1 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -595,7 +595,7 @@ dependencies = [ [[package]] name = "shadow-tls" -version = "0.2.9" +version = "0.2.10" dependencies = [ "anyhow", "byteorder", diff --git a/Cargo.toml b/Cargo.toml index 59d40e1..3d7556b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -7,7 +7,7 @@ license = "MIT/Apache-2.0" name = "shadow-tls" readme = "README.md" repository = "https://github.com/ihciah/shadow-tls" -version = "0.2.9" +version = "0.2.10" [dependencies] monoio = {version = "0.0.9"} diff --git a/src/client.rs b/src/client.rs index 3b3733c..e1eb6a8 100644 --- a/src/client.rs +++ b/src/client.rs @@ -14,12 +14,47 @@ use crate::{ /// ShadowTlsClient. pub struct ShadowTlsClient { tls_connector: TlsConnector, - server_names: Box<[ServerName]>, + server_names: TlsNames, address: A, password: String, opts: Opts, } +#[derive(Clone, Debug, PartialEq)] +pub struct TlsNames(Vec); + +impl TlsNames { + pub fn random_choose(&self) -> &ServerName { + self.0.choose(&mut rand::thread_rng()).unwrap() + } +} + +impl TryFrom<&str> for TlsNames { + type Error = anyhow::Error; + + fn try_from(value: &str) -> Result { + let v: Result, _> = value.trim().split(';').map(ServerName::try_from).collect(); + let v = v.map_err(Into::into).and_then(|v| { + if v.is_empty() { + Err(anyhow::anyhow!("empty tls names")) + } else { + Ok(v) + } + })?; + Ok(Self(v)) + } +} + +impl std::fmt::Display for TlsNames { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{:?}", self.0) + } +} + +pub fn parse_client_names(addrs: &str) -> anyhow::Result { + TlsNames::try_from(addrs) +} + #[derive(Default, Debug)] pub struct TlsExtConfig { alpn: Option>>, @@ -34,7 +69,7 @@ impl TlsExtConfig { impl ShadowTlsClient { /// Create new ShadowTlsClient. pub fn new( - server_names: std::slice::Iter, + server_names: TlsNames, address: A, password: String, opts: Opts, @@ -60,15 +95,10 @@ impl ShadowTlsClient { } let tls_connector = TlsConnector::from(tls_config); - let mut sni_list = Vec::new(); - for s in server_names { - let sni = ServerName::try_from(s.as_str())?; - sni_list.push(sni); - } Ok(Self { tls_connector, - server_names: sni_list.into_boxed_slice(), + server_names, address, password, opts, @@ -107,11 +137,7 @@ impl ShadowTlsClient { mod_tcp_conn(&mut stream, true, !self.opts.disable_nodelay); tracing::debug!("tcp connected, start handshaking"); let stream = HashedReadStream::new(stream, self.password.as_bytes())?; - let endpoint = self - .server_names - .choose(&mut rand::thread_rng()) - .expect("empty endpoints set") - .clone(); + let endpoint = self.server_names.random_choose().clone(); let tls_stream = self.tls_connector.connect(endpoint, stream).await?; let (io, _) = tls_stream.into_parts(); let hash = io.hash(); @@ -120,11 +146,3 @@ impl ShadowTlsClient { Ok((stream, hash)) } } - -pub fn parse_client_addrs(addrs: &str) -> anyhow::Result> { - Ok(addrs - .trim() - .split(';') - .map(|a| a.trim().to_string()) - .collect()) -} diff --git a/src/main.rs b/src/main.rs index 64889d1..8972a77 100644 --- a/src/main.rs +++ b/src/main.rs @@ -16,7 +16,7 @@ use tracing::{error, info}; use tracing_subscriber::{filter::LevelFilter, fmt, prelude::*, EnvFilter}; use crate::{ - client::{parse_client_addrs, ShadowTlsClient, TlsExtConfig}, + client::{parse_client_names, ShadowTlsClient, TlsExtConfig, TlsNames}, server::{parse_server_addrs, ShadowTlsServer, TlsAddrs}, util::mod_tcp_conn, }; @@ -75,9 +75,9 @@ enum Commands { #[clap( long = "sni", help = "TLS handshake SNI(like cloud.tencent.com, captive.apple.com;cloud.tencent.com)", - value_parser = parse_client_addrs + value_parser = parse_client_names )] - tls_names: Vec, + tls_names: TlsNames, #[clap(long = "password", help = "Password")] password: String, #[clap( @@ -198,16 +198,15 @@ fn get_parallelism(args: &Args) -> usize { async fn run_client( listen: String, server_addr: String, - tls_names: Vec, + tls_names: TlsNames, password: String, opts: Opts, tls_ext: TlsExtConfig, ) -> anyhow::Result<()> { - assert!(!tls_names.is_empty(), "empty tls name is not allowed"); - info!("Client is running!\nListen address: {listen}\nRemote address: {server_addr}\nTLS server names: {tls_names:?}\nOpts: {opts}"); + info!("Client is running!\nListen address: {listen}\nRemote address: {server_addr}\nTLS server names: {tls_names}\nOpts: {opts}"); let nodelay = !opts.disable_nodelay; let shadow_client = Rc::new(ShadowTlsClient::new( - tls_names.iter(), + tls_names, server_addr, password, opts, diff --git a/src/sip003.rs b/src/sip003.rs index a2a24aa..0cb1913 100644 --- a/src/sip003.rs +++ b/src/sip003.rs @@ -66,7 +66,7 @@ pub(crate) fn get_sip003_arg() -> Option { let host = opts .get("host") .expect("need host param(like host=www.baidu.com)"); - let hosts = crate::client::parse_client_addrs(host).expect("tls names parse failed"); + let hosts = crate::client::parse_client_names(host).expect("tls names parse failed"); Args { cmd: crate::Commands::Client { listen: format!("{ss_local_host}:{ss_local_port}"),