diff --git a/src/impl_methods.rs b/src/impl_methods.rs index 03dc5efce..5caa11b81 100644 --- a/src/impl_methods.rs +++ b/src/impl_methods.rs @@ -139,7 +139,7 @@ where S: Data, { debug_assert!(self.pointer_is_inbounds()); - unsafe { ArrayView::new_(self.ptr.as_ptr(), self.dim.clone(), self.strides.clone()) } + unsafe { ArrayView::new(self.ptr, self.dim.clone(), self.strides.clone()) } } /// Return a read-write view of the array @@ -148,7 +148,7 @@ where S: DataMut, { self.ensure_unique(); - unsafe { ArrayViewMut::new_(self.ptr.as_ptr(), self.dim.clone(), self.strides.clone()) } + unsafe { ArrayViewMut::new(self.ptr, self.dim.clone(), self.strides.clone()) } } /// Return an uniquely owned copy of the array. @@ -1313,7 +1313,7 @@ where /// Return a raw view of the array. #[inline] pub fn raw_view(&self) -> RawArrayView { - unsafe { RawArrayView::new_(self.ptr.as_ptr(), self.dim.clone(), self.strides.clone()) } + unsafe { RawArrayView::new(self.ptr, self.dim.clone(), self.strides.clone()) } } /// Return a raw mutable view of the array. @@ -1323,7 +1323,7 @@ where S: RawDataMut, { self.try_ensure_unique(); // for RcArray - unsafe { RawArrayViewMut::new_(self.ptr.as_ptr(), self.dim.clone(), self.strides.clone()) } + unsafe { RawArrayViewMut::new(self.ptr, self.dim.clone(), self.strides.clone()) } } /// Return the array’s data as a slice, if it is contiguous and in standard order. @@ -1620,7 +1620,7 @@ where Some(st) => st, None => return None, }; - unsafe { Some(ArrayView::new_(self.ptr.as_ptr(), dim, broadcast_strides)) } + unsafe { Some(ArrayView::new(self.ptr, dim, broadcast_strides)) } } /// Swap axes `ax` and `bx`. diff --git a/src/impl_raw_views.rs b/src/impl_raw_views.rs index 154f86688..643753e7d 100644 --- a/src/impl_raw_views.rs +++ b/src/impl_raw_views.rs @@ -1,3 +1,6 @@ +use std::mem; +use std::ptr::NonNull; + use crate::dimension::{self, stride_offset}; use crate::extension::nonnull::nonnull_debug_checked_from_ptr; use crate::imp_prelude::*; @@ -11,16 +14,20 @@ where /// /// Unsafe because caller is responsible for ensuring that the array will /// meet all of the invariants of the `ArrayBase` type. - #[inline(always)] - pub(crate) unsafe fn new_(ptr: *const A, dim: D, strides: D) -> Self { + #[inline] + pub(crate) unsafe fn new(ptr: NonNull, dim: D, strides: D) -> Self { RawArrayView { data: RawViewRepr::new(), - ptr: nonnull_debug_checked_from_ptr(ptr as *mut _), + ptr, dim, strides, } } + unsafe fn new_(ptr: *const A, dim: D, strides: D) -> Self { + Self::new(nonnull_debug_checked_from_ptr(ptr as *mut A), dim, strides) + } + /// Create an `RawArrayView` from shape information and a raw pointer /// to the elements. /// @@ -76,7 +83,7 @@ where /// ensure that all of the data is valid and choose the correct lifetime. #[inline] pub unsafe fn deref_into_view<'a>(self) -> ArrayView<'a, A, D> { - ArrayView::new_(self.ptr.as_ptr(), self.dim, self.strides) + ArrayView::new(self.ptr, self.dim, self.strides) } /// Split the array view along `axis` and return one array pointer strictly @@ -105,6 +112,32 @@ where (left, right) } + + /// Cast the raw pointer of the raw array view to a different type + /// + /// **Panics** if element size is not compatible. + /// + /// Lack of panic does not imply it is a valid cast. The cast works the same + /// way as regular raw pointer casts. + /// + /// While this method is safe, for the same reason as regular raw pointer + /// casts are safe, access through the produced raw view is only possible + /// in an unsafe block or function. + pub fn cast(self) -> RawArrayView { + assert_eq!( + mem::size_of::(), + mem::size_of::(), + "size mismatch in raw view cast" + ); + let ptr = self.ptr.cast::(); + debug_assert!( + is_aligned(ptr.as_ptr()), + "alignment mismatch in raw view cast" + ); + /* Alignment checked with debug assertion: alignment could be dynamically correct, + * and we don't have a check that compiles out for that. */ + unsafe { RawArrayView::new(ptr, self.dim, self.strides) } + } } impl RawArrayViewMut @@ -115,16 +148,20 @@ where /// /// Unsafe because caller is responsible for ensuring that the array will /// meet all of the invariants of the `ArrayBase` type. - #[inline(always)] - pub(crate) unsafe fn new_(ptr: *mut A, dim: D, strides: D) -> Self { + #[inline] + pub(crate) unsafe fn new(ptr: NonNull, dim: D, strides: D) -> Self { RawArrayViewMut { data: RawViewRepr::new(), - ptr: nonnull_debug_checked_from_ptr(ptr), + ptr, dim, strides, } } + unsafe fn new_(ptr: *mut A, dim: D, strides: D) -> Self { + Self::new(nonnull_debug_checked_from_ptr(ptr), dim, strides) + } + /// Create an `RawArrayViewMut` from shape information and a raw /// pointer to the elements. /// @@ -176,7 +213,7 @@ where /// Converts to a non-mutable `RawArrayView`. #[inline] pub(crate) fn into_raw_view(self) -> RawArrayView { - unsafe { RawArrayView::new_(self.ptr.as_ptr(), self.dim, self.strides) } + unsafe { RawArrayView::new(self.ptr, self.dim, self.strides) } } /// Converts to a read-only view of the array. @@ -186,7 +223,7 @@ where /// ensure that all of the data is valid and choose the correct lifetime. #[inline] pub unsafe fn deref_into_view<'a>(self) -> ArrayView<'a, A, D> { - ArrayView::new_(self.ptr.as_ptr(), self.dim, self.strides) + ArrayView::new(self.ptr, self.dim, self.strides) } /// Converts to a mutable view of the array. @@ -196,7 +233,7 @@ where /// ensure that all of the data is valid and choose the correct lifetime. #[inline] pub unsafe fn deref_into_view_mut<'a>(self) -> ArrayViewMut<'a, A, D> { - ArrayViewMut::new_(self.ptr.as_ptr(), self.dim, self.strides) + ArrayViewMut::new(self.ptr, self.dim, self.strides) } /// Split the array view along `axis` and return one array pointer strictly @@ -207,9 +244,35 @@ where let (left, right) = self.into_raw_view().split_at(axis, index); unsafe { ( - Self::new_(left.ptr.as_ptr(), left.dim, left.strides), - Self::new_(right.ptr.as_ptr(), right.dim, right.strides), + Self::new(left.ptr, left.dim, left.strides), + Self::new(right.ptr, right.dim, right.strides), ) } } + + /// Cast the raw pointer of the raw array view to a different type + /// + /// **Panics** if element size is not compatible. + /// + /// Lack of panic does not imply it is a valid cast. The cast works the same + /// way as regular raw pointer casts. + /// + /// While this method is safe, for the same reason as regular raw pointer + /// casts are safe, access through the produced raw view is only possible + /// in an unsafe block or function. + pub fn cast(self) -> RawArrayViewMut { + assert_eq!( + mem::size_of::(), + mem::size_of::(), + "size mismatch in raw view cast" + ); + let ptr = self.ptr.cast::(); + debug_assert!( + is_aligned(ptr.as_ptr()), + "alignment mismatch in raw view cast" + ); + /* Alignment checked with debug assertion: alignment could be dynamically correct, + * and we don't have a check that compiles out for that. */ + unsafe { RawArrayViewMut::new(ptr, self.dim, self.strides) } + } } diff --git a/src/impl_views/constructors.rs b/src/impl_views/constructors.rs index 765e88e7a..efa854e51 100644 --- a/src/impl_views/constructors.rs +++ b/src/impl_views/constructors.rs @@ -6,6 +6,8 @@ // option. This file may not be copied, modified, or distributed // except according to those terms. +use std::ptr::NonNull; + use crate::dimension; use crate::error::ShapeError; use crate::extension::nonnull::nonnull_debug_checked_from_ptr; @@ -200,11 +202,11 @@ where /// Convert the view into an `ArrayViewMut<'b, A, D>` where `'b` is a lifetime /// outlived by `'a'`. - pub fn reborrow<'b>(mut self) -> ArrayViewMut<'b, A, D> + pub fn reborrow<'b>(self) -> ArrayViewMut<'b, A, D> where 'a: 'b, { - unsafe { ArrayViewMut::new_(self.as_mut_ptr(), self.dim, self.strides) } + unsafe { ArrayViewMut::new(self.ptr, self.dim, self.strides) } } } @@ -217,14 +219,24 @@ where /// /// Unsafe because: `ptr` must be valid for the given dimension and strides. #[inline(always)] - pub(crate) unsafe fn new_(ptr: *const A, dim: D, strides: D) -> Self { + pub(crate) unsafe fn new(ptr: NonNull, dim: D, strides: D) -> Self { + if cfg!(debug_assertions) { + assert!(is_aligned(ptr.as_ptr()), "The pointer must be aligned."); + dimension::max_abs_offset_check_overflow::(&dim, &strides).unwrap(); + } ArrayView { data: ViewRepr::new(), - ptr: nonnull_debug_checked_from_ptr(ptr as *mut A), + ptr, dim, strides, } } + + /// Unsafe because: `ptr` must be valid for the given dimension and strides. + #[inline] + pub(crate) unsafe fn new_(ptr: *const A, dim: D, strides: D) -> Self { + Self::new(nonnull_debug_checked_from_ptr(ptr as *mut A), dim, strides) + } } impl<'a, A, D> ArrayViewMut<'a, A, D> @@ -235,17 +247,24 @@ where /// /// Unsafe because: `ptr` must be valid for the given dimension and strides. #[inline(always)] - pub(crate) unsafe fn new_(ptr: *mut A, dim: D, strides: D) -> Self { + pub(crate) unsafe fn new(ptr: NonNull, dim: D, strides: D) -> Self { if cfg!(debug_assertions) { - assert!(!ptr.is_null(), "The pointer must be non-null."); - assert!(is_aligned(ptr), "The pointer must be aligned."); + assert!(is_aligned(ptr.as_ptr()), "The pointer must be aligned."); dimension::max_abs_offset_check_overflow::(&dim, &strides).unwrap(); } ArrayViewMut { data: ViewRepr::new(), - ptr: nonnull_debug_checked_from_ptr(ptr), + ptr, dim, strides, } } + + /// Create a new `ArrayView` + /// + /// Unsafe because: `ptr` must be valid for the given dimension and strides. + #[inline(always)] + pub(crate) unsafe fn new_(ptr: *mut A, dim: D, strides: D) -> Self { + Self::new(nonnull_debug_checked_from_ptr(ptr), dim, strides) + } } diff --git a/src/impl_views/conversions.rs b/src/impl_views/conversions.rs index 0c2222be3..303541b8b 100644 --- a/src/impl_views/conversions.rs +++ b/src/impl_views/conversions.rs @@ -26,7 +26,7 @@ where where 'a: 'b, { - unsafe { ArrayView::new_(self.as_ptr(), self.dim, self.strides) } + unsafe { ArrayView::new(self.ptr, self.dim, self.strides) } } /// Return the array’s data as a slice, if it is contiguous and in standard order. @@ -53,7 +53,7 @@ where /// Converts to a raw array view. pub(crate) fn into_raw_view(self) -> RawArrayView { - unsafe { RawArrayView::new_(self.ptr.as_ptr(), self.dim, self.strides) } + unsafe { RawArrayView::new(self.ptr, self.dim, self.strides) } } } @@ -161,12 +161,12 @@ where { // Convert into a read-only view pub(crate) fn into_view(self) -> ArrayView<'a, A, D> { - unsafe { ArrayView::new_(self.ptr.as_ptr(), self.dim, self.strides) } + unsafe { ArrayView::new(self.ptr, self.dim, self.strides) } } /// Converts to a mutable raw array view. pub(crate) fn into_raw_view_mut(self) -> RawArrayViewMut { - unsafe { RawArrayViewMut::new_(self.ptr.as_ptr(), self.dim, self.strides) } + unsafe { RawArrayViewMut::new(self.ptr, self.dim, self.strides) } } #[inline] diff --git a/src/lib.rs b/src/lib.rs index 1390655c1..a164bee93 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1517,7 +1517,7 @@ where let ptr = self.ptr; let mut strides = dim.clone(); strides.slice_mut().copy_from_slice(self.strides.slice()); - unsafe { ArrayView::new_(ptr.as_ptr(), dim, strides) } + unsafe { ArrayView::new(ptr, dim, strides) } } fn raw_strides(&self) -> D { diff --git a/src/zip/mod.rs b/src/zip/mod.rs index d0003bf7d..33317253e 100644 --- a/src/zip/mod.rs +++ b/src/zip/mod.rs @@ -73,7 +73,7 @@ where type Output = ArrayView<'a, A, E::Dim>; fn broadcast_unwrap(self, shape: E) -> Self::Output { let res: ArrayView<'_, A, E::Dim> = (&self).broadcast_unwrap(shape.into_dimension()); - unsafe { ArrayView::new_(res.ptr.as_ptr(), res.dim, res.strides) } + unsafe { ArrayView::new(res.ptr, res.dim, res.strides) } } private_impl! {} } diff --git a/tests/raw_views.rs b/tests/raw_views.rs new file mode 100644 index 000000000..09e01aebd --- /dev/null +++ b/tests/raw_views.rs @@ -0,0 +1,86 @@ +use ndarray::prelude::*; +use ndarray::Zip; + +use std::cell::Cell; +#[cfg(debug_assertions)] +use std::mem; + +#[test] +fn raw_view_cast_cell() { + // Test .cast() by creating an ArrayView> + + let mut a = Array::from_shape_fn((10, 5), |(i, j)| (i * j) as f32); + let answer = &a + 1.; + + { + let raw_cell_view = a.raw_view_mut().cast::>(); + let cell_view = unsafe { raw_cell_view.deref_into_view() }; + + Zip::from(cell_view).apply(|elt| elt.set(elt.get() + 1.)); + } + assert_eq!(a, answer); +} + +#[test] +fn raw_view_cast_reinterpret() { + // Test .cast() by reinterpreting u16 as [u8; 2] + let a = Array::from_shape_fn((5, 5).f(), |(i, j)| (i as u16) << 8 | j as u16); + let answer = a.mapv(u16::to_ne_bytes); + + let raw_view = a.raw_view().cast::<[u8; 2]>(); + let view = unsafe { raw_view.deref_into_view() }; + assert_eq!(view, answer); +} + +#[test] +fn raw_view_cast_zst() { + struct Zst; + + let a = Array::<(), _>::default((250, 250)); + let b: RawArrayView = a.raw_view().cast::(); + assert_eq!(a.shape(), b.shape()); + assert_eq!(a.as_ptr() as *const u8, b.as_ptr() as *const u8); +} + +#[test] +#[should_panic] +fn raw_view_invalid_size_cast() { + let data = [0i32; 16]; + ArrayView::from(&data[..]).raw_view().cast::(); +} + +#[test] +#[should_panic] +fn raw_view_mut_invalid_size_cast() { + let mut data = [0i32; 16]; + ArrayViewMut::from(&mut data[..]) + .raw_view_mut() + .cast::(); +} + +#[test] +#[cfg(debug_assertions)] +#[should_panic = "alignment mismatch"] +fn raw_view_invalid_align_cast() { + #[derive(Copy, Clone, Debug)] + #[repr(transparent)] + struct A([u8; 16]); + #[derive(Copy, Clone, Debug)] + #[repr(transparent)] + struct B([f64; 2]); + + unsafe { + const LEN: usize = 16; + let mut buffer = [0u8; mem::size_of::() * (LEN + 1)]; + // Take out a slice of buffer as &[A] which is misaligned for B + let mut ptr = buffer.as_mut_ptr(); + if ptr as usize % mem::align_of::() == 0 { + ptr = ptr.add(1); + } + + let view = RawArrayViewMut::from_shape_ptr(LEN, ptr as *mut A); + + // misaligned cast - test debug assertion + view.cast::(); + } +}