diff --git a/platform/src/allow_rw.rs b/platform/src/allow_rw.rs index 52fcf1e2..cfa2d3f6 100644 --- a/platform/src/allow_rw.rs +++ b/platform/src/allow_rw.rs @@ -1,6 +1,10 @@ use crate::share::List; use crate::Syscalls; -use core::marker::PhantomData; +use core::{ + cell::Cell, + marker::{PhantomData, PhantomPinned}, + pin::Pin, +}; // ----------------------------------------------------------------------------- // `AllowRw` struct @@ -71,3 +75,56 @@ pub trait Config { /// By default, the non-zero buffer is ignored. fn returned_nonzero_buffer(_driver_num: u32, _buffer_num: u32) {} } + +// ----------------------------------------------------------------------------- +// `AllowRwBuffer` struct +// ----------------------------------------------------------------------------- + +pub struct AllowRwBuffer< + S: Syscalls, + const DRIVER_NUM: u32, + const BUFFER_NUM: u32, + const BUFFER_SIZE: usize, +> { + _syscalls: PhantomData, + // Flag to mark whether the buffer is shared with the kernel or not. + pub(crate) allowed: Cell, + pub(crate) buffer: [u8; BUFFER_SIZE], + // This field makes the AllowRwBuffer !Unpin so that the buffer + // cannot be moved after pinning. + _pinned: PhantomPinned, +} + +impl + AllowRwBuffer +{ + pub fn from_array(buffer: [u8; BUFFER_SIZE]) -> Self { + Self { + allowed: core::cell::Cell::new(false), + buffer, + _syscalls: Default::default(), + _pinned: Default::default(), + } + } + + pub fn get_mut_buffer(self: Pin<&mut Self>) -> &mut [u8; BUFFER_SIZE] { + if self.allowed.get() { + self.allowed.set(false); + S::unallow_rw(DRIVER_NUM, BUFFER_NUM); + } + + // SAFETY: The reference is used only to return a mutable reference + // to the `buffer` field. + &mut (unsafe { self.get_unchecked_mut() }.buffer) + } +} + +impl Drop + for AllowRwBuffer +{ + fn drop(&mut self) { + if self.allowed.get() { + S::unallow_rw(DRIVER_NUM, BUFFER_NUM); + } + } +} diff --git a/platform/src/syscalls.rs b/platform/src/syscalls.rs index 546ba6c6..1d468a23 100644 --- a/platform/src/syscalls.rs +++ b/platform/src/syscalls.rs @@ -57,6 +57,18 @@ pub trait Syscalls: RawSyscalls + Sized { buffer: &'share mut [u8], ) -> Result<(), ErrorCode>; + /// Shares a read-write buffer with the kernel. + fn allow_rw_buffer< + CONFIG: allow_rw::Config, + const DRIVER_NUM: u32, + const BUFFER_NUM: u32, + const BUFFER_SIZE: usize, + >( + allow_rw_buffer: core::pin::Pin< + &mut allow_rw::AllowRwBuffer, + >, + ) -> Result<(), ErrorCode>; + /// Revokes the kernel's access to the buffer with the given ID, overwriting /// it with a zero buffer. If no buffer is shared with the given ID, /// `unallow_rw` does nothing. diff --git a/platform/src/syscalls_impl.rs b/platform/src/syscalls_impl.rs index cd1961cf..e035fa99 100644 --- a/platform/src/syscalls_impl.rs +++ b/platform/src/syscalls_impl.rs @@ -251,6 +251,86 @@ impl Syscalls for S { unsafe { inner::(DRIVER_NUM, BUFFER_NUM, buffer) } } + fn allow_rw_buffer< + CONFIG: allow_rw::Config, + const DRIVER_NUM: u32, + const BUFFER_NUM: u32, + const BUFFER_SIZE: usize, + >( + mut allow_rw_buffer: core::pin::Pin< + &mut allow_rw::AllowRwBuffer, + >, + ) -> Result<(), ErrorCode> { + // Inner function that does the majority of the work. This is not + // monomorphized over DRIVER_NUM and BUFFER_NUM to keep code size small. + // + // Safety: `buffer` must be a reference to the buffer field of a pinned + // AllowRwBuffer that must last for at least 'share lifetime. + unsafe fn inner( + driver_num: u32, + buffer_num: u32, + buffer: &mut [u8], + ) -> Result<(), ErrorCode> { + let [r0, r1, r2, _] = unsafe { + S::syscall4::<{ syscall_class::ALLOW_RW }>([ + driver_num.into(), + buffer_num.into(), + buffer.as_mut_ptr().into(), + buffer.len().into(), + ]) + }; + + let return_variant: ReturnVariant = r0.as_u32().into(); + // TRD 104 guarantees that Read-Write Allow returns either Success + // with 2 U32 or Failure with 2 U32. We check the return variant by + // comparing against Failure with 2 U32 for 2 reasons: + // + // 1. On RISC-V with compressed instructions, it generates smaller + // code. FAILURE_2_U32 has value 2, which can be loaded into a + // register with a single compressed instruction, whereas + // loading SUCCESS_2_U32 uses an uncompressed instruction. + // 2. In the event the kernel malfunctions and returns a different + // return variant, the success path is actually safer than the + // failure path. The failure path assumes that r1 contains an + // ErrorCode, and produces UB if it has an out of range value. + // Incorrectly assuming the call succeeded will not generate + // unsoundness, and will likely lead to the application + // panicing. + if return_variant == return_variant::FAILURE_2_U32 { + // Safety: TRD 104 guarantees that if r0 is Failure with 2 U32, + // then r1 will contain a valid error code. ErrorCode is + // designed to be safely transmuted directly from a kernel error + // code. + return Err(unsafe { core::mem::transmute::(r1.as_u32()) }); + } + + // r0 indicates Success with 2 u32s. Confirm a zero buffer was + // returned, and it if wasn't then call the configured function. + // We're relying on the optimizer to remove this branch if + // returned_nozero_buffer is a no-op. + let returned_buffer: (usize, usize) = (r1.into(), r2.into()); + if returned_buffer != (0, 0) { + CONFIG::returned_nonzero_buffer(driver_num, buffer_num); + } + Ok(()) + } + + let alias = allow_rw_buffer.as_mut(); + // Safety: We do not move out of the reference. + let buffer = unsafe { &mut alias.get_unchecked_mut().buffer }; + + // Safety: The presence of a Pin<&'share mut AllowRwBuffer> + // indicates that the buffer will be valid and will either clean up this + // Allow ID when it is dropped, or leak the allowed memory. + let res = unsafe { inner::(DRIVER_NUM, BUFFER_NUM, buffer) }; + + if res.is_ok() { + allow_rw_buffer.allowed.set(true); + } + + res + } + fn unallow_rw(driver_num: u32, buffer_num: u32) { unsafe { // syscall4's documentation indicates it can be used to call diff --git a/syscalls_tests/src/allow_rw.rs b/syscalls_tests/src/allow_rw.rs index 300dd250..f36bb81b 100644 --- a/syscalls_tests/src/allow_rw.rs +++ b/syscalls_tests/src/allow_rw.rs @@ -1,3 +1,4 @@ +use libtock_platform::allow_rw::AllowRwBuffer; use libtock_platform::{allow_rw, share, CommandReturn, ErrorCode, Syscalls}; use libtock_unittest::{command_return, fake, DriverInfo, RwAllowBuffer, SyscallLogEntry}; use std::cell::Cell; @@ -123,3 +124,68 @@ fn allow_rw() { // Verify the buffer write occurred. assert_eq!(buffer2, [5, 31]); } + +#[test] +fn allow_rw_buffer() { + let kernel = fake::Kernel::new(); + let driver = Rc::new(TestDriver::default()); + kernel.add_driver(&driver); + + // Tests a call that should fail because it has an incorrect buffer + // number. + let buffer = [1, 2, 3, 4]; + let mut allow_buf = std::pin::pin!(AllowRwBuffer::from_array(buffer)); + let result = fake::Syscalls::allow_rw_buffer::(allow_buf.as_mut()); + + assert!(!CALLED.with(|c| c.get())); + assert_eq!(result, Err(ErrorCode::NoSupport)); + assert_eq!( + kernel.take_syscall_log(), + [SyscallLogEntry::AllowRw { + driver_num: 42, + buffer_num: 1, + len: 4, + }] + ); + + // Verify that no unallow occurred. + let _ = allow_buf.get_mut_buffer(); + assert_eq!(kernel.take_syscall_log(), []); + + // Tests a call that should succeed and return a nonzero buffer. + let buffer = [0, 0]; + let mut allow_buf = std::pin::pin!(AllowRwBuffer::from_array(buffer)); + let result = fake::Syscalls::allow_rw_buffer::(allow_buf.as_mut()); + + assert!(!CALLED.with(|c| c.get())); + assert_eq!(result, Ok(())); + assert_eq!( + kernel.take_syscall_log(), + [SyscallLogEntry::AllowRw { + driver_num: 42, + buffer_num: 0, + len: 2, + }] + ); + + // Mutate the buffer, which under Miri will verify the buffer has been + // shared with the kernel properly. + let mut buffer = driver.buffer_0.take(); + buffer[0] = 32; + driver.buffer_0.set(buffer); + + let buffer_ref = allow_buf.get_mut_buffer(); + + // Verify that the buffer take unallowed the buffer. + assert_eq!( + kernel.take_syscall_log(), + [SyscallLogEntry::AllowRw { + driver_num: 42, + buffer_num: 0, + len: 0, + }] + ); + + // Verify the buffer write occurred. + assert_eq!(*buffer_ref, [32, 0]); +}