Skip to content

Commit

Permalink
WIP: Re-target IPO path to Jameson's branch
Browse files Browse the repository at this point in the history
  • Loading branch information
Keno committed Nov 14, 2024
1 parent 24b782b commit faf7d3f
Show file tree
Hide file tree
Showing 8 changed files with 94 additions and 107 deletions.
7 changes: 2 additions & 5 deletions Manifest.toml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

13 changes: 10 additions & 3 deletions src/analysis/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -461,8 +461,12 @@ function dae_result_for_inst(interp, inst::CC.Instruction)
return result isa Union{DAEIPOResult, UncompilableIPOResult} ? result : nothing
end
else
codeinst = CC.get(CC.code_cache(interp), mi, nothing)
codeinst === nothing && return nothing
if isa(mi, MethodInstance)
codeinst = CC.get(CC.code_cache(interp), mi, nothing)
codeinst === nothing && return nothing
else
codeinst = mi::CodeInstance
end
result = CC.traverse_analysis_results(codeinst) do @nospecialize result
return result isa Union{DAEIPOResult, UncompilableIPOResult} ? result : nothing
end
Expand Down Expand Up @@ -997,8 +1001,11 @@ end
return result
end

mi_or_ci = stmt.args[1]
isva = (isa(mi_or_ci, CodeInstance) ? mi_or_ci.def.def : mi_or_ci.def).isva

callee_argtypes = CC.va_process_argtypes(CC.optimizer_lattice(analysis_interp),
CC.collect_argtypes(analysis_interp, stmt.args[2:end], nothing, irsv), UInt(length(result.argtypes)), stmt.args[1].def.isva)
CC.collect_argtypes(analysis_interp, stmt.args[2:end], nothing, irsv), UInt(length(result.argtypes)), isva)
mapping = CalleeMapping(CC.optimizer_lattice(analysis_interp), callee_argtypes, result)
end
append!(warnings, result.warnings)
Expand Down
79 changes: 38 additions & 41 deletions src/analysis/interpreter.jl
Original file line number Diff line number Diff line change
Expand Up @@ -170,54 +170,47 @@ end
#=CC.=#get_inference_world(interp::DAEInterpreter) = interp.world
CC.get_inference_cache(interp::DAEInterpreter) = interp.inf_cache

if Base.__has_internal_change(v"1.12-alpha", :methodspecialization)
"""
struct AnalysisSpec
The cache partition for DAECompiler analysis results. This is essentially
equivalent to a regular type inference, except that optimization ir prohibited
from inlining any functions that have frules and we perform the DAE analysis
on every ir after optimization.
"""
struct AnalysisSpec; end

"""
struct RHSSpec
Cache partition for the RHS
"""
struct RHSSpec
key::TornCacheKey
ordinal::Int
end
"""
struct AnalysisSpec
Base.show(io::IO, ms::MethodSpecialization{RHSSpec}) = print(io, "RHS Spec#$(ms.data.ordinal) for ", ms.def)
The cache partition for DAECompiler analysis results. This is essentially
equivalent to a regular type inference, except that optimization ir prohibited
from inlining any functions that have frules and we perform the DAE analysis
on every ir after optimization.
"""
struct AnalysisSpec; end

"""
struct RHSSpec
Cache partition for the RHS
"""
struct RHSSpec
key::TornCacheKey
ordinal::Int
end

"""
struct SICMSpec

Cache partition for the state-invariant prologue
"""
"""
struct SICMSpec
key::TornCacheKey
end
Base.show(io::IO, ms::MethodSpecialization{SICMSpec}) = print(io, "SICM Spec for ", ms.def)
Cache partition for the state-invariant prologue
"""
struct SICMSpec
key::TornCacheKey
end

function CC.code_cache(interp::DAEInterpreter)
if interp.ipo_analysis_mode
return CC.WorldView(
CC.InternalCodeCache(Core.MethodSpecialization{AnalysisSpec}),
CC.WorldRange(CC.get_inference_world(interp)))
else
return interp.code_cache
end

function CC.code_cache(interp::DAEInterpreter)
if interp.ipo_analysis_mode
return CC.WorldView(
CC.InternalCodeCache(AnalysisSpec()),
CC.WorldRange(CC.get_inference_world(interp)))
else
return interp.code_cache
end
else
CC.cache_owner(interp::DAEInterpreter) = interp.code_cache
CC.method_table(interp::DAEInterpreter) = interp.method_table
end
CC.cache_owner(interp::DAEInterpreter) = interp.ipo_analysis_mode ? AnalysisSpec() : interp.code_cache

# abstract interpretation
# -----------------------
Expand Down Expand Up @@ -813,10 +806,14 @@ struct MappingInfo <: CC.CallInfo
end

function _abstract_eval_invoke_inst(interp::DAEInterpreter, inst::Union{CC.Instruction, Nothing}, @nospecialize(stmt), irsv::IRInterpretationState)
mi = stmt.args[1]
invokee = stmt.args[1]
RT = Pair{Any, Tuple{Bool, Bool}}
good_effects = (true, true)
m = mi.def
if isa(invokee, Core.CodeInstance)
m = invokee.def.def
else
m = invokee.def
end
if m === variable_method0 || m === variable_method1
# Nothing to do - we'll read the incidence out of the ssavaluetypes
return RT(nothing, good_effects)
Expand Down
10 changes: 2 additions & 8 deletions src/cache.jl
Original file line number Diff line number Diff line change
Expand Up @@ -58,12 +58,6 @@ end
CalleeInternal
end

@static if !Base.__has_internal_change(v"1.12-alpha", :methodspecialization)
const MethodSpecialization = Core.MethodInstance
else
import Core: MethodSpecialization
end

struct DAEIPOResult
ir::IRCode
extended_rt::Any
Expand All @@ -90,8 +84,8 @@ struct DAEIPOResult
tearing_cache::Dict{TornCacheKey, TornIR}

# TODO: Should this be looked up via the regular code instance cache instead?
sicm_cache::Dict{TornCacheKey, MethodSpecialization}
dae_finish_cache::Dict{TornCacheKey, MethodSpecialization}
sicm_cache::Dict{TornCacheKey, CodeInstance}
dae_finish_cache::Dict{TornCacheKey, Vector{CodeInstance}}
end

struct UncompilableIPOResult
Expand Down
11 changes: 8 additions & 3 deletions src/transform/common.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,11 @@ function compile_invokes!(ir, interp)
inst = ir.stmts[i]
e = inst[:inst]
if isexpr(e, :invoke)
mi = e.args[1]::MethodInstance
if !CC.haskey(CC.code_cache(interp), mi)
CC.typeinf_ext_toplevel(interp, mi, CC.SOURCE_MODE_ABI)
mi = e.args[1]
if isa(mi, MethodInstance)
if !CC.haskey(CC.code_cache(interp), mi)
CC.typeinf_ext_toplevel(interp, mi, CC.SOURCE_MODE_ABI)
end
end
end
end
Expand Down Expand Up @@ -212,6 +214,9 @@ function check_for_daecompiler_intrinstics(ir::IRCode)
inst = ir[SSAValue(i)][:inst]
isexpr(inst, :invoke) || continue
mi = inst.args[1]
if isa(mi, CodeInstance)
mi = mi.def
end
if mi.def.module == DAECompiler.Intrinsics
throw(UnexpectedIntrinsicException(inst))
end
Expand Down
51 changes: 25 additions & 26 deletions src/transform/dae_finish.jl
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,19 @@ end

const VectorViewType = SubArray{Float64, 1, Vector{Float64}, Tuple{UnitRange{Int}}, true}

function cache_dae_ci!(old_ci, src, debuginfo, owner)
daef_ci = CC.engine_reserve(old_ci.def, owner)
ccall(:jl_fill_codeinst, Cvoid, (Any, Any, Any, Any, Int32, UInt, UInt, UInt32, Any, Any, Any),
daef_ci, Tuple{}, Union{}, nothing, Int32(0),
UInt(1)#=ci.min_world=#, old_ci.max_world,
old_ci.ipo_purity_bits, nothing, nothing, CC.empty_edges)
ccall(:jl_update_codeinst, Cvoid, (Any, Any, Int32, UInt, UInt, UInt32, Any, UInt8, Any, Any),
daef_ci, src, Int32(0), UInt(1)#=ci.min_world=#, old_ci.max_world, old_ci.ipo_purity_bits,
nothing, 0x0, debuginfo, CC.empty_edges)
ccall(:jl_mi_cache_insert, Cvoid, (Any, Any), old_ci.def, daef_ci)
return daef_ci
end

function dae_finish_ipo!(
interp,
ci::CodeInstance,
Expand All @@ -288,6 +301,7 @@ function dae_finish_ipo!(
old_daef_mi = nothing
assigned_slots = falses(length(result.total_incidence))

cis = Vector{CodeInstance}()
for (ir_ordinal, ir) in enumerate(torn.ir_seq)
ir = torn.ir_seq[ir_ordinal]

Expand Down Expand Up @@ -338,11 +352,11 @@ function dae_finish_ipo!(
spec_data = stmt.args[1]
callee_key = stmt.args[1][2]
callee_ordinal = stmt.args[1][end]::Int
callee_daef_mi = dae_finish_ipo!(interp, callee_ci, callee_key, callee_ordinal)
callee_daef_cis = dae_finish_ipo!(interp, callee_ci, callee_key, callee_ordinal)
# Allocate a continuous block of variables for all callee alg and diff states

empty!(stmt.args)
push!(stmt.args, callee_daef_mi)
push!(stmt.args, callee_daef_cis[1])
push!(stmt.args, closure_env)
push!(stmt.args, in_vars)

Expand Down Expand Up @@ -402,33 +416,17 @@ function dae_finish_ipo!(
widen_extra_info!(ir)
src = ir_to_src(ir)

daef_mi = MethodSpecialization{RHSSpec}(ci.def, Tuple{}, Tuple{Tuple, Tuple, (VectorViewType for _ in arg_range)..., Float64})
daef_mi.data = RHSSpec(key, ir_ordinal)
abi = Tuple{Tuple, Tuple, (VectorViewType for _ in arg_range)..., Float64}
owner = Core.ABIOverwrite(abi, RHSSpec(key, ir_ordinal))
daef_ci = cache_dae_ci!(ci, src, src.debuginfo, owner)

daef_ci = CodeInstance(daef_mi, Tuple, Union{}, nothing,
src, Int32(0), UInt(1)#=ci.min_world=#, ci.max_world, ci.ipo_purity_bits, ci.purity_bits,
nothing, 0x0, src.debuginfo)

@atomic :release daef_mi.cache = daef_ci
global nrhscompiles += 1

if old_daef_mi !== nothing
@atomic :release old_daef_mi.next = daef_mi
end
old_daef_mi = daef_mi

if rhs_ms === nothing
rhs_ms = daef_mi
end
push!(cis, daef_ci)
end

result.dae_finish_cache[key] = rhs_ms
result.dae_finish_cache[key] = cis

ms = rhs_ms
while !isa(ms.data, RHSSpec) || ms.data.ordinal != ordinal
ms = ms.next
end
return ms
return cis
end

function ir_to_src(ir::IRCode)
Expand Down Expand Up @@ -499,6 +497,7 @@ function dae_factory_gen(world::UInt, source::LineNumberNode, _, @nospecialize(f
src.ssavaluetypes = length(src.code)
src.min_world = @atomic codeinst.min_world
src.max_world = @atomic codeinst.max_world
src.edges = codeinst.edges

return src
end
Expand Down Expand Up @@ -527,7 +526,7 @@ function dae_factory_gen(interp, ci::CodeInstance, key)

argt = Tuple{Vector{Float64}, Vector{Float64}, Vector{Float64}, SciMLBase.NullParameters, Float64}

daef_ci = dae_finish_ipo!(interp, ci, key, 1)
daef_cis = dae_finish_ipo!(interp, ci, key, 1)

# Create a small opaque closure to adapt from SciML ABI to our own internal
# ABI
Expand Down Expand Up @@ -582,7 +581,7 @@ function dae_factory_gen(interp, ci::CodeInstance, key)
oc_sicm = insert_node_here!(oc_compact,
NewInstruction(Expr(:call, getfield, Argument(1), 1), Tuple, line))
insert_node_here!(oc_compact,
NewInstruction(Expr(:invoke, daef_ci, oc_sicm, (), out_du_mm, out_eq, in_u_mm, in_u_unassgn, in_du_unassgn, in_alg, Argument(6)), Nothing, line))
NewInstruction(Expr(:invoke, daef_cis[1], oc_sicm, (), out_du_mm, out_eq, in_u_mm, in_u_unassgn, in_du_unassgn, in_alg, Argument(6)), Nothing, line))

# Manually apply mass matrix
bc = insert_node_here!(oc_compact,
Expand Down
28 changes: 7 additions & 21 deletions src/transform/tearing_schedule_ipo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -619,7 +619,7 @@ function tearing_schedule!(interp, ci::CodeInstance, key::TornCacheKey)
end

callee_codeinst = CC.get(CC.code_cache(interp), stmt.args[1], nothing)
callee_sicm_mi = tearing_schedule!(interp, callee_codeinst, callee_key)
callee_sicm_ci = tearing_schedule!(interp, callee_codeinst, callee_key)

inst[:type] = Any
inst[:flag] = UInt32(0)
Expand All @@ -630,8 +630,8 @@ function tearing_schedule!(interp, ci::CodeInstance, key::TornCacheKey)
(AssignedDiff, UnassignedDiff, Algebraic, Explicit))...)
resize!(stmt.args, 1)

if !isdefined(callee_sicm_mi.cache, :rettype_const)
new_stmt.args[1] = callee_sicm_mi
if !isdefined(callee_sicm_ci, :rettype_const)
new_stmt.args[1] = callee_sicm_ci

urs = userefs(new_stmt)
for ur in urs
Expand All @@ -646,7 +646,7 @@ function tearing_schedule!(interp, ci::CodeInstance, key::TornCacheKey)
state = insert_node_here!(compact, NewInstruction(inst; stmt=new_stmt, type=Tuple, flag=UInt32(0)))
push!(stmt.args, SICMSSAValue(state.id))
else
push!(stmt.args, callee_sicm_mi.cache.rettype_const)
push!(stmt.args, callee_sicm_ci.rettype_const)
end
elseif stmt === nothing || isa(stmt, ReturnNode)
continue
Expand Down Expand Up @@ -948,24 +948,10 @@ function tearing_schedule!(interp, ci::CodeInstance, key::TornCacheKey)
debuginfo = src.debuginfo
end

sicm_mi = MethodSpecialization{SICMSpec}(ci.def, Tuple{}, sig)
sicm_mi.data = SICMSpec(key)
sicm_ci = cache_dae_ci!(ci, src, debuginfo, Core.ABIOverwrite(sig, SICMSpec(key)))

sicm_ci = CodeInstance(sicm_mi, Tuple, Union{}, ir_sicm === nothing ? () : nothing,
src, ir_sicm === nothing ? Int32(0x3) : Int32(0), UInt(1)#=ci.min_world=#, ci.max_world, ci.ipo_purity_bits, ci.purity_bits,
nothing, 0x0, debuginfo)

@atomic :release sicm_mi.cache = sicm_ci

result.sicm_cache[key] = sicm_mi
result.sicm_cache[key] = sicm_ci
result.tearing_cache[key] = TornIR(ir_sicm, irs)

cache_mi = ci.def
while isdefined(cache_mi, :next)
cache_mi = @atomic cache_mi.next
end
@atomic :release cache_mi.next = sicm_mi
global nsicmcompiles += 1

return sicm_mi
return sicm_ci
end
2 changes: 2 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
using Test

#=
@testset "state_mapping.jl" include("state_mapping.jl")
@testset "interpreter.jl" include("interpreter.jl")
@testset "compiler_and_lattice.jl" include("compiler_and_lattice.jl")
@testset "JITOpaqueClosures.jl" include("JITOpaqueClosures.jl")
@testset "robertson.jl" include("robertson.jl")
=#
@testset "ipo.jl" include("ipo.jl")
@testset "lorenz.jl" include("lorenz_tests.jl")
@testset "pendulum.jl" include("pendulum_tests.jl")
Expand Down

0 comments on commit faf7d3f

Please sign in to comment.