Skip to content

Commit

Permalink
refactor: improve socket rebind
Browse files Browse the repository at this point in the history
  • Loading branch information
Itsusinn committed Jun 27, 2024
1 parent 79bffe3 commit e3a25c0
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 34 deletions.
46 changes: 16 additions & 30 deletions tuic-client/src/connection/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ use quinn::{
use register_count::Counter;
use rustls::ClientConfig as RustlsClientConfig;
use std::{
net::{Ipv4Addr, Ipv6Addr, SocketAddr, UdpSocket},
net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, UdpSocket},
sync::{atomic::AtomicU32, Arc},
time::Duration,
};
Expand Down Expand Up @@ -49,7 +49,7 @@ pub struct Connection {
}

impl Connection {
pub fn set_config(cfg: Relay) -> Result<(), Error> {
pub async fn set_config(cfg: Relay) -> Result<(), Error> {
let certs = utils::load_certs(cfg.certificates, cfg.disable_native_certs)?;

let mut crypto =
Expand Down Expand Up @@ -88,11 +88,18 @@ impl Connection {
config.transport_config(Arc::new(tp_cfg));

// Try to create an IPv4 socket as the placeholder first, if it fails, try IPv6.
let socket = UdpSocket::bind(SocketAddr::from((Ipv4Addr::UNSPECIFIED, 0)))
.or_else(|err| {
UdpSocket::bind(SocketAddr::from((Ipv6Addr::UNSPECIFIED, 0))).map_err(|_| err)
})
.map_err(|err| Error::Socket("failed to create endpoint UDP socket", err))?;
let server = ServerAddr::new(cfg.server.0, cfg.server.1, cfg.ip);
let server_ip: Option<IpAddr> = match server.resolve().await?.next() {
Some(SocketAddr::V4(v4)) => Some(v4.ip().to_owned().into()),
Some(SocketAddr::V6(v6)) => Some(v6.ip().to_owned().into()),
None => None,
};
let server_ip = server_ip.expect("Server ip not found");
let socket = if server_ip.is_ipv4() {
UdpSocket::bind(SocketAddr::from((Ipv4Addr::UNSPECIFIED, 0)))?
} else {
UdpSocket::bind(SocketAddr::from((Ipv6Addr::UNSPECIFIED, 0)))?
};

let mut ep = QuinnEndpoint::new(
EndpointConfig::default(),
Expand All @@ -105,7 +112,7 @@ impl Connection {

let ep = Endpoint {
ep,
server: ServerAddr::new(cfg.server.0, cfg.server.1, cfg.ip),
server,
uuid: cfg.uuid,
password: cfg.password,
udp_relay_mode: cfg.udp_relay_mode,
Expand All @@ -125,7 +132,7 @@ impl Connection {
Ok(())
}

pub async fn get() -> Result<Connection, Error> {
pub async fn get_conn() -> Result<Connection, Error> {
let try_init_conn = async {
ENDPOINT
.get()
Expand Down Expand Up @@ -259,27 +266,6 @@ impl Endpoint {

for addr in self.server.resolve().await? {
let connect_to = async {
let match_ipv4 =
addr.is_ipv4() && self.ep.local_addr().map_or(false, |addr| addr.is_ipv4());
let match_ipv6 =
addr.is_ipv6() && self.ep.local_addr().map_or(false, |addr| addr.is_ipv6());

if !match_ipv4 && !match_ipv6 {
let bind_addr = if addr.is_ipv4() {
SocketAddr::from((Ipv4Addr::UNSPECIFIED, 0))
} else {
SocketAddr::from((Ipv6Addr::UNSPECIFIED, 0))
};

self.ep
.rebind(UdpSocket::bind(bind_addr).map_err(|err| {
Error::Socket("failed to create endpoint UDP socket", err)
})?)
.map_err(|err| {
Error::Socket("failed to rebind endpoint UDP socket", err)
})?;
}

let conn = self.ep.connect(addr, self.server.server_name())?;
let (conn, zero_rtt_accepted) = if self.zero_rtt_handshake {
match conn.into_0rtt() {
Expand Down
2 changes: 1 addition & 1 deletion tuic-client/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ async fn main() {
.format_target(false)
.init();

match Connection::set_config(cfg.relay) {
match Connection::set_config(cfg.relay).await {
Ok(()) => {}
Err(err) => {
eprintln!("{err}");
Expand Down
6 changes: 3 additions & 3 deletions tuic-client/src/socks5/handle_task.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ impl Server {
Address::SocketAddress(addr) => TuicAddress::SocketAddress(addr),
};

match TuicConnection::get().await {
match TuicConnection::get_conn().await {
Ok(conn) => conn.packet(pkt, target_addr, assoc_id).await,
Err(err) => Err(err),
}
Expand Down Expand Up @@ -101,7 +101,7 @@ impl Server {
.remove(&assoc_id)
.unwrap();

let res = match TuicConnection::get().await {
let res = match TuicConnection::get_conn().await {
Ok(conn) => conn.dissociate(assoc_id).await,
Err(err) => Err(err),
};
Expand Down Expand Up @@ -151,7 +151,7 @@ impl Server {
Address::SocketAddress(addr) => TuicAddress::SocketAddress(addr),
};

let relay = match TuicConnection::get().await {
let relay = match TuicConnection::get_conn().await {
Ok(conn) => conn.connect(target_addr.clone()).await,
Err(err) => Err(err),
};
Expand Down

0 comments on commit e3a25c0

Please sign in to comment.