Skip to content

Commit 2759c3c

Browse files
jumerckxgithub-actions[bot]Jules Merckx
authored
Inference cache (EnzymeAD#405)
* add inference cache * start from `typeinf_ircode` * julia 1.10 * Apply formatting suggestions Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * remove debug logging * vendor in type inference code for v1.10 To avoid having to build a MethodInstance twice (performance hazard) --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: Jules Merckx <jumerckx@mac.local>
1 parent f9c43ad commit 2759c3c

File tree

1 file changed

+42
-65
lines changed

1 file changed

+42
-65
lines changed

src/utils.jl

Lines changed: 42 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -259,6 +259,29 @@ end
259259

260260
const DEBUG_INTERP = Ref(false)
261261

262+
# Rewrite type unstable calls to recurse into call_with_reactant to ensure
263+
# they continue to use our interpreter. Reset the derived return type
264+
# to Any if our interpreter would change the return type of any result.
265+
# Also rewrite invoke (type stable call) to be :call, since otherwise apparently
266+
# screws up type inference after this (TODO this should be fixed).
267+
function rewrite_insts!(ir, interp)
268+
any_changed = false
269+
for (i, inst) in enumerate(ir.stmts)
270+
@static if VERSION < v"1.11"
271+
changed, next = rewrite_inst(inst[:inst], ir, interp)
272+
Core.Compiler.setindex!(ir.stmts[i], next, :inst)
273+
else
274+
changed, next = rewrite_inst(inst[:stmt], ir, interp)
275+
Core.Compiler.setindex!(ir.stmts[i], next, :stmt)
276+
end
277+
if changed
278+
any_changed = true
279+
Core.Compiler.setindex!(ir.stmts[i], Any, :type)
280+
end
281+
end
282+
return ir, any_changed
283+
end
284+
262285
# Generator function which ensures that all calls to the function are executed within the ReactantInterpreter
263286
# In particular this entails two pieces:
264287
# 1) We enforce the use of the ReactantInterpreter method table when generating the original methodinstance
@@ -320,72 +343,28 @@ function call_with_reactant_generator(
320343
match.spec_types,
321344
match.sparams,
322345
)
323-
324-
result = Core.Compiler.InferenceResult(mi, Core.Compiler.typeinf_lattice(interp))
325-
frame = Core.Compiler.InferenceState(result, VERSION < v"1.11-" ? :local : :no, interp) #=cache_mode=#
326-
@assert frame !== nothing
327-
Core.Compiler.typeinf(interp, frame)
328-
@static if VERSION >= v"1.11"
329-
# `typeinf` doesn't update the cfg. We need to do it manually.
330-
# frame.cfg = Core.Compiler.compute_basic_blocks(frame.src.code)
331-
end
332-
@assert Core.Compiler.is_inferred(frame)
333-
334-
method = match.method
335-
336-
# The original julia code (on 1.11+) has the potential constprop, for now
337-
# we assume this outermost function does not constprop, for ease.
338-
#if Core.Compiler.result_is_constabi(interp, frame.result)
339-
# rt = frame.result.result::Core.Compiler.Const
340-
# src = Core.Compiler.codeinfo_for_const(interp, frame.linfo, rt.val)
341-
#else
342-
#
343-
opt = Core.Compiler.OptimizationState(frame, interp)
344-
345-
if DEBUG_INTERP[]
346-
safe_print("opt.src", opt.src)
347-
end
348-
349-
caller = frame.result
350-
@static if VERSION < v"1.11-"
351-
ir = Core.Compiler.run_passes(opt.src, opt, caller)
346+
method = mi.def
347+
348+
@static if VERSION < v"1.11"
349+
# For older Julia versions, we vendor in some of the code to prevent
350+
# having to build the MethodInstance twice.
351+
result = CC.InferenceResult(mi, CC.typeinf_lattice(interp))
352+
frame = CC.InferenceState(result, :no, interp)
353+
@assert !isnothing(frame)
354+
CC.typeinf(interp, frame)
355+
ir = CC.run_passes(frame.src, CC.OptimizationState(frame, interp), result, nothing)
356+
rt = CC.widenconst(CC.ignorelimited(result.result))
352357
else
353-
ir = Core.Compiler.run_passes_ipo_safe(opt.src, opt, caller)
354-
@static if VERSION < v"1.12-"
355-
else
356-
Core.Compiler.ipo_dataflow_analysis!(interp, ir, caller)
357-
end
358-
end
359-
360-
if DEBUG_INTERP[]
361-
safe_print("ir1", ir)
362-
end
363-
364-
# Rewrite type unstable calls to recurse into call_with_reactant to ensure
365-
# they continue to use our interpreter. Reset the derived return type
366-
# to Any if our interpreter would change the return type of any result.
367-
# Also rewrite invoke (type stable call) to be :call, since otherwise apparently
368-
# screws up type inference after this (TODO this should be fixed).
369-
any_changed = false
370-
if should_rewrite_ft(args[1]) && !is_reactant_method(mi)
371-
for (i, inst) in enumerate(ir.stmts)
372-
@static if VERSION < v"1.11"
373-
changed, next = rewrite_inst(inst[:inst], ir, interp)
374-
Core.Compiler.setindex!(ir.stmts[i], next, :inst)
375-
else
376-
changed, next = rewrite_inst(inst[:stmt], ir, interp)
377-
Core.Compiler.setindex!(ir.stmts[i], next, :stmt)
378-
end
379-
if changed
380-
any_changed = true
381-
Core.Compiler.setindex!(ir.stmts[i], Any, :type)
382-
end
383-
end
358+
ir, rt = CC.typeinf_ircode(interp, mi, nothing)
384359
end
385360

386-
Core.Compiler.finish(interp, opt, ir, caller)
387-
388-
src = Core.Compiler.ir_to_codeinf!(opt)
361+
ir, any_changed = rewrite_insts!(ir, interp)
362+
src = ccall(:jl_new_code_info_uninit, Ref{CC.CodeInfo}, ())
363+
src.slotnames = fill(:none, length(ir.argtypes) + 1)
364+
src.slotflags = fill(zero(UInt8), length(ir.argtypes))
365+
src.slottypes = copy(ir.argtypes)
366+
src.rettype = rt
367+
src = CC.ir_to_codeinf!(src, ir)
389368

390369
if DEBUG_INTERP[]
391370
safe_print("src", src)
@@ -488,8 +467,6 @@ function call_with_reactant_generator(
488467
end
489468
end
490469

491-
rt = Base.Experimental.compute_ir_rettype(ir)
492-
493470
# ocva = method.isva
494471

495472
ocva = false # method.isva

0 commit comments

Comments
 (0)