Skip to content

Commit

Permalink
convert libc sockaddr to SocketAddr explicitly
Browse files Browse the repository at this point in the history
SocketAddr may not have the same layout as libc, rust-lang/rust#78802

- fixes #462
- ref tokio-rs/mio#1388, rust-lang/socket2#120
  • Loading branch information
zonyitoo committed Mar 22, 2021
1 parent bf409c2 commit 772f3c1
Show file tree
Hide file tree
Showing 6 changed files with 156 additions and 189 deletions.
70 changes: 38 additions & 32 deletions crates/shadowsocks-service/src/local/redir/sys/unix/bsd_pf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,13 @@ use std::{
ffi::CString,
io::{self, Error, ErrorKind},
mem,
net::{SocketAddr, SocketAddrV4, SocketAddrV6},
net::SocketAddr,
ptr,
};

use lazy_static::lazy_static;
use log::trace;
use socket2::Protocol;

use crate::sys::sockaddr_to_std;
use socket2::{Protocol, SockAddr};

mod ffi {
use cfg_if::cfg_if;
Expand Down Expand Up @@ -161,7 +159,8 @@ impl PacketFilter {
SocketAddr::V4(ref v4) => {
pnl.af = libc::AF_INET as libc::sa_family_t;

let sockaddr: *const libc::sockaddr_in = v4 as *const SocketAddrV4 as *const _;
let sockaddr = SockAddr::from(*v4);
let sockaddr = sockaddr.as_ptr() as *const libc::sockaddr_in;

let addr: *const libc::in_addr = &((*sockaddr).sin_addr) as *const _;
let port: libc::in_port_t = (*sockaddr).sin_port;
Expand All @@ -172,7 +171,8 @@ impl PacketFilter {
SocketAddr::V6(ref v6) => {
pnl.af = libc::AF_INET6 as libc::sa_family_t;

let sockaddr: *const libc::sockaddr_in6 = v6 as *const SocketAddrV6 as *const _;
let sockaddr = SockAddr::from(*v6);
let sockaddr = sockaddr.as_ptr() as *const libc::sockaddr_in6;

let addr: *const libc::in6_addr = &((*sockaddr).sin6_addr) as *const _;
let port: libc::in_port_t = (*sockaddr).sin6_port;
Expand All @@ -188,7 +188,8 @@ impl PacketFilter {
return Err(Error::new(ErrorKind::InvalidInput, "client addr must be ipv4"));
}

let sockaddr: *const libc::sockaddr_in = v4 as *const SocketAddrV4 as *const _;
let sockaddr = SockAddr::from(*v4);
let sockaddr = sockaddr.as_ptr() as *const libc::sockaddr_in;

let addr: *const libc::in_addr = &((*sockaddr).sin_addr) as *const _;
let port: libc::in_port_t = (*sockaddr).sin_port;
Expand All @@ -201,7 +202,8 @@ impl PacketFilter {
return Err(Error::new(ErrorKind::InvalidInput, "client addr must be ipv6"));
}

let sockaddr: *const libc::sockaddr_in6 = v6 as *const SocketAddrV6 as *const _;
let sockaddr = SockAddr::from(*v6);
let sockaddr = sockaddr.as_ptr() as *const libc::sockaddr_in6;

let addr: *const libc::in6_addr = &((*sockaddr).sin6_addr) as *const _;
let port: libc::in_port_t = (*sockaddr).sin6_port;
Expand All @@ -222,31 +224,35 @@ impl PacketFilter {
return Err(nerr);
}

let mut dst_addr: libc::sockaddr_storage = mem::zeroed();

if pnl.af == libc::AF_INET as libc::sa_family_t {
let dst_addr: &mut libc::sockaddr_in = &mut *(&mut dst_addr as *mut _ as *mut _);
dst_addr.sin_family = pnl.af;
dst_addr.sin_port = pnl.rdport();
ptr::copy_nonoverlapping(
&pnl.rdaddr.pfa.v4,
&mut dst_addr.sin_addr,
mem::size_of_val(&pnl.rdaddr.pfa.v4),
);
} else if pnl.af == libc::AF_INET6 as libc::sa_family_t {
let dst_addr: &mut libc::sockaddr_in6 = &mut *(&mut dst_addr as *mut _ as *mut _);
dst_addr.sin6_family = pnl.af;
dst_addr.sin6_port = pnl.rdport();
ptr::copy_nonoverlapping(
&pnl.rdaddr.pfa.v6,
&mut dst_addr.sin6_addr,
mem::size_of_val(&pnl.rdaddr.pfa.v6),
);
} else {
unreachable!("sockaddr should be either ipv4 or ipv6");
}
let (_, dst_addr) = SockAddr::init(|dst_addr, addr_len| {
if pnl.af == libc::AF_INET as libc::sa_family_t {
let dst_addr: &mut libc::sockaddr_in = &mut *(dst_addr as *mut _);
dst_addr.sin_family = pnl.af;
dst_addr.sin_port = pnl.rdport();
ptr::copy_nonoverlapping(
&pnl.rdaddr.pfa.v4,
&mut dst_addr.sin_addr,
mem::size_of_val(&pnl.rdaddr.pfa.v4),
);
*addr_len = mem::size_of_val(&pnl.rdaddr.pfa.v4) as libc::socklen_t;
} else if pnl.af == libc::AF_INET6 as libc::sa_family_t {
let dst_addr: &mut libc::sockaddr_in6 = &mut *(dst_addr as *mut _);
dst_addr.sin6_family = pnl.af;
dst_addr.sin6_port = pnl.rdport();
ptr::copy_nonoverlapping(
&pnl.rdaddr.pfa.v6,
&mut dst_addr.sin6_addr,
mem::size_of_val(&pnl.rdaddr.pfa.v6),
);
*addr_len = mem::size_of_val(&pnl.rdaddr.pfa.v6) as libc::socklen_t;
} else {
unreachable!("sockaddr should be either ipv4 or ipv6");
}

Ok(())
})?;

sockaddr_to_std(&dst_addr)
Ok(dst_addr.as_socket().expect("SocketAddr"))
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@ use std::{
};

use async_trait::async_trait;
use socket2::SockAddr;
use tokio::net::{TcpListener, TcpSocket, TcpStream};

use crate::{
config::RedirType,
local::redir::redir_ext::{TcpListenerRedirExt, TcpStreamRedirExt},
sys::sockaddr_to_std,
};

#[async_trait]
Expand Down Expand Up @@ -51,41 +51,41 @@ fn get_original_destination_addr(s: &TcpStream) -> io::Result<SocketAddr> {
let fd = s.as_raw_fd();

unsafe {
let mut target_addr: libc::sockaddr_storage = mem::zeroed();
let mut target_addr_len = mem::size_of_val(&target_addr) as libc::socklen_t;

match s.local_addr()? {
SocketAddr::V4(..) => {
let ret = libc::getsockopt(
fd,
libc::SOL_IP,
libc::SO_ORIGINAL_DST,
&mut target_addr as *mut _ as *mut _,
&mut target_addr_len,
);
if ret != 0 {
let err = Error::last_os_error();
return Err(err);
let (_, target_addr) = SockAddr::init(|target_addr, target_addr_len| {
match s.local_addr()? {
SocketAddr::V4(..) => {
let ret = libc::getsockopt(
fd,
libc::SOL_IP,
libc::SO_ORIGINAL_DST,
target_addr as *mut _,
target_addr_len, // libc::socklen_t
);
if ret != 0 {
let err = Error::last_os_error();
return Err(err);
}
}
}
SocketAddr::V6(..) => {
let ret = libc::getsockopt(
fd,
libc::SOL_IPV6,
libc::IP6T_SO_ORIGINAL_DST,
&mut target_addr as *mut _ as *mut _,
&mut target_addr_len,
);
SocketAddr::V6(..) => {
let ret = libc::getsockopt(
fd,
libc::SOL_IPV6,
libc::IP6T_SO_ORIGINAL_DST,
target_addr as *mut _,
target_addr_len, // libc::socklen_t
);

if ret != 0 {
let err = Error::last_os_error();
return Err(err);
if ret != 0 {
let err = Error::last_os_error();
return Err(err);
}
}
}
}
Ok(())
})?;

// Convert sockaddr_storage to SocketAddr
sockaddr_to_std(&target_addr)
Ok(target_addr.as_socket().expect("SocketAddr"))
}
}

Expand Down
80 changes: 40 additions & 40 deletions crates/shadowsocks-service/src/local/redir/udprelay/sys/unix/bsd.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ use tokio::io::unix::AsyncFd;
use crate::{
config::RedirType,
local::redir::redir_ext::{RedirSocketOpts, UdpSocketRedirExt},
sys::sockaddr_to_std,
};

pub fn check_support_tproxy() -> io::Result<()> {
Expand Down Expand Up @@ -186,45 +185,44 @@ fn set_socket_before_bind(addr: &SocketAddr, socket: &Socket) -> io::Result<()>
Ok(())
}

fn get_destination_addr(msg: &libc::msghdr) -> Option<libc::sockaddr_storage> {
fn get_destination_addr(msg: &libc::msghdr) -> io::Result<SocketAddr> {
// https://www.freebsd.org/cgi/man.cgi?ip(4)
//
// Called `recvmsg` with `IP_ORIGDSTADDR` set

unsafe {
let mut cmsg: *mut libc::cmsghdr = libc::CMSG_FIRSTHDR(msg);
while !cmsg.is_null() {
let rcmsg = &*cmsg;
match (rcmsg.cmsg_level, rcmsg.cmsg_type) {
(libc::IPPROTO_IP, libc::IP_ORIGDSTADDR) => {
let mut dst_addr: libc::sockaddr_storage = mem::zeroed();

ptr::copy(
libc::CMSG_DATA(cmsg),
&mut dst_addr as *mut _ as *mut _,
mem::size_of::<libc::sockaddr_in>(),
);

return Some(dst_addr);
let (_, addr) = SockAddr::init(|dst_addr, dst_addr_len| {
let mut cmsg: *mut libc::cmsghdr = libc::CMSG_FIRSTHDR(msg);
while !cmsg.is_null() {
let rcmsg = &*cmsg;
match (rcmsg.cmsg_level, rcmsg.cmsg_type) {
(libc::IPPROTO_IP, libc::IP_ORIGDSTADDR) => {
ptr::copy_nonoverlapping(libc::CMSG_DATA(cmsg), dst_addr, mem::size_of::<libc::sockaddr_in>());
*dst_addr_len = mem::size_of::<libc::sockaddr_in>() as libc::socklen_t;

return Ok(());
}
(libc::IPPROTO_IPV6, libc::IPV6_ORIGDSTADDR) => {
ptr::copy_nonoverlapping(
libc::CMSG_DATA(cmsg),
dst_addr as *mut _,
mem::size_of::<libc::sockaddr_in6>(),
);
*dst_addr_len = mem::size_of::<libc::sockaddr_in6>() as libc::socklen_t;

return Ok(());
}
_ => {}
}
(libc::IPPROTO_IPV6, libc::IPV6_ORIGDSTADDR) => {
let mut dst_addr: libc::sockaddr_storage = mem::zeroed();
cmsg = libc::CMSG_NXTHDR(msg, cmsg);
}

ptr::copy(
libc::CMSG_DATA(cmsg),
&mut dst_addr as *mut _ as *mut _,
mem::size_of::<libc::sockaddr_in6>(),
);
let err = Error::new(ErrorKind::InvalidData, "missing destination address in msghdr");
Err(err)
})?;

return Some(dst_addr);
}
_ => {}
}
cmsg = libc::CMSG_NXTHDR(msg, cmsg);
}
Ok(addr.as_socket().expect("SocketAddr"))
}

None
}

fn recv_dest_from(socket: &UdpSocket, buf: &mut [u8]) -> io::Result<(usize, SocketAddr, SocketAddr)> {
Expand Down Expand Up @@ -252,14 +250,16 @@ fn recv_dest_from(socket: &UdpSocket, buf: &mut [u8]) -> io::Result<(usize, Sock
return Err(Error::last_os_error());
}

let dst_addr = match get_destination_addr(&msg) {
None => {
let err = Error::new(ErrorKind::InvalidData, "missing destination address in msghdr");
return Err(err);
}
Some(d) => d,
};

Ok((ret as usize, sockaddr_to_std(&src_addr)?, sockaddr_to_std(&dst_addr)?))
let (_, src_saddr) = SockAddr::init(|a, l| {
ptr::copy_nonoverlapping(msg.msg_name, a, msg.msg_namelen as usize);
*l = msg.msg_namelen;
Ok(())
})?;

Ok((
ret as usize,
src_saddr.as_socket().expect("SocketAddr"),
get_destination_addr(&msg)?,
))
}
}
Loading

0 comments on commit 772f3c1

Please sign in to comment.