Skip to content

Commit

Permalink
Support semi concrete eval (#319)
Browse files Browse the repository at this point in the history
Co-authored-by: Shuhei Kadowaki <aviatesk@gmail.com>
Co-authored-by: Keno Fisher <keno@juliacomputing.com>
  • Loading branch information
3 people authored Sep 5, 2022
1 parent fc6d5b1 commit 6a5a21b
Show file tree
Hide file tree
Showing 6 changed files with 115 additions and 30 deletions.
20 changes: 17 additions & 3 deletions src/Cthulhu.jl
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,7 @@ end
@static if VERSION v"1.9.0-DEV.409"
get_effects(result::CC.ConstPropResult) = get_effects(result.result)
get_effects(result::CC.ConcreteResult) = result.effects
get_effects(result::CC.SemiConcreteResult) = result.effects
else
get_effects(result::CC.ConstResult) = result.effects
end
Expand Down Expand Up @@ -380,8 +381,19 @@ function lookup_constproped_unoptimized(interp::CthulhuInterpreter, override::In
return (; src, rt, infos, slottypes, effects, codeinf)
end

function lookup_semiconcrete(interp::CthulhuInterpreter, override::SemiConcreteCallInfo, optimize::Bool)
src = CC.copy(override.ir)
rt = get_rt(override)
infos = src.stmts.info
slottypes = src.argtypes
effects = get_effects(override)
(; codeinf) = lookup(interp, get_mi(override), optimize)
return (; src, rt, infos, slottypes, effects, codeinf)
end

function get_override(@nospecialize(info))
isa(info, ConstPropCallInfo) && return info.result
isa(info, SemiConcreteCallInfo) && return info
isa(info, OCCallInfo) && return get_override(info.ci)
return nothing
end
Expand All @@ -392,7 +404,7 @@ end
# src/ui.jl provides the user facing interface to which _descend responds
##
function _descend(term::AbstractTerminal, interp::AbstractInterpreter, curs::AbstractCursor;
override::Union{Nothing,InferenceResult} = nothing,
override::Union{Nothing,InferenceResult,SemiConcreteCallInfo} = nothing,
debuginfo::Union{Symbol,DebugInfo} = CONFIG.debuginfo, # default is compact debuginfo
optimize::Bool = CONFIG.optimize, # default is true
interruptexc::Bool = CONFIG.interruptexc,
Expand Down Expand Up @@ -434,8 +446,10 @@ function _descend(term::AbstractTerminal, interp::AbstractInterpreter, curs::Abs
""")
end
while true
if override !== nothing
if isa(override, InferenceResult)
(; src, rt, infos, slottypes, codeinf, effects) = lookup_constproped(interp, curs, override, optimize)
elseif isa(override, SemiConcreteCallInfo)
(; src, rt, infos, slottypes, codeinf, effects) = lookup_semiconcrete(interp, curs, override, optimize)
else
if optimize
codeinst = get_optimized_codeinst(interp, curs)
Expand Down Expand Up @@ -466,7 +480,7 @@ function _descend(term::AbstractTerminal, interp::AbstractInterpreter, curs::Abs
end
mi = get_mi(curs)
src = preprocess_ci!(src, mi, optimize, CONFIG)
if optimize # optimization might have deleted some statements
if optimize || isa(src, IRCode) # optimization might have deleted some statements
infos = src.stmts.info
else
@assert length(src.code) == length(infos)
Expand Down
20 changes: 20 additions & 0 deletions src/callsite.jl
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,14 @@ get_mi(ceci::ConcreteCallInfo) = get_mi(ceci.mi)
get_rt(ceci::ConcreteCallInfo) = get_rt(ceci.mi)
get_effects(ceci::ConcreteCallInfo) = get_effects(ceci.mi)

struct SemiConcreteCallInfo <: CallInfo
mi::CallInfo
ir::IRCode
end
get_mi(scci::SemiConcreteCallInfo) = get_mi(scci.mi)
get_rt(scci::SemiConcreteCallInfo) = get_rt(scci.mi)
get_effects(scci::SemiConcreteCallInfo) = get_effects(scci.mi)

# CUDA callsite
struct CuCallInfo <: CallInfo
cumi::MICallInfo
Expand Down Expand Up @@ -309,6 +317,13 @@ function show_callinfo(limiter, ci::ConstPropCallInfo)
__show_limited(limiter, name, tt, get_rt(ci), get_effects(ci))
end

function show_callinfo(limiter, ci::SemiConcreteCallInfo)
# XXX: The first argument could be const-overriden too
name = get_mi(ci).def.name
tt = ci.ir.argtypes[2:end]
__show_limited(limiter, name, tt, get_rt(ci), get_effects(ci))
end

function show_callinfo(limiter, ci::ConcreteCallInfo)
# XXX: The first argument could be const-overriden too
name = get_mi(ci).def.name
Expand Down Expand Up @@ -373,6 +388,11 @@ function print_callsite_info(limiter::IO, info::ConstPropCallInfo)
show_callinfo(limiter, info)
end

function print_callsite_info(limiter::IO, info::SemiConcreteCallInfo)
print(limiter, " = < semi-concrete eval > ")
show_callinfo(limiter, info)
end

function print_callsite_info(limiter::IO, info::ConcreteCallInfo)
print(limiter, "< concrete eval > ")
show_callinfo(limiter, info)
Expand Down
7 changes: 7 additions & 0 deletions src/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,13 @@ missing `$AbstractCursor` API:
lookup_constproped(interp::CthulhuInterpreter, ::CthulhuCursor, override::InferenceResult, optimize::Bool) =
lookup_constproped(interp, override, optimize)

lookup_semiconcrete(interp::AbstractInterpreter, curs::AbstractCursor, override::SemiConcreteCallInfo, optimize::Bool) = error(lazy"""
missing `$AbstractCursor` API:
`$(typeof(curs))` is required to implement the `$lookup_semicocnrete(interp::$(typeof(interp)), curs::$(typeof(curs)), override::SemiConcreteCallInfo, optimize::Bool)` interface.
""")
lookup_semiconcrete(interp::CthulhuInterpreter, ::CthulhuCursor, override::SemiConcreteCallInfo, optimize::Bool) =
lookup_semiconcrete(interp, override, optimize)

get_mi(curs::AbstractCursor) = error(lazy"""
missing `$AbstractCursor` API:
`$(typeof(curs))` is required to implement the `$get_mi(curs::$(typeof(curs))) -> MethodInstance` interface.
Expand Down
7 changes: 7 additions & 0 deletions src/interpreter.jl
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,13 @@ function CC.inlining_policy(interp::CthulhuInterpreter)
end
end # @static if isdefined(CC, :is_stmt_inline)

@static if isdefined(CC, :codeinst_to_ir)
function CC.codeinst_to_ir(interp::CthulhuInterpreter, code::CodeInstance)
isa(code.inferred, Nothing) && return nothing
return CC.copy((code.inferred::OptimizedSource).ir)
end
end # @static if isdefined(CC, :codeinst_to_ir)

function CC.finish!(interp::CthulhuInterpreter, caller::InferenceResult)
effects = EFFECTS_ENABLED ? caller.ipo_effects : nothing
caller.src = create_cthulhu_source(caller.src, effects)
Expand Down
5 changes: 5 additions & 0 deletions src/reflection.jl
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,11 @@ function process_const_info(interp::AbstractInterpreter, @nospecialize(thisinfo)
effects = get_effects(result)
mici = MICallInfo(linfo, rt, effects)
return ConstPropCallInfo(is_cached(optimize ? linfo : result) ? mici : UncachedCallInfo(mici), result)
elseif (@static isdefined(CC, :SemiConcreteResult) && true) && isa(result, CC.SemiConcreteResult)
linfo = result.mi
effects = get_effects(result)
mici = MICallInfo(linfo, rt, effects)
return SemiConcreteCallInfo(mici, result.ir)
elseif (@static isdefined(CC, :ConstResult) && true) && isa(result, CC.ConstResult)
linfo = result.mi
effects = get_effects(result)
Expand Down
86 changes: 59 additions & 27 deletions test/test_Cthulhu.jl
Original file line number Diff line number Diff line change
Expand Up @@ -237,32 +237,59 @@ end
@test occursin("= < constprop > getproperty(", String(take!(io)))
end
end
end

@static isdefined(Core.Compiler, :ConstResult) && @testset "ConstResult" begin
# constant prop' on all the splits
let callsites = (@eval Module() begin
Base.@assume_effects :terminates_locally function issue41694(x)
res = 1
1 < x < 20 || throw("bad")
while x > 1
res *= x
x -= 1
end
return res
end
Base.@assume_effects :terminates_locally function issue41694(x)
res = 1
1 < x < 20 || throw("bad")
while x > 1
res *= x
x -= 1
end
return res
end
@static isdefined(Core.Compiler, :ConstResult) && @testset "ConstResult" begin
# constant prop' on all the splits
let callsites = find_callsites_by_ftt(; optimize = false) do
issue41694(12)
end
callinfo = only(callsites).info
@test isa(callinfo, Cthulhu.ConcreteCallInfo)
@test Cthulhu.get_rt(callinfo) == Core.Const(factorial(12))
@test Cthulhu.get_effects(callinfo) |> Core.Compiler.is_foldable
io = IOBuffer()
print(io, only(callsites))
@test occursin("= < concrete eval > issue41694(::Core.Const(12))", String(take!(io)))
end
end

$find_callsites_by_ftt(; optimize = false) do
issue41694(12)
end
end)
callinfo = only(callsites).info
@test isa(callinfo, Cthulhu.ConcreteCallInfo)
@test Cthulhu.get_rt(callinfo) == Core.Const(factorial(12))
@test Cthulhu.get_effects(callinfo) |> Core.Compiler.is_foldable
io = IOBuffer()
print(io, only(callsites))
@test occursin("= < concrete eval > issue41694(::Core.Const(12))", String(take!(io)))
let # check the performance benefit of semi concrete evaluation
param = 1000
ex = Expr(:block)
var = gensym()
push!(ex.args, :($var = x))
for _ = 1:param
newvar = gensym()
push!(ex.args, :($newvar = sin($var)))
var = newvar
end
@eval global Base.@constprop :aggressive Base.@assume_effects :nothrow function semi_concrete_eval(x::Int, _::Int)
out = $ex
out
end
end
@static isdefined(Core.Compiler, :SemiConcreteResult) && @testset "SemiConcreteResult" begin
# constant prop' on all the splits
let callsites = find_callsites_by_ftt((Int,); optimize = false) do x
semi_concrete_eval(42, x)
end
callinfo = only(callsites).info
@test isa(callinfo, Cthulhu.SemiConcreteCallInfo)
@test Cthulhu.get_rt(callinfo) == Core.Const(semi_concrete_eval(42, 0))
# @test Cthulhu.get_effects(callinfo) |> Core.Compiler.is_semiconcrete_eligible
io = IOBuffer()
print(io, only(callsites))
@test occursin("= < semi-concrete eval > semi_concrete_eval(::Core.Const(42),::$Int)", String(take!(io)))
end
end

Expand Down Expand Up @@ -402,7 +429,7 @@ invoke_constcall(a::Number, c::Bool) = c ? Number : :number
@test info.ci.rt === Core.Compiler.Const(:Int)
end

# const prop' callsite
# const prop' / semi-concrete callsite
@static hasfield(Core.Compiler.InvokeCallInfo, :result) && let callsites = find_callsites_by_ftt((Any,); optimize=false) do a
Base.@invoke invoke_constcall(a::Any, true::Bool)
end
Expand All @@ -411,12 +438,17 @@ invoke_constcall(a::Number, c::Bool) = c ? Number : :number
@test isa(info, Cthulhu.InvokeCallInfo)
@static Cthulhu.EFFECTS_ENABLED && Cthulhu.get_effects(info) |> Core.Compiler.is_total
inner = info.ci
@test isa(inner, Cthulhu.ConstPropCallInfo)
rt = Core.Compiler.Const(Any)
@test inner.result.result === rt
@test Cthulhu.get_rt(info) === rt
buf = IOBuffer()
show(buf, callsite)
@test occursin("= invoke < invoke_constcall(::Any,::$(Core.Compiler.Const(true)))::$rt", String(take!(buf)))
@static if isdefined(Core.Compiler, :SemiConcreteResult)
@test isa(inner, Cthulhu.SemiConcreteCallInfo)
@test occursin("= invoke < invoke_constcall(::Any,::$(Core.Compiler.Const(true)))::$rt", String(take!(buf)))
else
@test isa(inner, Cthulhu.ConstPropCallInfo)
@test occursin("= invoke < invoke_constcall(::Any,::$(Core.Compiler.Const(true)))::$rt", String(take!(buf)))
end
end
end

Expand Down

0 comments on commit 6a5a21b

Please sign in to comment.