Skip to content

Commit

Permalink
make flatten more compiler frendly
Browse files Browse the repository at this point in the history
1. add inference test
2. add more test and remove `cat_nested`:
   `cat_nested` failed to infer in some cases.
    It has been inserted into `make_makeargs`, thus unneeded.
  • Loading branch information
N5N3 committed Mar 12, 2022
1 parent 2cba553 commit 1f9864d
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 74 deletions.
110 changes: 39 additions & 71 deletions base/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -314,21 +314,21 @@ some cases.
"""
function flatten(bc::Broadcasted{Style}) where {Style}
isflat(bc) && return bc
# concatenate the nested arguments into {a, b, c, d}
args = cat_nested(bc)
# build a function `makeargs` that takes a "flat" argument list and
# 1. concatenate the nested arguments into {a, b, c, d}
# 2. build a function `makeargs` that takes a "flat" argument list and
# and creates the appropriate input arguments for `f`, e.g.,
# makeargs = (w, x, y, z) -> (w, g(x, y), z)
#
# `makeargs` is built recursively and looks a bit like this:
# makeargs(w, x, y, z) = (w, makeargs1(x, y, z)...)
# = (w, g(x, y), makeargs2(z)...)
# = (w, g(x, y), z)
let makeargs = make_makeargs(()->(), bc.args), f = bc.f
newf = @inline function(args::Vararg{Any,N}) where N
f(makeargs(args...)...)
end
return Broadcasted{Style}(newf, args, bc.axes)
let (makeargs, args) = make_makeargs((), bc.args), f = bc.f
_make(::NTuple{N,Any}) where {N} =
@inline function (args::Vararg{Any,N})
f(makeargs(args...)...)
end
return Broadcasted{Style}(_make(args), args, bc.axes)
end
end

Expand All @@ -338,79 +338,47 @@ _isflat(args::NestedTuple) = false
_isflat(args::Tuple) = _isflat(tail(args))
_isflat(args::Tuple{}) = true

cat_nested(t::Broadcasted, rest...) = (cat_nested(t.args...)..., cat_nested(rest...)...)
cat_nested(t::Any, rest...) = (t, cat_nested(rest...)...)
cat_nested() = ()

"""
make_makeargs(makeargs_tail::Function, t::Tuple) -> Function
make_makeargs(args::Tuple, t::Tuple) -> Function, Tuple
Each element of `t` is one (consecutive) node in a broadcast tree.
Ignoring `makeargs_tail` for the moment, the job of `make_makeargs` is
to return a function that takes in flattened argument list and returns a
tuple (each entry corresponding to an entry in `t`, having evaluated
the corresponding element in the broadcast tree). As an additional
complication, the passed in tuple may be longer than the number of leaves
in the subtree described by `t`. The `makeargs_tail` function should
be called on such additional arguments (but not the arguments consumed
by `t`).
`args` contains the rest arguments on the "right" side of `t`.
The jobs of `make_makeargs` are:
1. append the flattened arguments in `t` at the beginning of `args`.
2. return a function that takes in flattened argument list and returns a
tuple (each entry corresponding to an entry in `t`, having evaluated
the corresponding element in the broadcast tree).
"""
@inline make_makeargs(makeargs_tail, t::Tuple{}) = makeargs_tail
@inline function make_makeargs(makeargs_tail, t::Tuple)
makeargs = make_makeargs(makeargs_tail, tail(t))
(head, tail...)->(head, makeargs(tail...)...)
@inline function make_makeargs(args, t::Tuple{})
_make(::NTuple{N,Any}) where {N} = (args::Vararg{Any,N}) -> args
_make(args), args
end
function make_makeargs(makeargs_tail, t::Tuple{<:Broadcasted, Vararg{Any}})
@inline function make_makeargs(args, t::Tuple)
makeargs, args′ = make_makeargs(args, tail(t))
_make(::NTuple{N,Any}) where {N} =
@inline function (head, tail::Vararg{Any,N})
(head, makeargs(tail...)...)
end
_make(args′), (t[1], args′...)
end
function make_makeargs(args, t::Tuple{<:Broadcasted,Vararg{Any}})
bc = t[1]
# c.f. the same expression in the function on leaf nodes above. Here
# we recurse into siblings in the broadcast tree.
let makeargs_tail = make_makeargs(makeargs_tail, tail(t)),
# Here we recurse into children. It would be valid to pass in makeargs_tail
# here, and not use it below. However, in that case, our recursion is no
# longer purely structural because we're building up one argument (the closure)
# while destructuing another.
makeargs_head = make_makeargs((args...)->args, bc.args),
f = bc.f
# Create two functions, one that splits of the first length(bc.args)
# elements from the tuple and one that yields the remaining arguments.
# N.B. We can't call headargs on `args...` directly because
# args is flattened (i.e. our children have not been evaluated
# yet).
headargs, tailargs = make_headargs(bc.args), make_tailargs(bc.args)
return @inline function(args::Vararg{Any,N}) where N
args1 = makeargs_head(args...)
a, b = headargs(args1...), makeargs_tail(tailargs(args1...)...)
(f(a...), b...)
end
end
end

@inline function make_headargs(t::Tuple)
let headargs = make_headargs(tail(t))
return @inline function(head, tail::Vararg{Any,N}) where N
(head, headargs(tail...)...)
end
let (makeargs, args′) = make_makeargs(args, tail(t)), f = bc.f
# Here we recurse into children. We can pass in `args′` here,
# and get `args″` directly, but it is more compiler frendly to
# treat `bc` as a new parent "node".
makeargs_head, argsˢ = make_makeargs((), bc.args)
args″ = (argsˢ..., args′...)
_make(::NTuple{L,Any}, ::NTuple{N,Any}) where {L,N} =
@inline function (args::Vararg{Any,N})
a, b = Base.IteratorsMD.split(args, Val(L)) # split `args...` directly
(f(makeargs_head(a...)...), makeargs(b...)...)
end
_make(argsˢ, args″), args″
end
end
@inline function make_headargs(::Tuple{})
return @inline function(tail::Vararg{Any,N}) where N
()
end
end

@inline function make_tailargs(t::Tuple)
let tailargs = make_tailargs(tail(t))
return @inline function(head, tail::Vararg{Any,N}) where N
tailargs(tail...)
end
end
end
@inline function make_tailargs(::Tuple{})
return @inline function(tail::Vararg{Any,N}) where N
tail
end
end

## Broadcasting utilities ##

## logic for deciding the BroadcastStyle
Expand Down
12 changes: 9 additions & 3 deletions test/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -775,13 +775,19 @@ end

# issue #27988: inference of Broadcast.flatten
using .Broadcast: Broadcasted
let
let _cat_nested(bc) = Broadcast.flatten(bc).args
bc = Broadcasted(+, (Broadcasted(*, (1, 2)), Broadcasted(*, (Broadcasted(*, (3, 4)), 5))))
@test @inferred(Broadcast.cat_nested(bc)) == (1,2,3,4,5)
@test @inferred(_cat_nested(bc)) == (1,2,3,4,5)
@test @inferred(Broadcast.materialize(Broadcast.flatten(bc))) == @inferred(Broadcast.materialize(bc)) == 62
bc = Broadcasted(+, (Broadcasted(*, (1, Broadcasted(/, (2.0, 2.5)))), Broadcasted(*, (Broadcasted(*, (3, 4)), 5))))
@test @inferred(Broadcast.cat_nested(bc)) == (1,2.0,2.5,3,4,5)
@test @inferred(_cat_nested(bc)) == (1,2.0,2.5,3,4,5)
@test @inferred(Broadcast.materialize(Broadcast.flatten(bc))) == @inferred(Broadcast.materialize(bc)) == 60.8
# 1 .* 1 .- 1 .* 1 .^2 .+ 1 .* 1 .+ 1 .^ 3
bc = Broadcasted(+, (Broadcasted(+, (Broadcasted(-, (Broadcasted(*, (1, 1)), Broadcasted(*, (1, Broadcasted(Base.literal_pow, (Ref(^), 1, Ref(Val(2)))))))), Broadcasted(*, (1, 1)))), Broadcasted(Base.literal_pow, (Base.RefValue{typeof(^)}(^), 1, Base.RefValue{Val{3}}(Val{3}())))))
@test @inferred(Broadcast.materialize(Broadcast.flatten(bc))) == @inferred(Broadcast.materialize(bc)) == 2
# @. 1 + 1 * (1 + 1 + 1 + 1)
bc = Broadcasted(+, (1, Broadcasted(*, (1, Broadcasted(+, (1, 1, 1, 1))))))
@test @inferred(_cat_nested(bc)) == (1,1,1,1,1,1) # `cat_nested` failed to infer this
end

let
Expand Down

0 comments on commit 1f9864d

Please sign in to comment.