Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add itertoolz.flatten #547

Open
wants to merge 17 commits into
base: master
Choose a base branch
from
1 change: 1 addition & 0 deletions toolz/curried/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@
drop = toolz.curry(toolz.drop)
excepts = toolz.curry(toolz.excepts)
filter = toolz.curry(toolz.filter)
flatten = toolz.curry(toolz.flatten)
get = toolz.curry(toolz.get)
get_in = toolz.curry(toolz.get_in)
groupby = toolz.curry(toolz.groupby)
Expand Down
59 changes: 57 additions & 2 deletions toolz/itertoolz.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import operator
from functools import partial
from itertools import filterfalse, zip_longest
from collections.abc import Sequence
from collections.abc import Sequence, Mapping
from toolz.utils import no_default


Expand All @@ -13,7 +13,8 @@
'first', 'second', 'nth', 'last', 'get', 'concat', 'concatv',
'mapcat', 'cons', 'interpose', 'frequencies', 'reduceby', 'iterate',
'sliding_window', 'partition', 'partition_all', 'count', 'pluck',
'join', 'tail', 'diff', 'topk', 'peek', 'peekn', 'random_sample')
'join', 'tail', 'diff', 'topk', 'peek', 'peekn', 'random_sample',
'flatten')


def remove(predicate, seq):
Expand Down Expand Up @@ -484,6 +485,7 @@ def concat(seqs):

See also:
itertools.chain.from_iterable equivalent
flatten
"""
return itertools.chain.from_iterable(seqs)

Expand Down Expand Up @@ -1051,3 +1053,56 @@ def random_sample(prob, seq, random_state=None):

random_state = Random(random_state)
return filter(lambda _: random_state.random() < prob, seq)


def _default_descend(x):
return not isinstance(x, (str, bytes, Mapping))


def flatten(level, seq, descend=_default_descend):
""" Flatten a possibly nested sequence

Inspired by Javascript's Array.flat(), this is a recursive,
depth limited flattening generator. A level 0 flattening will
yield the input sequence unchanged. A level -1 flattening
will flatten all possible levels of nesting.

>>> list(flatten(0, [1, [2], [[3]]])) # flatten 0 levels
[1, [2], [[3]]]
>>> list(flatten(1, [1, [2], [[3]]])) # flatten 1 level
[1, 2, [3]]
>>> list(flatten(2, [1, [2], [[3]]]))
[1, 2, 3]
>>> list(flatten(-1, [1, [[[[2]]]]])) # flatten all levels
[1, 2]

An optional ``descend`` function can be provided by the user
to determine which iterable objects to recurse into. This function
should return a boolean with True meaning it is permissible to descend
another level of recursion. The recursion limit of the Python interpreter
is the ultimate bounding factor on depth. By default, stings, bytes,
and mappings are exempted.

>>> list(flatten(-1, ['abc', [{'a': 2}, [b'123']]]))
['abc', {'a': 2}, b'123']

See also:
concat
"""
if level < -1:
raise ValueError("Level must be >= -1")
if not callable(descend):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe let users pass None for descend to say "I want the default value"?

Suggested change
if not callable(descend):
if descend is None:
descend = _default_descend
if not callable(descend):

This is something I almost always recommend for default arguments, especially when the default value is private, like it is here.

Because

  1. It can help code that programmatically selects the descend argument.

  2. It can help with writing wrappers of flatten which themselves let users pass an optional descend.

raise ValueError("descend must be a callable boolean function")

def flat(level, seq):
if level == 0:
yield from seq
return

for item in seq:
if isiterable(item) and descend(item):
yield from flat(level - 1, item)
else:
yield item

yield from flat(level, seq)
37 changes: 36 additions & 1 deletion toolz/tests/test_itertoolz.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import pytest
import itertools
from itertools import starmap
from toolz.utils import raises
Expand All @@ -13,7 +14,7 @@
reduceby, iterate, accumulate,
sliding_window, count, partition,
partition_all, take_nth, pluck, join,
diff, topk, peek, peekn, random_sample)
diff, topk, peek, peekn, random_sample, flatten)
from operator import add, mul


Expand Down Expand Up @@ -547,3 +548,37 @@ def test_random_sample():
assert mk_rsample(b"a") == mk_rsample(u"a")

assert raises(TypeError, lambda: mk_rsample([]))


def test_flat():
seq = [1, 2, 3, 4]
assert list(flatten(0, seq)) == seq
assert list(flatten(1, seq)) == seq

seq = [1, [2, [3]]]
assert list(flatten(0, seq)) == seq
assert list(flatten(1, seq)) == [1, 2, [3]]
assert list(flatten(2, seq)) == [1, 2, 3]

# Test mappings
seq = [{'a': 1}, [1, 2, 3]]
assert list(flatten(0, seq)) == seq
assert list(flatten(1, seq)) == [{'a': 1}, 1, 2, 3]

# Test stringsj
seq = ["asgf", b"abcd"]
assert list(flatten(-1, seq)) == seq

# Test custom descend function
def descend(x):
if isinstance(x, str):
return len(x) != 1
return False
seq = ["asdf", [1, 2, 3]]
assert list(flatten(1, seq, descend=descend)) == ["a", "s", "d", "f", [1, 2, 3]]

with pytest.raises(ValueError):
list(flatten(0, [1, 2], descend=True))

with pytest.raises(ValueError):
list(flatten(-2, [1, 2]))