Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make Broadcast.flatten(bc).f more complier frendly. (better inferred and inlined) #43322

Merged
merged 1 commit into from
Jul 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 43 additions & 78 deletions base/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -341,20 +341,16 @@ function flatten(bc::Broadcasted)
isflat(bc) && return bc
# concatenate the nested arguments into {a, b, c, d}
args = cat_nested(bc)
# build a function `makeargs` that takes a "flat" argument list and
# and creates the appropriate input arguments for `f`, e.g.,
# makeargs = (w, x, y, z) -> (w, g(x, y), z)
#
# `makeargs` is built recursively and looks a bit like this:
# makeargs(w, x, y, z) = (w, makeargs1(x, y, z)...)
# = (w, g(x, y), makeargs2(z)...)
# = (w, g(x, y), z)
let makeargs = make_makeargs(()->(), bc.args), f = bc.f
newf = @inline function(args::Vararg{Any,N}) where N
f(makeargs(args...)...)
end
return Broadcasted(bc.style, newf, args, bc.axes)
end
# build a tuple of functions `makeargs`. Its elements take
# the whole "flat" argument list and and generate the appropriate
# input arguments for the broadcasted function `f`, e.g.,
# makeargs[1] = ((w, x, y, z)) -> w
# makeargs[2] = ((w, x, y, z)) -> g(x, y)
# makeargs[3] = ((w, x, y, z)) -> z
makeargs = make_makeargs(bc.args)
f = Base.maybeconstructor(bc.f)
newf = (args...) -> (@inline; f(prepare_args(makeargs, args)...))
return Broadcasted(bc.style, newf, args, bc.axes)
end

const NestedTuple = Tuple{<:Broadcasted,Vararg{Any}}
Expand All @@ -363,78 +359,47 @@ _isflat(args::NestedTuple) = false
_isflat(args::Tuple) = _isflat(tail(args))
_isflat(args::Tuple{}) = true

cat_nested(t::Broadcasted, rest...) = (cat_nested(t.args...)..., cat_nested(rest...)...)
cat_nested(t::Any, rest...) = (t, cat_nested(rest...)...)
cat_nested() = ()
cat_nested(bc::Broadcasted) = cat_nested_args(bc.args)
cat_nested_args(::Tuple{}) = ()
cat_nested_args(t::Tuple{Any}) = cat_nested(t[1])
cat_nested_args(t::Tuple) = (cat_nested(t[1])..., cat_nested_args(tail(t))...)
cat_nested(a) = (a,)

"""
make_makeargs(makeargs_tail::Function, t::Tuple) -> Function
make_makeargs(t::Tuple) -> Tuple{Vararg{Function}}

Each element of `t` is one (consecutive) node in a broadcast tree.
Ignoring `makeargs_tail` for the moment, the job of `make_makeargs` is
to return a function that takes in flattened argument list and returns a
tuple (each entry corresponding to an entry in `t`, having evaluated
the corresponding element in the broadcast tree). As an additional
complication, the passed in tuple may be longer than the number of leaves
in the subtree described by `t`. The `makeargs_tail` function should
be called on such additional arguments (but not the arguments consumed
by `t`).
The returned `Tuple` are functions which take in the (whole) flattened
list and generate the inputs for the corresponding broadcasted function.
"""
@inline make_makeargs(makeargs_tail, t::Tuple{}) = makeargs_tail
@inline function make_makeargs(makeargs_tail, t::Tuple)
makeargs = make_makeargs(makeargs_tail, tail(t))
(head, tail...)->(head, makeargs(tail...)...)
make_makeargs(args::Tuple) = _make_makeargs(args, 1)[1]

# We build `makeargs` by traversing the broadcast nodes recursively.
# note: `n` indicates the flattened index of the next unused argument.
@inline function _make_makeargs(args::Tuple, n::Int)
head, n = _make_makeargs1(args[1], n)
rest, n = _make_makeargs(tail(args), n)
(head, rest...), n
end
function make_makeargs(makeargs_tail, t::Tuple{<:Broadcasted, Vararg{Any}})
bc = t[1]
# c.f. the same expression in the function on leaf nodes above. Here
# we recurse into siblings in the broadcast tree.
let makeargs_tail = make_makeargs(makeargs_tail, tail(t)),
# Here we recurse into children. It would be valid to pass in makeargs_tail
# here, and not use it below. However, in that case, our recursion is no
# longer purely structural because we're building up one argument (the closure)
# while destructuing another.
makeargs_head = make_makeargs((args...)->args, bc.args),
f = bc.f
# Create two functions, one that splits of the first length(bc.args)
# elements from the tuple and one that yields the remaining arguments.
# N.B. We can't call headargs on `args...` directly because
# args is flattened (i.e. our children have not been evaluated
# yet).
headargs, tailargs = make_headargs(bc.args), make_tailargs(bc.args)
return @inline function(args::Vararg{Any,N}) where N
args1 = makeargs_head(args...)
a, b = headargs(args1...), makeargs_tail(tailargs(args1...)...)
(f(a...), b...)
end
end
_make_makeargs(::Tuple{}, n::Int) = (), n

# A help struct to store the flattened index staticly
struct Pick{N} <: Function end
(::Pick{N})(@nospecialize(args::Tuple)) where {N} = args[N]

# For flat nodes, we just consume one argument (n += 1), and return the "Pick" function
@inline _make_makeargs1(_, n::Int) = Pick{n}(), n + 1
# For nested nodes, we form the `makeargs1` based on the child `makeargs` (n += length(cat_nested(bc)))
@inline function _make_makeargs1(bc::Broadcasted, n::Int)
makeargs, n = _make_makeargs(bc.args, n)
f = Base.maybeconstructor(bc.f)
makeargs1 = (args::Tuple) -> (@inline; f(prepare_args(makeargs, args)...))
makeargs1, n
end

@inline function make_headargs(t::Tuple)
let headargs = make_headargs(tail(t))
return @inline function(head, tail::Vararg{Any,N}) where N
(head, headargs(tail...)...)
end
end
end
@inline function make_headargs(::Tuple{})
return @inline function(tail::Vararg{Any,N}) where N
()
end
end

@inline function make_tailargs(t::Tuple)
let tailargs = make_tailargs(tail(t))
return @inline function(head, tail::Vararg{Any,N}) where N
tailargs(tail...)
end
end
end
@inline function make_tailargs(::Tuple{})
return @inline function(tail::Vararg{Any,N}) where N
tail
end
end
@inline prepare_args(makeargs::Tuple, @nospecialize(x::Tuple)) = (makeargs[1](x), prepare_args(tail(makeargs), x)...)
@inline prepare_args(makeargs::Tuple{Any}, @nospecialize(x::Tuple)) = (makeargs[1](x),)
prepare_args(::Tuple{}, ::Tuple) = ()

## Broadcasting utilities ##

Expand Down
19 changes: 16 additions & 3 deletions test/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -774,14 +774,27 @@ let X = zeros(2, 3)
end

# issue #27988: inference of Broadcast.flatten
using .Broadcast: Broadcasted
using .Broadcast: Broadcasted, cat_nested
let
bc = Broadcasted(+, (Broadcasted(*, (1, 2)), Broadcasted(*, (Broadcasted(*, (3, 4)), 5))))
@test @inferred(Broadcast.cat_nested(bc)) == (1,2,3,4,5)
@test @inferred(cat_nested(bc)) == (1,2,3,4,5)
@test @inferred(Broadcast.materialize(Broadcast.flatten(bc))) == @inferred(Broadcast.materialize(bc)) == 62
bc = Broadcasted(+, (Broadcasted(*, (1, Broadcasted(/, (2.0, 2.5)))), Broadcasted(*, (Broadcasted(*, (3, 4)), 5))))
@test @inferred(Broadcast.cat_nested(bc)) == (1,2.0,2.5,3,4,5)
@test @inferred(cat_nested(bc)) == (1,2.0,2.5,3,4,5)
@test @inferred(Broadcast.materialize(Broadcast.flatten(bc))) == @inferred(Broadcast.materialize(bc)) == 60.8
# 1 .* 1 .- 1 .* 1 .^2 .+ 1 .* 1 .+ 1 .^ 3
bc = Broadcasted(+, (Broadcasted(+, (Broadcasted(-, (Broadcasted(*, (1, 1)), Broadcasted(*, (1, Broadcasted(Base.literal_pow, (Ref(^), 1, Ref(Val(2)))))))), Broadcasted(*, (1, 1)))), Broadcasted(Base.literal_pow, (Base.RefValue{typeof(^)}(^), 1, Base.RefValue{Val{3}}(Val{3}())))))
@test @inferred(Broadcast.materialize(Broadcast.flatten(bc))) == @inferred(Broadcast.materialize(bc)) == 2
# @. 1 + 1 * (1 + 1 + 1 + 1)
bc = Broadcasted(+, (1, Broadcasted(*, (1, Broadcasted(+, (1, 1, 1, 1))))))
@test @inferred(cat_nested(bc)) == (1, 1, 1, 1, 1, 1) # `cat_nested` failed to infer this
@test @inferred(Broadcast.materialize(Broadcast.flatten(bc))) == Broadcast.materialize(bc)
# @. 1 + (1 + 1) + 1 + (1 + 1) + 1 + (1 + 1) + 1
bc = Broadcasted(+, (1, Broadcasted(+, (1, 1)), 1, Broadcasted(+, (1, 1)), 1, Broadcasted(+, (1, 1)), 1))
@test @inferred(cat_nested(bc)) == (1, 1, 1, 1, 1, 1, 1, 1, 1, 1)
@test @inferred(Broadcast.materialize(Broadcast.flatten(bc))) == Broadcast.materialize(bc)
bc = Broadcasted(Float32, (Broadcasted(+, (1, 1)),))
@test @inferred(Broadcast.materialize(Broadcast.flatten(bc))) == Broadcast.materialize(bc)
end

let
Expand Down