Skip to content

Commit

Permalink
feature-gate custom transports behind any_transport feature
Browse files Browse the repository at this point in the history
  • Loading branch information
MaxVerevkin committed Sep 9, 2024
1 parent adb1908 commit eaec63e
Show file tree
Hide file tree
Showing 4 changed files with 100 additions and 44 deletions.
7 changes: 7 additions & 0 deletions wayrs-client/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,13 @@ edition.workspace = true
rust-version.workspace = true
license.workspace = true

[features]
any_transport = []

[[example]]
name = "custom_transport"
required-features = ["any_transport"]

[dependencies]
wayrs-core = { version = "1.0", path = "../wayrs-core" }
wayrs-scanner = { version = "0.15", path = "../wayrs-scanner" }
Expand Down
59 changes: 59 additions & 0 deletions wayrs-client/src/any_transport.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
use std::any::Any;
use std::collections::VecDeque;
use std::io;
use std::os::fd::{OwnedFd, RawFd};

use wayrs_core::transport::Transport;
use wayrs_core::IoMode;

pub struct AnyTranpsort(Box<dyn AnyTransportImp>);

impl AnyTranpsort {
pub fn new<T>(transport: T) -> Self
where
T: Transport + Send + 'static,
{
Self(Box::new(transport))
}

pub fn as_any(&self) -> &dyn Any {
self.0.as_any()
}

pub fn as_any_mut(&mut self) -> &mut dyn Any {
self.0.as_any_mut()
}
}

trait AnyTransportImp: Transport + Send {
fn as_any(&self) -> &dyn Any;
fn as_any_mut(&mut self) -> &mut dyn Any;
}

impl<T: Transport + Send + 'static> AnyTransportImp for T {
fn as_any(&self) -> &dyn Any {
self
}
fn as_any_mut(&mut self) -> &mut dyn Any {
self
}
}

impl Transport for AnyTranpsort {
fn pollable_fd(&self) -> RawFd {
self.0.as_ref().pollable_fd()
}

fn send(&mut self, bytes: &[io::IoSlice], fds: &[OwnedFd], mode: IoMode) -> io::Result<usize> {
self.0.as_mut().send(bytes, fds, mode)
}

fn recv(
&mut self,
bytes: &mut [io::IoSliceMut],
fds: &mut VecDeque<OwnedFd>,
mode: IoMode,
) -> io::Result<usize> {
self.0.as_mut().recv(bytes, fds, mode)
}
}
75 changes: 31 additions & 44 deletions wayrs-client/src/connection.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
//! Wayland connection
use std::any::Any;
use std::collections::VecDeque;
use std::env;
use std::io;
use std::num::NonZeroU32;
use std::os::fd::{AsRawFd, OwnedFd, RawFd};
use std::os::fd::{AsRawFd, RawFd};
use std::os::unix::net::UnixStream;
use std::path::PathBuf;

Expand All @@ -15,14 +14,21 @@ use crate::protocol::wl_registry::GlobalArgs;
use crate::protocol::*;
use crate::EventCtx;

use wayrs_core::transport::{
BufferedSocket, PeekHeaderError, RecvMessageError, SendMessageError, Transport,
};
use wayrs_core::transport::{BufferedSocket, PeekHeaderError, RecvMessageError, SendMessageError};
use wayrs_core::{ArgType, ArgValue, Interface, IoMode, Message, MessageBuffersPool, ObjectId};

#[cfg(feature = "tokio")]
use tokio::io::unix::AsyncFd;

#[cfg(feature = "any_transport")]
use crate::any_transport::AnyTranpsort;
#[cfg(feature = "any_transport")]
use wayrs_core::transport::Transport;
#[cfg(feature = "any_transport")]
type TransportImp = AnyTranpsort;
#[cfg(not(feature = "any_transport"))]
type TransportImp = UnixStream;

/// An error that can occur while connecting to a Wayland socket.
#[derive(Debug, thiserror::Error)]
pub enum ConnectError {
Expand All @@ -44,7 +50,7 @@ pub struct Connection<D> {
#[cfg(feature = "tokio")]
async_fd: Option<AsyncFd<RawFd>>,

socket: BufferedSocket<AnyTranpsort>,
socket: BufferedSocket<TransportImp>,
msg_buffers_pool: MessageBuffersPool,

object_mgr: ObjectManager<D>,
Expand Down Expand Up @@ -83,7 +89,7 @@ impl<D> AsRawFd for Connection<D> {
/// Connect to a Wayland socket at the standard path with [`connect`](Self::connect), or use any
/// Wayland transport method with [`with_transport`](Self::with_transport).
pub struct ConnectionBuilder {
transport: AnyTranpsort,
transport: TransportImp,
}

impl ConnectionBuilder {
Expand All @@ -98,13 +104,20 @@ impl ConnectionBuilder {
path.push(runtime_dir);
path.push(wayland_disp);

Ok(Self::with_transport(UnixStream::connect(path)?))
Ok(Self {
#[cfg(feature = "any_transport")]
transport: TransportImp::new(UnixStream::connect(path)?),
#[cfg(not(feature = "any_transport"))]
transport: UnixStream::connect(path)?,
})
}

/// Use a custom transport
#[cfg(feature = "any_transport")]
#[cfg_attr(docsrs, doc(cfg(feature = "any_transport")))]
pub fn with_transport<T: Transport + Send + 'static>(transport: T) -> Self {
Self {
transport: AnyTranpsort(Box::new(transport)),
transport: TransportImp::new(transport),
}
}

Expand Down Expand Up @@ -305,15 +318,19 @@ impl<D> Connection<D> {
/// Try to get a reference to the underlying transport.
///
/// Returns `None` if the type of the transport is not `T`.
#[cfg(feature = "any_transport")]
#[cfg_attr(docsrs, doc(cfg(feature = "any_transport")))]
pub fn transport<T: 'static>(&self) -> Option<&T> {
self.socket.transport().0.as_any().downcast_ref()
self.socket.transport().as_any().downcast_ref()
}

/// Try to get a mutable reference to the underlying transport.
///
/// Returns `None` if the type of the transport is not `T`.
#[cfg(feature = "any_transport")]
#[cfg_attr(docsrs, doc(cfg(feature = "any_transport")))]
pub fn transport_mut<T: 'static>(&mut self) -> Option<&mut T> {
self.socket.transport_mut().0.as_any_mut().downcast_mut()
self.socket.transport_mut().as_any_mut().downcast_mut()
}

#[doc(hidden)]
Expand Down Expand Up @@ -657,39 +674,6 @@ impl<D> Connection<D> {
}
}

trait IsAnyTransport: Transport + Send {
fn as_any(&self) -> &dyn Any;
fn as_any_mut(&mut self) -> &mut dyn Any;
}
impl<T: Transport + Send + 'static> IsAnyTransport for T {
fn as_any(&self) -> &dyn Any {
self
}
fn as_any_mut(&mut self) -> &mut dyn Any {
self
}
}

struct AnyTranpsort(Box<dyn IsAnyTransport>);
impl Transport for AnyTranpsort {
fn pollable_fd(&self) -> RawFd {
self.0.as_ref().pollable_fd()
}

fn send(&mut self, bytes: &[io::IoSlice], fds: &[OwnedFd], mode: IoMode) -> io::Result<usize> {
self.0.as_mut().send(bytes, fds, mode)
}

fn recv(
&mut self,
bytes: &mut [io::IoSliceMut],
fds: &mut VecDeque<OwnedFd>,
mode: IoMode,
) -> io::Result<usize> {
self.0.as_mut().recv(bytes, fds, mode)
}
}

#[cfg(test)]
mod tests {
use super::*;
Expand All @@ -702,7 +686,10 @@ mod tests {
}

#[test]
#[cfg(feature = "any_transport")]
fn transport_downcast() {
use std::os::fd::OwnedFd;

struct T;
impl Transport for T {
fn pollable_fd(&self) -> RawFd {
Expand Down
3 changes: 3 additions & 0 deletions wayrs-client/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@ pub mod protocol;
mod connection;
mod debug_message;

#[cfg(feature = "any_transport")]
mod any_transport;

pub use connection::{ConnectError, Connection, ConnectionBuilder};

#[doc(hidden)]
Expand Down

0 comments on commit eaec63e

Please sign in to comment.