Skip to content

Commit

Permalink
fix #41546, make using thread-safe (#41602)
Browse files Browse the repository at this point in the history
use more precision when handling loading lock, merge with TOML lock
(since we typically are needing both, sometimes in unpredictable
orders), and unlock before call most user code

Co-authored-by: Jameson Nash <vtjnash@gmail.com>
  • Loading branch information
JeffBezanson and vtjnash authored Oct 25, 2021
1 parent 5ecf5fc commit 3d4b213
Show file tree
Hide file tree
Showing 4 changed files with 84 additions and 27 deletions.
76 changes: 52 additions & 24 deletions base/loading.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# This file is a part of Julia. License is MIT: https://julialang.org/license

# Base.require is the implementation for the `import` statement
const require_lock = ReentrantLock()

# Cross-platform case-sensitive path canonicalization

Expand Down Expand Up @@ -129,6 +130,7 @@ end
const ns_dummy_uuid = UUID("fe0723d6-3a44-4c41-8065-ee0f42c8ceab")

function dummy_uuid(project_file::String)
@lock require_lock begin
cache = LOADING_CACHE[]
if cache !== nothing
uuid = get(cache.dummy_uuid, project_file, nothing)
Expand All @@ -144,6 +146,7 @@ function dummy_uuid(project_file::String)
cache.dummy_uuid[project_file] = uuid
end
return uuid
end
end

## package path slugs: turning UUID + SHA1 into a pair of 4-byte "slugs" ##
Expand Down Expand Up @@ -236,8 +239,7 @@ struct TOMLCache
end
const TOML_CACHE = TOMLCache(TOML.Parser(), Dict{String, Dict{String, Any}}())

const TOML_LOCK = ReentrantLock()
parsed_toml(project_file::AbstractString) = parsed_toml(project_file, TOML_CACHE, TOML_LOCK)
parsed_toml(project_file::AbstractString) = parsed_toml(project_file, TOML_CACHE, require_lock)
function parsed_toml(project_file::AbstractString, toml_cache::TOMLCache, toml_lock::ReentrantLock)
lock(toml_lock) do
cache = LOADING_CACHE[]
Expand Down Expand Up @@ -337,13 +339,15 @@ Use [`dirname`](@ref) to get the directory part and [`basename`](@ref)
to get the file name part of the path.
"""
function pathof(m::Module)
pkgid = get(Base.module_keys, m, nothing)
@lock require_lock begin
pkgid = get(module_keys, m, nothing)
pkgid === nothing && return nothing
origin = get(Base.pkgorigins, pkgid, nothing)
origin = get(pkgorigins, pkgid, nothing)
origin === nothing && return nothing
path = origin.path
path === nothing && return nothing
return fixup_stdlib_path(path)
end
end

"""
Expand All @@ -366,7 +370,7 @@ julia> pkgdir(Foo, "src", "file.jl")
The optional argument `paths` requires at least Julia 1.7.
"""
function pkgdir(m::Module, paths::String...)
rootmodule = Base.moduleroot(m)
rootmodule = moduleroot(m)
path = pathof(rootmodule)
path === nothing && return nothing
return joinpath(dirname(dirname(path)), paths...)
Expand All @@ -383,6 +387,7 @@ const preferences_names = ("JuliaLocalPreferences.toml", "LocalPreferences.toml"
# - `true`: `env` is an implicit environment
# - `path`: the path of an explicit project file
function env_project_file(env::String)::Union{Bool,String}
@lock require_lock begin
cache = LOADING_CACHE[]
if cache !== nothing
project_file = get(cache.env_project_file, env, nothing)
Expand All @@ -406,6 +411,7 @@ function env_project_file(env::String)::Union{Bool,String}
cache.env_project_file[env] = project_file
end
return project_file
end
end

function project_deps_get(env::String, name::String)::Union{Nothing,PkgId}
Expand Down Expand Up @@ -473,6 +479,7 @@ end

# find project file's corresponding manifest file
function project_file_manifest_path(project_file::String)::Union{Nothing,String}
@lock require_lock begin
cache = LOADING_CACHE[]
if cache !== nothing
manifest_path = get(cache.project_file_manifest_path, project_file, missing)
Expand Down Expand Up @@ -501,6 +508,7 @@ function project_file_manifest_path(project_file::String)::Union{Nothing,String}
cache.project_file_manifest_path[project_file] = manifest_path
end
return manifest_path
end
end

# given a directory (implicit env from LOAD_PATH) and a name,
Expand Down Expand Up @@ -688,7 +696,7 @@ function implicit_manifest_deps_get(dir::String, where::PkgId, name::String)::Un
@assert where.uuid !== nothing
project_file = entry_point_and_project_file(dir, where.name)[2]
project_file === nothing && return nothing # a project file is mandatory for a package with a uuid
proj = project_file_name_uuid(project_file, where.name, )
proj = project_file_name_uuid(project_file, where.name)
proj == where || return nothing # verify that this is the correct project file
# this is the correct project, so stop searching here
pkg_uuid = explicit_project_deps_get(project_file, name)
Expand Down Expand Up @@ -753,19 +761,26 @@ function _include_from_serialized(path::String, depmods::Vector{Any})
if isa(sv, Exception)
return sv
end
restored = sv[1]
if !isa(restored, Exception)
for M in restored::Vector{Any}
M = M::Module
if isdefined(M, Base.Docs.META)
push!(Base.Docs.modules, M)
end
if parentmodule(M) === M
register_root_module(M)
end
sv = sv::SimpleVector
restored = sv[1]::Vector{Any}
for M in restored
M = M::Module
if isdefined(M, Base.Docs.META)
push!(Base.Docs.modules, M)
end
if parentmodule(M) === M
register_root_module(M)
end
end
inits = sv[2]::Vector{Any}
if !isempty(inits)
unlock(require_lock) # temporarily _unlock_ during these callbacks
try
ccall(:jl_init_restored_modules, Cvoid, (Any,), inits)
finally
lock(require_lock)
end
end
isassigned(sv, 2) && ccall(:jl_init_restored_modules, Cvoid, (Any,), sv[2])
return restored
end

Expand Down Expand Up @@ -873,7 +888,7 @@ function _require_search_from_serialized(pkg::PkgId, sourcepath::String, depth::
end

# to synchronize multiple tasks trying to import/using something
const package_locks = Dict{PkgId,Condition}()
const package_locks = Dict{PkgId,Threads.Condition}()

# to notify downstream consumers that a module was successfully loaded
# Callbacks take the form (mod::Base.PkgId) -> nothing.
Expand All @@ -896,7 +911,9 @@ function _include_dependency(mod::Module, _path::AbstractString)
path = normpath(joinpath(dirname(prev), _path))
end
if _track_dependencies[]
@lock require_lock begin
push!(_require_dependencies, (mod, path, mtime(path)))
end
end
return path, prev
end
Expand Down Expand Up @@ -968,6 +985,7 @@ For more details regarding code loading, see the manual sections on [modules](@r
[parallel computing](@ref code-availability).
"""
function require(into::Module, mod::Symbol)
@lock require_lock begin
LOADING_CACHE[] = LoadingCache()
try
uuidkey = identify_package(into, String(mod))
Expand Down Expand Up @@ -1019,6 +1037,7 @@ function require(into::Module, mod::Symbol)
finally
LOADING_CACHE[] = nothing
end
end
end

mutable struct PkgOrigin
Expand All @@ -1030,6 +1049,7 @@ PkgOrigin() = PkgOrigin(nothing, nothing)
const pkgorigins = Dict{PkgId,PkgOrigin}()

function require(uuidkey::PkgId)
@lock require_lock begin
if !root_module_exists(uuidkey)
cachefile = _require(uuidkey)
if cachefile !== nothing
Expand All @@ -1041,15 +1061,19 @@ function require(uuidkey::PkgId)
end
end
return root_module(uuidkey)
end
end

const loaded_modules = Dict{PkgId,Module}()
const module_keys = IdDict{Module,PkgId}() # the reverse

is_root_module(m::Module) = haskey(module_keys, m)
root_module_key(m::Module) = module_keys[m]
is_root_module(m::Module) = @lock require_lock haskey(module_keys, m)
root_module_key(m::Module) = @lock require_lock module_keys[m]

function register_root_module(m::Module)
# n.b. This is called from C after creating a new module in `Base.__toplevel__`,
# instead of adding them to the binding table there.
@lock require_lock begin
key = PkgId(m, String(nameof(m)))
if haskey(loaded_modules, key)
oldm = loaded_modules[key]
Expand All @@ -1059,6 +1083,7 @@ function register_root_module(m::Module)
end
loaded_modules[key] = m
module_keys[m] = key
end
nothing
end

Expand All @@ -1074,12 +1099,13 @@ using Base
end

# get a top-level Module from the given key
root_module(key::PkgId) = loaded_modules[key]
root_module(key::PkgId) = @lock require_lock loaded_modules[key]
root_module(where::Module, name::Symbol) =
root_module(identify_package(where, String(name)))
maybe_root_module(key::PkgId) = @lock require_lock get(loaded_modules, key, nothing)

root_module_exists(key::PkgId) = haskey(loaded_modules, key)
loaded_modules_array() = collect(values(loaded_modules))
root_module_exists(key::PkgId) = @lock require_lock haskey(loaded_modules, key)
loaded_modules_array() = @lock require_lock collect(values(loaded_modules))

function unreference_module(key::PkgId)
if haskey(loaded_modules, key)
Expand All @@ -1098,7 +1124,7 @@ function _require(pkg::PkgId)
wait(loading)
return
end
package_locks[pkg] = Condition()
package_locks[pkg] = Threads.Condition(require_lock)

last = toplevel_load[]
try
Expand Down Expand Up @@ -1166,10 +1192,12 @@ function _require(pkg::PkgId)
if uuid !== old_uuid
ccall(:jl_set_module_uuid, Cvoid, (Any, NTuple{2, UInt64}), __toplevel__, uuid)
end
unlock(require_lock)
try
include(__toplevel__, path)
return
finally
lock(require_lock)
if uuid !== old_uuid
ccall(:jl_set_module_uuid, Cvoid, (Any, NTuple{2, UInt64}), __toplevel__, old_uuid)
end
Expand Down
2 changes: 1 addition & 1 deletion base/toml_parser.jl
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ function Parser(str::String; filepath=nothing)
IdSet{TOMLDict}(), # defined_tables
root,
filepath,
isdefined(Base, :loaded_modules) ? get(Base.loaded_modules, DATES_PKGID, nothing) : nothing,
isdefined(Base, :maybe_root_module) ? Base.maybe_root_module(DATES_PKGID) : nothing,
)
startup(l)
return l
Expand Down
5 changes: 3 additions & 2 deletions src/dump.c
Original file line number Diff line number Diff line change
Expand Up @@ -2228,8 +2228,6 @@ static jl_array_t *jl_finalize_deserializer(jl_serializer_state *s, arraylist_t

JL_DLLEXPORT void jl_init_restored_modules(jl_array_t *init_order)
{
if (!init_order)
return;
int i, l = jl_array_len(init_order);
for (i = 0; i < l; i++) {
jl_value_t *mod = jl_array_ptr_ref(init_order, i);
Expand Down Expand Up @@ -2683,6 +2681,9 @@ static jl_value_t *_jl_restore_incremental(ios_t *f, jl_array_t *mod_array)
jl_recache_other(); // make all of the other objects identities correct (needs to be after insert methods)
htable_free(&uniquing_table);
jl_array_t *init_order = jl_finalize_deserializer(&s, tracee_list); // done with f and s (needs to be after recache)
if (init_order == NULL)
init_order = (jl_array_t*)jl_an_empty_vec_any;
assert(jl_isa((jl_value_t*)init_order, jl_array_any_type));

JL_GC_PUSH4(&init_order, &restored, &external_backedges, &external_edges);
jl_gc_enable(en); // subtyping can allocate a lot, not valid before recache-other
Expand Down
28 changes: 28 additions & 0 deletions test/threads_exec.jl
Original file line number Diff line number Diff line change
Expand Up @@ -912,3 +912,31 @@ end
@test reproducible_rand(r, 10) == val
end
end

# issue #41546, thread-safe package loading
@testset "package loading" begin
ch = Channel{Bool}(nthreads())
barrier = Base.Event()
old_act_proj = Base.ACTIVE_PROJECT[]
try
pushfirst!(LOAD_PATH, "@")
Base.ACTIVE_PROJECT[] = joinpath(@__DIR__, "TestPkg")
@sync begin
for _ in 1:nthreads()
Threads.@spawn begin
put!(ch, true)
wait(barrier)
@eval using TestPkg
end
end
for _ in 1:nthreads()
take!(ch)
end
notify(barrier)
end
@test Base.root_module(@__MODULE__, :TestPkg) isa Module
finally
Base.ACTIVE_PROJECT[] = old_act_proj
popfirst!(LOAD_PATH)
end
end

0 comments on commit 3d4b213

Please sign in to comment.