Skip to content

Commit

Permalink
fix: client args parse error
Browse files Browse the repository at this point in the history
  • Loading branch information
ihciah committed Feb 3, 2023
1 parent 41568cd commit 6bb1af4
Show file tree
Hide file tree
Showing 5 changed files with 48 additions and 31 deletions.
2 changes: 1 addition & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ license = "MIT/Apache-2.0"
name = "shadow-tls"
readme = "README.md"
repository = "https://github.com/ihciah/shadow-tls"
version = "0.2.9"
version = "0.2.10"

[dependencies]
monoio = {version = "0.0.9"}
Expand Down
60 changes: 39 additions & 21 deletions src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,47 @@ use crate::{
/// ShadowTlsClient.
pub struct ShadowTlsClient<A> {
tls_connector: TlsConnector,
server_names: Box<[ServerName]>,
server_names: TlsNames,
address: A,
password: String,
opts: Opts,
}

#[derive(Clone, Debug, PartialEq)]
pub struct TlsNames(Vec<ServerName>);

impl TlsNames {
pub fn random_choose(&self) -> &ServerName {
self.0.choose(&mut rand::thread_rng()).unwrap()
}
}

impl TryFrom<&str> for TlsNames {
type Error = anyhow::Error;

fn try_from(value: &str) -> Result<Self, Self::Error> {
let v: Result<Vec<_>, _> = value.trim().split(';').map(ServerName::try_from).collect();
let v = v.map_err(Into::into).and_then(|v| {
if v.is_empty() {
Err(anyhow::anyhow!("empty tls names"))
} else {
Ok(v)
}
})?;
Ok(Self(v))
}
}

impl std::fmt::Display for TlsNames {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{:?}", self.0)
}
}

pub fn parse_client_names(addrs: &str) -> anyhow::Result<TlsNames> {
TlsNames::try_from(addrs)
}

#[derive(Default, Debug)]
pub struct TlsExtConfig {
alpn: Option<Vec<Vec<u8>>>,
Expand All @@ -34,7 +69,7 @@ impl TlsExtConfig {
impl<A> ShadowTlsClient<A> {
/// Create new ShadowTlsClient.
pub fn new(
server_names: std::slice::Iter<String>,
server_names: TlsNames,
address: A,
password: String,
opts: Opts,
Expand All @@ -60,15 +95,10 @@ impl<A> ShadowTlsClient<A> {
}

let tls_connector = TlsConnector::from(tls_config);
let mut sni_list = Vec::new();
for s in server_names {
let sni = ServerName::try_from(s.as_str())?;
sni_list.push(sni);
}

Ok(Self {
tls_connector,
server_names: sni_list.into_boxed_slice(),
server_names,
address,
password,
opts,
Expand Down Expand Up @@ -107,11 +137,7 @@ impl<A> ShadowTlsClient<A> {
mod_tcp_conn(&mut stream, true, !self.opts.disable_nodelay);
tracing::debug!("tcp connected, start handshaking");
let stream = HashedReadStream::new(stream, self.password.as_bytes())?;
let endpoint = self
.server_names
.choose(&mut rand::thread_rng())
.expect("empty endpoints set")
.clone();
let endpoint = self.server_names.random_choose().clone();
let tls_stream = self.tls_connector.connect(endpoint, stream).await?;
let (io, _) = tls_stream.into_parts();
let hash = io.hash();
Expand All @@ -120,11 +146,3 @@ impl<A> ShadowTlsClient<A> {
Ok((stream, hash))
}
}

pub fn parse_client_addrs(addrs: &str) -> anyhow::Result<Vec<String>> {
Ok(addrs
.trim()
.split(';')
.map(|a| a.trim().to_string())
.collect())
}
13 changes: 6 additions & 7 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ use tracing::{error, info};
use tracing_subscriber::{filter::LevelFilter, fmt, prelude::*, EnvFilter};

use crate::{
client::{parse_client_addrs, ShadowTlsClient, TlsExtConfig},
client::{parse_client_names, ShadowTlsClient, TlsExtConfig, TlsNames},
server::{parse_server_addrs, ShadowTlsServer, TlsAddrs},
util::mod_tcp_conn,
};
Expand Down Expand Up @@ -75,9 +75,9 @@ enum Commands {
#[clap(
long = "sni",
help = "TLS handshake SNI(like cloud.tencent.com, captive.apple.com;cloud.tencent.com)",
value_parser = parse_client_addrs
value_parser = parse_client_names
)]
tls_names: Vec<String>,
tls_names: TlsNames,
#[clap(long = "password", help = "Password")]
password: String,
#[clap(
Expand Down Expand Up @@ -198,16 +198,15 @@ fn get_parallelism(args: &Args) -> usize {
async fn run_client(
listen: String,
server_addr: String,
tls_names: Vec<String>,
tls_names: TlsNames,
password: String,
opts: Opts,
tls_ext: TlsExtConfig,
) -> anyhow::Result<()> {
assert!(!tls_names.is_empty(), "empty tls name is not allowed");
info!("Client is running!\nListen address: {listen}\nRemote address: {server_addr}\nTLS server names: {tls_names:?}\nOpts: {opts}");
info!("Client is running!\nListen address: {listen}\nRemote address: {server_addr}\nTLS server names: {tls_names}\nOpts: {opts}");
let nodelay = !opts.disable_nodelay;
let shadow_client = Rc::new(ShadowTlsClient::new(
tls_names.iter(),
tls_names,
server_addr,
password,
opts,
Expand Down
2 changes: 1 addition & 1 deletion src/sip003.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ pub(crate) fn get_sip003_arg() -> Option<Args> {
let host = opts
.get("host")
.expect("need host param(like host=www.baidu.com)");
let hosts = crate::client::parse_client_addrs(host).expect("tls names parse failed");
let hosts = crate::client::parse_client_names(host).expect("tls names parse failed");
Args {
cmd: crate::Commands::Client {
listen: format!("{ss_local_host}:{ss_local_port}"),
Expand Down

0 comments on commit 6bb1af4

Please sign in to comment.