Skip to content

Commit

Permalink
Make flattened Broadcasted better inlined
Browse files Browse the repository at this point in the history
Similar to #41139, but avoid unnecessary extra methods.
  • Loading branch information
N5N3 committed Mar 12, 2022
1 parent 1f9864d commit 58a9f23
Showing 1 changed file with 52 additions and 35 deletions.
87 changes: 52 additions & 35 deletions base/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -323,13 +323,9 @@ function flatten(bc::Broadcasted{Style}) where {Style}
# makeargs(w, x, y, z) = (w, makeargs1(x, y, z)...)
# = (w, g(x, y), makeargs2(z)...)
# = (w, g(x, y), z)
let (makeargs, args) = make_makeargs((), bc.args), f = bc.f
_make(::NTuple{N,Any}) where {N} =
@inline function (args::Vararg{Any,N})
f(makeargs(args...)...)
end
return Broadcasted{Style}(_make(args), args, bc.axes)
end
headf, args = make_makeargs(bc.args, ())
newf = RootNode(bc.f, headf)
Broadcasted{Style}(newf, args, bc.axes)
end

const NestedTuple = Tuple{<:Broadcasted,Vararg{Any}}
Expand All @@ -339,7 +335,7 @@ _isflat(args::Tuple) = _isflat(tail(args))
_isflat(args::Tuple{}) = true

"""
make_makeargs(args::Tuple, t::Tuple) -> Function, Tuple
make_makeargs(t::Tuple, args::Tuple) -> Function, Tuple
Each element of `t` is one (consecutive) node in a broadcast tree.
`args` contains the rest arguments on the "right" side of `t`.
Expand All @@ -349,36 +345,57 @@ The jobs of `make_makeargs` are:
tuple (each entry corresponding to an entry in `t`, having evaluated
the corresponding element in the broadcast tree).
"""
@inline function make_makeargs(args, t::Tuple{})
_make(::NTuple{N,Any}) where {N} = (args::Vararg{Any,N}) -> args
_make(args), args
end
@inline function make_makeargs(args, t::Tuple)
makeargs, args′ = make_makeargs(args, tail(t))
_make(::NTuple{N,Any}) where {N} =
@inline function (head, tail::Vararg{Any,N})
(head, makeargs(tail...)...)
end
_make(args′), (t[1], args′...)
make_makeargs(::Tuple{}, args) = tuple, args

function make_makeargs(t::Tuple, args)
tailf, args′ = make_makeargs(tail(t), args)
newf = tailf === tuple ? tuple : FlatNode(tailf) # avoid unneeded recursion
newf, (t[1], args′...)
end
function make_makeargs(args, t::Tuple{<:Broadcasted,Vararg{Any}})

function make_makeargs(t::NestedTuple, args)
bc = t[1]
# c.f. the same expression in the function on leaf nodes above. Here
# we recurse into siblings in the broadcast tree.
let (makeargs, args′) = make_makeargs(args, tail(t)), f = bc.f
# Here we recurse into children. We can pass in `args′` here,
# and get `args″` directly, but it is more compiler frendly to
# treat `bc` as a new parent "node".
makeargs_head, argsˢ = make_makeargs((), bc.args)
args″ = (argsˢ..., args′...)
_make(::NTuple{L,Any}, ::NTuple{N,Any}) where {L,N} =
@inline function (args::Vararg{Any,N})
a, b = Base.IteratorsMD.split(args, Val(L)) # split `args...` directly
(f(makeargs_head(a...)...), makeargs(b...)...)
end
_make(argsˢ, args″), args″
end
# Here we recurse into siblings in the broadcast tree.
tailf, args′ = make_makeargs(tail(t), args)
# Here we recurse into children.
# It is more compiler frendly to treat `bc` as a new parent "node".
headf, argsˢ = make_makeargs(bc.args, ())
NestedNode{length(argsˢ)}(bc.f, headf, tailf), (argsˢ..., args′...)
end

# Some help structs to flatten `Broadcasted`.
# TODO: make them better printed in REPL.
struct RootNode{F,H} <: Function
f::F
prepare::H
end
RootNode(::Type{F}, prepare::H) where {F,H} = RootNode{Type{F},H}(F, prepare)
@inline (f::RootNode)(args::Vararg{Any}) = f.f(f.prepare(args...)...)

struct FlatNode{T} <: Function
rest::T
end
@inline (f::FlatNode)(x, args::Vararg{Any}) = (x, f.rest(args...)...)

struct NestedNode{L,F,H,T} <: Function
f::F
prepare::H
rest::T
end
NestedNode{L}(f::F, prepare::H, rest::T) where {L,F,T,H} = NestedNode{L,F,H,T}(f, prepare, rest)
NestedNode{L}(::Type{F}, prepare::H, rest::T) where {L,F,T,H} = NestedNode{L,Type{F},H,T}(F, prepare, rest)

# Specialize small `L` manually.
@inline (f::NestedNode{1})(x, args::Vararg{Any}) = (f.f(f.prepare(x)...), f.rest(args...)...)
@inline (f::NestedNode{2})(x1, x2, args::Vararg{Any}) = (f.f(f.prepare(x1, x2)...), f.rest(args...)...)
@inline (f::NestedNode{3})(x1, x2, x3, args::Vararg{Any}) = (f.f(f.prepare(x1, x2, x3)...), f.rest(args...)...)
@inline (f::NestedNode{4})(x1, x2, x3, x4, args::Vararg{Any}) = (f.f(f.prepare(x1, x2, x3, x4)...), f.rest(args...)...)
# Split based fallback.
@inline function (f::NestedNode{L})(args::Vararg{Any}) where {L}
head, tail = Base.IteratorsMD.split(args, Val(L))
(f.f(f.prepare(head...)...), f.rest(tail...)...)
end

## Broadcasting utilities ##

## logic for deciding the BroadcastStyle
Expand Down

0 comments on commit 58a9f23

Please sign in to comment.