From b10d58b56a862fe318501c2a2bf44833d57e4050 Mon Sep 17 00:00:00 2001 From: Raphael Taylor-Davies <1781103+tustvold@users.noreply.github.com> Date: Wed, 15 Jun 2022 14:18:45 +0100 Subject: [PATCH] Mark typed buffer APIs safe (#996) (#1027) (#1866) * Mark typed buffer APIs safe (#996) (#1027) * Fix parquet * Format * Review feedback --- arrow/src/array/array_union.rs | 8 +++--- arrow/src/array/builder.rs | 4 +-- arrow/src/array/data.rs | 5 ++-- arrow/src/buffer/immutable.rs | 26 ++++++++----------- arrow/src/buffer/mutable.rs | 15 +++++------ arrow/src/buffer/ops.rs | 4 +-- arrow/src/compute/kernels/cast.rs | 2 +- arrow/src/compute/kernels/sort.rs | 6 ++--- arrow/src/compute/kernels/take.rs | 3 +-- arrow/src/datatypes/native.rs | 3 ++- parquet/src/arrow/array_reader/byte_array.rs | 4 +-- .../array_reader/byte_array_dictionary.rs | 6 ++--- parquet/src/arrow/array_reader/mod.rs | 8 +++--- parquet/src/arrow/arrow_writer/mod.rs | 2 +- parquet/src/arrow/buffer/dictionary_buffer.rs | 2 +- parquet/src/arrow/record_reader/buffer.rs | 4 +-- parquet/src/arrow/record_reader/mod.rs | 4 +-- 17 files changed, 47 insertions(+), 59 deletions(-) diff --git a/arrow/src/array/array_union.rs b/arrow/src/array/array_union.rs index bae771bb9e75..4ff0a31c6529 100644 --- a/arrow/src/array/array_union.rs +++ b/arrow/src/array/array_union.rs @@ -185,7 +185,7 @@ impl UnionArray { } // Check the type_ids - let type_id_slice: &[i8] = unsafe { type_ids.typed_data() }; + let type_id_slice: &[i8] = type_ids.typed_data(); let invalid_type_ids = type_id_slice .iter() .filter(|i| *i < &0) @@ -201,7 +201,7 @@ impl UnionArray { // Check the value offsets if provided if let Some(offset_buffer) = &value_offsets { let max_len = type_ids.len() as i32; - let offsets_slice: &[i32] = unsafe { offset_buffer.typed_data() }; + let offsets_slice: &[i32] = offset_buffer.typed_data(); let invalid_offsets = offsets_slice .iter() .filter(|i| *i < &0 || *i > &max_len) @@ -255,9 +255,7 @@ impl UnionArray { pub fn value_offset(&self, index: usize) -> i32 { assert!(index - self.offset() < self.len()); if self.is_dense() { - // safety: reinterpreting is safe since the offset buffer contains `i32` values and is - // properly aligned. - unsafe { self.data().buffers()[1].typed_data::()[index] } + self.data().buffers()[1].typed_data::()[index] } else { index as i32 } diff --git a/arrow/src/array/builder.rs b/arrow/src/array/builder.rs index ed26d3c2f480..2df4aecf65bf 100644 --- a/arrow/src/array/builder.rs +++ b/arrow/src/array/builder.rs @@ -76,7 +76,7 @@ pub(crate) fn builder_to_mutable_buffer( /// builder.append(45); /// let buffer = builder.finish(); /// -/// assert_eq!(unsafe { buffer.typed_data::() }, &[42, 43, 44, 45]); +/// assert_eq!(buffer.typed_data::(), &[42, 43, 44, 45]); /// # Ok(()) /// # } /// ``` @@ -380,7 +380,7 @@ impl BufferBuilder { /// /// let buffer = builder.finish(); /// - /// assert_eq!(unsafe { buffer.typed_data::() }, &[42, 44, 46]); + /// assert_eq!(buffer.typed_data::(), &[42, 44, 46]); /// ``` #[inline] pub fn finish(&mut self) -> Buffer { diff --git a/arrow/src/array/data.rs b/arrow/src/array/data.rs index 0ccbe6a70178..65fbc4df9704 100644 --- a/arrow/src/array/data.rs +++ b/arrow/src/array/data.rs @@ -767,8 +767,7 @@ impl ArrayData { ))); } - // SAFETY: Bounds checked above - Ok(unsafe { &(buffer.typed_data::()[self.offset..self.offset + len]) }) + Ok(&buffer.typed_data::()[self.offset..self.offset + len]) } /// Does a cheap sanity check that the `self.len` values in `buffer` are valid @@ -1161,7 +1160,7 @@ impl ArrayData { // Justification: buffer size was validated above let indexes: &[T] = - unsafe { &(buffer.typed_data::()[self.offset..self.offset + self.len]) }; + &buffer.typed_data::()[self.offset..self.offset + self.len]; indexes.iter().enumerate().try_for_each(|(i, &dict_index)| { // Do not check the value is null (value can be arbitrary) diff --git a/arrow/src/buffer/immutable.rs b/arrow/src/buffer/immutable.rs index c34ea101bb3b..f5d59c5ed555 100644 --- a/arrow/src/buffer/immutable.rs +++ b/arrow/src/buffer/immutable.rs @@ -181,19 +181,15 @@ impl Buffer { /// View buffer as typed slice. /// - /// # Safety + /// # Panics /// - /// `ArrowNativeType` is public so that it can be used as a trait bound for other public - /// components, such as the `ToByteSlice` trait. However, this means that it can be - /// implemented by user defined types, which it is not intended for. - pub unsafe fn typed_data(&self) -> &[T] { - // JUSTIFICATION - // Benefit - // Many of the buffers represent specific types, and consumers of `Buffer` often need to re-interpret them. - // Soundness - // * The pointer is non-null by construction - // * alignment asserted below. - let (prefix, offsets, suffix) = self.as_slice().align_to::(); + /// This function panics if the underlying buffer is not aligned + /// correctly for type `T`. + pub fn typed_data(&self) -> &[T] { + // SAFETY + // ArrowNativeType is trivially transmutable, is sealed to prevent potentially incorrect + // implementation outside this crate, and this method checks alignment + let (prefix, offsets, suffix) = unsafe { self.as_slice().align_to::() }; assert!(prefix.is_empty() && suffix.is_empty()); offsets } @@ -451,7 +447,7 @@ mod tests { macro_rules! check_as_typed_data { ($input: expr, $native_t: ty) => {{ let buffer = Buffer::from_slice_ref($input); - let slice: &[$native_t] = unsafe { buffer.typed_data::<$native_t>() }; + let slice: &[$native_t] = buffer.typed_data::<$native_t>(); assert_eq!($input, slice); }}; } @@ -573,12 +569,12 @@ mod tests { ) }; - let slice = unsafe { buffer.typed_data::() }; + let slice = buffer.typed_data::(); assert_eq!(slice, &[1, 2, 3, 4, 5]); let buffer = buffer.slice(std::mem::size_of::()); - let slice = unsafe { buffer.typed_data::() }; + let slice = buffer.typed_data::(); assert_eq!(slice, &[2, 3, 4, 5]); } } diff --git a/arrow/src/buffer/mutable.rs b/arrow/src/buffer/mutable.rs index ef3e35209a1c..5710a97f38e7 100644 --- a/arrow/src/buffer/mutable.rs +++ b/arrow/src/buffer/mutable.rs @@ -290,17 +290,16 @@ impl MutableBuffer { /// View this buffer asa slice of a specific type. /// - /// # Safety - /// - /// This function must only be used with buffers which are treated - /// as type `T` (e.g. extended with items of type `T`). - /// /// # Panics /// /// This function panics if the underlying buffer is not aligned /// correctly for type `T`. - pub unsafe fn typed_data_mut(&mut self) -> &mut [T] { - let (prefix, offsets, suffix) = self.as_slice_mut().align_to_mut::(); + pub fn typed_data_mut(&mut self) -> &mut [T] { + // SAFETY + // ArrowNativeType is trivially transmutable, is sealed to prevent potentially incorrect + // implementation outside this crate, and this method checks alignment + let (prefix, offsets, suffix) = + unsafe { self.as_slice_mut().align_to_mut::() }; assert!(prefix.is_empty() && suffix.is_empty()); offsets } @@ -314,7 +313,7 @@ impl MutableBuffer { /// assert_eq!(buffer.len(), 8) // u32 has 4 bytes /// ``` #[inline] - pub fn extend_from_slice(&mut self, items: &[T]) { + pub fn extend_from_slice(&mut self, items: &[T]) { let len = items.len(); let additional = len * std::mem::size_of::(); self.reserve(additional); diff --git a/arrow/src/buffer/ops.rs b/arrow/src/buffer/ops.rs index b3571d1740b1..ea155c8d78e4 100644 --- a/arrow/src/buffer/ops.rs +++ b/arrow/src/buffer/ops.rs @@ -68,9 +68,7 @@ where let left_chunks = left.bit_chunks(offset_in_bits, len_in_bits); - // Safety: buffer is always treated as type `u64` in the code - // below. - let result_chunks = unsafe { result.typed_data_mut::().iter_mut() }; + let result_chunks = result.typed_data_mut::().iter_mut(); result_chunks .zip(left_chunks.iter()) diff --git a/arrow/src/compute/kernels/cast.rs b/arrow/src/compute/kernels/cast.rs index 93a8ebcb6b5a..9a4638d9773f 100644 --- a/arrow/src/compute/kernels/cast.rs +++ b/arrow/src/compute/kernels/cast.rs @@ -2084,7 +2084,7 @@ where let list_data = array.data(); let str_values_buf = str_array.value_data(); - let offsets = unsafe { list_data.buffers()[0].typed_data::() }; + let offsets = list_data.buffers()[0].typed_data::(); let mut offset_builder = BufferBuilder::::new(offsets.len()); offsets.iter().try_for_each::<_, Result<_>>(|offset| { diff --git a/arrow/src/compute/kernels/sort.rs b/arrow/src/compute/kernels/sort.rs index 140a57f33ed5..72ee8b68da21 100644 --- a/arrow/src/compute/kernels/sort.rs +++ b/arrow/src/compute/kernels/sort.rs @@ -452,8 +452,7 @@ fn sort_boolean( let mut result = MutableBuffer::new(result_capacity); // sets len to capacity so we can access the whole buffer as a typed slice result.resize(result_capacity, 0); - // Safety: the buffer is always treated as `u32` in the code below - let result_slice: &mut [u32] = unsafe { result.typed_data_mut() }; + let result_slice: &mut [u32] = result.typed_data_mut(); if options.nulls_first { let size = nulls_len.min(len); @@ -565,8 +564,7 @@ where let mut result = MutableBuffer::new(result_capacity); // sets len to capacity so we can access the whole buffer as a typed slice result.resize(result_capacity, 0); - // Safety: the buffer is always treated as `u32` in the code below - let result_slice: &mut [u32] = unsafe { result.typed_data_mut() }; + let result_slice: &mut [u32] = result.typed_data_mut(); if options.nulls_first { let size = nulls_len.min(len); diff --git a/arrow/src/compute/kernels/take.rs b/arrow/src/compute/kernels/take.rs index 567bf5c8ba27..03637ec81dd6 100644 --- a/arrow/src/compute/kernels/take.rs +++ b/arrow/src/compute/kernels/take.rs @@ -688,8 +688,7 @@ where let bytes_offset = (data_len + 1) * std::mem::size_of::(); let mut offsets_buffer = MutableBuffer::from_len_zeroed(bytes_offset); - // Safety: the buffer is always treated as as a type of `OffsetSize` in the code below - let offsets = unsafe { offsets_buffer.typed_data_mut() }; + let offsets = offsets_buffer.typed_data_mut(); let mut values = MutableBuffer::new(0); let mut length_so_far = OffsetSize::zero(); offsets[0] = length_so_far; diff --git a/arrow/src/datatypes/native.rs b/arrow/src/datatypes/native.rs index efb1d3e6b2de..d9a3f667d8e4 100644 --- a/arrow/src/datatypes/native.rs +++ b/arrow/src/datatypes/native.rs @@ -42,7 +42,8 @@ pub trait JsonSerializable: 'static { /// /// Note: in the case of floating point numbers this transmutation can result in a signalling /// NaN, which, whilst sound, can be unwieldy. In general, whilst it is perfectly sound to -/// reinterpret bytes as different types using this trait, it is likely unwise +/// reinterpret bytes as different types using this trait, it is likely unwise. For more information +/// see [f32::from_bits] and [f64::from_bits]. /// /// Note: `bool` is restricted to `0` or `1`, and so `bool: !ArrowNativeType` /// diff --git a/parquet/src/arrow/array_reader/byte_array.rs b/parquet/src/arrow/array_reader/byte_array.rs index 2e29b6094741..9e0f83fa9450 100644 --- a/parquet/src/arrow/array_reader/byte_array.rs +++ b/parquet/src/arrow/array_reader/byte_array.rs @@ -125,13 +125,13 @@ impl ArrayReader for ByteArrayReader { fn get_def_levels(&self) -> Option<&[i16]> { self.def_levels_buffer .as_ref() - .map(|buf| unsafe { buf.typed_data() }) + .map(|buf| buf.typed_data()) } fn get_rep_levels(&self) -> Option<&[i16]> { self.rep_levels_buffer .as_ref() - .map(|buf| unsafe { buf.typed_data() }) + .map(|buf| buf.typed_data()) } } diff --git a/parquet/src/arrow/array_reader/byte_array_dictionary.rs b/parquet/src/arrow/array_reader/byte_array_dictionary.rs index 0e64f0d25b7b..0cd67206f000 100644 --- a/parquet/src/arrow/array_reader/byte_array_dictionary.rs +++ b/parquet/src/arrow/array_reader/byte_array_dictionary.rs @@ -187,13 +187,13 @@ where fn get_def_levels(&self) -> Option<&[i16]> { self.def_levels_buffer .as_ref() - .map(|buf| unsafe { buf.typed_data() }) + .map(|buf| buf.typed_data()) } fn get_rep_levels(&self) -> Option<&[i16]> { self.rep_levels_buffer .as_ref() - .map(|buf| unsafe { buf.typed_data() }) + .map(|buf| buf.typed_data()) } } @@ -356,7 +356,7 @@ where assert_eq!(dict.data_type(), &self.value_type); let dict_buffers = dict.data().buffers(); - let dict_offsets = unsafe { dict_buffers[0].typed_data::() }; + let dict_offsets = dict_buffers[0].typed_data::(); let dict_values = dict_buffers[1].as_slice(); values.extend_from_dictionary( diff --git a/parquet/src/arrow/array_reader/mod.rs b/parquet/src/arrow/array_reader/mod.rs index 21c49b338783..6207b377d137 100644 --- a/parquet/src/arrow/array_reader/mod.rs +++ b/parquet/src/arrow/array_reader/mod.rs @@ -226,13 +226,13 @@ where fn get_def_levels(&self) -> Option<&[i16]> { self.def_levels_buffer .as_ref() - .map(|buf| unsafe { buf.typed_data() }) + .map(|buf| buf.typed_data()) } fn get_rep_levels(&self) -> Option<&[i16]> { self.rep_levels_buffer .as_ref() - .map(|buf| unsafe { buf.typed_data() }) + .map(|buf| buf.typed_data()) } } @@ -447,13 +447,13 @@ where fn get_def_levels(&self) -> Option<&[i16]> { self.def_levels_buffer .as_ref() - .map(|buf| unsafe { buf.typed_data() }) + .map(|buf| buf.typed_data()) } fn get_rep_levels(&self) -> Option<&[i16]> { self.rep_levels_buffer .as_ref() - .map(|buf| unsafe { buf.typed_data() }) + .map(|buf| buf.typed_data()) } } diff --git a/parquet/src/arrow/arrow_writer/mod.rs b/parquet/src/arrow/arrow_writer/mod.rs index 44631e57409a..b64517ad19d1 100644 --- a/parquet/src/arrow/arrow_writer/mod.rs +++ b/parquet/src/arrow/arrow_writer/mod.rs @@ -576,7 +576,7 @@ macro_rules! def_get_binary_array_fn { fn $name(array: &$ty) -> Vec { let mut byte_array = ByteArray::new(); let ptr = crate::util::memory::ByteBufferPtr::new( - unsafe { array.value_data().typed_data::() }.to_vec(), + array.value_data().as_slice().to_vec(), ); byte_array.set_data(ptr); array diff --git a/parquet/src/arrow/buffer/dictionary_buffer.rs b/parquet/src/arrow/buffer/dictionary_buffer.rs index 7f4458507001..ffa3a4843c50 100644 --- a/parquet/src/arrow/buffer/dictionary_buffer.rs +++ b/parquet/src/arrow/buffer/dictionary_buffer.rs @@ -106,7 +106,7 @@ impl Self::Dict { keys, values } => { let mut spilled = OffsetBuffer::default(); let dict_buffers = values.data().buffers(); - let dict_offsets = unsafe { dict_buffers[0].typed_data::() }; + let dict_offsets = dict_buffers[0].typed_data::(); let dict_values = dict_buffers[1].as_slice(); if values.is_empty() { diff --git a/parquet/src/arrow/record_reader/buffer.rs b/parquet/src/arrow/record_reader/buffer.rs index fa0f919916ee..7101eaa9ccc9 100644 --- a/parquet/src/arrow/record_reader/buffer.rs +++ b/parquet/src/arrow/record_reader/buffer.rs @@ -19,7 +19,7 @@ use std::marker::PhantomData; use crate::arrow::buffer::bit_util::iter_set_bits_rev; use arrow::buffer::{Buffer, MutableBuffer}; -use arrow::datatypes::ToByteSlice; +use arrow::datatypes::ArrowNativeType; /// A buffer that supports writing new data to the end, and removing data from the front /// @@ -172,7 +172,7 @@ impl ScalarBuffer { } } -impl ScalarBuffer { +impl ScalarBuffer { pub fn push(&mut self, v: T) { self.buffer.push(v); self.len += 1; diff --git a/parquet/src/arrow/record_reader/mod.rs b/parquet/src/arrow/record_reader/mod.rs index 89d782b1aca8..023a538a2741 100644 --- a/parquet/src/arrow/record_reader/mod.rs +++ b/parquet/src/arrow/record_reader/mod.rs @@ -573,7 +573,7 @@ mod tests { // Verify result record data let actual = record_reader.consume_record_data().unwrap(); - let actual_values = unsafe { actual.typed_data::() }; + let actual_values = actual.typed_data::(); let expected = &[0, 7, 0, 6, 3, 0, 8]; assert_eq!(actual_values.len(), expected.len()); @@ -687,7 +687,7 @@ mod tests { // Verify result record data let actual = record_reader.consume_record_data().unwrap(); - let actual_values = unsafe { actual.typed_data::() }; + let actual_values = actual.typed_data::(); let expected = &[4, 0, 0, 7, 6, 3, 2, 8, 9]; assert_eq!(actual_values.len(), expected.len());