diff --git a/bindings/python/src/tokenizer.rs b/bindings/python/src/tokenizer.rs index 372bb7755..e35f96c5a 100644 --- a/bindings/python/src/tokenizer.rs +++ b/bindings/python/src/tokenizer.rs @@ -259,48 +259,50 @@ impl<'s> From> for tk::InputSequence<'s> { struct PyArrayUnicode(Vec); impl FromPyObject<'_> for PyArrayUnicode { fn extract(ob: &PyAny) -> PyResult { + // SAFETY Making sure the pointer is a valid numpy array requires calling numpy C code if unsafe { npyffi::PyArray_Check(ob.py(), ob.as_ptr()) } == 0 { return Err(exceptions::PyTypeError::new_err("Expected an np.array")); } let arr = ob.as_ptr() as *mut npyffi::PyArrayObject; - if unsafe { (*arr).nd } != 1 { - return Err(exceptions::PyTypeError::new_err( - "Expected a 1 dimensional np.array", - )); - } - if unsafe { (*arr).flags } - & (npyffi::NPY_ARRAY_C_CONTIGUOUS | npyffi::NPY_ARRAY_F_CONTIGUOUS) - == 0 - { - return Err(exceptions::PyTypeError::new_err( - "Expected a contiguous np.array", - )); - } - let n_elem = unsafe { *(*arr).dimensions } as usize; - let (type_num, elsize, alignment, data) = unsafe { + // SAFETY Getting all the metadata about the numpy array to check its sanity + let (type_num, elsize, alignment, data, nd, flags) = unsafe { let desc = (*arr).descr; ( (*desc).type_num, (*desc).elsize as usize, (*desc).alignment as usize, (*arr).data, + (*arr).nd, + (*arr).flags, ) }; + if nd != 1 { + return Err(exceptions::PyTypeError::new_err( + "Expected a 1 dimensional np.array", + )); + } + if flags & (npyffi::NPY_ARRAY_C_CONTIGUOUS | npyffi::NPY_ARRAY_F_CONTIGUOUS) == 0 { + return Err(exceptions::PyTypeError::new_err( + "Expected a contiguous np.array", + )); + } if type_num != npyffi::types::NPY_TYPES::NPY_UNICODE as i32 { return Err(exceptions::PyTypeError::new_err( "Expected a np.array[dtype='U']", )); } + // SAFETY Looking at the raw numpy data to create new owned Rust strings via copies (so it's safe afterwards). unsafe { + let n_elem = *(*arr).dimensions as usize; let all_bytes = std::slice::from_raw_parts(data as *const u8, elsize * n_elem); let seq = (0..n_elem) .map(|i| { let bytes = &all_bytes[i * elsize..(i + 1) * elsize]; - #[allow(deprecated)] - let unicode = pyo3::ffi::PyUnicode_FromUnicode( + let unicode = pyo3::ffi::PyUnicode_FromKindAndData( + pyo3::ffi::PyUnicode_4BYTE_KIND as _, bytes.as_ptr() as *const _, elsize as isize / alignment as isize, );