Skip to content
2 changes: 1 addition & 1 deletion toolz/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
# Aliases
comp = compose

from . import curried, sandbox
from . import curried, exceptions, sandbox

functoolz._sigs.create_signature_registry()

Expand Down
6 changes: 6 additions & 0 deletions toolz/exceptions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@

__all__ = ('IterationError',)


class IterationError(RuntimeError):
pass
36 changes: 27 additions & 9 deletions toolz/itertoolz.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from random import Random
from collections.abc import Sequence
from toolz.utils import no_default
from toolz.exceptions import IterationError


__all__ = ('remove', 'accumulate', 'groupby', 'merge_sorted', 'interleave',
Expand Down Expand Up @@ -373,7 +374,9 @@ def first(seq):
>>> first('ABC')
'A'
"""
return next(iter(seq))
for rv in seq:
return rv
raise IterationError("Received empty sequence")


def second(seq):
Expand All @@ -383,8 +386,14 @@ def second(seq):
'B'
"""
seq = iter(seq)
next(seq)
return next(seq)
for item in seq:
break
else:
raise IterationError("Received empty sequence")
for item in seq:
return item
else:
raise IterationError("Length of sequence is < 2")


def nth(n, seq):
Expand All @@ -396,7 +405,9 @@ def nth(n, seq):
if isinstance(seq, (tuple, list, Sequence)):
return seq[n]
else:
return next(itertools.islice(seq, n, None))
for rv in itertools.islice(seq, n, None):
return rv
raise IterationError("Length of seq is < %d" % n)


def last(seq):
Expand Down Expand Up @@ -531,8 +542,9 @@ def interpose(el, seq):
[1, 'a', 2, 'a', 3]
"""
inposed = concat(zip(itertools.repeat(el), seq))
next(inposed)
return inposed
for _ in inposed:
return inposed
raise IterationError("Received empty sequence")


def frequencies(seq):
Expand Down Expand Up @@ -722,13 +734,16 @@ def partition_all(n, seq):
"""
args = [iter(seq)] * n
it = zip_longest(*args, fillvalue=no_pad)

try:
prev = next(it)
except StopIteration:
return

for item in it:
yield prev
prev = item

if prev[-1] is no_pad:
try:
# If seq defines __len__, then
Expand Down Expand Up @@ -997,8 +1012,11 @@ def peek(seq):
[0, 1, 2, 3, 4]
"""
iterator = iter(seq)
item = next(iterator)
return item, itertools.chain((item,), iterator)
for peeked in iterator:
break
else:
raise IterationError("Received empty sequence")
return peeked, itertools.chain((peeked,), iterator)


def peekn(n, seq):
Expand All @@ -1016,7 +1034,7 @@ def peekn(n, seq):
"""
iterator = iter(seq)
peeked = tuple(take(n, iterator))
return peeked, itertools.chain(iter(peeked), iterator)
return peeked, itertools.chain(peeked, iterator)


def random_sample(prob, seq, random_state=None):
Expand Down
13 changes: 9 additions & 4 deletions toolz/tests/test_itertoolz.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,10 @@
sliding_window, count, partition,
partition_all, take_nth, pluck, join,
diff, topk, peek, peekn, random_sample)
from operator import add, mul

from toolz.exceptions import IterationError

from operator import add, mul

# is comparison will fail between this and no_default
no_default2 = loads(dumps('__no__default__'))
Expand Down Expand Up @@ -127,7 +129,7 @@ def test_nth():
assert nth(2, iter('ABCDE')) == 'C'
assert nth(1, (3, 2, 1)) == 2
assert nth(0, {'foo': 'bar'}) == 'foo'
assert raises(StopIteration, lambda: nth(10, {10: 'foo'}))
assert raises(IterationError, lambda: nth(10, {10: 'foo'}))
assert nth(-2, 'ABCDE') == 'D'
assert raises(ValueError, lambda: nth(-2, iter('ABCDE')))

Expand All @@ -136,12 +138,15 @@ def test_first():
assert first('ABCDE') == 'A'
assert first((3, 2, 1)) == 3
assert isinstance(first({0: 'zero', 1: 'one'}), int)
assert raises(IterationError, lambda: first([]))


def test_second():
assert second('ABCDE') == 'B'
assert second((3, 2, 1)) == 2
assert isinstance(second({0: 'zero', 1: 'one'}), int)
assert raises(IterationError, lambda: second([1]))
assert raises(IterationError, lambda: second([]))


def test_last():
Expand Down Expand Up @@ -228,6 +233,7 @@ def test_interpose():
assert "tXaXrXzXaXn" == "".join(interpose("X", "tarzan"))
assert list(interpose(0, itertools.repeat(1, 4))) == [1, 0, 1, 0, 1, 0, 1]
assert list(interpose('.', ['a', 'b', 'c'])) == ['a', '.', 'b', '.', 'c']
assert raises(IterationError, lambda: interpose('a', []))


def test_frequencies():
Expand Down Expand Up @@ -510,8 +516,7 @@ def test_peek():
element, blist = peek(alist)
assert element == alist[0]
assert list(blist) == alist

assert raises(StopIteration, lambda: peek([]))
assert raises(IterationError, lambda: peek([]))


def test_peekn():
Expand Down