Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

split DeviceWriter&DeviceReader #98

Merged
merged 1 commit into from
Aug 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .cargo/config.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
protocol = "sparse"

[build]
# target = ["x86_64-unknown-linux-musl"]
# target = ["x86_64-unknown-linux-gnu"]
# target = ["aarch64-linux-android"]
# target = ["aarch64-apple-ios"]
Expand Down
4 changes: 2 additions & 2 deletions src/async/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,12 @@ pub use codec::TunPacketCodec;
#[cfg(unix)]
mod unix_device;
#[cfg(unix)]
pub use unix_device::AsyncDevice;
pub use unix_device::{AsyncDevice, DeviceReader, DeviceWriter};

#[cfg(target_os = "windows")]
mod win_device;
#[cfg(target_os = "windows")]
pub use win_device::AsyncDevice;
pub use win_device::{AsyncDevice, DeviceReader, DeviceWriter};

/// Create a TUN device with the given name.
pub fn create_as_async(configuration: &Configuration) -> Result<AsyncDevice, error::Error> {
Expand Down
91 changes: 91 additions & 0 deletions src/async/unix_device.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ use tokio_util::codec::Framed;

use super::TunPacketCodec;
use crate::device::AbstractDevice;
use crate::platform::posix::{Reader, Writer};
use crate::platform::Device;

/// An async TUN device wrapper around a TUN device.
Expand Down Expand Up @@ -62,6 +63,11 @@ impl AsyncDevice {
// associate mtu with the capacity of ReadBuf
Framed::with_capacity(self, codec, mtu as usize)
}
pub fn split(self) -> std::io::Result<(DeviceWriter, DeviceReader)> {
let device = self.inner.into_inner();
let (reader, writer) = device.split();
Ok((DeviceWriter::new(writer)?, DeviceReader::new(reader)?))
}

/// Recv a packet from tun device
pub async fn recv(&self, buf: &mut [u8]) -> std::io::Result<usize> {
Expand Down Expand Up @@ -146,3 +152,88 @@ impl AsyncWrite for AsyncDevice {
true
}
}
pub struct DeviceReader {
inner: AsyncFd<Reader>,
}
impl DeviceReader {
fn new(reader: Reader) -> std::io::Result<Self> {
Ok(Self {
inner: AsyncFd::new(reader)?,
})
}
}
impl AsyncRead for DeviceReader {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf,
) -> Poll<std::io::Result<()>> {
loop {
let mut guard = ready!(self.inner.poll_read_ready_mut(cx))?;
let rbuf = buf.initialize_unfilled();
match guard.try_io(|inner| inner.get_mut().read(rbuf)) {
Ok(res) => return Poll::Ready(res.map(|n| buf.advance(n))),
Err(_wb) => continue,
}
}
}
}

pub struct DeviceWriter {
inner: AsyncFd<Writer>,
}
impl DeviceWriter {
fn new(writer: Writer) -> std::io::Result<Self> {
Ok(Self {
inner: AsyncFd::new(writer)?,
})
}
}

impl AsyncWrite for DeviceWriter {
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<std::io::Result<usize>> {
loop {
let mut guard = ready!(self.inner.poll_write_ready_mut(cx))?;
match guard.try_io(|inner| inner.get_mut().write(buf)) {
Ok(res) => return Poll::Ready(res),
Err(_wb) => continue,
}
}
}

fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
loop {
let mut guard = ready!(self.inner.poll_write_ready_mut(cx))?;
match guard.try_io(|inner| inner.get_mut().flush()) {
Ok(res) => return Poll::Ready(res),
Err(_wb) => continue,
}
}
}

fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
Poll::Ready(Ok(()))
}

fn poll_write_vectored(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
bufs: &[IoSlice<'_>],
) -> Poll<std::io::Result<usize>> {
loop {
let mut guard = ready!(self.inner.poll_write_ready_mut(cx))?;
match guard.try_io(|inner| inner.get_mut().write_vectored(bufs)) {
Ok(res) => return Poll::Ready(res),
Err(_wb) => continue,
}
}
}

fn is_write_vectored(&self) -> bool {
true
}
}
75 changes: 46 additions & 29 deletions src/async/win_device.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,19 @@ use core::task::{Context, Poll};
use std::io;
use std::io::Error;

use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use tokio_util::codec::Framed;

use super::TunPacketCodec;
use crate::device::AbstractDevice;
use crate::platform::Device;
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use tokio::sync::mpsc::error::TrySendError;
use tokio_util::codec::Framed;
use wintun::Packet;

/// An async TUN device wrapper around a TUN device.
pub struct AsyncDevice {
inner: Device,
session: WinSession,
session_reader: DeviceReader,
session_writer: DeviceWriter,
}

/// Returns a shared reference to the underlying Device object.
Expand All @@ -49,10 +51,12 @@ impl core::ops::DerefMut for AsyncDevice {
impl AsyncDevice {
/// Create a new `AsyncDevice` wrapping around a `Device`.
pub fn new(device: Device) -> io::Result<AsyncDevice> {
let session = WinSession::new(device.tun.get_session())?;
let session_reader = DeviceReader::new(device.tun.get_session())?;
let session_writer = DeviceWriter::new(device.tun.get_session())?;
Ok(AsyncDevice {
inner: device,
session,
session_reader,
session_writer,
})
}

Expand All @@ -63,6 +67,9 @@ impl AsyncDevice {
// guarantee to avoid the mtu of wintun may far away larger than the default provided capacity of ReadBuf of Framed
Framed::with_capacity(self, codec, mtu as usize)
}
pub fn split(self) -> io::Result<(DeviceWriter, DeviceReader)> {
Ok((self.session_writer, self.session_reader))
}

/// Recv a packet from tun device - Not implemented for windows
pub async fn recv(&self, _buf: &mut [u8]) -> std::io::Result<usize> {
Expand All @@ -81,7 +88,7 @@ impl AsyncRead for AsyncDevice {
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
Pin::new(&mut self.session).poll_read(cx, buf)
Pin::new(&mut self.session_reader).poll_read(cx, buf)
}
}

Expand All @@ -91,34 +98,38 @@ impl AsyncWrite for AsyncDevice {
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<Result<usize, Error>> {
Pin::new(&mut self.session).poll_write(cx, buf)
Pin::new(&mut self.session_writer).poll_write(cx, buf)
}

fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
Pin::new(&mut self.session).poll_flush(cx)
Pin::new(&mut self.session_writer).poll_flush(cx)
}

fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
Pin::new(&mut self.session).poll_shutdown(cx)
Pin::new(&mut self.session_writer).poll_shutdown(cx)
}
}

struct WinSession {
session: std::sync::Arc<wintun::Session>,
receiver: tokio::sync::mpsc::UnboundedReceiver<Vec<u8>>,
pub struct DeviceReader {
receiver: tokio::sync::mpsc::Receiver<Packet>,
_task: std::thread::JoinHandle<()>,
}

impl WinSession {
fn new(session: std::sync::Arc<wintun::Session>) -> Result<WinSession, io::Error> {
let session_reader = session.clone();
let (receiver_tx, receiver_rx) = tokio::sync::mpsc::unbounded_channel::<Vec<u8>>();
impl DeviceReader {
fn new(session: std::sync::Arc<wintun::Session>) -> Result<DeviceReader, io::Error> {
let (receiver_tx, receiver_rx) = tokio::sync::mpsc::channel(1024);
let task = std::thread::spawn(move || loop {
match session_reader.receive_blocking() {
match session.receive_blocking() {
Ok(packet) => {
if let Err(err) = receiver_tx.send(packet.bytes().to_vec()) {
log::error!("{}", err);
break;
if let Err(err) = receiver_tx.try_send(packet) {
match err {
TrySendError::Full(_) => {
log::error!("receiver_tx Full");
continue;
}
TrySendError::Closed(_) => {
log::error!("receiver_tx Closed");
break;
}
}
}
}
Err(err) => {
Expand All @@ -127,32 +138,38 @@ impl WinSession {
}
}
});

Ok(WinSession {
session,
Ok(DeviceReader {
receiver: receiver_rx,
_task: task,
})
}
}
pub struct DeviceWriter {
session: std::sync::Arc<wintun::Session>,
}
impl DeviceWriter {
fn new(session: std::sync::Arc<wintun::Session>) -> Result<DeviceWriter, io::Error> {
Ok(Self { session })
}
}

impl AsyncRead for WinSession {
impl AsyncRead for DeviceReader {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
match std::task::ready!(self.receiver.poll_recv(cx)) {
Some(bytes) => {
buf.put_slice(&bytes);
buf.put_slice(bytes.bytes());
std::task::Poll::Ready(Ok(()))
}
None => std::task::Poll::Ready(Ok(())),
}
}
}

impl AsyncWrite for WinSession {
impl AsyncWrite for DeviceWriter {
fn poll_write(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
Expand Down
15 changes: 8 additions & 7 deletions src/platform/windows/device.rs
Original file line number Diff line number Diff line change
Expand Up @@ -287,13 +287,14 @@ impl Write for Tun {
}
}

impl Drop for Tun {
fn drop(&mut self) {
if let Err(err) = self.session.shutdown() {
log::error!("failed to shutdown session: {:?}", err);
}
}
}
// impl Drop for Tun {
// fn drop(&mut self) {
// // The session has implemented drop
// if let Err(err) = self.session.shutdown() {
// log::error!("failed to shutdown session: {:?}", err);
// }
// }
// }

#[repr(transparent)]
pub struct Reader(Arc<Tun>);
Expand Down
Loading