diff --git a/Cargo.toml b/Cargo.toml index 447d9d8dd..0397dc899 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -62,10 +62,17 @@ features = [ wasi = "0.11.0" libc = "0.2.149" +[target.x86_64-fortanix-unknown-sgx.dependencies] +async-usercalls = { git = "https://github.com/fortanix/rust-sgx.git", branch = "master" } # FIXME: use published version once available +crossbeam-channel = "0.5" + [dev-dependencies] env_logger = { version = "0.9.3", default-features = false } rand = "0.8" +[build-dependencies] +rustc_version = "0.2" + [package.metadata.docs.rs] all-features = true rustdoc-args = ["--cfg", "docsrs", "--generate-link-to-definition"] @@ -97,3 +104,7 @@ required-features = ["os-poll", "net"] [[example]] name = "udp_server" required-features = ["os-poll", "net"] + +# set number of threads so tests can run properly +[package.metadata.fortanix-sgx] +threads = 100 diff --git a/build.rs b/build.rs new file mode 100644 index 000000000..226630a47 --- /dev/null +++ b/build.rs @@ -0,0 +1,7 @@ +use rustc_version::Version; + +fn main() { + if Version::parse("1.78.0").unwrap() <= rustc_version::version().unwrap() { + println!("cargo:rustc-cfg=compiler_has_send_sgx_types"); + } +} diff --git a/ct.sh b/ct.sh new file mode 100755 index 000000000..80f02c813 --- /dev/null +++ b/ct.sh @@ -0,0 +1,18 @@ +#!/bin/bash -ex +export RUST_BACKTRACE=1 + +toolchains=("my_sgx_build" "nightly") +platforms=("x86_64-fortanix-unknown-sgx" "x86_64-unknown-linux-gnu") + +for toolchain in "${toolchains[@]}"; do + for platform in "${platforms[@]}"; do + echo "toolchain: $toolchain" + echo "platform: $platform" + echo "" + cargo +${toolchain} test --target ${platform} + cargo +${toolchain} test --features "net,os-poll" --target ${platform} + cargo +${toolchain} test --features "net,os-ext" --target ${platform} + cargo +${toolchain} test --features "net,os-poll,os-ext" --target ${platform} + done +done +exit 0 diff --git a/examples/tcp_listenfd_server.rs b/examples/tcp_listenfd_server.rs index 941d7f048..c1881dc0b 100644 --- a/examples/tcp_listenfd_server.rs +++ b/examples/tcp_listenfd_server.rs @@ -1,3 +1,5 @@ +#![cfg_attr(target_env = "sgx", feature(sgx_platform))] + // You can run this example from the root of the mio repo: // cargo run --example tcp_listenfd_server --features="os-poll net" // or with wasi: @@ -23,7 +25,12 @@ fn get_first_listen_fd_listener() -> Option { use std::os::unix::io::FromRawFd; #[cfg(target_os = "wasi")] use std::os::wasi::io::FromRawFd; + #[cfg(target_env = "sgx")] + use std::os::fortanix_sgx::io::FromRawFd; + #[cfg(target_env = "sgx")] + let stdlistener = unsafe { std::net::TcpListener::from_raw_fd(3, Default::default()) }; + #[cfg(not(target_env = "sgx"))] let stdlistener = unsafe { std::net::TcpListener::from_raw_fd(3) }; stdlistener.set_nonblocking(true).unwrap(); Some(stdlistener) diff --git a/examples/udp_server.rs b/examples/udp_server.rs index 698d710cd..f1a96ac0b 100644 --- a/examples/udp_server.rs +++ b/examples/udp_server.rs @@ -1,13 +1,17 @@ // You can run this example from the root of the mio repo: // cargo run --example udp_server --features="os-poll net" -use log::warn; -use mio::{Events, Interest, Poll, Token}; -use std::io; +#[cfg(not(any(target_os = "wasi", target_env = "sgx")))] +use { + log::warn, + mio::{Events, Interest, Poll, Token}, + std::io, +}; // A token to allow us to identify which event is for the `UdpSocket`. +#[cfg(not(any(target_os = "wasi", target_env = "sgx")))] const UDP_SOCKET: Token = Token(0); -#[cfg(not(target_os = "wasi"))] +#[cfg(not(any(target_os = "wasi", target_env = "sgx")))] fn main() -> io::Result<()> { use mio::net::UdpSocket; @@ -84,6 +88,11 @@ fn main() -> io::Result<()> { } } +#[cfg(target_env = "sgx")] +fn main() { + println!("SGX does not support UDP yet"); +} + #[cfg(target_os = "wasi")] fn main() { panic!("can't bind to an address with wasi") diff --git a/src/io_source.rs b/src/io_source.rs index 06dc5e17e..e58a9a339 100644 --- a/src/io_source.rs +++ b/src/io_source.rs @@ -3,6 +3,8 @@ use std::ops::{Deref, DerefMut}; use std::os::unix::io::AsRawFd; #[cfg(target_os = "wasi")] use std::os::wasi::io::AsRawFd; +#[cfg(target_env = "sgx")] +use std::os::fortanix_sgx::io::AsRawFd; #[cfg(windows)] use std::os::windows::io::AsRawSocket; #[cfg(debug_assertions)] @@ -102,6 +104,7 @@ impl IoSource { /// [`deregister`] it. /// /// [`deregister`]: Registry::deregister + #[cfg(not(target_env = "sgx"))] pub fn into_inner(self) -> T { self.inner } @@ -129,7 +132,7 @@ impl DerefMut for IoSource { } } -#[cfg(unix)] +#[cfg(any(unix, target_env = "sgx"))] impl event::Source for IoSource where T: AsRawFd, diff --git a/src/lib.rs b/src/lib.rs index 56a7160be..3a87df2e0 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,3 +1,5 @@ +#![cfg_attr(compiler_has_send_sgx_types, feature(stmt_expr_attributes))] +#![cfg_attr(not(compiler_has_send_sgx_types), allow(suspicious_auto_trait_impls))] #![deny( missing_docs, missing_debug_implementations, @@ -10,6 +12,7 @@ #![cfg_attr(test, deny(warnings))] // Disallow warnings in examples. #![doc(test(attr(deny(warnings))))] +#![cfg_attr(target_env = "sgx", feature(sgx_platform))] //! Mio is a fast, low-level I/O library for Rust focusing on non-blocking APIs //! and event notification for building high performance I/O apps with as little diff --git a/src/net/mod.rs b/src/net/mod.rs index 7d714ca00..3610df017 100644 --- a/src/net/mod.rs +++ b/src/net/mod.rs @@ -28,9 +28,9 @@ mod tcp; pub use self::tcp::{TcpListener, TcpStream}; -#[cfg(not(target_os = "wasi"))] +#[cfg(not(any(target_os = "wasi", target_env = "sgx")))] mod udp; -#[cfg(not(target_os = "wasi"))] +#[cfg(not(any(target_os = "wasi", target_env = "sgx")))] pub use self::udp::UdpSocket; #[cfg(unix)] diff --git a/src/net/tcp/listener.rs b/src/net/tcp/listener.rs index df51d57ae..79189e496 100644 --- a/src/net/tcp/listener.rs +++ b/src/net/tcp/listener.rs @@ -1,3 +1,6 @@ +#[cfg(target_env = "sgx")] +use crate::sys::tcp::net::{self, SocketAddr}; +#[cfg(not(target_env = "sgx"))] use std::net::{self, SocketAddr}; #[cfg(unix)] use std::os::unix::io::{AsRawFd, FromRawFd, IntoRawFd, RawFd}; @@ -11,7 +14,7 @@ use crate::io_source::IoSource; use crate::net::TcpStream; #[cfg(unix)] use crate::sys::tcp::set_reuseaddr; -#[cfg(not(target_os = "wasi"))] +#[cfg(not(any(target_os = "wasi", target_env = "sgx")))] use crate::sys::tcp::{bind, listen, new_for_addr}; use crate::{event, sys, Interest, Registry, Token}; @@ -57,25 +60,42 @@ impl TcpListener { /// 4. Calls `listen` on the socket to prepare it to receive new connections. #[cfg(not(target_os = "wasi"))] pub fn bind(addr: SocketAddr) -> io::Result { - let socket = new_for_addr(addr)?; - #[cfg(unix)] - let listener = unsafe { TcpListener::from_raw_fd(socket) }; - #[cfg(windows)] - let listener = unsafe { TcpListener::from_raw_socket(socket as _) }; - - // On platforms with Berkeley-derived sockets, this allows to quickly - // rebind a socket, without needing to wait for the OS to clean up the - // previous one. - // - // On Windows, this allows rebinding sockets which are actively in use, - // which allows “socket hijacking”, so we explicitly don't set it here. - // https://docs.microsoft.com/en-us/windows/win32/winsock/using-so-reuseaddr-and-so-exclusiveaddruse - #[cfg(not(windows))] - set_reuseaddr(&listener.inner, true)?; - - bind(&listener.inner, addr)?; - listen(&listener.inner, 1024)?; - Ok(listener) + #[cfg(not(target_env = "sgx"))] { + let socket = new_for_addr(addr)?; + #[cfg(unix)] + let listener = unsafe { TcpListener::from_raw_fd(socket) }; + #[cfg(windows)] + let listener = unsafe { TcpListener::from_raw_socket(socket as _) }; + + // On platforms with Berkeley-derived sockets, this allows to quickly + // rebind a socket, without needing to wait for the OS to clean up the + // previous one. + // + // On Windows, this allows rebinding sockets which are actively in use, + // which allows “socket hijacking”, so we explicitly don't set it here. + // https://docs.microsoft.com/en-us/windows/win32/winsock/using-so-reuseaddr-and-so-exclusiveaddruse + #[cfg(not(windows))] + set_reuseaddr(&listener.inner, true)?; + + bind(&listener.inner, addr)?; + listen(&listener.inner, 1024)?; + Ok(listener) + } + + #[cfg(target_env = "sgx")] { + Ok(TcpListener { + inner: IoSource::new(sys::tcp::bind(addr)?), + }) + } + } + + /// Convenience method to bind a new TCP listener to the specified address + /// to receive new connections. + #[cfg(target_env = "sgx")] + pub fn bind_str(addr: &str) -> io::Result { + Ok(TcpListener { + inner: IoSource::new(sys::tcp::bind_str(addr)?), + }) } /// Creates a new `TcpListener` from a standard `net::TcpListener`. @@ -84,9 +104,9 @@ impl TcpListener { /// standard library in the Mio equivalent. The conversion assumes nothing /// about the underlying listener; ; it is left up to the user to set it /// in non-blocking mode. - pub fn from_std(listener: net::TcpListener) -> TcpListener { + pub fn from_std(listener: std::net::TcpListener) -> TcpListener { TcpListener { - inner: IoSource::new(listener), + inner: IoSource::new(listener.into()), } } @@ -100,7 +120,13 @@ impl TcpListener { /// returned along with it. pub fn accept(&self) -> io::Result<(TcpStream, SocketAddr)> { self.inner.do_io(|inner| { - sys::tcp::accept(inner).map(|(stream, addr)| (TcpStream::from_std(stream), addr)) + sys::tcp::accept(inner).map(|(stream, addr)| { + #[cfg(target_env = "sgx")] + let stream = TcpStream::internal_new(stream); + #[cfg(not(target_env = "sgx"))] + let stream = TcpStream::from_std(stream); + (stream, addr) + }) }) } diff --git a/src/net/tcp/stream.rs b/src/net/tcp/stream.rs index 8a3f6a2f2..9bc6a472d 100644 --- a/src/net/tcp/stream.rs +++ b/src/net/tcp/stream.rs @@ -1,5 +1,8 @@ use std::fmt; use std::io::{self, IoSlice, IoSliceMut, Read, Write}; +#[cfg(target_env = "sgx")] +use crate::sys::tcp::net::{self, Shutdown, SocketAddr}; +#[cfg(not(target_env = "sgx"))] use std::net::{self, Shutdown, SocketAddr}; #[cfg(unix)] use std::os::unix::io::{AsRawFd, FromRawFd, IntoRawFd, RawFd}; @@ -9,9 +12,11 @@ use std::os::wasi::io::{AsRawFd, FromRawFd, IntoRawFd, RawFd}; use std::os::windows::io::{AsRawSocket, FromRawSocket, IntoRawSocket, RawSocket}; use crate::io_source::IoSource; -#[cfg(not(target_os = "wasi"))] +#[cfg(not(any(target_os = "wasi", target_env = "sgx")))] use crate::sys::tcp::{connect, new_for_addr}; use crate::{event, Interest, Registry, Token}; +#[cfg(target_env = "sgx")] +use crate::sys; /// A non-blocking TCP stream between a local socket and a remote socket. /// @@ -50,6 +55,13 @@ pub struct TcpStream { } impl TcpStream { + #[cfg(target_env = "sgx")] + pub(crate) fn internal_new(stream: sys::tcp::TcpStream) -> TcpStream { + TcpStream { + inner: IoSource::new(stream), + } + } + /// Create a new TCP stream and issue a non-blocking connect to the /// specified address. /// @@ -82,13 +94,26 @@ impl TcpStream { /// [write interest]: Interest::WRITABLE #[cfg(not(target_os = "wasi"))] pub fn connect(addr: SocketAddr) -> io::Result { - let socket = new_for_addr(addr)?; - #[cfg(unix)] - let stream = unsafe { TcpStream::from_raw_fd(socket) }; - #[cfg(windows)] - let stream = unsafe { TcpStream::from_raw_socket(socket as _) }; - connect(&stream.inner, addr)?; - Ok(stream) + #[cfg(not(target_env = "sgx"))] { + let socket = new_for_addr(addr)?; + #[cfg(unix)] + let stream = unsafe { TcpStream::from_raw_fd(socket) }; + #[cfg(windows)] + let stream = unsafe { TcpStream::from_raw_socket(socket as _) }; + connect(&stream.inner, addr)?; + Ok(stream) + } + + #[cfg(target_env = "sgx")] { + sys::tcp::connect(addr).map(TcpStream::internal_new) + } + } + + /// Create a new TCP stream and issue a non-blocking connect to the + /// specified address. + #[cfg(target_env = "sgx")] + pub fn connect_str(addr: &str) -> io::Result { + sys::tcp::connect_str(addr).map(TcpStream::internal_new) } /// Creates a new `TcpStream` from a standard `net::TcpStream`. @@ -103,7 +128,10 @@ impl TcpStream { /// The TCP stream here will not have `connect` called on it, so it /// should already be connected via some other means (be it manually, or /// the standard library). - pub fn from_std(stream: net::TcpStream) -> TcpStream { + pub fn from_std(stream: std::net::TcpStream) -> TcpStream { + #[cfg(target_env = "sgx")] + let stream: sys::tcp::TcpStream = stream.into(); + TcpStream { inner: IoSource::new(stream), } diff --git a/src/sys/mod.rs b/src/sys/mod.rs index 2a968b265..b17691882 100644 --- a/src/sys/mod.rs +++ b/src/sys/mod.rs @@ -14,6 +14,7 @@ //! * `tcp` and `udp` modules: see the [`crate::net`] module. //! * `Waker`: see [`crate::Waker`]. +#[cfg(not(target_env = "sgx"))] cfg_os_poll! { macro_rules! debug_detail { ( @@ -69,6 +70,12 @@ cfg_os_poll! { pub(crate) use self::wasi::*; } +#[cfg(target_env = "sgx")] +cfg_os_poll! { + mod sgx; + pub(crate) use self::sgx::*; +} + cfg_not_os_poll! { mod shell; pub(crate) use self::shell::*; diff --git a/src/sys/sgx/mod.rs b/src/sys/sgx/mod.rs new file mode 100644 index 000000000..5245d6fb9 --- /dev/null +++ b/src/sys/sgx/mod.rs @@ -0,0 +1,61 @@ +mod selector; +pub(crate) use self::selector::{event, Event, Events, Selector}; + +mod waker; +pub(crate) use self::waker::Waker; + +cfg_net! { + pub(crate) mod tcp; +} + +cfg_net! { + use std::io; + use std::os::fortanix_sgx::io::RawFd; + use crate::Registry; + use crate::Token; + use crate::Interest; + + pub(crate) struct IoSourceState; + + impl IoSourceState { + pub fn new() -> IoSourceState { + IoSourceState + } + + pub fn do_io(&self, f: F, io: &T) -> io::Result + where + F: FnOnce(&T) -> io::Result, + { + // We don't hold state, so we can just call the function and + // return. + f(io) + } + + pub fn register( + &mut self, + registry: &Registry, + token: Token, + interests: Interest, + fd: RawFd, + ) -> io::Result<()> { + // Pass through, we don't have any state + registry.selector().register(fd, token, interests) + } + + pub fn reregister( + &mut self, + registry: &Registry, + token: Token, + interests: Interest, + fd: RawFd, + ) -> io::Result<()> { + // Pass through, we don't have any state + registry.selector().reregister(fd, token, interests) + } + + pub fn deregister(&mut self, registry: &Registry, fd: RawFd) -> io::Result<()> { + // Pass through, we don't have any state + registry.selector().deregister(fd) + } + } +} diff --git a/src/sys/sgx/selector.rs b/src/sys/sgx/selector.rs new file mode 100644 index 000000000..48e062cf1 --- /dev/null +++ b/src/sys/sgx/selector.rs @@ -0,0 +1,295 @@ +use async_usercalls::{AsyncUsercallProvider, CallbackHandler, CallbackHandlerWaker}; +use crossbeam_channel as mpmc; +use std::collections::HashMap; +use std::io; +use std::ops::Deref; +use std::os::fortanix_sgx::io::{AsRawFd, RawFd}; +#[cfg(debug_assertions)] +use std::sync::atomic::AtomicBool; +use std::sync::atomic::{AtomicUsize, Ordering}; +use std::sync::{Arc, Mutex}; +use std::time::Duration; + +pub struct Selector { + #[cfg(debug_assertions)] + id: usize, + event_rx: mpmc::Receiver<(RegistrationId, EventKind)>, + callback_handler: Arc, + shared_inner: Arc, + #[cfg(debug_assertions)] + has_waker: AtomicBool, +} + +struct SelectorSharedInner { + event_tx: mpmc::Sender<(RegistrationId, EventKind)>, + registrations: Mutex>, + provider: AsyncUsercallProvider, + callback_handler_waker: CallbackHandlerWaker, +} + +impl Selector { + pub fn new() -> io::Result { + #[cfg(debug_assertions)] + static NEXT_ID: AtomicUsize = AtomicUsize::new(1); + let (event_tx, event_rx) = mpmc::unbounded(); + let (provider, callback_handler) = AsyncUsercallProvider::new(); + let callback_handler_waker = callback_handler.waker(); + Ok(Selector { + #[cfg(debug_assertions)] + id: NEXT_ID.fetch_add(1, Ordering::Relaxed), + event_rx, + callback_handler: Arc::new(callback_handler), + shared_inner: Arc::new(SelectorSharedInner { + event_tx, + registrations: Mutex::new(HashMap::new()), + provider, + callback_handler_waker, + }), + #[cfg(debug_assertions)] + has_waker: AtomicBool::new(false), + }) + } + + pub fn try_clone(&self) -> io::Result { + Ok(Selector { + #[cfg(debug_assertions)] + id: self.id, + event_rx: self.event_rx.clone(), + callback_handler: self.callback_handler.clone(), + shared_inner: self.shared_inner.clone(), + #[cfg(debug_assertions)] + has_waker: AtomicBool::new(self.has_waker.load(Ordering::Acquire)), + }) + } + + pub fn select(&self, events: &mut Events, mut timeout: Option) -> io::Result<()> { + self.shared_inner.callback_handler_waker.clear(); + if !self.event_rx.is_empty() { + timeout = Some(Duration::from_nanos(0)); + } + self.callback_handler.poll(timeout); + + events.clear(); + let registrations = self.shared_inner.registrations.lock().unwrap(); + for (reg_id, kind) in self.event_rx.try_iter() { + if let Some((token, interest)) = registrations.get(®_id) { + if kind.matches_interest(interest) { + events.push(Event::new(kind, *token)); + } + } + if events.len() == events.capacity() { + break; + } + } + Ok(()) + } +} + +cfg_io_source! { + use crate::{Interest, Token}; + + impl Selector { + pub fn register(&self, _: RawFd, _: Token, _: Interest) -> io::Result<()> { + unimplemented!(); + } + + pub fn reregister(&self, _: RawFd, _: Token, _: Interest) -> io::Result<()> { + unimplemented!(); + } + + pub fn deregister(&self, _: RawFd) -> io::Result<()> { + unimplemented!(); + } + } +} + +cfg_net! { + #[cfg(debug_assertions)] + impl Selector { + pub fn id(&self) -> usize { + self.id + } + } +} + +impl AsRawFd for Selector { + fn as_raw_fd(&self) -> RawFd { + unimplemented!() + } +} + +pub(crate) struct Provider(Arc); + +impl Provider { + pub fn new(selector: &Selector) -> Self { + Self(selector.shared_inner.clone()) + } +} + +impl Deref for Provider { + type Target = AsyncUsercallProvider; + + fn deref(&self) -> &Self::Target { + &self.0.provider + } +} + +pub(crate) struct Registration { + id: RegistrationId, + shared_inner: Arc, + token: Token, + interest: Interest, +} + +impl Registration { + pub fn new(selector: &Selector, token: Token, interest: Interest) -> Self { + let id = RegistrationId::new(); + selector.shared_inner.registrations.lock().unwrap().insert(id, (token, interest)); + Registration { + id, + shared_inner: selector.shared_inner.clone(), + token, + interest: interest, + } + } + + pub fn provider(&self) -> Provider { + Provider(self.shared_inner.clone()) + } + + pub fn change_details(&mut self, token: Token, interest: Interest) -> bool { + if self.token == token && self.interest == interest { + return false; + } + self.token = token; + self.interest = interest; + self.shared_inner.registrations.lock().unwrap().insert(self.id, (self.token, self.interest)); + true + } + + pub fn token(&self) -> Token { + self.token + } + + pub fn push_event(&self, kind: EventKind) { + if kind.matches_interest(&self.interest) { + let _ = self.shared_inner.event_tx.send((self.id, kind)); + self.shared_inner.callback_handler_waker.wake(); + } + } +} + +impl Drop for Registration { + fn drop(&mut self) { + self.shared_inner.registrations.lock().unwrap().remove(&self.id); + } +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] +struct RegistrationId(usize); + +impl RegistrationId { + fn new() -> Self { + static NEXT_ID: AtomicUsize = AtomicUsize::new(1); + Self(NEXT_ID.fetch_add(1, Ordering::Relaxed)) + } +} + +#[derive(Clone, Debug)] +pub(crate) enum EventKind { + Readable, + ReadClosed, + ReadError, + Writable, + WriteClosed, + WriteError, +} + +impl EventKind { + fn matches_interest(&self, interest: &Interest) -> bool { + use EventKind::*; + match self { + Readable | ReadClosed => interest.is_readable(), + Writable | WriteClosed => interest.is_writable(), + // Always send error events + ReadError | WriteError => true, + } + } +} + +#[derive(Clone, Debug)] +pub struct Event { + kind: EventKind, + token: Token, +} + +impl Event { + pub(crate) fn new(kind: EventKind, token: Token) -> Self { + Event { kind, token } + } +} + +pub type Events = Vec; + +#[allow(clippy::trivially_copy_pass_by_ref)] +pub mod event { + use super::EventKind; + use crate::sys::Event; + use crate::Token; + use std::fmt; + + pub fn token(e: &Event) -> Token { + e.token + } + + pub fn is_readable(e: &Event) -> bool { + match e.kind { + EventKind::Readable | EventKind::ReadClosed | EventKind::ReadError => true, + _ => false, + } + } + + pub fn is_writable(e: &Event) -> bool { + match e.kind { + EventKind::Writable | EventKind::WriteClosed | EventKind::WriteError => true, + _ => false, + } + } + + pub fn is_error(e: &Event) -> bool { + match e.kind { + EventKind::ReadError | EventKind::WriteError => true, + _ => false, + } + } + + pub fn is_read_closed(e: &Event) -> bool { + match e.kind { + EventKind::ReadClosed => true, + _ => false, + } + } + + pub fn is_write_closed(e: &Event) -> bool { + match e.kind { + EventKind::WriteClosed => true, + _ => false, + } + } + + pub fn is_priority(_: &Event) -> bool { + false + } + + pub fn is_aio(_: &Event) -> bool { + false + } + + pub fn is_lio(_: &Event) -> bool { + false + } + + pub fn debug_details(f: &mut fmt::Formatter<'_>, e: &Event) -> fmt::Result { + fmt::Debug::fmt(e, f) + } +} diff --git a/src/sys/sgx/tcp/listener.rs b/src/sys/sgx/tcp/listener.rs new file mode 100644 index 000000000..4dfa29bbe --- /dev/null +++ b/src/sys/sgx/tcp/listener.rs @@ -0,0 +1,199 @@ +use async_usercalls::CancelHandle; +use std::fmt; +use std::io; +use std::mem; +use std::net::{self, SocketAddr}; +use std::os::fortanix_sgx::io::AsRawFd; +use std::os::fortanix_sgx::usercalls::raw::Fd; +use std::sync::{Arc, Mutex, MutexGuard}; + +use super::{other, would_block, State, TcpStream}; +use crate::sys::sgx::selector::{EventKind, Provider, Registration}; +use crate::{event, Interest, Registry, Token}; + +pub struct TcpListener { + listener: net::TcpListener, + imp: ListenerImp, +} + +#[derive(Clone)] +struct ListenerImp(Arc>); + +struct ListenerInner { + fd: Fd, + accept_state: State<(), Option, net::TcpStream>, + registration: Option, + provider: Option, +} + +impl TcpListener { + fn from_std(listener: net::TcpListener) -> TcpListener { + TcpListener { + imp: ListenerImp(Arc::new(Mutex::new(ListenerInner { + fd: listener.as_raw_fd(), + accept_state: State::New(()), + registration: None, + provider: None, + }))), + listener, + } + } + + pub fn bind(addr: SocketAddr) -> io::Result { + Ok(TcpListener::from_std(net::TcpListener::bind(addr)?)) + } + + pub fn bind_str(addr: &str) -> io::Result { + Ok(TcpListener::from_std(net::TcpListener::bind(addr)?)) + } + + pub fn accept(&self) -> io::Result<(TcpStream, SocketAddr)> { + let mut inner = self.inner(); + let ret = match mem::replace(&mut inner.accept_state, State::New(())) { + State::New(()) => Err(would_block()), + State::Pending(cancel_handle) => { + inner.accept_state = State::Pending(cancel_handle); + return Err(would_block()); + } + State::Ready(stream) => Ok(TcpStream::from_std(stream)), + State::Error(e) => Err(e), + }; + self.imp.schedule_accept(&mut inner); + ret + } + + pub fn local_addr(&self) -> io::Result { + self.listener.local_addr() + } + + pub fn set_ttl(&self, ttl: u32) -> io::Result<()> { + self.listener.set_ttl(ttl) + } + + pub fn ttl(&self) -> io::Result { + self.listener.ttl() + } + + pub fn take_error(&self) -> io::Result> { + self.listener.take_error() + } + + fn inner(&self) -> MutexGuard<'_, ListenerInner> { + self.imp.inner() + } +} + +impl ListenerImp { + fn inner(&self) -> MutexGuard<'_, ListenerInner> { + self.0.lock().unwrap() + } + + fn schedule_accept(&self, inner: &mut ListenerInner) { + let provider = match inner.provider.as_ref() { + Some(provider) => provider, + None => return, + }; + match inner.accept_state { + State::New(()) => {} + _ => return, + } + let weak_ref = Arc::downgrade(&self.0); + let cancel_handle = provider.accept_stream(inner.fd, move |res| { + let imp = match weak_ref.upgrade() { + Some(arc) => ListenerImp(arc), + None => return, + }; + let mut inner = imp.inner(); + assert!(inner.accept_state.is_pending()); + inner.accept_state = res.into(); + inner.push_event(if inner.accept_state.is_error() { + EventKind::ReadError + } else { + EventKind::Readable + }); + }); + inner.accept_state = State::Pending(Some(cancel_handle)); + } +} + +impl ListenerInner { + fn push_event(&self, kind: EventKind) { + if let Some(ref registration) = self.registration { + registration.push_event(kind); + } + } +} + +impl From for TcpListener { + fn from(listener: net::TcpListener) -> Self { + TcpListener::from_std(listener) + } +} + +impl event::Source for TcpListener { + fn register( + &mut self, + registry: &Registry, + token: Token, + interest: Interest, + ) -> io::Result<()> { + let mut inner = self.inner(); + match inner.registration { + Some(_) => return Err(other("I/O source already registered with a `Registry`")), + None => inner.registration = Some(Registration::new(registry.selector(), token, interest)), + } + inner.provider = Some(Provider::new(registry.selector())); + self.imp.schedule_accept(&mut inner); + Ok(()) + } + + fn reregister( + &mut self, + _registry: &Registry, + token: Token, + interest: Interest, + ) -> io::Result<()> { + let mut inner = self.inner(); + let changed = match inner.registration { + Some(ref mut registration) => registration.change_details(token, interest), + None => return Err(other("I/O source not registered with `Registry`")), + }; + if changed && inner.accept_state.is_ready() { + inner.push_event(EventKind::Readable); + } + if changed && inner.accept_state.is_error() { + inner.push_event(EventKind::ReadError); + } + Ok(()) + } + + fn deregister(&mut self, _registry: &Registry) -> io::Result<()> { + let mut inner = self.inner(); + match inner.registration { + Some(_) => inner.registration = None, + None => return Err(other("I/O source not registered with `Registry`")), + } + Ok(()) + } +} + +impl fmt::Debug for TcpListener { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let inner = self.inner(); + let mut res = f.debug_struct("TcpListener"); + res.field("accept_state", &inner.accept_state); + res.field("listener", &self.listener); + res.finish() + } +} + +impl Drop for TcpListener { + fn drop(&mut self) { + let mut inner = self.inner(); + // deregister so we don't send events after drop + inner.registration = None; + if let Some(cancel_handle) = inner.accept_state.as_pending_mut().and_then(|opt| opt.take()) { + cancel_handle.cancel(); + } + } +} diff --git a/src/sys/sgx/tcp/mod.rs b/src/sys/sgx/tcp/mod.rs new file mode 100644 index 000000000..37e52e0ce --- /dev/null +++ b/src/sys/sgx/tcp/mod.rs @@ -0,0 +1,177 @@ +use std::fmt; +use std::io; +use std::mem; +use std::net::SocketAddr; + +mod listener; +mod stream; + +pub use self::listener::TcpListener; +pub use self::stream::TcpStream; + +// The SGX platform is a bit special. The implementation of TcpStream/TcpListener in the standard library doesn't give us enough control. The `sys::sgx::tcp::TcpStream` and `sys::sgx::tcp::TcpListener` types provides more functionality. We use a `net` module here to make things easier to co-exist with the other platforms supported by this crate +pub(crate) mod net { + pub(crate) type TcpStream = super::TcpStream; + pub(crate) type TcpListener = super::TcpListener; + pub(crate) type Shutdown = std::net::Shutdown; + pub(crate) type SocketAddr = std::net::SocketAddr; +} + +pub fn connect(addr: SocketAddr) -> io::Result { + TcpStream::connect(addr) +} + +pub fn connect_str(addr: &str) -> io::Result { + TcpStream::connect_str(addr) +} + +pub fn bind(addr: SocketAddr) -> io::Result { + TcpListener::bind(addr) +} + +pub fn bind_str(addr: &str) -> io::Result { + TcpListener::bind_str(addr) +} + +pub fn accept(listener: &TcpListener) -> io::Result<(net::TcpStream, SocketAddr)> { + listener.accept() +} + +enum State { + New(N), + Pending(P), + Ready(R), + Error(io::Error), +} + +impl State { + fn as_ready(&self) -> Option<&R> { + match self { + State::Ready(ref r) => Some(r), + _ => None, + } + } + + fn as_pending_mut(&mut self) -> Option<&mut P> { + match self { + State::Pending(ref mut p) => Some(p), + _ => None, + } + } + + fn is_new(&self) -> bool { + match self { + State::New(_) => true, + _ => false, + } + } + + fn is_pending(&self) -> bool { + match self { + State::Pending(_) => true, + _ => false, + } + } + + fn is_ready(&self) -> bool { + match self { + State::Ready(_) => true, + _ => false, + } + } + + fn is_error(&self) -> bool { + match self { + State::Error(_) => true, + _ => false, + } + } + + fn take_error(&mut self, replacement: State) -> Option { + if self.is_error() { + match mem::replace(self, replacement) { + State::Error(e) => return Some(e), + _ => unreachable!(), + } + } + None + } +} + +impl From> for State { + fn from(res: io::Result) -> Self { + match res { + Ok(r) => State::Ready(r), + Err(e) => State::Error(e), + } + } +} + +impl fmt::Debug for State { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + State::New(_) => f.pad("new"), + State::Pending(_) => f.pad("pending"), + State::Ready(_) => f.pad("ready"), + State::Error(_) => f.pad("error"), + } + } +} + +fn other(s: &str) -> io::Error { + io::Error::new(io::ErrorKind::Other, s) +} + +fn would_block() -> io::Error { + io::ErrorKind::WouldBlock.into() +} + +// Interim solution until we mark the target types appropriately +#[cfg(not(compiler_has_send_sgx_types))] +mod make_send { + use { + async_usercalls::{ReadBuffer, WriteBuffer}, + std::ops::{Deref, DerefMut}, + std::os::fortanix_sgx::usercalls::alloc::User, + }; + + pub(crate) struct MakeSend(T); + + impl MakeSend { + pub fn new(t: T) -> Self { + Self(t) + } + + #[allow(unused)] + pub fn into_inner(self) -> T { + self.0 + } + } + + impl Deref for MakeSend { + type Target = T; + + fn deref(&self) -> &Self::Target { + &self.0 + } + } + + impl DerefMut for MakeSend { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 + } + } + + impl From for MakeSend { + fn from(t: T) -> Self { + MakeSend::new(t) + } + } + + unsafe impl Send for MakeSend> {} + unsafe impl Send for MakeSend {} + unsafe impl Send for MakeSend {} +} + +#[cfg(not(compiler_has_send_sgx_types))] +pub(crate) use make_send::MakeSend; diff --git a/src/sys/sgx/tcp/stream.rs b/src/sys/sgx/tcp/stream.rs new file mode 100644 index 000000000..37e2d0239 --- /dev/null +++ b/src/sys/sgx/tcp/stream.rs @@ -0,0 +1,488 @@ +use async_usercalls::{CancelHandle, ReadBuffer, WriteBuffer}; +use std::fmt; +use std::io::{self, IoSlice, IoSliceMut, Read, Write}; +use std::mem; +use std::net::{self, Shutdown, SocketAddr}; +use std::os::fortanix_sgx::io::AsRawFd; +use std::os::fortanix_sgx::usercalls::alloc::User; +use std::sync::{Arc, Mutex, MutexGuard}; + +#[cfg(not(compiler_has_send_sgx_types))] +use super::MakeSend; +use super::{other, would_block, State}; +use crate::sys::sgx::selector::{EventKind, Provider, Registration}; +use crate::{event, Interest, Registry, Token}; + +const WRITE_BUFFER_SIZE: usize = 16 * 1024; +const READ_BUFFER_SIZE: usize = WRITE_BUFFER_SIZE; +const DEFAULT_FAKE_TTL: u32 = 64; + +pub struct TcpStream { + imp: StreamImp, +} + +#[derive(Clone)] +struct StreamImp(Arc>); + +struct StreamInner { + connect_state: State, net::TcpStream>, + #[cfg(compiler_has_send_sgx_types)] + write_buffer: WriteBuffer, + #[cfg(not(compiler_has_send_sgx_types))] + write_buffer: MakeSend, + write_state: State<(), Option, ()>, + #[cfg(compiler_has_send_sgx_types)] + read_buf: Option>, + #[cfg(not(compiler_has_send_sgx_types))] + read_buf: Option>>, + #[cfg(compiler_has_send_sgx_types)] + read_state: State<(), Option, ReadBuffer>, + #[cfg(not(compiler_has_send_sgx_types))] + read_state: State<(), Option, MakeSend>, + registration: Option, + provider: Option, +} + +impl TcpStream { + fn new(connect_state: State, net::TcpStream>) -> Self { + TcpStream { + imp: StreamImp(Arc::new(Mutex::new(StreamInner { + connect_state, + #[cfg(compiler_has_send_sgx_types)] + write_buffer: WriteBuffer::new(User::<[u8]>::uninitialized(WRITE_BUFFER_SIZE)), + #[cfg(not(compiler_has_send_sgx_types))] + write_buffer: MakeSend::new(WriteBuffer::new(User::<[u8]>::uninitialized(WRITE_BUFFER_SIZE))), + write_state: State::New(()), + #[cfg(compiler_has_send_sgx_types)] + read_buf: Some(User::<[u8]>::uninitialized(READ_BUFFER_SIZE)), + #[cfg(not(compiler_has_send_sgx_types))] + read_buf: Some(MakeSend::new(User::<[u8]>::uninitialized(READ_BUFFER_SIZE))), + read_state: State::New(()), + registration: None, + provider: None, + }))), + } + } + + pub(super) fn from_std(stream: net::TcpStream) -> (Self, SocketAddr) { + let peer_addr = stream.peer_addr().unwrap_or_else(|_| ([0; 4], 0).into()); + let stream = TcpStream::new(State::Ready(stream)); + (stream, peer_addr) + } + + pub fn connect(addr: SocketAddr) -> io::Result { + Ok(TcpStream::new(State::New(addr.to_string()))) + } + + pub fn connect_str(addr: &str) -> io::Result { + Ok(TcpStream::new(State::New(addr.to_owned()))) + } + + pub fn peer_addr(&self) -> io::Result { + self.inner() + .connect_state + .as_ready() + .ok_or_else(|| would_block()) + .and_then(|stream| stream.peer_addr()) + } + + pub fn local_addr(&self) -> io::Result { + self.inner() + .connect_state + .as_ready() + .ok_or_else(|| would_block()) + .and_then(|stream| stream.local_addr()) + } + + pub fn shutdown(&self, _how: Shutdown) -> io::Result<()> { + Ok(()) // ineffective in SGX + } + + pub fn set_nodelay(&self, _nodelay: bool) -> io::Result<()> { + Ok(()) // ineffective in SGX + } + + pub fn nodelay(&self) -> io::Result { + Ok(false) // ineffective in SGX + } + + pub fn set_ttl(&self, _ttl: u32) -> io::Result<()> { + Ok(()) // ineffective in SGX + } + + pub fn ttl(&self) -> io::Result { + Ok(DEFAULT_FAKE_TTL) // ineffective in SGX + } + + pub fn take_error(&self) -> io::Result> { + let mut inner = self.inner(); + if let Some(err) = inner.connect_state.take_error(State::Error(io::ErrorKind::Other.into())) { + return Ok(Some(err)); + } + if let Some(err) = inner.read_state.take_error(State::New(())) { + return Ok(Some(err)); + } + if let Some(err) = inner.write_state.take_error(State::New(())) { + return Ok(Some(err)); + } + Ok(None) + } + + pub fn peek(&self, _buf: &mut [u8]) -> io::Result { + Ok(0) // undocumented current behavior in std::net::TcpStream for SGX target. + } + + fn inner(&self) -> MutexGuard<'_, StreamInner> { + self.imp.inner() + } +} + +impl StreamImp { + fn inner(&self) -> MutexGuard<'_, StreamInner> { + self.0.lock().unwrap() + } + + fn schedule_connect_or_read(&self, inner: &mut StreamInner) { + match inner.connect_state { + State::New(_) => self.schedule_connect(inner), + State::Ready(_) => self.post_connect(inner), + State::Pending(_) | State::Error(_) => {}, + } + } + + fn schedule_connect(&self, inner: &mut StreamInner) { + let provider = match inner.provider.as_ref() { + Some(provider) => provider, + None => return, + }; + let addr = match inner.connect_state { + State::New(ref addr) => addr.as_str(), + _ => return, + }; + let weak_ref = Arc::downgrade(&self.0); + let cancel_handle = provider.connect_stream(addr, move |res| { + let imp = match weak_ref.upgrade() { + Some(arc) => StreamImp(arc), + None => return, + }; + let mut inner = imp.inner(); + assert!(inner.connect_state.is_pending()); + inner.connect_state = res.into(); + imp.post_connect(&mut inner); + }); + inner.connect_state = State::Pending(Some(cancel_handle)); + } + + fn post_connect(&self, inner: &mut StreamInner) { + if inner.connect_state.is_ready() { + inner.push_event(EventKind::Writable); + self.schedule_read(inner); + } + if inner.connect_state.is_error() { + inner.push_event(EventKind::WriteError); + } + } + + fn schedule_read(&self, inner: &mut StreamInner) { + let provider = match inner.provider.as_ref() { + Some(provider) => provider, + None => return, + }; + let fd = match (inner.read_state.is_new(), inner.connect_state.as_ready()) { + (true, Some(stream)) => stream.as_raw_fd(), + _ => return, + }; + let read_buf = inner.read_buf.take().unwrap(); + #[cfg(not(compiler_has_send_sgx_types))] + let read_buf = read_buf.into_inner(); + let weak_ref = Arc::downgrade(&self.0); + let cancel_handle = provider.read(fd, read_buf, move |res, read_buf| { + let imp = match weak_ref.upgrade() { + Some(arc) => StreamImp(arc), + None => return, + }; + let mut inner = imp.inner(); + assert!(inner.read_state.is_pending()); + match res { + Ok(len) => { + inner.read_state = State::Ready(ReadBuffer::new(read_buf, len).into()); + inner.push_event(if len == 0 { + EventKind::ReadClosed + } else { + EventKind::Readable + }); + } + Err(e) => { + let is_closed = is_connection_closed(&e); + inner.read_state = State::Error(e); + inner.read_buf = Some(read_buf.into()); + inner.push_event(if is_closed { + EventKind::ReadClosed + } else { + EventKind::ReadError + }); + } + } + }); + inner.read_state = State::Pending(Some(cancel_handle)); + } + + fn schedule_write(&self, inner: &mut StreamInner) { + let provider = match inner.provider.as_ref() { + Some(provider) => provider, + None => return, + }; + let fd = match (inner.write_state.is_new(), inner.connect_state.as_ready()) { + (true, Some(stream)) => stream.as_raw_fd(), + _ => return, + }; + let chunk = match inner.write_buffer.consumable_chunk() { + Some(chunk) => chunk, + None => return, + }; + let imp = self.clone(); + let cancel_handle = provider.write(fd, chunk, move |res, buf| { + let mut inner = imp.inner(); + match res { + Ok(0) => { + // since we don't write 0 bytes, this signifies EOF + inner.write_state = State::Error(io::ErrorKind::WriteZero.into()); + inner.push_event(EventKind::WriteClosed); + } + Ok(n) => { + inner.write_buffer.consume(buf, n); + inner.write_state = State::New(()); + if !inner.write_buffer.is_empty() { + imp.schedule_write(&mut inner); + } else { + inner.push_event(EventKind::Writable); + } + } + Err(e) => { + let is_closed = is_connection_closed(&e); + inner.write_state = State::Error(e); + inner.push_event(if is_closed { + EventKind::WriteClosed + } else { + EventKind::WriteError + }); + } + } + }); + inner.write_state = State::Pending(Some(cancel_handle)); + } + + fn read_vectored(&self, bufs: &mut [IoSliceMut<'_>]) -> io::Result { + let mut inner = self.inner(); + let ret = match mem::replace(&mut inner.read_state, State::New(())) { + State::New(()) => Err(would_block()), + State::Pending(cancel_handle) => { + inner.read_state = State::Pending(cancel_handle); + return Err(would_block()); + } + #[cfg_attr(not(compiler_has_send_sgx_types), allow(unused_mut))] + State::Ready(mut read_buf) => { + #[cfg(not(compiler_has_send_sgx_types))] + let mut read_buf = read_buf.into_inner(); + let mut r = 0; + for buf in bufs { + r += read_buf.read(buf); + } + #[cfg(compiler_has_send_sgx_types)] + match read_buf.remaining_bytes() { + // Only schedule another read if the previous one returned some bytes. + // Otherwise assume subsequent reads will always return 0 bytes, so just + // stay at Ready state and always return 0 bytes from this point on. + 0 if read_buf.len() > 0 => inner.read_buf = Some(read_buf.into_inner()), + _ => inner.read_state = State::Ready(read_buf), + } + #[cfg(not(compiler_has_send_sgx_types))] + match read_buf.remaining_bytes() { + // Only schedule another read if the previous one returned some bytes. + // Otherwise assume subsequent reads will always return 0 bytes, so just + // stay at Ready state and always return 0 bytes from this point on. + 0 if read_buf.len() > 0 => inner.read_buf = Some(MakeSend::new(read_buf.into_inner())), + _ => inner.read_state = State::Ready(MakeSend::new(read_buf)), + } + Ok(r) + } + State::Error(e) => Err(e), + }; + self.schedule_read(&mut inner); + ret + } + + fn write_vectored(&self, bufs: &[IoSlice<'_>]) -> io::Result { + let mut inner = self.inner(); + if let Some(e) = inner.write_state.take_error(State::New(())) { + return Err(e); + } + if !inner.connect_state.is_ready() { + return Err(would_block()); + } + let written = inner.write_buffer.write_vectored(bufs); + if written == 0 { + return Err(would_block()); + } + self.schedule_write(&mut inner); + Ok(written) + } +} + +impl StreamInner { + fn push_event(&self, kind: EventKind) { + if let Some(ref registration) = self.registration { + registration.push_event(kind); + } + } + + fn announce_current_state(&self) { + if self.connect_state.is_ready() { + self.push_event(EventKind::Writable); + } + if self.connect_state.is_error() { + self.push_event(EventKind::WriteError); + } + if self.read_state.is_ready() { + self.push_event(EventKind::Readable); + } + if self.read_state.is_error() { + self.push_event(EventKind::ReadError); + } + } +} + +impl From for TcpStream { + fn from(stream: net::TcpStream) -> Self { + TcpStream::new(State::Ready(stream)) + } +} + +impl event::Source for TcpStream { + fn register( + &mut self, + registry: &Registry, + token: Token, + interest: Interest, + ) -> io::Result<()> { + let mut inner = self.inner(); + match inner.registration { + Some(_) => return Err(other("I/O source already registered with a `Registry`")), + None => inner.registration = Some(Registration::new(registry.selector(), token, interest)), + } + inner.provider = Some(Provider::new(registry.selector())); + self.imp.schedule_connect_or_read(&mut inner); + Ok(()) + } + + fn reregister( + &mut self, + _registry: &Registry, + token: Token, + interest: Interest, + ) -> io::Result<()> { + let mut inner = self.inner(); + let changed = match inner.registration { + Some(ref mut registration) => registration.change_details(token, interest), + None => return Err(other("I/O source not registered with `Registry`")), + }; + if changed { + inner.announce_current_state(); + } + Ok(()) + } + + fn deregister(&mut self, _registry: &Registry) -> io::Result<()> { + let mut inner = self.inner(); + match inner.registration { + Some(_) => inner.registration = None, + None => return Err(other("I/O source not registered with `Registry`")), + } + Ok(()) + } +} + +impl fmt::Debug for TcpStream { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let inner = self.inner(); + let mut res = f.debug_struct("TcpStream"); + res.field("connect_state", &inner.connect_state); + res.field("read_state", &inner.read_state); + res.field("write_state", &inner.write_state); + res.finish() + } +} + +impl Read for TcpStream { + fn read(&mut self, buf: &mut [u8]) -> io::Result { + self.imp.read_vectored(&mut [IoSliceMut::new(buf)]) + } + + fn read_vectored(&mut self, bufs: &mut [IoSliceMut<'_>]) -> io::Result { + self.imp.read_vectored(bufs) + } +} + +impl<'a> Read for &'a TcpStream { + fn read(&mut self, buf: &mut [u8]) -> io::Result { + self.imp.read_vectored(&mut [IoSliceMut::new(buf)]) + } + + fn read_vectored(&mut self, bufs: &mut [IoSliceMut<'_>]) -> io::Result { + self.imp.read_vectored(bufs) + } +} + +impl Write for TcpStream { + fn write(&mut self, buf: &[u8]) -> io::Result { + self.imp.write_vectored(&[IoSlice::new(buf)]) + } + + fn write_vectored(&mut self, bufs: &[IoSlice<'_>]) -> io::Result { + self.imp.write_vectored(bufs) + } + + fn flush(&mut self) -> io::Result<()> { + Ok(()) // same as in `impl Write for std::net::TcpStream` + } +} + +impl<'a> Write for &'a TcpStream { + fn write(&mut self, buf: &[u8]) -> io::Result { + self.imp.write_vectored(&[IoSlice::new(buf)]) + } + + fn write_vectored(&mut self, bufs: &[IoSlice<'_>]) -> io::Result { + self.imp.write_vectored(bufs) + } + + fn flush(&mut self) -> io::Result<()> { + Ok(()) // same as in `impl Write for std::net::TcpStream` + } +} + +impl Drop for TcpStream { + fn drop(&mut self) { + let mut inner = self.inner(); + // deregister so we don't send events after drop + inner.registration = None; + if let Some(cancel_handle) = inner.connect_state.as_pending_mut().and_then(|opt| opt.take()) { + cancel_handle.cancel(); + } + if let Some(cancel_handle) = inner.read_state.as_pending_mut().and_then(|opt| opt.take()) { + cancel_handle.cancel(); + } + // NOTE: We don't cancel write since we have promised to write those bytes before drop. + // Also note that the callback in schedule_write() holds an Arc not a Weak, so it can + // continue writing the remaining bytes in the write buffer. + } +} + +fn is_connection_closed(e: &io::Error) -> bool { + match e.kind() { + io::ErrorKind::ConnectionReset + | io::ErrorKind::ConnectionAborted + | io::ErrorKind::BrokenPipe => true, + _ => false, + } +} diff --git a/src/sys/sgx/waker.rs b/src/sys/sgx/waker.rs new file mode 100644 index 000000000..ee5265aa0 --- /dev/null +++ b/src/sys/sgx/waker.rs @@ -0,0 +1,38 @@ +use crate::sys::sgx::selector::{EventKind, Registration}; +use crate::sys::Selector; +use crate::{Interest, Token}; +use std::fmt; +use std::io; +use std::sync::Arc; + +pub struct Waker(Arc); + +impl Waker { + pub fn new(selector: &Selector, token: Token) -> io::Result { + Ok(Waker(Arc::new(Registration::new( + selector, + token, + Interest::READABLE, + )))) + } + + pub fn wake(&self) -> io::Result<()> { + let weak_ref = Arc::downgrade(&self.0); + self.0.provider().insecure_time(move |_| { + let inner = match weak_ref.upgrade() { + Some(arc) => arc, + None => return, + }; + inner.push_event(EventKind::Readable); + }); + Ok(()) + } +} + +impl fmt::Debug for Waker { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Waker") + .field("token", &self.0.token()) + .finish() + } +} diff --git a/src/sys/unix/mod.rs b/src/sys/unix/mod.rs index 7804236da..86e5007a5 100644 --- a/src/sys/unix/mod.rs +++ b/src/sys/unix/mod.rs @@ -18,6 +18,7 @@ cfg_os_poll! { pub(crate) use self::selector::{event, Event, Events, Selector}; mod sourcefd; + #[cfg(all(unix, feature = "os-ext"))] pub use self::sourcefd::SourceFd; mod waker; diff --git a/tests/poll.rs b/tests/poll.rs index 22db7b739..1e1136e79 100644 --- a/tests/poll.rs +++ b/tests/poll.rs @@ -8,7 +8,9 @@ use std::time::Duration; use std::{fmt, io}; use mio::event::Source; -use mio::net::{TcpListener, TcpStream, UdpSocket}; +use mio::net::{TcpListener, TcpStream}; +#[cfg(not(target_env = "sgx"))] +use mio::net::UdpSocket; use mio::{event, Events, Interest, Poll, Registry, Token}; mod util; @@ -233,6 +235,7 @@ pub fn registry_ops_flow( registry.reregister(source, token, final_interests) } +#[cfg(not(target_env = "sgx"))] // no UDP support in SGX. #[test] fn registry_operations_are_thread_safe() { let (mut poll, mut events) = init_with_poll(); @@ -319,6 +322,7 @@ fn registry_operations_are_thread_safe() { handle3.join().unwrap(); } +#[cfg(not(target_env = "sgx"))] // no UDP support in SGX. #[test] fn register_during_poll() { let (mut poll, mut events) = init_with_poll(); @@ -367,6 +371,7 @@ fn register_during_poll() { // - `reregister` can use the same token as `register` // - `reregister` can use different token from `register` // - multiple `reregister` are ok +#[cfg(not(target_env = "sgx"))] // no UDP support in SGX. #[test] fn reregister_interest_token_usage() { let (mut poll, mut events) = init_with_poll(); @@ -421,6 +426,10 @@ pub fn double_register_different_token() { } #[test] +#[cfg_attr( + target_env = "sgx", + ignore = "this test expects connect to make progress before registering" +)] fn poll_ok_after_cancelling_pending_ops() { let (mut poll, mut events) = init_with_poll(); diff --git a/tests/registering.rs b/tests/registering.rs index c8415b831..bdc6f239b 100644 --- a/tests/registering.rs +++ b/tests/registering.rs @@ -6,6 +6,7 @@ use std::thread::sleep; use std::time::Duration; use log::{debug, info, trace}; +#[cfg(not(target_env = "sgx"))] #[cfg(debug_assertions)] use mio::net::UdpSocket; use mio::net::{TcpListener, TcpStream}; @@ -194,6 +195,7 @@ fn tcp_register_multiple_event_loops() { } #[test] +#[cfg(not(target_env = "sgx"))] // no UDP support in SGX. #[cfg(debug_assertions)] // Check is only present when debug assertions are enabled. fn udp_register_multiple_event_loops() { init(); diff --git a/tests/tcp.rs b/tests/tcp.rs index 82ed6bcbf..ccf36873e 100644 --- a/tests/tcp.rs +++ b/tests/tcp.rs @@ -4,7 +4,8 @@ use mio::net::{TcpListener, TcpStream}; use mio::{Events, Interest, Poll, Token}; use std::io::{self, Read, Write}; -use std::net::{self, Shutdown}; +use std::net; +use std::net::Shutdown; use std::sync::mpsc::channel; use std::thread::{self, sleep}; use std::time::Duration; @@ -13,8 +14,10 @@ use std::time::Duration; mod util; use util::{ any_local_address, assert_send, assert_sync, expect_events, expect_no_events, init, - init_with_poll, set_linger_zero, ExpectEvent, + init_with_poll, ExpectEvent, }; +#[cfg(not(target_env = "sgx"))] +use util::set_linger_zero; const LISTEN: Token = Token(0); const CLIENT: Token = Token(1); @@ -342,6 +345,12 @@ fn write() { } } } + + #[cfg(target_env = "sgx")] // some writes may not have finished yet and to make progress we need to poll. + for _ in 0..3 { + poll.poll(&mut events, Some(Duration::from_millis(10))).unwrap(); + } + handle.join().unwrap(); } @@ -466,10 +475,19 @@ fn multiple_writes_immediate_success() { s.write_all(&[1; 1024]).unwrap(); } + #[cfg(target_env = "sgx")] // some writes may not have finished yet and to make progress we need to poll. + for _ in 0..3 { + poll.poll(&mut events, Some(Duration::from_millis(10))).unwrap(); + } + handle.join().unwrap(); } #[test] +#[cfg_attr( + target_env = "sgx", + ignore = "No set_linger_zero() on SGX" +)] fn connection_reset_by_peer() { init(); @@ -483,6 +501,7 @@ fn connection_reset_by_peer() { // Connect client let mut client = TcpStream::connect(addr).unwrap(); + #[cfg(not(target_env = "sgx"))] set_linger_zero(&client); // Register server @@ -572,6 +591,7 @@ fn connect_error() { for event in &events { if event.token() == Token(0) { assert!(event.is_writable()); + #[cfg(not(target_env = "sgx"))] assert!(event.is_write_closed()); break 'outer; } @@ -664,6 +684,10 @@ macro_rules! wait { } #[test] +#[cfg_attr( + target_env = "sgx", + ignore = "Socket shutdown is ineffective in SGX" +)] fn write_shutdown() { init(); @@ -940,8 +964,69 @@ fn tcp_no_events_after_deregister() { checked_write!(stream2.write(&[1, 2, 3, 4])); expect_no_events(&mut poll, &mut events); + #[cfg(target_env = "sgx")] // make progress by polling, but we still expect no events + for _ in 0..2 { + expect_no_events(&mut poll, &mut events); + } + sleep(Duration::from_millis(200)); expect_read!(stream.read(&mut buf), &[1, 2, 3, 4]); expect_no_events(&mut poll, &mut events); } + +#[cfg(target_env = "sgx")] // SGX-specific API +#[test] +fn bind_str() { + let mut listener = TcpListener::bind_str("localhost:0").unwrap(); + let addr = listener.local_addr().unwrap(); + assert!(addr.ip().is_loopback()); + + let handle = thread::spawn(move || { + net::TcpStream::connect(addr).unwrap(); + }); + + let mut poll = Poll::new().unwrap(); + + poll.registry() + .register(&mut listener, Token(1), Interest::READABLE) + .unwrap(); + + let mut events = Events::with_capacity(16); + while events.is_empty() { + poll.poll(&mut events, None).unwrap(); + } + assert_eq!(events.iter().count(), 1); + assert_eq!(events.iter().next().unwrap().token(), Token(1)); + + listener.accept().unwrap(); + handle.join().unwrap(); +} + +#[cfg(target_env = "sgx")] // SGX-specific API +#[test] +fn connect_str() { + let listener = net::TcpListener::bind("localhost:0").unwrap(); + let addr = listener.local_addr().unwrap(); + + let handle = thread::spawn(move || { + listener.accept().unwrap(); + }); + + let addr = format!("localhost:{}", addr.port()); + let mut stream = TcpStream::connect_str(&addr).unwrap(); + + let mut poll = Poll::new().unwrap(); + + poll.registry() + .register(&mut stream, Token(1), Interest::WRITABLE) + .unwrap(); + + let mut events = Events::with_capacity(16); + while events.is_empty() { + poll.poll(&mut events, None).unwrap(); + } + assert_eq!(events.iter().count(), 1); + assert_eq!(events.iter().next().unwrap().token(), Token(1)); + handle.join().unwrap(); +} diff --git a/tests/tcp_listener.rs b/tests/tcp_listener.rs index 086e619f7..59fe42ffe 100644 --- a/tests/tcp_listener.rs +++ b/tests/tcp_listener.rs @@ -92,6 +92,10 @@ where } #[test] +#[cfg_attr( + target_env = "sgx", + ignore = "set_ttl() is ineffective on SGX", +)] fn set_get_ttl() { init(); diff --git a/tests/tcp_stream.rs b/tests/tcp_stream.rs index df9b3f2e3..e131922ef 100644 --- a/tests/tcp_stream.rs +++ b/tests/tcp_stream.rs @@ -14,13 +14,15 @@ use mio::{Interest, Token}; #[macro_use] mod util; -#[cfg(not(target_os = "windows"))] +#[cfg(not(any(target_os = "windows", target_env = "sgx")))] use util::init; use util::{ any_local_address, any_local_ipv6_address, assert_send, assert_socket_close_on_exec, assert_socket_non_blocking, assert_sync, assert_would_block, expect_events, expect_no_events, - init_with_poll, set_linger_zero, ExpectEvent, Readiness, + init_with_poll, ExpectEvent, Readiness, }; +#[cfg(not(target_env = "sgx"))] +use util::set_linger_zero; const DATA1: &[u8] = b"Hello world!"; const DATA2: &[u8] = b"Hello mars!"; @@ -81,6 +83,7 @@ where ); let mut buf = [0; 16]; + #[cfg(not(target_env = "sgx"))] // peek always returns Ok(0) in SGX. assert_would_block(stream.peek(&mut buf)); assert_would_block(stream.read(&mut buf)); @@ -99,6 +102,7 @@ where vec![ExpectEvent::new(ID1, Interest::READABLE)], ); + #[cfg(not(target_env = "sgx"))] // peek always returns Ok(0) in SGX. expect_read!(stream.peek(&mut buf), DATA1); expect_read!(stream.read(&mut buf), DATA1); @@ -135,6 +139,10 @@ where } #[test] +#[cfg_attr( + target_env = "sgx", + ignore = "set_ttl() is ineffective on SGX", +)] fn set_get_ttl() { let (mut poll, mut events) = init_with_poll(); @@ -197,6 +205,10 @@ fn get_ttl_without_previous_set() { } #[test] +#[cfg_attr( + target_env = "sgx", + ignore = "set_nodelay() is ineffective on SGX", +)] fn set_get_nodelay() { let (mut poll, mut events) = init_with_poll(); @@ -351,6 +363,10 @@ fn shutdown_write() { } #[test] +#[cfg_attr( + target_env = "sgx", + ignore = "shutdown is ineffective on SGX", +)] fn shutdown_both() { let (mut poll, mut events) = init_with_poll(); @@ -395,6 +411,7 @@ fn shutdown_both() { expect_read!(stream.read(&mut buf), &[]); } + #[cfg_attr(not(any(unix, windows)), allow(unused_variables))] let err = stream.write(DATA2).unwrap_err(); #[cfg(unix)] assert_eq!(err.kind(), io::ErrorKind::BrokenPipe); @@ -497,6 +514,7 @@ fn no_events_after_deregister() { // Also, write should work let mut buf = [0; 16]; + #[cfg(not(target_env = "sgx"))] // peek always returns Ok(0) in SGX. assert_would_block(stream.peek(&mut buf)); assert_would_block(stream.read(&mut buf)); @@ -510,6 +528,10 @@ fn no_events_after_deregister() { } #[test] +#[cfg_attr( + target_env = "sgx", + ignore = "shutdown is ineffective on SGX", +)] #[cfg_attr( windows, ignore = "fails on Windows; client read closed events are not triggered" @@ -550,6 +572,10 @@ fn tcp_shutdown_client_read_close_event() { any(target_os = "android", target_os = "illumos", target_os = "linux"), ignore = "fails; client write_closed events are not found" )] +#[cfg_attr( + target_env = "sgx", + ignore = "shutdown is ineffective on SGX", +)] fn tcp_shutdown_client_write_close_event() { let (mut poll, mut events) = init_with_poll(); let barrier = Arc::new(Barrier::new(2)); @@ -581,6 +607,10 @@ fn tcp_shutdown_client_write_close_event() { } #[test] +#[cfg_attr( + target_env = "sgx", + ignore = "shutdown is ineffective on SGX", +)] fn tcp_shutdown_server_write_close_event() { let (mut poll, mut events) = init_with_poll(); let barrier = Arc::new(Barrier::new(2)); @@ -662,6 +692,10 @@ fn tcp_reset_close_event() { any(target_os = "illumos"), ignore = "fails; client write_closed events are not found" )] +#[cfg_attr( + target_env = "sgx", + ignore = "shutdown is ineffective on SGX", +)] fn tcp_shutdown_client_both_close_event() { let (mut poll, mut events) = init_with_poll(); let barrier = Arc::new(Barrier::new(2)); @@ -755,6 +789,7 @@ fn start_listener( (thread_handle, receiver.recv().unwrap()) } +#[cfg(not(target_env = "sgx"))] // no TcpSocket in SGX #[test] fn hup_event_on_disconnect() { use mio::net::TcpListener; diff --git a/tests/udp_socket.rs b/tests/udp_socket.rs index 2a3a20c9b..5756bcba7 100644 --- a/tests/udp_socket.rs +++ b/tests/udp_socket.rs @@ -1,4 +1,4 @@ -#![cfg(not(target_os = "wasi"))] +#![cfg(not(any(target_os = "wasi", target_env = "sgx")))] #![cfg(all(feature = "os-poll", feature = "net"))] use log::{debug, info}; diff --git a/tests/util/mod.rs b/tests/util/mod.rs index 7a192d9b0..01343509a 100644 --- a/tests/util/mod.rs +++ b/tests/util/mod.rs @@ -3,18 +3,23 @@ #![cfg(not(target_os = "wasi"))] #![cfg(all(feature = "os-poll", feature = "net"))] +#[cfg(any(unix, windows))] use std::mem::size_of; use std::net::SocketAddr; use std::ops::BitOr; #[cfg(unix)] use std::os::unix::io::AsRawFd; +#[cfg(unix)] use std::path::PathBuf; use std::sync::Once; use std::time::Duration; -use std::{env, fmt, fs, io}; +#[cfg(unix)] +use std::{env, fs}; +use std::{fmt, io}; use log::{error, warn}; use mio::event::Event; +#[cfg(any(unix, windows))] use mio::net::TcpStream; use mio::{Events, Interest, Poll, Token}; @@ -24,10 +29,13 @@ pub fn init() { INIT.call_once(|| { env_logger::try_init().expect("unable to initialise logger"); - // Remove all temporary files from previous test runs. - let dir = temp_dir(); - let _ = fs::remove_dir_all(&dir); - fs::create_dir_all(&dir).expect("unable to create temporary directory"); + #[cfg(unix)] + { + // Remove all temporary files from previous test runs. + let dir = temp_dir(); + let _ = fs::remove_dir_all(&dir); + fs::create_dir_all(&dir).expect("unable to create temporary directory"); + } }) } @@ -214,6 +222,11 @@ pub fn assert_socket_non_blocking(_: &S) { // No way to get this information... } +#[cfg(target_env = "sgx")] +pub fn assert_socket_non_blocking(_: &S) { + // Does not apply to SGX async model +} + /// Assert that `CLOEXEC` is set on `socket`. #[cfg(unix)] pub fn assert_socket_close_on_exec(socket: &S) @@ -229,6 +242,11 @@ pub fn assert_socket_close_on_exec(_: &S) { // Windows doesn't have this concept. } +#[cfg(target_env = "sgx")] +pub fn assert_socket_close_on_exec(_: &S) { + // Does not apply to SGX async model +} + /// Bind to any port on localhost. pub fn any_local_address() -> SocketAddr { "127.0.0.1:0".parse().unwrap() @@ -299,6 +317,7 @@ pub fn set_linger_zero(socket: &TcpStream) { } /// Returns a path to a temporary file using `name` as filename. +#[cfg(unix)] pub fn temp_file(name: &'static str) -> PathBuf { let mut path = temp_dir(); path.push(name); @@ -306,6 +325,7 @@ pub fn temp_file(name: &'static str) -> PathBuf { } /// Returns the temporary directory for Mio test files. +#[cfg(unix)] fn temp_dir() -> PathBuf { let mut path = env::temp_dir(); path.push("mio_tests");