From b947b59079a6197d7930dfb535818ac4896113e8 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Fri, 16 Feb 2024 10:14:24 +0100 Subject: [PATCH] Adding support for integer indexing `[0, :2, -1]`. (#440) * Adding support for integer indexing `[0, :2, -1]`. * Clean up error for too large indexing. --- bindings/python/Cargo.toml | 2 +- bindings/python/src/lib.rs | 102 +++++++++++++++++++-------- bindings/python/tests/test_simple.py | 91 ++++++++++++++++++++++++ safetensors/src/slice.rs | 100 ++++++++++++++++++++++++-- 4 files changed, 257 insertions(+), 38 deletions(-) diff --git a/bindings/python/Cargo.toml b/bindings/python/Cargo.toml index 37e2aa0e..8c7a3017 100644 --- a/bindings/python/Cargo.toml +++ b/bindings/python/Cargo.toml @@ -14,5 +14,5 @@ memmap2 = "0.5" serde_json = "1.0" [dependencies.safetensors] -version = "0.4.2-dev.0" +version = "0.4.3-dev.0" path = "../../safetensors" diff --git a/bindings/python/src/lib.rs b/bindings/python/src/lib.rs index d4cd89fb..2eee2968 100644 --- a/bindings/python/src/lib.rs +++ b/bindings/python/src/lib.rs @@ -167,24 +167,41 @@ fn deserialize(py: Python, bytes: &[u8]) -> PyResult Result { - let py_start = slice.getattr(intern!(slice.py(), "start"))?; - let start: Option = py_start.extract()?; - let start = if let Some(start) = start { - Bound::Included(start) - } else { - Bound::Unbounded - }; - - let py_stop = slice.getattr(intern!(slice.py(), "stop"))?; - let stop: Option = py_stop.extract()?; - let stop = if let Some(stop) = stop { - Bound::Excluded(stop) - } else { - Bound::Unbounded - }; - - Ok(TensorIndexer::Narrow(start, stop)) +fn slice_to_indexer( + (dim_idx, (slice_index, dim)): (usize, (SliceIndex, usize)), +) -> Result { + match slice_index { + SliceIndex::Slice(slice) => { + let py_start = slice.getattr(intern!(slice.py(), "start"))?; + let start: Option = py_start.extract()?; + let start = if let Some(start) = start { + Bound::Included(start) + } else { + Bound::Unbounded + }; + + let py_stop = slice.getattr(intern!(slice.py(), "stop"))?; + let stop: Option = py_stop.extract()?; + let stop = if let Some(stop) = stop { + Bound::Excluded(stop) + } else { + Bound::Unbounded + }; + Ok(TensorIndexer::Narrow(start, stop)) + } + SliceIndex::Index(idx) => { + if idx < 0 { + let idx = dim + .checked_add_signed(idx as isize) + .ok_or(SafetensorError::new_err(format!( + "Invalid index {idx} for dimension {dim_idx} of size {dim}" + )))?; + Ok(TensorIndexer::Select(idx)) + } else { + Ok(TensorIndexer::Select(idx as usize)) + } + } + } } #[derive(Debug, Clone, PartialEq, Eq)] @@ -730,10 +747,30 @@ struct PySafeSlice { } #[derive(FromPyObject)] -enum Slice<'a> { - // Index(usize), +enum SliceIndex<'a> { Slice(&'a PySlice), - Slices(Vec<&'a PySlice>), + Index(i32), +} + +#[derive(FromPyObject)] +enum Slice<'a> { + Slice(SliceIndex<'a>), + Slices(Vec>), +} + +use std::fmt; +struct Disp(Vec); + +/// Should be more readable that the standard +/// `Debug` +impl fmt::Display for Disp { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "[")?; + for item in &self.0 { + write!(f, "{item}")?; + } + write!(f, "]") + } } #[pymethods] @@ -780,37 +817,42 @@ impl PySafeSlice { Ok(dtype) } - pub fn __getitem__(&self, slices: Slice) -> PyResult { - let slices: Vec<&PySlice> = match slices { - Slice::Slice(slice) => vec![slice], - Slice::Slices(slices) => slices, - }; - + pub fn __getitem__(&self, slices: &PyAny) -> PyResult { match &self.storage.as_ref() { Storage::Mmap(mmap) => { + let slices: Slice = slices.extract()?; + let slices: Vec = match slices { + Slice::Slice(slice) => vec![slice], + Slice::Slices(slices) => slices, + }; let data = &mmap[self.info.data_offsets.0 + self.offset ..self.info.data_offsets.1 + self.offset]; + let shape = self.info.shape.clone(); + let tensor = TensorView::new(self.info.dtype, self.info.shape.clone(), data) .map_err(|e| { SafetensorError::new_err(format!("Error preparing tensor view: {e:?}")) })?; let slices: Vec = slices .into_iter() + .zip(shape) + .enumerate() .map(slice_to_indexer) .collect::>()?; let iterator = tensor.sliced_data(&slices).map_err(|e| { SafetensorError::new_err(format!( - "Error during slicing {slices:?} vs {:?}: {:?}", - self.info.shape, e + "Error during slicing {} with shape {:?}: {:?}", + Disp(slices), + self.info.shape, + e )) })?; let newshape = iterator.newshape(); let mut offset = 0; let length = iterator.remaining_byte_len(); - Python::with_gil(|py| { let array: PyObject = PyByteArray::new_with(py, length, |bytes: &mut [u8]| { diff --git a/bindings/python/tests/test_simple.py b/bindings/python/tests/test_simple.py index c46320c0..e8543c72 100644 --- a/bindings/python/tests/test_simple.py +++ b/bindings/python/tests/test_simple.py @@ -231,3 +231,94 @@ def test_exception(self): with self.assertRaises(SafetensorError): serialize(flattened) + + def test_torch_slice(self): + A = torch.randn((10, 5)) + tensors = { + "a": A, + } + save_file_pt(tensors, "./slice.safetensors") + + # Now loading + with safe_open("./slice.safetensors", framework="pt", device="cpu") as f: + slice_ = f.get_slice("a") + tensor = slice_[:] + self.assertEqual(list(tensor.shape), [10, 5]) + torch.testing.assert_close(tensor, A) + + tensor = slice_[:2] + self.assertEqual(list(tensor.shape), [2, 5]) + torch.testing.assert_close(tensor, A[:2]) + + tensor = slice_[:, :2] + self.assertEqual(list(tensor.shape), [10, 2]) + torch.testing.assert_close(tensor, A[:, :2]) + + tensor = slice_[0, :2] + self.assertEqual(list(tensor.shape), [2]) + torch.testing.assert_close(tensor, A[0, :2]) + + tensor = slice_[2:, 0] + self.assertEqual(list(tensor.shape), [8]) + torch.testing.assert_close(tensor, A[2:, 0]) + + tensor = slice_[2:, 1] + self.assertEqual(list(tensor.shape), [8]) + torch.testing.assert_close(tensor, A[2:, 1]) + + tensor = slice_[2:, -1] + self.assertEqual(list(tensor.shape), [8]) + torch.testing.assert_close(tensor, A[2:, -1]) + + def test_numpy_slice(self): + A = np.random.rand(10, 5) + tensors = { + "a": A, + } + save_file(tensors, "./slice.safetensors") + + # Now loading + with safe_open("./slice.safetensors", framework="np", device="cpu") as f: + slice_ = f.get_slice("a") + tensor = slice_[:] + self.assertEqual(list(tensor.shape), [10, 5]) + self.assertTrue(np.allclose(tensor, A)) + + tensor = slice_[:2] + self.assertEqual(list(tensor.shape), [2, 5]) + self.assertTrue(np.allclose(tensor, A[:2])) + + tensor = slice_[:, :2] + self.assertEqual(list(tensor.shape), [10, 2]) + self.assertTrue(np.allclose(tensor, A[:, :2])) + + tensor = slice_[0, :2] + self.assertEqual(list(tensor.shape), [2]) + self.assertTrue(np.allclose(tensor, A[0, :2])) + + tensor = slice_[2:, 0] + self.assertEqual(list(tensor.shape), [8]) + self.assertTrue(np.allclose(tensor, A[2:, 0])) + + tensor = slice_[2:, 1] + self.assertEqual(list(tensor.shape), [8]) + self.assertTrue(np.allclose(tensor, A[2:, 1])) + + tensor = slice_[2:, -1] + self.assertEqual(list(tensor.shape), [8]) + self.assertTrue(np.allclose(tensor, A[2:, -1])) + + tensor = slice_[2:, -5] + self.assertEqual(list(tensor.shape), [8]) + self.assertTrue(np.allclose(tensor, A[2:, -5])) + + with self.assertRaises(SafetensorError) as cm: + tensor = slice_[2:, -6] + self.assertEqual(str(cm.exception), "Invalid index -6 for dimension 1 of size 5") + + with self.assertRaises(SafetensorError) as cm: + tensor = slice_[2:, 20] + self.assertEqual( + str(cm.exception), + "Error during slicing [2:20] with shape [10, 5]: SliceOutOfRange { dim_index: 1, asked: 20, dim_size: 5 }", + ) diff --git a/safetensors/src/slice.rs b/safetensors/src/slice.rs index 276bbb70..d19b4b59 100644 --- a/safetensors/src/slice.rs +++ b/safetensors/src/slice.rs @@ -1,5 +1,6 @@ //! Module handling lazy loading via iterating on slices on the original buffer. use crate::tensor::TensorView; +use std::fmt; use std::ops::{ Bound, Range, RangeBounds, RangeFrom, RangeFull, RangeInclusive, RangeTo, RangeToInclusive, }; @@ -9,22 +10,54 @@ use std::ops::{ pub enum InvalidSlice { /// When the client asked for more slices than the tensors has dimensions TooManySlices, + /// When the client asked for a slice that exceeds the allowed bounds + SliceOutOfRange { + /// The rank of the dimension that has the out of bounds + dim_index: usize, + /// The problematic value + asked: usize, + /// The dimension size we shouldn't go over. + dim_size: usize, + }, } #[derive(Debug, Clone)] /// Generic structure used to index a slice of the tensor pub enum TensorIndexer { - //Select(usize), + /// This is selecting an entire dimension + Select(usize), /// This is a regular slice, purely indexing a chunk of the tensor Narrow(Bound, Bound), //IndexSelect(Tensor), } -// impl From for TensorIndexer { -// fn from(index: usize) -> Self { -// TensorIndexer::Select(index) -// } -// } +fn display_bound(bound: &Bound) -> String { + match bound { + Bound::Unbounded => "".to_string(), + Bound::Excluded(n) => format!("{n}"), + Bound::Included(n) => format!("{n}"), + } +} + +/// Intended for Python users mostly or at least for its conventions +impl fmt::Display for TensorIndexer { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + TensorIndexer::Select(n) => { + write!(f, "{n}") + } + TensorIndexer::Narrow(left, right) => { + write!(f, "{}:{}", display_bound(left), display_bound(right)) + } + } + } +} + +impl From for TensorIndexer { + fn from(index: usize) -> Self { + TensorIndexer::Select(index) + } +} // impl From<&[usize]> for TensorIndexer { // fn from(index: &[usize]) -> Self { @@ -249,8 +282,18 @@ impl<'data> SliceIterator<'data> { TensorIndexer::Narrow(Bound::Excluded(s), Bound::Included(stop)) => { (*s + 1, *stop + 1) } + TensorIndexer::Select(s) => (*s, *s + 1), }; - newshape.push(stop - start); + if start >= shape && stop > shape { + return Err(InvalidSlice::SliceOutOfRange { + dim_index: i, + asked: stop.saturating_sub(1), + dim_size: shape, + }); + } + if let TensorIndexer::Narrow(..) = slice { + newshape.push(stop - start); + } if indices.is_empty() { if start == 0 && stop == shape { // We haven't started to slice yet, just increase the span @@ -487,4 +530,47 @@ mod tests { assert_eq!(iterator.next(), Some(&data[16..24])); assert_eq!(iterator.next(), None); } + + #[test] + fn test_slice_select() { + let data: Vec = vec![0.0f32, 1.0, 2.0, 3.0, 4.0, 5.0] + .into_iter() + .flat_map(|f| f.to_le_bytes()) + .collect(); + + let attn_0 = TensorView::new(Dtype::F32, vec![2, 3], &data).unwrap(); + + let mut iterator = SliceIterator::new( + &attn_0, + &[ + TensorIndexer::Select(1), + TensorIndexer::Narrow(Bound::Included(1), Bound::Excluded(3)), + ], + ) + .unwrap(); + assert_eq!(iterator.next(), Some(&data[16..24])); + assert_eq!(iterator.next(), None); + + let mut iterator = SliceIterator::new( + &attn_0, + &[ + TensorIndexer::Select(0), + TensorIndexer::Narrow(Bound::Included(1), Bound::Excluded(3)), + ], + ) + .unwrap(); + assert_eq!(iterator.next(), Some(&data[4..12])); + assert_eq!(iterator.next(), None); + + let mut iterator = SliceIterator::new( + &attn_0, + &[ + TensorIndexer::Narrow(Bound::Included(1), Bound::Excluded(2)), + TensorIndexer::Select(0), + ], + ) + .unwrap(); + assert_eq!(iterator.next(), Some(&data[12..16])); + assert_eq!(iterator.next(), None); + } }