Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Taking world ages seriously #394

Merged
merged 20 commits into from
Mar 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "GPUCompiler"
uuid = "61eb1bfa-7361-4325-ad38-22787b887f55"
authors = ["Tim Besard <tim.besard@gmail.com>"]
version = "0.17.3"
version = "0.18.0"

[deps]
ExprTools = "e2ba6199-217a-4e67-a87a-7c52f15ade04"
Expand Down
4 changes: 2 additions & 2 deletions examples/kernel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@ GPUCompiler.runtime_module(::CompilerJob{<:Any,TestCompilerParams}) = TestRuntim
kernel() = nothing

function main()
source = FunctionSpec(typeof(kernel))
source = FunctionSpec(typeof(kernel), Tuple{})
target = NativeCompilerTarget()
params = TestCompilerParams()
job = CompilerJob(target, source, params)
job = CompilerJob(source, target, params)

println(GPUCompiler.compile(:asm, job)[1])
end
Expand Down
146 changes: 114 additions & 32 deletions src/cache.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,29 +3,82 @@
using Core.Compiler: retrieve_code_info, CodeInfo, MethodInstance, SSAValue, SlotNumber, ReturnNode
using Base: _methods_by_ftype

# generated function that crafts a custom code info to call the actual compiler.
# this gives us the flexibility to insert manual back edges for automatic recompilation.
# generated function that returns the world age of a compilation job. this can be used to
# drive compilation, e.g. by using it as a key for a cache, as the age will change when a
# function or any called function is redefined.


"""
get_world(ft, tt)

A special function that returns the world age in which the current definition of function
type `ft`, invoked with argument types `tt`, is defined. This can be used to cache
compilation results:

compilation_cache = Dict()
function cache_compilation(ft, tt)
world = get_world(ft, tt)
get!(compilation_cache, (ft, tt, world)) do
# compile
end
end

What makes this function special is that it is a generated function, returning a constant,
whose result is automatically invalidated when the function `ft` (or any called function) is
redefined. This makes this query ideally suited for hot code, where you want to avoid a
costly look-up of the current world age on every invocation.

Normally, you shouldn't have to use this function, as it's used by `FunctionSpec`.

!!! warning

Due to a bug in Julia, JuliaLang/julia#34962, this function's results are only
guaranteed to be correctly invalidated when the target function `ft` is executed or
processed by codegen (e.g., by calling `code_llvm`).
"""
get_world

# generate functions currently do not know which world they are invoked for, so we fall
# back to using the current world. this may be wrong when the generator is invoked in a
# different world (TODO: when does this happen?)
#
# we also increment a global specialization counter and pass it along to index the cache.

const specialization_counter = Ref{UInt}(0)
@generated function specialization_id(job::CompilerJob{<:Any,<:Any,FunctionSpec{f,tt}}) where {f,tt}
# get a hold of the method and code info of the kernel function
sig = Tuple{f, tt.parameters...}
# XXX: instead of typemax(UInt) we should use the world-age of the fspec
mthds = _methods_by_ftype(sig, -1, typemax(UInt))
# XXX: this should be fixed by JuliaLang/julia#48611

function get_world_generator(self, ::Type{Type{ft}}, ::Type{Type{tt}}) where {ft, tt}
maleadt marked this conversation as resolved.
Show resolved Hide resolved
@nospecialize

# look up the method
sig = Tuple{ft, tt.parameters...}
min_world = Ref{UInt}(typemin(UInt))
max_world = Ref{UInt}(typemax(UInt))
has_ambig = Ptr{Int32}(C_NULL) # don't care about ambiguous results
mthds = if VERSION >= v"1.7.0-DEV.1297"
Base._methods_by_ftype(sig, #=mt=# nothing, #=lim=# -1,
#=world=# typemax(UInt), #=ambig=# false,
min_world, max_world, has_ambig)
# XXX: use the correct method table to support overlaying kernels
else
Base._methods_by_ftype(sig, #=lim=# -1,
#=world=# typemax(UInt), #=ambig=# false,
min_world, max_world, has_ambig)
end
# XXX: using world=-1 is wrong, but the current world isn't exposed to this generator

# check the validity of the method matches
method_error = :(throw(MethodError(ft, tt)))
mthds === nothing && return method_error
Base.isdispatchtuple(tt) || return(:(error("$tt is not a dispatch tuple")))
maleadt marked this conversation as resolved.
Show resolved Hide resolved
length(mthds) == 1 || return (:(throw(MethodError(job.source.f,job.source.tt))))
length(mthds) == 1 || return method_error

# look up the method and code instance
mtypes, msp, m = mthds[1]
mi = ccall(:jl_specializations_get_linfo, Ref{MethodInstance}, (Any, Any, Any), m, mtypes, msp)
ci = retrieve_code_info(mi)::CodeInfo

# generate a unique id to represent this specialization
# TODO: just use the lower world age bound in which this code info is valid.
# (the method instance doesn't change when called functions are changed).
# but how to get that? the ci here always has min/max world 1/-1.
# XXX: don't use `objectid(ci)` here, apparently it can alias (or the CI doesn't change?)
id = (specialization_counter[] += 1)
# XXX: we don't know the world age that this generator was requested to run in, so use
# the current world (we cannot use the mi's world because that doesn't update when
# called functions are changed). this isn't correct, but should be close.
world = Base.get_world_counter()

# prepare a new code info
new_ci = copy(ci)
Expand All @@ -34,22 +87,20 @@ const specialization_counter = Ref{UInt}(0)
resize!(new_ci.linetable, 1) # see note below
empty!(new_ci.ssaflags)
new_ci.ssavaluetypes = 0
new_ci.min_world = min_world[]
new_ci.max_world = max_world[]
new_ci.edges = MethodInstance[mi]
# XXX: setting this edge does not give us proper method invalidation, see
# JuliaLang/julia#34962 which demonstrates we also need to "call" the kernel.
# invoking `code_llvm` also does the necessary codegen, as does calling the
# underlying C methods -- which GPUCompiler does, so everything Just Works.

# prepare the slots
new_ci.slotnames = Symbol[Symbol("#self#"), :cache, :job, :compiler, :linker]
new_ci.slotflags = UInt8[0x00 for i = 1:5]
cache = SlotNumber(2)
job = SlotNumber(3)
compiler = SlotNumber(4)
linker = SlotNumber(5)

# call the compiler
push!(new_ci.code, ReturnNode(id))
new_ci.slotnames = Symbol[Symbol("#self#"), :ft, :tt]
new_ci.slotflags = UInt8[0x00 for i = 1:3]

# return the world
push!(new_ci.code, ReturnNode(world))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems to be trying to always return the value of the world for the get_world function definition, not the lookup world?

But this also seems very risky, as you are claiming to the compiler that this world value is a constant, but we know for a fact (based on the edges, min_world, and max_world) that the meaning of world numbers is certain to be dynamic.

The appropriate token to return here might be the MethodInstance object mi? That would encapsulate fully the MethodMatch lookup result, and exactly represent the result of the lookup for a long as that is the correct compilation target?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be fine to return the generator's world (post-JuliaLang/julia#48766)? That's what I'm doing in #403.

Returning the mi could work I guess, but would need some refactoring. We're currently returning the world so that we can use it to construct a FunctionSpec, which is the object that we use to look-up the code to compile:

# what we'll be compiling
struct FunctionSpec
ft::Type
tt::Type
world::UInt
FunctionSpec(ft::Type, tt::Type, world::Integer=get_world(ft, tt)) =
new(ft, tt, world)
end

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that the reason we use a separate struct is that it used to carry more fields. Now that they're gone, maybe we should just use MethodInstance instead of FunctionSpec, as @vchuravy notes.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, it is never allowable to expose the user to worlds from the compiler and vice versa. It violates the contract that only edges are observable effects.

Copy link
Member Author

@maleadt maleadt Mar 29, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK. We do however need the world age for some cases though, e.g., to look-up in our CodeInstance cache, and to pass to jl_create_native (not sure what to use here). I would assume that both need to be the world age that we need to generate code for (i.e. what we're returning now from this generator), and not the Method's primary world. Or what do you suggest to use here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AFAIU using mi.def.primary_world won't work because that returns too early of a world:

julia> foo(x) = x
julia> mi = methodinstance(typeof(foo), Tuple{Int})
julia> mi.def.primary_world |> Int
32455

julia> mi = methodinstance(typeof(map), Tuple{typeof(foo), Int})
julia> mi.def.primary_world |> Int
6074

i.e. using map's primary_world and passing that to, say jl_create_native won't work because we don't have foo there. That's why I was 'leaking' the world from the generated function before.

Any thoughts?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't that just be the tls_world_age ccall?

push!(new_ci.ssaflags, 0x00) # Julia's native compilation pipeline (and its verifier) expects `ssaflags` to be the same length as `code`
push!(new_ci.codelocs, 1) # see note below
new_ci.ssavaluetypes += 1
Expand All @@ -62,17 +113,48 @@ const specialization_counter = Ref{UInt}(0)
return new_ci
end

@eval function get_world(ft, tt)
$(Expr(:meta, :generated_only))
$(Expr(:meta,
:generated,
Expr(:new,
Core.GeneratedFunctionStub,
:get_world_generator,
Any[:get_world, :ft, :tt],
Any[],
@__LINE__,
QuoteNode(Symbol(@__FILE__)),
true)))
end

const cache_lock = ReentrantLock()

"""
cached_compilation(cache::Dict, job::CompilerJob, compiler, linker)

Compile `job` using `compiler` and `linker`, and store the result in `cache`.

The `cache` argument should be a dictionary that can be indexed using a `UInt` and store
whatever the `linker` function returns. The `compiler` function should take a `CompilerJob`
and return data that can be cached across sessions (e.g., LLVM IR). This data is then
forwarded, along with the `CompilerJob`, to the `linker` function which is allowed to create
session-dependent objects (e.g., a `CuModule`).
"""
function cached_compilation(cache::AbstractDict,
@nospecialize(job::CompilerJob),
compiler::Function, linker::Function)
# XXX: CompilerJob contains a world age, so can't be respecialized.
# have specialization_id take a f/tt and return a world to construct a CompilerJob?
key = hash(job, specialization_id(job))
force_compilation = compile_hook[] !== nothing
# NOTE: it is OK to index the compilation cache directly with the compilation job, i.e.,
# using a world age instead of intersecting world age ranges, because we expect
# that the world age is aquired through calling `get_world` and thus will only
# ever change when the kernel function is redefined.
#
# if we ever want to be able to index the cache using a compilation job that
# contains a more recent world age, yet still return an older cached object that
# would still be valid, we'd need the cache to store world ranges instead and
# use an invalidation callback to add upper bounds to entries.
key = hash(job)

# XXX: by taking the hash, we index the compilation cache directly with the world age.
# that's wrong; we should perform an intersection with the entry its bounds.
force_compilation = compile_hook[] !== nothing

# NOTE: no use of lock(::Function)/@lock/get! to keep stack traces clean
lock(cache_lock)
Expand Down
31 changes: 25 additions & 6 deletions src/driver.jl
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,13 @@ end

# get the method instance
sig = typed_signature(job)
meth = which(sig)
meth = if VERSION >= v"1.10.0-DEV.65"
Base._which(sig; world=job.source.world).method
elseif VERSION >= v"1.7.0-DEV.435"
Base._which(sig, job.source.world).method
else
ccall(:jl_gf_invoke_lookup, Any, (Any, UInt), sig, job.source.world)
end

(ti, env) = ccall(:jl_type_intersection_with_env, Any,
(Any, Any), sig, meth.sig)::Core.SimpleVector
Expand All @@ -175,6 +181,10 @@ end
end
end

# ensure that the returned method instance is valid in the compilation world.
# otherwise, `jl_create_native` won't actually emit any code.
@assert method_instance.def.primary_world <= job.source.world <= method_instance.def.deleted_world
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These seems overly strict, as it is supposed to be perfectly reasonable to run a Method in any world, irrespective of its primary_world and delete_world values (which only refer to the lookup of said Method).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it is supposed to be perfectly reasonable to run a Method in any world

What would that look like (in the context of GPUCompiler)? I put that there because jl_create_native uses the same check to decide whether to emit any code: https://github.com/JuliaLang/julia/blob/7341fb9517d290be02fdc54ae453999843a0dc7e/src/aotcompile.cpp#L332-L335

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh, that is just a heuristic decision, since our caller isn't designed to be particularly reliable about filtering those earlier (since we do want it to compile for both the current and typeinf worlds when applicable)


return method_instance, ()
end

Expand All @@ -189,9 +199,9 @@ Base.@ccallable Ptr{Cvoid} function deferred_codegen(ptr::Ptr{Cvoid})
ptr
end

@generated function deferred_codegen(::Val{f}, ::Val{tt}) where {f,tt}
@generated function deferred_codegen(::Val{ft}, ::Val{tt}) where {ft,tt}
id = length(deferred_codegen_jobs) + 1
deferred_codegen_jobs[id] = FunctionSpec(f,tt)
deferred_codegen_jobs[id] = FunctionSpec(ft, tt)

pseudo_ptr = reinterpret(Ptr{Cvoid}, id)
quote
Expand Down Expand Up @@ -286,10 +296,19 @@ const __llvm_initialized = Ref(false)
id = convert(Int, first(operands(call)))

global deferred_codegen_jobs
dyn_job = deferred_codegen_jobs[id]
if dyn_job isa FunctionSpec
dyn_job = similar(job, dyn_job)
dyn_val = deferred_codegen_jobs[id]

# get a job in the appopriate world
dyn_job = if dyn_val isa CompilerJob
dyn_spec = FunctionSpec(dyn_val.source; world=job.source.world)
CompilerJob(dyn_val; source=dyn_spec)
elseif dyn_val isa FunctionSpec
dyn_spec = FunctionSpec(dyn_val; world=job.source.world)
CompilerJob(job; source=dyn_spec)
else
error("invalid deferred job type $(typeof(dyn_val))")
end

push!(get!(worklist, dyn_job, LLVM.CallInst[]), call)
end

Expand Down
Loading