Skip to content

Commit

Permalink
split DeviceWriter&DeviceReader
Browse files Browse the repository at this point in the history
  • Loading branch information
vnt-dev committed Aug 19, 2024
1 parent cf1468b commit dd8711d
Show file tree
Hide file tree
Showing 5 changed files with 148 additions and 38 deletions.
1 change: 1 addition & 0 deletions .cargo/config.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
protocol = "sparse"

[build]
# target = ["x86_64-unknown-linux-musl"]
# target = ["x86_64-unknown-linux-gnu"]
# target = ["aarch64-linux-android"]
# target = ["aarch64-apple-ios"]
Expand Down
4 changes: 2 additions & 2 deletions src/async/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,12 @@ pub use codec::TunPacketCodec;
#[cfg(unix)]
mod unix_device;
#[cfg(unix)]
pub use unix_device::AsyncDevice;
pub use unix_device::{AsyncDevice, DeviceReader, DeviceWriter};

#[cfg(target_os = "windows")]
mod win_device;
#[cfg(target_os = "windows")]
pub use win_device::AsyncDevice;
pub use win_device::{AsyncDevice, DeviceReader, DeviceWriter};

/// Create a TUN device with the given name.
pub fn create_as_async(configuration: &Configuration) -> Result<AsyncDevice, error::Error> {
Expand Down
91 changes: 91 additions & 0 deletions src/async/unix_device.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ use tokio_util::codec::Framed;

use super::TunPacketCodec;
use crate::device::AbstractDevice;
use crate::platform::posix::{Reader, Writer};
use crate::platform::Device;

/// An async TUN device wrapper around a TUN device.
Expand Down Expand Up @@ -62,6 +63,11 @@ impl AsyncDevice {
// associate mtu with the capacity of ReadBuf
Framed::with_capacity(self, codec, mtu as usize)
}
pub fn split(self) -> std::io::Result<(DeviceWriter, DeviceReader)> {
let device = self.inner.into_inner();
let (reader, writer) = device.split();
Ok((DeviceWriter::new(writer)?, DeviceReader::new(reader)?))
}

/// Recv a packet from tun device
pub async fn recv(&self, buf: &mut [u8]) -> std::io::Result<usize> {
Expand Down Expand Up @@ -146,3 +152,88 @@ impl AsyncWrite for AsyncDevice {
true
}
}
pub struct DeviceReader {
inner: AsyncFd<Reader>,
}
impl DeviceReader {
fn new(reader: Reader) -> std::io::Result<Self> {
Ok(Self {
inner: AsyncFd::new(reader)?,
})
}
}
impl AsyncRead for DeviceReader {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf,
) -> Poll<std::io::Result<()>> {
loop {
let mut guard = ready!(self.inner.poll_read_ready_mut(cx))?;
let rbuf = buf.initialize_unfilled();
match guard.try_io(|inner| inner.get_mut().read(rbuf)) {
Ok(res) => return Poll::Ready(res.map(|n| buf.advance(n))),
Err(_wb) => continue,
}
}
}
}

pub struct DeviceWriter {
inner: AsyncFd<Writer>,
}
impl DeviceWriter {
fn new(writer: Writer) -> std::io::Result<Self> {
Ok(Self {
inner: AsyncFd::new(writer)?,
})
}
}

impl AsyncWrite for DeviceWriter {
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<std::io::Result<usize>> {
loop {
let mut guard = ready!(self.inner.poll_write_ready_mut(cx))?;
match guard.try_io(|inner| inner.get_mut().write(buf)) {
Ok(res) => return Poll::Ready(res),
Err(_wb) => continue,
}
}
}

fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
loop {
let mut guard = ready!(self.inner.poll_write_ready_mut(cx))?;
match guard.try_io(|inner| inner.get_mut().flush()) {
Ok(res) => return Poll::Ready(res),
Err(_wb) => continue,
}
}
}

fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
Poll::Ready(Ok(()))
}

fn poll_write_vectored(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
bufs: &[IoSlice<'_>],
) -> Poll<std::io::Result<usize>> {
loop {
let mut guard = ready!(self.inner.poll_write_ready_mut(cx))?;
match guard.try_io(|inner| inner.get_mut().write_vectored(bufs)) {
Ok(res) => return Poll::Ready(res),
Err(_wb) => continue,
}
}
}

fn is_write_vectored(&self) -> bool {
true
}
}
75 changes: 46 additions & 29 deletions src/async/win_device.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,19 @@ use core::task::{Context, Poll};
use std::io;
use std::io::Error;

use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use tokio_util::codec::Framed;

use super::TunPacketCodec;
use crate::device::AbstractDevice;
use crate::platform::Device;
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use tokio::sync::mpsc::error::TrySendError;
use tokio_util::codec::Framed;
use wintun::Packet;

/// An async TUN device wrapper around a TUN device.
pub struct AsyncDevice {
inner: Device,
session: WinSession,
session_reader: DeviceReader,
session_writer: DeviceWriter,
}

/// Returns a shared reference to the underlying Device object.
Expand All @@ -49,10 +51,12 @@ impl core::ops::DerefMut for AsyncDevice {
impl AsyncDevice {
/// Create a new `AsyncDevice` wrapping around a `Device`.
pub fn new(device: Device) -> io::Result<AsyncDevice> {
let session = WinSession::new(device.tun.get_session())?;
let session_reader = DeviceReader::new(device.tun.get_session())?;
let session_writer = DeviceWriter::new(device.tun.get_session())?;
Ok(AsyncDevice {
inner: device,
session,
session_reader,
session_writer,
})
}

Expand All @@ -63,6 +67,9 @@ impl AsyncDevice {
// guarantee to avoid the mtu of wintun may far away larger than the default provided capacity of ReadBuf of Framed
Framed::with_capacity(self, codec, mtu as usize)
}
pub fn split(self) -> io::Result<(DeviceWriter, DeviceReader)> {
Ok((self.session_writer, self.session_reader))
}

/// Recv a packet from tun device - Not implemented for windows
pub async fn recv(&self, _buf: &mut [u8]) -> std::io::Result<usize> {
Expand All @@ -81,7 +88,7 @@ impl AsyncRead for AsyncDevice {
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
Pin::new(&mut self.session).poll_read(cx, buf)
Pin::new(&mut self.session_reader).poll_read(cx, buf)
}
}

Expand All @@ -91,34 +98,38 @@ impl AsyncWrite for AsyncDevice {
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<Result<usize, Error>> {
Pin::new(&mut self.session).poll_write(cx, buf)
Pin::new(&mut self.session_writer).poll_write(cx, buf)
}

fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
Pin::new(&mut self.session).poll_flush(cx)
Pin::new(&mut self.session_writer).poll_flush(cx)
}

fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
Pin::new(&mut self.session).poll_shutdown(cx)
Pin::new(&mut self.session_writer).poll_shutdown(cx)
}
}

struct WinSession {
session: std::sync::Arc<wintun::Session>,
receiver: tokio::sync::mpsc::UnboundedReceiver<Vec<u8>>,
pub struct DeviceReader {
receiver: tokio::sync::mpsc::Receiver<Packet>,
_task: std::thread::JoinHandle<()>,
}

impl WinSession {
fn new(session: std::sync::Arc<wintun::Session>) -> Result<WinSession, io::Error> {
let session_reader = session.clone();
let (receiver_tx, receiver_rx) = tokio::sync::mpsc::unbounded_channel::<Vec<u8>>();
impl DeviceReader {
fn new(session: std::sync::Arc<wintun::Session>) -> Result<DeviceReader, io::Error> {
let (receiver_tx, receiver_rx) = tokio::sync::mpsc::channel(1024);
let task = std::thread::spawn(move || loop {
match session_reader.receive_blocking() {
match session.receive_blocking() {
Ok(packet) => {
if let Err(err) = receiver_tx.send(packet.bytes().to_vec()) {
log::error!("{}", err);
break;
if let Err(err) = receiver_tx.try_send(packet) {
match err {
TrySendError::Full(_) => {
log::error!("receiver_tx Full");
continue;
}
TrySendError::Closed(_) => {
log::error!("receiver_tx Closed");
break;
}
}
}
}
Err(err) => {
Expand All @@ -127,32 +138,38 @@ impl WinSession {
}
}
});

Ok(WinSession {
session,
Ok(DeviceReader {
receiver: receiver_rx,
_task: task,
})
}
}
pub struct DeviceWriter {
session: std::sync::Arc<wintun::Session>,
}
impl DeviceWriter {
fn new(session: std::sync::Arc<wintun::Session>) -> Result<DeviceWriter, io::Error> {
Ok(Self { session })
}
}

impl AsyncRead for WinSession {
impl AsyncRead for DeviceReader {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
match std::task::ready!(self.receiver.poll_recv(cx)) {
Some(bytes) => {
buf.put_slice(&bytes);
buf.put_slice(bytes.bytes());
std::task::Poll::Ready(Ok(()))
}
None => std::task::Poll::Ready(Ok(())),
}
}
}

impl AsyncWrite for WinSession {
impl AsyncWrite for DeviceWriter {
fn poll_write(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
Expand Down
15 changes: 8 additions & 7 deletions src/platform/windows/device.rs
Original file line number Diff line number Diff line change
Expand Up @@ -287,13 +287,14 @@ impl Write for Tun {
}
}

impl Drop for Tun {
fn drop(&mut self) {
if let Err(err) = self.session.shutdown() {
log::error!("failed to shutdown session: {:?}", err);
}
}
}
// impl Drop for Tun {
// fn drop(&mut self) {
// // The session has implemented drop
// if let Err(err) = self.session.shutdown() {
// log::error!("failed to shutdown session: {:?}", err);
// }
// }
// }

#[repr(transparent)]
pub struct Reader(Arc<Tun>);
Expand Down

0 comments on commit dd8711d

Please sign in to comment.