diff --git a/src/stream/mod.rs b/src/stream/mod.rs index 016da21aba0..7cc61f0ac0a 100755 --- a/src/stream/mod.rs +++ b/src/stream/mod.rs @@ -587,10 +587,16 @@ pub trait Stream { for_each::new(self, f) } - /// Creates a new stream of at most `amt` items. + /// Creates a new stream of at most `amt` items of the underlying stream. /// /// Once `amt` items have been yielded from this stream then it will always /// return that the stream is done. + /// + /// # Errors + /// + /// Any errors yielded from underlying stream, before the desired amount of + /// items is reached, are passed through and do not affect the total number + /// of items taken. fn take(self, amt: u64) -> Take where Self: Sized { @@ -601,6 +607,11 @@ pub trait Stream { /// /// Once `amt` items have been skipped from this stream then it will always /// return the remaining items on this stream. + /// + /// # Errors + /// + /// All errors yielded from underlying stream are passed through and do not + /// affect the total number of items skipped. fn skip(self, amt: u64) -> Skip where Self: Sized { diff --git a/src/stream/take.rs b/src/stream/take.rs index aa7425ad397..9c43252d24b 100644 --- a/src/stream/take.rs +++ b/src/stream/take.rs @@ -29,13 +29,12 @@ impl Stream for Take if self.remaining == 0 { Ok(Async::Ready(None)) } else { - match self.stream.poll() { - e @ Ok(Async::Ready(Some(_))) | e @ Err(_) => { - self.remaining -= 1; - e - } - other => other, + let next = try_ready!(self.stream.poll()); + match next { + Some(_) => self.remaining -= 1, + None => self.remaining = 0, } + Ok(Async::Ready(next)) } } } diff --git a/tests/stream.rs b/tests/stream.rs index 7c5b2e40725..d2186206a32 100644 --- a/tests/stream.rs +++ b/tests/stream.rs @@ -90,6 +90,18 @@ fn skip() { assert_done(|| list().skip(2).collect(), Ok(vec![3])); } +#[test] +fn skip_passes_errors_through() { + let mut s = iter(vec![Err(1), Err(2), Ok(3), Ok(4), Ok(5)]) + .skip(1) + .wait(); + assert_eq!(s.next(), Some(Err(1))); + assert_eq!(s.next(), Some(Err(2))); + assert_eq!(s.next(), Some(Ok(4))); + assert_eq!(s.next(), Some(Ok(5))); + assert_eq!(s.next(), None); +} + #[test] fn skip_while() { assert_done(|| list().skip_while(|e| Ok(*e % 2 == 1)).collect(), @@ -100,6 +112,21 @@ fn take() { assert_done(|| list().take(2).collect(), Ok(vec![1, 2])); } +#[test] +fn take_passes_errors_through() { + let mut s = iter(vec![Err(1), Err(2), Ok(3), Ok(4), Err(4)]) + .take(1) + .wait(); + assert_eq!(s.next(), Some(Err(1))); + assert_eq!(s.next(), Some(Err(2))); + assert_eq!(s.next(), Some(Ok(3))); + assert_eq!(s.next(), None); + + let mut s = iter(vec![Ok(1), Err(2)]).take(1).wait(); + assert_eq!(s.next(), Some(Ok(1))); + assert_eq!(s.next(), None); +} + #[test] fn peekable() { assert_done(|| list().peekable().collect(), Ok(vec![1, 2, 3]));