Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RFC: introduce runtime representation of broadcast fusion #23692

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
225 changes: 225 additions & 0 deletions base/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -859,4 +859,229 @@ macro __dot__(x)
esc(__dot__(x))
end

############################################################
## The parser turns dotted calls into the equivalent Fusion expression.
## Effectively, this turns the Expr tree into a runtime AST,
## for a limited subset of expression types.
#
## For example, in the expression:
# d = sin.((a .+ (b .* c))...)
## The kernel becomes
# d' = Fusion{3}(
# FusionApply(
# sin,
# ( FusionCall(
# +,
# ( FusionArg{1}(),
# FusionCall(
# *,
# ( FusionArg{2}(),
# FusionArg{3}() )), )), )),
# (:a, :b, :c))
## and then the final expansion becomes:
# d = broadcast(d', a, b, c)

struct Fusion{N, vararg#=::Bool=#, T}
f::T
# Debugging Metadata:
# names::NTuple{N, Symbol}
# source::LineNumberNode
function Fusion{N, vararg}(f) where {N, vararg}
return new{N, vararg::Bool, typeof(f)}(f)
end
end

struct FusionArg{N}
end

struct FusionConstant{T}
c::T
function FusionConstant(c) where {}
return new{typeof(c)}(c)
end
end

struct FusionCall{F, Args<:Tuple}
f::F
args::Args
function FusionCall(f, args::Tuple) where {}
return new{typeof(f), typeof(args)}(f, args)
end
end

struct FusionApply{N, F, Args<:NTuple{N, Any}}
f::F
args::Args
function FusionApply(f, args::NTuple{N, Any}) where {N}
return new{N, typeof(f), typeof(args)}(f, args)
end
end

function kw_to_vec(kws::Vector{Any})
kwargs = Vector{Any}(2 * length(kws))
for i in 1:2:length(kws)
kw = kws[i]::Tuple{Any, Any}
kwargs[i] = getfield(kw, 1)
kwargs[i + 1] = getfield(kw, 2)
end
return kwargs
end

struct FusionKWCall{F, Args<:Tuple}
f::F
args::Args
kwargs::Vector{Any}
function FusionKWCall(f, args::Tuple; kwargs...) where {}
return new{typeof(f), typeof(args)}(f, args, kw_to_vec(kwargs))
end
end

struct FusionKWApply{F, Args<:Tuple}
f::F
args::Args
kwargs::Vector{Any}
function FusionKWApply(f, args::Tuple; kwargs...) where {}
return new{typeof(f), typeof(args)}(f, args, kw_to_vec(kwargs))
end
end

function tuplehead(t::Tuple, N::Val)
return ntuple(i -> t[i], N)
end
@generated function tupletail(t::NTuple{M, Any}, ::Val{N}) where {N, M}
# alternative, non-generated versions,
# enable when inference is improved:
#tupletail(t, Nreq) = ntuple(i -> t[i + Nreq], length(t) - Nreq)
#tupletail(t, Nreq) = t[(Nreq + 1):end]
args = Any[ :(getfield(t, $i)) for i in (N + 1):M ]
tpl = Expr(:tuple)
tpl.args = args
return tpl
end

@inline (f::Fusion{N, false})(args::Vararg{Any, N}) where {N} = f.f(args...)
function (f::Fusion{Nreq, true})(args::Vararg{Any, M}) where {Nreq, M}
M >= Nreq || throw(MethodError(f, args))
fargs = tuplehead(args, Val(Nreq))
vararg = tupletail(args, Val(Nreq))
return f.f((fargs..., vararg)...)
end
@inline (f::FusionArg{N})(args...) where {N} = args[N]
@inline (f::FusionConstant)(args...) = f.c
@inline (f::FusionCall)(args...) = f.f(map(a -> a(args...), f.args)...)
# TODO: calling _apply on map _apply is not handled by inference
# for now, we unroll some cases and generate others, to help it out
#@inline (f::FusionApply)(args...) = Core._apply(f.f, map(a -> a(args...), f.args)...)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

equivalently, I think this is also f.f(flatten(map(a -> a(args...), f.args)...)...)
where flatten(args...) = Core._apply(tuple, args...)

@inline (f::FusionApply{0})(args...) = f.f()
@inline (f::FusionApply{1})(args...) = f.f(f.args[1](args...)...)
@inline (f::FusionApply{2})(args...) = f.f(f.args[1](args...)..., f.args[2](args...)...)
@inline (f::FusionApply{3})(args...) = f.f(f.args[1](args...)..., f.args[2](args...)..., f.args[3](args...)...)
@generated function (f::FusionApply{N})(args...) where {N}
fargs = Any[ :(getfield(f.args, $i)(args...)) for i in 1:N ]
return Expr(:call, GlobalRef(Core, :_apply), :(f.f), fargs...)
end
@inline function (f::FusionKWCall)(args...)
fargs = map(a -> a(args...), f.args)
# return f.f(args...; kwargs...)
if isempty(f.kwargs)
return f.f(fargs...)
else
return Core.kwfunc(f.f)(f.kwargs, f.f, fargs...)
end
end
@inline function (f::FusionKWApply)(args...)
fargs = map(a -> a(args...), f.args)
# return Core._apply(f.f, args...; kwargs...)
if isempty(f.kwargs)
return Core._apply(f.f, fargs...)
else
return Core._apply(Core.kwfunc(f.f), (f.kwargs,), (f.f,), fargs...)
end
end

function Base.show(io::IO, f::Fusion{N, vararg}) where {N, vararg}
nargs = (vararg ? N + 1 : N)
names = String[ "a_$i" for i in 1:nargs ] # f.names
print(io, "(")
join(io, names, ", ")
vararg && print(io, "...")
print(io, ") -> ")
show_fusion(io, f.f, names)
end

function show_fusion(io::IO, f::FusionArg{N}, names) where N
print(io, names[N])
nothing
end

function show_fusion(io::IO, f::FusionConstant{N}, names) where N
print(io, f.c)
nothing
end

function show_fusion(io::IO, f::FusionCall, names)
Base.show(io, f.f)
print(io, '(')
first = true
for i in f.args
first || print(io, ", ")
first = false
show_fusion(io, i, names)
end
print(io, ')')
nothing
end

function show_fusion(io::IO, f::FusionApply, names)
print(io, "Core._apply(")
Base.show(io, f.f)
for i in f.args
print(io, ", ")
show_fusion(io, i, names)
end
print(io, ')')
nothing
end

function show_fusion(io::IO, f::FusionKWCall, names)
Base.show(io, f.f)
print(io, '(')
first = true
for i in f.args
first || print(io, ", ")
first = false
show_fusion(io, i, names)
end
print(io, "; ")
first = true
for i in 1:2:length(f.kwargs)
first || print(io, ", ")
first = false
print(io, f.kwargs[i])
print(io, "=")
end
print(io, ')')
nothing
end


function show_fusion(io::IO, f::FusionKWApply, names)
print(io, "Core._apply(")
Base.show(io, f.f)
for i in f.args
print(io, ", ")
show_fusion(io, i, names)
end
print(io, "; #=kwargs=#...)")
nothing
end


function show_fusion(io::IO, @nospecialize(f), names)
print(io, "#= unexpected expression ")
show(io, f)
print(io, " =#")
nothing
end

end # module
2 changes: 1 addition & 1 deletion base/inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import Core: _apply, svec, apply_type, Builtin, IntrinsicFunction, MethodInstanc

#### parameters limiting potentially-infinite types ####
const MAX_TYPEUNION_LEN = 3
const MAX_TYPE_DEPTH = 8
const MAX_TYPE_DEPTH = 10
const TUPLE_COMPLEXITY_LIMIT_DEPTH = 3

const MAX_INLINE_CONST_SIZE = 256
Expand Down
Loading