Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(vector): Update legacy vector to optimize nth operations on iterators #634

Merged
merged 3 commits into from
Nov 19, 2021
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file modified examples/non-fungible-token/res/non_fungible_token.wasm
Binary file not shown.
2 changes: 1 addition & 1 deletion near-sdk/src/collections/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ pub use lookup_map::LookupMap;
mod lookup_set;
pub use lookup_set::LookupSet;

mod vector;
pub mod vector;
pub use vector::Vector;

mod unordered_map;
Expand Down
144 changes: 137 additions & 7 deletions near-sdk/src/collections/vector.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
//! A vector implemented on a trie. Unlike standard vector does not support insertion and removal
//! of an element results in the last element being placed in the empty position.
use core::ops::Range;
use std::iter::FusedIterator;
use std::marker::PhantomData;

use borsh::{BorshDeserialize, BorshSerialize};
Expand Down Expand Up @@ -126,11 +128,8 @@ impl<T> Vector<T> {
}

/// Iterate over raw serialized elements.
pub fn iter_raw(&self) -> impl Iterator<Item = Vec<u8>> + '_ {
(0..self.len).map(move |i| {
let lookup_key = self.index_to_lookup_key(i);
expect_consistent_state(env::storage_read(&lookup_key))
})
pub fn iter_raw(&self) -> RawIter<T> {
RawIter::new(self)
}

/// Extends vector from the given collection of serialized elements.
Expand Down Expand Up @@ -206,8 +205,8 @@ where
}

/// Iterate over deserialized elements.
pub fn iter(&self) -> impl Iterator<Item = T> + '_ {
self.iter_raw().map(|raw_element| Self::deserialize_element(&raw_element))
pub fn iter(&self) -> Iter<T> {
Iter::new(self)
}

pub fn to_vec(&self) -> Vec<T> {
Expand Down Expand Up @@ -237,6 +236,110 @@ impl<T: std::fmt::Debug + BorshDeserialize> std::fmt::Debug for Vector<T> {
}
}

/// An iterator over raw serialized bytes of each element in the [`Vector`].
pub struct RawIter<'a, T> {
vec: &'a Vector<T>,
range: Range<u64>,
}

impl<'a, T> RawIter<'a, T> {
fn new(vec: &'a Vector<T>) -> Self {
Self { vec, range: Range { start: 0, end: vec.len() } }
}

/// Returns number of elements left to iterate.
fn remaining(&self) -> usize {
(self.range.end - self.range.start) as usize
}
}

impl<'a, T> Iterator for RawIter<'a, T> {
type Item = Vec<u8>;

fn next(&mut self) -> Option<Self::Item> {
<Self as Iterator>::nth(self, 0)
}

fn size_hint(&self) -> (usize, Option<usize>) {
let remaining = self.remaining();
(remaining, Some(remaining))
}

fn count(self) -> usize {
self.remaining()
}

fn nth(&mut self, n: usize) -> Option<Self::Item> {
let idx = self.range.nth(n)?;
self.vec.get_raw(idx)
}
}

impl<'a, T> ExactSizeIterator for RawIter<'a, T> {}
impl<'a, T> FusedIterator for RawIter<'a, T> {}

impl<'a, T> DoubleEndedIterator for RawIter<'a, T> {
fn next_back(&mut self) -> Option<Self::Item> {
<Self as DoubleEndedIterator>::nth_back(self, 0)
}

fn nth_back(&mut self, n: usize) -> Option<Self::Item> {
let idx = self.range.nth_back(n)?;
self.vec.get_raw(idx)
}
}

/// An iterator over each element deserialized in the [`Vector`].
pub struct Iter<'a, T> {
inner: RawIter<'a, T>,
}

impl<'a, T> Iter<'a, T> {
fn new(vec: &'a Vector<T>) -> Self {
Self { inner: RawIter::new(vec) }
}
}

impl<'a, T> Iterator for Iter<'a, T>
where
T: BorshDeserialize,
{
type Item = T;

fn next(&mut self) -> Option<Self::Item> {
<Self as Iterator>::nth(self, 0)
}

fn size_hint(&self) -> (usize, Option<usize>) {
let remaining = self.inner.remaining();
(remaining, Some(remaining))
}

fn count(self) -> usize {
self.inner.remaining()
}

fn nth(&mut self, n: usize) -> Option<Self::Item> {
self.inner.nth(n).map(|raw_element| Vector::deserialize_element(&raw_element))
}
}

impl<'a, T> ExactSizeIterator for Iter<'a, T> where T: BorshDeserialize {}
impl<'a, T> FusedIterator for Iter<'a, T> where T: BorshDeserialize {}

impl<'a, T> DoubleEndedIterator for Iter<'a, T>
where
T: BorshDeserialize,
{
fn next_back(&mut self) -> Option<Self::Item> {
<Self as DoubleEndedIterator>::nth_back(self, 0)
}

fn nth_back(&mut self, n: usize) -> Option<Self::Item> {
self.inner.nth_back(n).map(|raw_element| Vector::deserialize_element(&raw_element))
}
}

#[cfg(not(target_arch = "wasm32"))]
#[cfg(test)]
mod tests {
Expand Down Expand Up @@ -393,4 +496,31 @@ mod tests {
);
}
}

#[test]
pub fn iterator_checks() {
let mut vec = Vector::new(b"v");
let mut baseline = vec![];
for i in 0..10 {
vec.push(&i);
baseline.push(i);
}

let mut vec_iter = vec.iter();
let mut bl_iter = baseline.iter();
assert_eq!(vec_iter.next(), bl_iter.next().copied());
assert_eq!(vec_iter.next_back(), bl_iter.next_back().copied());
assert_eq!(vec_iter.nth(3), bl_iter.nth(3).copied());
assert_eq!(vec_iter.nth_back(2), bl_iter.nth_back(2).copied());

// Check to make sure indexing overflow is handled correctly
assert!(vec_iter.nth(5).is_none());
assert!(bl_iter.nth(5).is_none());

assert!(vec_iter.next().is_none());
assert!(bl_iter.next().is_none());

// Count check
assert_eq!(vec.iter().count(), baseline.len());
}
}