Skip to content

Commit

Permalink
make itertools.chain be a class and from_iterable a classmethod of chain
Browse files Browse the repository at this point in the history
  • Loading branch information
iury authored and trotterdylan committed Jan 11, 2017
1 parent d9af156 commit 3adaa1a
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 10 deletions.
23 changes: 14 additions & 9 deletions lib/itertools.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,20 @@
import _collections
import sys

def chain(*iterables):
for it in iterables:
for element in it:
yield element
class chain(object):

def from_iterable(cls, iterable):
return cls(*iterable)

from_iterable = classmethod(from_iterable)

def __init__(self, *iterables):
self.iterables = iterables

def __iter__(self):
for it in self.iterables:
for element in it:
yield element

def compress(data, selectors):
return (d for d,s in izip(data, selectors) if s)
Expand Down Expand Up @@ -49,11 +59,6 @@ def dropwhile(predicate, iterable):
for x in iterable:
yield x

def from_iterable(iterables):
for it in iterables:
for element in it:
yield element

def ifilter(predicate, iterable):
if predicate is None:
predicate = bool
Expand Down
13 changes: 12 additions & 1 deletion lib/itertools_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,17 @@ def TestDropwhile():
got = tuple(itertools.dropwhile(*args))
assert got == want, 'tuple(dropwhile%s) == %s, want %s' % (args, got, want)

def TestChain():
r = range(10)
cases = [
([r], tuple(r)),
([r, r], tuple(r) + tuple(r)),
([], ())
]
for args, want in cases:
got = tuple(itertools.chain(*args))
assert got == want, 'tuple(chain%s) == %s, want %s' % (args, got, want)

def TestFromIterable():
r = range(10)
cases = [
Expand All @@ -54,7 +65,7 @@ def TestFromIterable():
([], ())
]
for args, want in cases:
got = tuple(itertools.from_iterable(args))
got = tuple(itertools.chain.from_iterable(args))
assert got == want, 'tuple(from_iterable%s) == %s, want %s' % (args, got, want)

def TestIFilter():
Expand Down

0 comments on commit 3adaa1a

Please sign in to comment.