Skip to content

Commit

Permalink
Implements flatmap (#44792)
Browse files Browse the repository at this point in the history
flatmap is the composition of map and flatten. It is important for functional programming patterns.

Some tasks that can be easily attained with list-comprehensions, including the composition of filter and mapping, or flattening a list of computed lists, can only be attained with do-syntax style if a flatmap functor is available. (Or appending a `|> flatten`, etc.)

Filtering can be implemented by outputing empty lists or singleton lists for the values to be removed or kept.
  • Loading branch information
nlw0 authored Apr 7, 2022
1 parent ad047d0 commit badad9d
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 1 deletion.
25 changes: 24 additions & 1 deletion base/iterators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import .Base:
getindex, setindex!, get, iterate,
popfirst!, isdone, peek

export enumerate, zip, rest, countfrom, take, drop, takewhile, dropwhile, cycle, repeated, product, flatten, partition
export enumerate, zip, rest, countfrom, take, drop, takewhile, dropwhile, cycle, repeated, product, flatten, partition, flatmap

"""
Iterators.map(f, iterators...)
Expand Down Expand Up @@ -1162,6 +1162,29 @@ end
reverse(f::Flatten) = Flatten(reverse(itr) for itr in reverse(f.it))
last(f::Flatten) = last(last(f.it))

"""
Iterators.flatmap(f, iterators...)
Equivalent to `flatten(map(f, iterators...))`.
# Examples
```jldoctest
julia> Iterators.flatmap(n->-n:2:n, 1:3) |> collect
9-element Vector{Int64}:
-1
1
-2
0
2
-3
-1
1
3
```
"""
# flatmap = flatten ∘ map
flatmap(f, c...) = flatten(map(f, c...))

"""
partition(collection, n)
Expand Down
23 changes: 23 additions & 0 deletions test/iterators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -476,6 +476,29 @@ end
# see #29112, #29464, #29548
@test Base.return_types(Base.IteratorEltype, Tuple{Array}) == [Base.HasEltype]

# flatmap
# -------
@test flatmap(1:3) do j flatmap(1:3) do k
j!=k ? ((j,k),) : ()
end end |> collect == [(j,k) for j in 1:3 for k in 1:3 if j!=k]
# Test inspired by the monad associativity law
fmf(x) = x<0 ? () : (x^2,)
fmg(x) = x<1 ? () : (x/2,)
fmdata = -2:0.75:2
fmv1 = flatmap(tuple.(fmdata)) do h
flatmap(h) do x
gx = fmg(x)
flatmap(gx) do x
fmf(x)
end
end
end
fmv2 = flatmap(tuple.(fmdata)) do h
gh = flatmap(h) do x fmg(x) end
flatmap(gh) do x fmf(x) end
end
@test all(fmv1 .== fmv2)

# partition(c, n)
let v = collect(partition([1,2,3,4,5], 1))
@test all(i->v[i][1] == i, v)
Expand Down

0 comments on commit badad9d

Please sign in to comment.