Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

staticdata: Close data race after backedge insertion #57229

Merged
merged 2 commits into from
Feb 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 16 additions & 12 deletions base/staticdata.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,16 @@ end
function _insert_backedges(edges::Vector{Any}, stack::Vector{CodeInstance}, visiting::IdDict{CodeInstance,Int}, mwis::IdSet{Method}, external::Bool=false)
for i = 1:length(edges)
codeinst = edges[i]::CodeInstance
verify_method_graph(codeinst, stack, visiting, mwis)
validation_world = get_world_counter()
verify_method_graph(codeinst, stack, visiting, mwis, validation_world)
# After validation, under the world_counter_lock, set max_world to typemax(UInt) for all dependencies
# (recursively). From that point onward the ordinary backedge mechanism is responsible for maintaining
# validity.
@ccall jl_promote_ci_to_current(codeinst::Any, validation_world::UInt)::Cvoid
minvalid = codeinst.min_world
maxvalid = codeinst.max_world
# Finally, if this CI is still valid in some world age and and belongs to an external method(specialization),
# poke it that mi's cache
if maxvalid ≥ minvalid && external
caller = get_ci_mi(codeinst)
@assert isdefined(codeinst, :inferred) # See #53586, #53109
Expand All @@ -55,9 +62,9 @@ function _insert_backedges(edges::Vector{Any}, stack::Vector{CodeInstance}, visi
end
end

function verify_method_graph(codeinst::CodeInstance, stack::Vector{CodeInstance}, visiting::IdDict{CodeInstance,Int}, mwis::IdSet{Method})
function verify_method_graph(codeinst::CodeInstance, stack::Vector{CodeInstance}, visiting::IdDict{CodeInstance,Int}, mwis::IdSet{Method}, validation_world::UInt)
@assert isempty(stack); @assert isempty(visiting);
child_cycle, minworld, maxworld = verify_method(codeinst, stack, visiting, mwis)
child_cycle, minworld, maxworld = verify_method(codeinst, stack, visiting, mwis, validation_world)
@assert child_cycle == 0
@assert isempty(stack); @assert isempty(visiting);
nothing
Expand All @@ -67,15 +74,14 @@ end
# - Visit the entire call graph, starting from edges[idx] to determine if that method is valid
# - Implements Tarjan's SCC (strongly connected components) algorithm, simplified to remove the count variable
# and slightly modified with an early termination option once the computation reaches its minimum
function verify_method(codeinst::CodeInstance, stack::Vector{CodeInstance}, visiting::IdDict{CodeInstance,Int}, mwis::IdSet{Method})
function verify_method(codeinst::CodeInstance, stack::Vector{CodeInstance}, visiting::IdDict{CodeInstance,Int}, mwis::IdSet{Method}, validation_world::UInt)
world = codeinst.min_world
let max_valid2 = codeinst.max_world
if max_valid2 ≠ WORLD_AGE_REVALIDATION_SENTINEL
return 0, world, max_valid2
end
end
current_world = get_world_counter()
local minworld::UInt, maxworld::UInt = 1, current_world
local minworld::UInt, maxworld::UInt = 1, validation_world
def = get_ci_mi(codeinst).def
@assert def isa Method
if haskey(visiting, codeinst)
Expand Down Expand Up @@ -177,7 +183,7 @@ function verify_method(codeinst::CodeInstance, stack::Vector{CodeInstance}, visi
end
callee = edge
local min_valid2::UInt, max_valid2::UInt
child_cycle, min_valid2, max_valid2 = verify_method(callee, stack, visiting, mwis)
child_cycle, min_valid2, max_valid2 = verify_method(callee, stack, visiting, mwis, validation_world)
if minworld < min_valid2
minworld = min_valid2
end
Expand Down Expand Up @@ -209,16 +215,14 @@ function verify_method(codeinst::CodeInstance, stack::Vector{CodeInstance}, visi
if maxworld ≠ 0
@atomic :monotonic child.min_world = minworld
end
if maxworld == current_world
@atomic :monotonic child.max_world = maxworld
if maxworld == validation_world && validation_world == get_world_counter()
Base.Compiler.store_backedges(child, child.edges)
@atomic :monotonic child.max_world = typemax(UInt)
else
@atomic :monotonic child.max_world = maxworld
end
@assert visiting[child] == length(stack) + 1
delete!(visiting, child)
invalidations = _jl_debug_method_invalidation[]
if invalidations !== nothing && maxworld < current_world
if invalidations !== nothing && maxworld < validation_world
push!(invalidations, child, "verify_methods", cause)
end
end
Expand Down
28 changes: 28 additions & 0 deletions src/staticdata.c
Original file line number Diff line number Diff line change
Expand Up @@ -4265,6 +4265,34 @@ JL_DLLEXPORT jl_value_t *jl_restore_package_image_from_file(const char *fname, j
return mod;
}

JL_DLLEXPORT void _jl_promote_ci_to_current(jl_code_instance_t *ci, size_t validated_world) JL_NOTSAFEPOINT
{
if (jl_atomic_load_relaxed(&ci->max_world) != validated_world)
return;
jl_atomic_store_relaxed(&ci->max_world, ~(size_t)0);
jl_svec_t *edges = jl_atomic_load_relaxed(&ci->edges);
for (size_t i = 0; i < jl_svec_len(edges); i++) {
jl_value_t *edge = jl_svecref(edges, i);
if (!jl_is_code_instance(edge))
continue;
_jl_promote_ci_to_current(ci, validated_world);
}
}

JL_DLLEXPORT void jl_promote_ci_to_current(jl_code_instance_t *ci, size_t validated_world)
{
size_t current_world = jl_atomic_load_relaxed(&jl_world_counter);
// No need to acquire the lock if we've been invalidated anyway
if (current_world > validated_world)
return;
JL_LOCK(&world_counter_lock);
current_world = jl_atomic_load_relaxed(&jl_world_counter);
if (current_world == validated_world) {
_jl_promote_ci_to_current(ci, validated_world);
}
JL_UNLOCK(&world_counter_lock);
}

#ifdef __cplusplus
}
#endif