Skip to content

Commit

Permalink
improve concurrency safety for Compiler.finish!
Browse files Browse the repository at this point in the history
Similar to #57229, this commit ensures that
`Compiler.finish!` properly synchronizes the operations to set
`max_world` for cached `CodeInstance`s by holding the world counter
lock. Previously, `Compiler.finish!` relied on a narrow timing window to
avoid race conditions, which was not a robust approach in a concurrent
execution environment.

This change ensures that `Compiler.finish!` holds the appropriate lock
(via `jl_promote_ci_to_current`).
  • Loading branch information
aviatesk committed Feb 7, 2025
1 parent b65f004 commit f503808
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 9 deletions.
29 changes: 20 additions & 9 deletions Compiler/src/typeinfer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ If set to `true`, record per-method-instance timings within type inference in th
__set_measure_typeinf(onoff::Bool) = __measure_typeinf__[] = onoff
const __measure_typeinf__ = RefValue{Bool}(false)

function finish!(interp::AbstractInterpreter, caller::InferenceState)
function finish!(interp::AbstractInterpreter, caller::InferenceState, validation_world::UInt)
result = caller.result
opt = result.src
if opt isa OptimizationState
Expand All @@ -108,12 +108,7 @@ function finish!(interp::AbstractInterpreter, caller::InferenceState)
ci = result.ci
# if we aren't cached, we don't need this edge
# but our caller might, so let's just make it anyways
if last(result.valid_worlds) >= get_world_counter()
# TODO: this should probably come after all store_backedges (after optimizations) for the entire graph in finish_cycle
# since we should be requiring that all edges first get their backedges set, as a batch
result.valid_worlds = WorldRange(first(result.valid_worlds), typemax(UInt))
end
if last(result.valid_worlds) == typemax(UInt)
if last(result.valid_worlds) >= validation_world
# if we can record all of the backedges in the global reverse-cache,
# we can now widen our applicability in the global cache too
store_backedges(ci, edges)
Expand Down Expand Up @@ -202,7 +197,14 @@ function finish_nocycle(::AbstractInterpreter, frame::InferenceState)
if opt isa OptimizationState # implies `may_optimize(caller.interp) === true`
optimize(frame.interp, opt, frame.result)
end
finish!(frame.interp, frame)
validation_world = get_world_counter()
finish!(frame.interp, frame, validation_world)
if isdefined(frame.result, :ci)
# After validation, under the world_counter_lock, set max_world to typemax(UInt) for all dependencies
# (recursively). From that point onward the ordinary backedge mechanism is responsible for maintaining
# validity.
ccall(:jl_promote_ci_to_current, Cvoid, (Any, UInt), frame.result.ci, validation_world)
end
if frame.cycleid != 0
frames = frame.callstack::Vector{AbsIntState}
@assert frames[end] === frame
Expand Down Expand Up @@ -236,10 +238,19 @@ function finish_cycle(::AbstractInterpreter, frames::Vector{AbsIntState}, cyclei
optimize(caller.interp, opt, caller.result)
end
end
validation_world = get_world_counter()
cis = CodeInstance[]
for frameid = cycleid:length(frames)
caller = frames[frameid]::InferenceState
finish!(caller.interp, caller)
finish!(caller.interp, caller, validation_world)
if isdefined(caller.result, :ci)
push!(cis, caller.result.ci)
end
end
# After validation, under the world_counter_lock, set max_world to typemax(UInt) for all dependencies
# (recursively). From that point onward the ordinary backedge mechanism is responsible for maintaining
# validity.
ccall(:jl_promote_cis_to_current, Cvoid, (Ptr{CodeInstance}, Csize_t, UInt), cis, length(cis), validation_world)
resize!(frames, cycleid - 1)
return nothing
end
Expand Down
16 changes: 16 additions & 0 deletions src/staticdata.c
Original file line number Diff line number Diff line change
Expand Up @@ -4408,6 +4408,22 @@ JL_DLLEXPORT void jl_promote_ci_to_current(jl_code_instance_t *ci, size_t valida
JL_UNLOCK(&world_counter_lock);
}

JL_DLLEXPORT void jl_promote_cis_to_current(jl_code_instance_t **cis, size_t n, size_t validated_world)
{
size_t current_world = jl_atomic_load_relaxed(&jl_world_counter);
// No need to acquire the lock if we've been invalidated anyway
if (current_world > validated_world)
return;
JL_LOCK(&world_counter_lock);
current_world = jl_atomic_load_relaxed(&jl_world_counter);
if (current_world == validated_world) {
for (size_t i = 0; i < n; i++) {
_jl_promote_ci_to_current(cis[i], validated_world);
}
}
JL_UNLOCK(&world_counter_lock);
}

#ifdef __cplusplus
}
#endif

0 comments on commit f503808

Please sign in to comment.