Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Inference cache #405

Merged
merged 6 commits into from
Dec 23, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 42 additions & 65 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,29 @@ end

const DEBUG_INTERP = Ref(false)

# Rewrite type unstable calls to recurse into call_with_reactant to ensure
# they continue to use our interpreter. Reset the derived return type
# to Any if our interpreter would change the return type of any result.
# Also rewrite invoke (type stable call) to be :call, since otherwise apparently
# screws up type inference after this (TODO this should be fixed).
function rewrite_insts!(ir, interp)
any_changed = false
for (i, inst) in enumerate(ir.stmts)
@static if VERSION < v"1.11"
changed, next = rewrite_inst(inst[:inst], ir, interp)
Core.Compiler.setindex!(ir.stmts[i], next, :inst)
else
changed, next = rewrite_inst(inst[:stmt], ir, interp)
Core.Compiler.setindex!(ir.stmts[i], next, :stmt)
end
if changed
any_changed = true
Core.Compiler.setindex!(ir.stmts[i], Any, :type)
end
end
return ir, any_changed
end

# Generator function which ensures that all calls to the function are executed within the ReactantInterpreter
# In particular this entails two pieces:
# 1) We enforce the use of the ReactantInterpreter method table when generating the original methodinstance
Expand Down Expand Up @@ -320,72 +343,28 @@ function call_with_reactant_generator(
match.spec_types,
match.sparams,
)

result = Core.Compiler.InferenceResult(mi, Core.Compiler.typeinf_lattice(interp))
frame = Core.Compiler.InferenceState(result, VERSION < v"1.11-" ? :local : :no, interp) #=cache_mode=#
@assert frame !== nothing
Core.Compiler.typeinf(interp, frame)
@static if VERSION >= v"1.11"
# `typeinf` doesn't update the cfg. We need to do it manually.
# frame.cfg = Core.Compiler.compute_basic_blocks(frame.src.code)
end
@assert Core.Compiler.is_inferred(frame)

method = match.method

# The original julia code (on 1.11+) has the potential constprop, for now
# we assume this outermost function does not constprop, for ease.
#if Core.Compiler.result_is_constabi(interp, frame.result)
# rt = frame.result.result::Core.Compiler.Const
# src = Core.Compiler.codeinfo_for_const(interp, frame.linfo, rt.val)
#else
#
opt = Core.Compiler.OptimizationState(frame, interp)

if DEBUG_INTERP[]
safe_print("opt.src", opt.src)
end

caller = frame.result
@static if VERSION < v"1.11-"
ir = Core.Compiler.run_passes(opt.src, opt, caller)
method = mi.def

@static if VERSION < v"1.11"
# For older Julia versions, we vendor in some of the code to prevent
# having to build the MethodInstance twice.
result = CC.InferenceResult(mi, CC.typeinf_lattice(interp))
frame = CC.InferenceState(result, :no, interp)
@assert !isnothing(frame)
CC.typeinf(interp, frame)
ir = CC.run_passes(frame.src, CC.OptimizationState(frame, interp), result, nothing)
rt = CC.widenconst(CC.ignorelimited(result.result))
else
ir = Core.Compiler.run_passes_ipo_safe(opt.src, opt, caller)
@static if VERSION < v"1.12-"
else
Core.Compiler.ipo_dataflow_analysis!(interp, ir, caller)
end
end

if DEBUG_INTERP[]
safe_print("ir1", ir)
end

# Rewrite type unstable calls to recurse into call_with_reactant to ensure
# they continue to use our interpreter. Reset the derived return type
# to Any if our interpreter would change the return type of any result.
# Also rewrite invoke (type stable call) to be :call, since otherwise apparently
# screws up type inference after this (TODO this should be fixed).
any_changed = false
if should_rewrite_ft(args[1]) && !is_reactant_method(mi)
for (i, inst) in enumerate(ir.stmts)
@static if VERSION < v"1.11"
changed, next = rewrite_inst(inst[:inst], ir, interp)
Core.Compiler.setindex!(ir.stmts[i], next, :inst)
else
changed, next = rewrite_inst(inst[:stmt], ir, interp)
Core.Compiler.setindex!(ir.stmts[i], next, :stmt)
end
if changed
any_changed = true
Core.Compiler.setindex!(ir.stmts[i], Any, :type)
end
end
ir, rt = CC.typeinf_ircode(interp, mi, nothing)
end

Core.Compiler.finish(interp, opt, ir, caller)

src = Core.Compiler.ir_to_codeinf!(opt)
ir, any_changed = rewrite_insts!(ir, interp)
src = ccall(:jl_new_code_info_uninit, Ref{CC.CodeInfo}, ())
src.slotnames = fill(:none, length(ir.argtypes) + 1)
src.slotflags = fill(zero(UInt8), length(ir.argtypes))
src.slottypes = copy(ir.argtypes)
src.rettype = rt
src = CC.ir_to_codeinf!(src, ir)

if DEBUG_INTERP[]
safe_print("src", src)
Expand Down Expand Up @@ -488,8 +467,6 @@ function call_with_reactant_generator(
end
end

rt = Base.Experimental.compute_ir_rettype(ir)

# ocva = method.isva

ocva = false # method.isva
Expand Down
Loading