Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support execution of code from external abstract interpreters #52964

Draft
wants to merge 9 commits into
base: master
Choose a base branch
from

Conversation

vchuravy
Copy link
Member

@vchuravy vchuravy commented Jan 18, 2024

Currently external abstract interpreters are succesfully used to power analysis,
compilation for external targets like GPU, but there currently doesn't exist a
mechanism within Julia to execute the output of these abstract interpreters and
the community has developed various work-arounds.

  • Enzyme.jl uses GPUCompiler.jl + LLVM.jl to spin up it's own external JIT and
    uses ccall to call from Julia native to Enzyme controlled land. It furthermore
    needs to detect dynamic callsites and replaces them to functions controlled by Enzyme.
  • CassetteOverlay.jl uses the "cassette" transform to replace all calls f(args...)
    to calls to overdub(ctx, f, args...). While this has worked in Cassette.jl for a
    long time, it also leads to hard to read backtraces and overly relies on generated functions.

This PR is a combination of two ideas:

Concretly this PR introduces abstract type CompilerInstance end that allows for the creation of temporary AbstractInterpreter instances,
it then uses a new built-in call_within to switch between compiler instances, furthermore the compiler instance is also the owner of
the corresponding CodeInstances.

After that it is mostly and exercise of threading the compiler instance through in the right places.

TODO

  • Go through all uses of explicit jl_nothing as a owner token and use the correct one.
  • Expose the compiler instance to reflection methods
  • Add abstract interpreter support for call_within
  • Concrete evaluation will need to use call_within
  • Interpreter support? JuliaInterpreter.jl support? Cthulhu support?
  • Finalizers?

Example: Cassette style tracer

This PR doesn't do anything about how to work with the IR and write compiler plugins,
it also doesn't provide any hooks for compiler instances to modify the LLVM pipeline,
but a below is a prototype of a Cassette type tracer. Where we have a prehook
and a posthook function to execute over the callgraph.

const CC = Core.Compiler

import .CC: SSAValue, GlobalRef

struct Tracer <: CC.AbstractCompiler end
CC.abstract_interpreter(::Tracer, world::UInt) =
    TracerInterp(; world)

struct TracerInterp <: CC.AbstractInterpreter
    world::UInt
    inf_params::CC.InferenceParams
    opt_params::CC.OptimizationParams
    inf_cache::Vector{CC.InferenceResult}
    code_cache::CC.InternalCodeCache
    compiler::Tracer
    function TracerInterp(;
                world::UInt = Base.get_world_counter(),
                compiler::Tracer = Tracer(),
                inf_params::CC.InferenceParams = CC.InferenceParams(),
                opt_params::CC.OptimizationParams = CC.OptimizationParams(),
                inf_cache::Vector{CC.InferenceResult} = CC.InferenceResult[],
                code_cache::CC.InternalCodeCache = CC.InternalCodeCache(compiler))
        return new(world, inf_params, opt_params, inf_cache, code_cache, compiler)
    end
end

CC.InferenceParams(interp::TracerInterp) = interp.inf_params
CC.OptimizationParams(interp::TracerInterp) = interp.opt_params
CC.get_world_counter(interp::TracerInterp) = interp.world
CC.get_inference_cache(interp::TracerInterp) = interp.inf_cache
CC.code_cache(interp::TracerInterp) = CC.WorldView(interp.code_cache, CC.WorldRange(interp.world))
CC.cache_owner(interp::TracerInterp) = interp.compiler

import Core.Compiler: retrieve_code_info, maybe_validate_code
# Replace usage sited of `retrieve_code_info`, OptimizationState is one such, but in all interesting use-cases
# it is derived from an InferenceState. There is a third one in `typeinf_ext` in case the module forbids inference.
function CC.InferenceState(result::CC.InferenceResult, cache_mode::UInt8, interp::TracerInterp)
    world = CC.get_world_counter(interp)
    src = retrieve_code_info(result.linfo, world)
    src === nothing && return nothing
    maybe_validate_code(result.linfo, src, "lowered")
    src = transform(interp, result.linfo, src)
    maybe_validate_code(result.linfo, src, "transformed")
    return CC.InferenceState(result, src, cache_mode, interp)
end

##
# Cassette style early transform
##

# Allows for Cassette Pass transforms

function static_eval(mod, name)
    if Base.isbindingresolved(mod, name) && Base.isdefined(mod, name)
        return getfield(mod, name)
    else
        return nothing
    end
end

function prehook end
function posthook end

function transform(interp, mi, src)
    method = mi.def
    f = static_eval(method.module, method.name)
    ccall(:jl_, Cvoid, (Any,), mi)
    if f === Core._apply
        return src
    end
    if f isa Core.Builtin
        error("Transforming builtin")
    end
    # if f === prehook || f === posthook
    #     return src
    # end
    ci = copy(src)
    transform!(mi, ci)
    ci.ssavaluetypes = length(ci.code)
    # XXX we need to copy flags
    ci.ssaflags = [0x00 for _ in 1:length(ci.code)]
    # ccall(:jl_, Cvoid, (Any,), ci)
    return ci
end

function ir_element(x, code::Vector)
    while isa(x, Core.SSAValue)
        x = code[x.id]
    end
    return x
end

"""
    is_ir_element(x, y, code::Vector)

Return `true` if `x === y` or if `x` is an `SSAValue` such that
`is_ir_element(code[x.id], y, code)` is `true`.
See also: [`replace_match!`](@ref), [`insert_statements!`](@ref)
"""
function is_ir_element(x, y, code::Vector)
    result = false
    while true # break by default
        if x === y #
            result = true
            break
        elseif isa(x, Core.SSAValue)
            x = code[x.id]
        else
            break
        end
    end
    return result
end

"""
    insert_statements!(code::Vector, codelocs::Vector, stmtcount, newstmts)


For every statement `stmt` at position `i` in `code` for which `stmtcount(stmt, i)` returns
an `Int`, remove `stmt`, and in its place, insert the statements returned by
`newstmts(stmt, i)`. If `stmtcount(stmt, i)` returns `nothing`, leave `stmt` alone.

For every insertion, all downstream `SSAValue`s, label indices, etc. are incremented
appropriately according to number of inserted statements.

Proper usage of this function dictates that following properties hold true:

- `code` is expected to be a valid value for the `code` field of a `CodeInfo` object.
- `codelocs` is expected to be a valid value for the `codelocs` field of a `CodeInfo` object.
- `newstmts(stmt, i)` should return a `Vector` of valid IR statements.
- `stmtcount` and `newstmts` must obey `stmtcount(stmt, i) == length(newstmts(stmt, i))` if
    `isa(stmtcount(stmt, i), Int)`.

To gain a mental model for this function's behavior, consider the following scenario. Let's
say our `code` object contains several statements:
code = Any[oldstmt1, oldstmt2, oldstmt3, oldstmt4, oldstmt5, oldstmt6]
codelocs = Int[1, 2, 3, 4, 5, 6]

Let's also say that for our `stmtcount` returns `2` for `stmtcount(oldstmt2, 2)`, returns `3`
for `stmtcount(oldstmt5, 5)`, and returns `nothing` for all other inputs. From this setup, we
can think of `code`/`codelocs` being modified in the following manner:
newstmts2 = newstmts(oldstmt2, 2)
newstmts5 = newstmts(oldstmt5, 5)
code = Any[oldstmt1,
           newstmts2[1], newstmts2[2],
           oldstmt3, oldstmt4,
           newstmts5[1], newstmts5[2], newstmts5[3],
           oldstmt6]
codelocs = Int[1, 2, 2, 3, 4, 5, 5, 5, 6]

See also: [`replace_match!`](@ref), [`is_ir_element`](@ref)
"""
function insert_statements!(code, codelocs, stmtcount, newstmts)
    ssachangemap = fill(0, length(code))
    labelchangemap = fill(0, length(code))
    worklist = Tuple{Int,Int}[]
    for i in 1:length(code)
        stmt = code[i]
        nstmts = stmtcount(stmt, i)
        if nstmts !== nothing
            addedstmts = nstmts - 1
            push!(worklist, (i, addedstmts))
            ssachangemap[i] = addedstmts
            if i < length(code)
                labelchangemap[i + 1] = addedstmts
            end
        end
    end
    Core.Compiler.renumber_ir_elements!(code, ssachangemap, labelchangemap)
    for (i, addedstmts) in worklist
        i += ssachangemap[i] - addedstmts # correct the index for accumulated offsets
        stmts = newstmts(code[i], i)
        @assert length(stmts) == (addedstmts + 1)
        code[i] = stmts[end]
        for j in 1:(length(stmts) - 1) # insert in reverse to maintain the provided ordering
            insert!(code, i, stmts[end - j])
            insert!(codelocs, i, codelocs[i])
        end
    end
end

function transform!(mi, src)
    stmtcount = (x, i) -> begin
        isassign = Base.Meta.isexpr(x, :(=))
        stmt = isassign ? x.args[2] : x
        if Base.Meta.isexpr(stmt, :call)
            return 4
        end
        return nothing
    end
    newstmts = (x, i) -> begin
        callstmt = Base.Meta.isexpr(x, :(=)) ? x.args[2] : x
        isapplycall = is_ir_element(callstmt.args[1], GlobalRef(Core, :_apply), src.code)
        isapplyiteratecall = is_ir_element(callstmt.args[1], GlobalRef(Core, :_apply_iterate), src.code)
        if isapplycall || isapplyiteratecall
            callf = callstmt.args[2]
            callargs = callstmt.args[3:end]
            stmts = Any[
                Expr(:call,
                     GlobalRef(Core, :_call_within), nothing,
                     prehook, callf, callargs...),
                callstmt,
                Expr(:call,
                     GlobalRef(Core, :_call_within), nothing,
                     posthook, SSAValue(i + 1), callf, callargs...),
                Base.Meta.isexpr(x, :(=)) ? Expr(:(=), x.args[1], SSAValue(i + 1)) : SSAValue(i + 1)
            ]
        else
            stmts = Any[
                Expr(:call, GlobalRef(Core, :_call_within), nothing, prehook, callstmt.args...),
                callstmt,
                Expr(:call, GlobalRef(Core, :_call_within), nothing, posthook, SSAValue(i + 1), callstmt.args...),
                Base.Meta.isexpr(x, :(=)) ? Expr(:(=), x.args[1], SSAValue(i + 1)) : SSAValue(i + 1)
            ]
        end
        return stmts
    end
    insert_statements!(src.code, src.codelocs, stmtcount, newstmts)
    return nothing
end

# TODO:
# - anything called from prehook is still instrumented
#   - either we can stop instrumentation / or switch to native (high overhead)
# - Similar problem with invokelatest, we somehow got recursion
# - invokelatest trace sees prehook call -- do we double transform?
struct Call
    parent
    f
    args
    children
end

# TODO: Handle task-safety
const call_tree = ScopedValue{Ref{Call}}()

function prehook(f, args...)
    parent = call_tree[][]
    current = Call(parent, f, args, Call[])
    push!(parent.children, current)
    call_tree[][] = current
end

function posthook(_, f, args...)
    current = call_tree[][]
    call_tree[][] = current.parent
end

function f()
end

function trace(f, args...)
    top = Call(nothing, f, args, Call[])
    @with call_tree => Ref(top) begin
        Base.invoke_within(Tracer(), f, args...)
    end
    return top
end

trace(f)

@vchuravy vchuravy added speculative Whether the change will be implemented is speculative compiler:plugins labels Jan 18, 2024
@vchuravy
Copy link
Member Author

One thing I particularly enjoyed about Cassette was that composition is well defined. This proposal ignores composition, compiler instances don't form a stack and there is no expectation that the output of one is meant to be consumed by another.

For Cassette uninferred IR was a viable communication layer, but with compiler instances customization can occur along many levels and we run into the pipeline ordering problem. The hope would be that using compiler instance we could build an actual "compiler plugins" infrastructure that allows for the registration of passes/intrinsics and provides sensible composition, but that seems further away and I do think we need some experimentation with compiler customization first before we tackle that.

1. Introduced a task-local inherited (nee ScopedValue) compiler field
2. Allow compilation and execution of code generated by foreign abstract
   interpreters

A new primitive `call_within` is introduced that switches the compiler
instance. The compiler instance is used for cache-lookups, compilation
request, and dispatch.

# FIXME: Currently doesn't infer and ends in "Skipped call_within since compiler plugin not constant"
overlay(f, args...) = CustomMethodTables.overlay(CustomMT, f, args...)
@test_broken overlay(sin, 1.0) == cos(1.0) # Bug in inference, not using the method_table for initial lookup
Copy link
Member Author

@vchuravy vchuravy Apr 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is due to jl_lookup_generic not taking into account the custom method table.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
compiler:plugins speculative Whether the change will be implemented is speculative
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant