From d20de400817aece1c2590a96844329444e01bfdc Mon Sep 17 00:00:00 2001 From: sagudev <16504129+sagudev@users.noreply.github.com> Date: Tue, 20 Aug 2024 14:19:34 +0200 Subject: [PATCH] Impl unsafe `deref_mut` on IpcSharedMemory Signed-off-by: sagudev <16504129+sagudev@users.noreply.github.com> --- src/ipc.rs | 18 ++++++++++++++++++ src/platform/inprocess/mod.rs | 10 ++++++++++ src/platform/macos/mod.rs | 10 ++++++++++ src/platform/unix/mod.rs | 9 ++++++++- src/platform/windows/mod.rs | 8 ++++++++ 5 files changed, 54 insertions(+), 1 deletion(-) diff --git a/src/ipc.rs b/src/ipc.rs index 823f107a..6f49f59e 100644 --- a/src/ipc.rs +++ b/src/ipc.rs @@ -570,6 +570,24 @@ impl Deref for IpcSharedMemory { } } +impl IpcSharedMemory { + /// Returns a mutable reference to the deref of this [`IpcSharedMemory`]. + /// + /// # Safety + /// + /// This is safe if there is only one reader/writer on the data. + /// User can achieve this by not cloning [`IpcSharedMemory`] + /// and serializing/deserializing only once. + #[inline] + pub unsafe fn deref_mut(&mut self) -> &mut [u8] { + if let Some(os_shared_memory) = &mut self.os_shared_memory { + os_shared_memory.deref_mut() + } else { + &mut [] + } + } +} + impl<'de> Deserialize<'de> for IpcSharedMemory { fn deserialize(deserializer: D) -> Result where diff --git a/src/platform/inprocess/mod.rs b/src/platform/inprocess/mod.rs index 7ce68ec5..36df8472 100644 --- a/src/platform/inprocess/mod.rs +++ b/src/platform/inprocess/mod.rs @@ -395,6 +395,16 @@ impl Deref for OsIpcSharedMemory { } } +impl OsIpcSharedMemory { + #[inline] + pub unsafe fn deref_mut(&mut self) -> &mut [u8] { + if self.ptr.is_null() { + panic!("attempted to access a consumed `OsIpcSharedMemory`") + } + unsafe { slice::from_raw_parts_mut(self.ptr, self.length) } + } +} + impl OsIpcSharedMemory { pub fn from_byte(byte: u8, length: usize) -> OsIpcSharedMemory { let mut v = Arc::new(vec![byte; length]); diff --git a/src/platform/macos/mod.rs b/src/platform/macos/mod.rs index 99beef83..651f1f2f 100644 --- a/src/platform/macos/mod.rs +++ b/src/platform/macos/mod.rs @@ -940,6 +940,16 @@ impl Deref for OsIpcSharedMemory { } } +impl OsIpcSharedMemory { + #[inline] + pub unsafe fn deref_mut(&mut self) -> &mut [u8] { + if self.ptr.is_null() && self.length > 0 { + panic!("attempted to access a consumed `OsIpcSharedMemory`") + } + unsafe { slice::from_raw_parts_mut(self.ptr, self.length) } + } +} + impl OsIpcSharedMemory { unsafe fn from_raw_parts(ptr: *mut u8, length: usize) -> OsIpcSharedMemory { OsIpcSharedMemory { diff --git a/src/platform/unix/mod.rs b/src/platform/unix/mod.rs index ec3c7128..19743275 100644 --- a/src/platform/unix/mod.rs +++ b/src/platform/unix/mod.rs @@ -31,7 +31,7 @@ use std::hash::BuildHasherDefault; use std::io; use std::marker::PhantomData; use std::mem; -use std::ops::{Deref, RangeFrom}; +use std::ops::{Deref, DerefMut, RangeFrom}; use std::os::fd::RawFd; use std::ptr; use std::slice; @@ -866,6 +866,13 @@ impl Deref for OsIpcSharedMemory { } } +impl OsIpcSharedMemory { + #[inline] + pub unsafe fn deref_mut(&mut self) -> &mut [u8] { + unsafe { slice::from_raw_parts_mut(self.ptr, self.length) } + } +} + impl OsIpcSharedMemory { unsafe fn from_raw_parts( ptr: *mut u8, diff --git a/src/platform/windows/mod.rs b/src/platform/windows/mod.rs index 0fa91c07..73abc1a3 100644 --- a/src/platform/windows/mod.rs +++ b/src/platform/windows/mod.rs @@ -1834,6 +1834,14 @@ impl Deref for OsIpcSharedMemory { } } +impl OsIpcSharedMemory { + #[inline] + pub unsafe fn deref_mut(&mut self) -> &mut [u8] { + assert!(!self.view_handle.Value.is_null() && self.handle.is_valid()); + unsafe { slice::from_raw_parts_mut(self.view_handle.Value as _, self.length) } + } +} + impl OsIpcSharedMemory { fn new(length: usize) -> Result { unsafe {