Skip to content

Commit

Permalink
优化环路检测逻辑
Browse files Browse the repository at this point in the history
  • Loading branch information
vnt-dev committed Apr 29, 2024
1 parent 18fcee5 commit bbfa255
Show file tree
Hide file tree
Showing 4 changed files with 139 additions and 20 deletions.
47 changes: 29 additions & 18 deletions vnt/src/channel/punch.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use std::collections::HashMap;
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6};
use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6};
use std::str::FromStr;
use std::time::Duration;
use std::{io, thread};
Expand All @@ -11,6 +11,7 @@ use rand::Rng;
use crate::channel::context::Context;
use crate::channel::sender::AcceptSocketSender;
use crate::external_route::ExternalRoute;
use crate::nat::NatTest;

#[derive(Copy, Clone, Eq, PartialEq, Debug)]
pub enum PunchModel {
Expand Down Expand Up @@ -187,6 +188,7 @@ pub struct Punch {
is_tcp: bool,
tcp_socket_sender: AcceptSocketSender<(TcpStream, SocketAddr, Option<Vec<u8>>)>,
external_route: ExternalRoute,
nat_test: NatTest,
}

impl Punch {
Expand All @@ -196,6 +198,7 @@ impl Punch {
is_tcp: bool,
tcp_socket_sender: AcceptSocketSender<(TcpStream, SocketAddr, Option<Vec<u8>>)>,
external_route: ExternalRoute,
nat_test: NatTest,
) -> Self {
let mut port_vec: Vec<u16> = (1..65535).collect();
port_vec.push(65535);
Expand All @@ -209,19 +212,16 @@ impl Punch {
is_tcp,
tcp_socket_sender,
external_route,
nat_test,
}
}
}

impl Punch {
fn connect_tcp(&self, buf: &[u8], addr: SocketAddr) -> bool {
if let IpAddr::V4(ip) = addr.ip() {
if self.external_route.route(&ip).is_some() {
log::warn!("跳过打洞目标{},防止环路 ", addr);
return false;
}
if self.nat_test.is_local_address(true, addr) {
return false;
}

// mio是非阻塞的,不能立马判断是否能连接成功,所以用标准库的tcp
match std::net::TcpStream::connect_timeout(&addr, Duration::from_millis(100)) {
Ok(tcp_stream) => {
Expand All @@ -243,13 +243,26 @@ impl Punch {
&mut self,
buf: &[u8],
id: Ipv4Addr,
nat_info: NatInfo,
mut nat_info: NatInfo,
punch_tcp: bool,
) -> io::Result<()> {
if self.context.route_table.no_need_punch(&id) {
log::info!("已打洞成功,无需打洞:{:?}", id);
return Ok(());
}
nat_info
.public_ips
.retain(|ip| self.external_route.route(&ip).is_none());
nat_info
.local_ipv4
.filter(|ip| self.external_route.route(&ip).is_none());
nat_info.ipv6.filter(|ip| {
if let Some(ip) = ip.to_ipv4_mapped() {
self.external_route.route(&ip).is_none()
} else {
true
}
});
if punch_tcp && self.is_tcp && nat_info.tcp_port != 0 {
//向tcp发起连接
if let Some(ipv6_addr) = nat_info.local_tcp_ipv6addr() {
Expand All @@ -274,23 +287,21 @@ impl Punch {
let channel_num = self.context.channel_num();
for index in 0..channel_num {
if let Some(ipv4_addr) = nat_info.local_udp_ipv4addr(index) {
if let IpAddr::V4(ip) = ipv4_addr.ip() {
if self.external_route.route(&ip).is_some() {
log::warn!("跳过打洞目标{},防止环路", ipv4_addr);
continue;
}
if !self.nat_test.is_local_address(false, ipv4_addr) {
let _ = self.context.send_main_udp(index, buf, ipv4_addr);
}
let _ = self.context.send_main_udp(index, buf, ipv4_addr);
}
}

if self.punch_model != PunchModel::IPv4 {
for index in 0..channel_num {
if let Some(ipv6_addr) = nat_info.local_udp_ipv6addr(index) {
let rs = self.context.send_main_udp(index, buf, ipv6_addr);
log::info!("发送到ipv6地址:{:?},rs={:?}", ipv6_addr, rs);
if rs.is_ok() && self.punch_model == PunchModel::IPv6 {
return Ok(());
if !self.nat_test.is_local_address(false, ipv6_addr) {
let rs = self.context.send_main_udp(index, buf, ipv6_addr);
log::info!("发送到ipv6地址:{:?},rs={:?}", ipv6_addr, rs);
if rs.is_ok() && self.punch_model == PunchModel::IPv6 {
return Ok(());
}
}
}
}
Expand Down
1 change: 1 addition & 0 deletions vnt/src/core/conn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,7 @@ impl Vnt {
config.tcp,
tcp_socket_sender.clone(),
external_route.clone(),
nat_test.clone(),
);

#[cfg(not(target_os = "android"))]
Expand Down
41 changes: 41 additions & 0 deletions vnt/src/handle/recv_data/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ use crate::protocol::{
use crate::tun_tap_device::tun_create_helper::DeviceAdapter;
#[cfg(any(target_os = "windows", target_os = "linux", target_os = "macos"))]
use tun::device::IFace;

/// 处理来源于客户端的包
#[derive(Clone)]
pub struct ClientPacketHandler {
Expand Down Expand Up @@ -142,6 +143,31 @@ impl ClientPacketHandler {
//拦截不符合的目标
return Ok(());
}
match ipv4.protocol() {
ipv4::protocol::Protocol::Tcp => {
let payload = ipv4.payload();
if payload.len() < 20 {
return Ok(());
}
let destination_port =
u16::from_be_bytes(payload[2..4].try_into().unwrap());
if self.nat_test.is_local_tcp(real_dest, destination_port) {
return Ok(());
}
}
ipv4::protocol::Protocol::Udp => {
let payload = ipv4.payload();
if payload.len() < 8 {
return Ok(());
}
let destination_port =
u16::from_be_bytes(payload[2..4].try_into().unwrap());
if self.nat_test.is_local_udp(real_dest, destination_port) {
return Ok(());
}
}
_ => {}
}
#[cfg(feature = "ip_proxy")]
if let Some(ip_proxy_map) = &self.ip_proxy_map {
if ip_proxy_map.recv_handle(&mut ipv4, source, destination)? {
Expand Down Expand Up @@ -192,6 +218,15 @@ impl ClientPacketHandler {
if context.use_channel_type().is_only_relay() {
return Ok(());
}
//忽略掉来源于自己的包
if route_key.is_tcp() {}
if self
.nat_test
.is_local_address(route_key.is_tcp(), route_key.addr)
{
return Ok(());
}

//回应
net_packet.set_transport_protocol(control_packet::Protocol::PunchResponse.into());
net_packet.set_source(current_device.virtual_ip);
Expand All @@ -207,6 +242,12 @@ impl ClientPacketHandler {
if context.use_channel_type().is_only_relay() {
return Ok(());
}
if self
.nat_test
.is_local_address(route_key.is_tcp(), route_key.addr)
{
return Ok(());
}
let route = Route::from_default_rt(route_key, 1);
context.route_table.add_route_if_absent(source, route);
}
Expand Down
70 changes: 68 additions & 2 deletions vnt/src/nat/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use std::io;
use std::net::UdpSocket;
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
use std::net::{SocketAddr, UdpSocket};
use std::ops::Sub;
use std::sync::Arc;
use std::time::{Duration, Instant};
Expand All @@ -22,6 +22,7 @@ pub fn local_ipv4_() -> io::Result<Ipv4Addr> {
IpAddr::V6(_) => Ok(Ipv4Addr::UNSPECIFIED),
}
}

pub fn local_ipv4() -> Option<Ipv4Addr> {
match local_ipv4_() {
Ok(ipv4) => Some(ipv4),
Expand All @@ -41,6 +42,7 @@ pub fn local_ipv6_() -> io::Result<Ipv6Addr> {
IpAddr::V6(ip) => Ok(ip),
}
}

pub fn local_ipv6() -> Option<Ipv6Addr> {
match local_ipv6_() {
Ok(ipv6) => Some(ipv6),
Expand All @@ -56,6 +58,8 @@ pub struct NatTest {
stun_server: Vec<String>,
info: Arc<Mutex<NatInfo>>,
time: Arc<AtomicCell<Instant>>,
udp_ports: Vec<u16>,
tcp_port: u16,
}

impl From<NatType> for PunchNatType {
Expand Down Expand Up @@ -94,7 +98,7 @@ impl NatTest {
0,
local_ipv4,
ipv6,
udp_ports,
udp_ports.clone(),
tcp_port,
NatType::Cone,
);
Expand All @@ -105,6 +109,8 @@ impl NatTest {
time: Arc::new(AtomicCell::new(
Instant::now().sub(Duration::from_secs(100)),
)),
udp_ports,
tcp_port,
}
}
pub fn can_update(&self) -> bool {
Expand All @@ -116,6 +122,66 @@ impl NatTest {
pub fn nat_info(&self) -> NatInfo {
self.info.lock().clone()
}
pub fn is_local_udp(&self, ipv4: Ipv4Addr, port: u16) -> bool {
for x in &self.udp_ports {
if x == &port {
let guard = self.info.lock();
if let Some(ip) = guard.local_ipv4 {
if ipv4 == ip {
return true;
}
}
break;
}
}
false
}
pub fn is_local_tcp(&self, ipv4: Ipv4Addr, port: u16) -> bool {
if self.tcp_port == port {
let guard = self.info.lock();
if let Some(ip) = guard.local_ipv4 {
if ipv4 == ip {
return true;
}
}
}
false
}
pub fn is_local_address(&self, is_tcp: bool, addr: SocketAddr) -> bool {
let port = addr.port();
let check_ip = || {
let guard = self.info.lock();
match addr.ip() {
IpAddr::V4(ipv4) => {
if let Some(ip) = guard.local_ipv4 {
if ipv4 == ip {
return true;
}
}
}
IpAddr::V6(ipv6) => {
if let Some(ip) = guard.ipv6 {
if ipv6 == ip {
return true;
}
}
}
}
false
};
if is_tcp {
if self.tcp_port == port {
return check_ip();
}
} else {
for x in &self.udp_ports {
if x == &port {
return check_ip();
}
}
}
false
}
pub fn update_addr(&self, index: usize, ip: Ipv4Addr, port: u16) {
let mut guard = self.info.lock();
guard.update_addr(index, ip, port)
Expand Down

0 comments on commit bbfa255

Please sign in to comment.