diff --git a/SnoopCompileCore/src/snoopi_deep.jl b/SnoopCompileCore/src/snoopi_deep.jl index 3719cf84..c569d598 100644 --- a/SnoopCompileCore/src/snoopi_deep.jl +++ b/SnoopCompileCore/src/snoopi_deep.jl @@ -70,28 +70,201 @@ function addchildren!(parent::InferenceTimingNode, t::Core.Compiler.Timings.Timi end end +module SnoopiDeepParallelism + +# Mutex ordering: MUTEX > jl_typeinf_lock +const MUTEX = ReentrantLock() + +mutable struct Invocation + # start_idx is mutated when older invocations are deleted and the profile is shifted. + start_idx::Int + stop_idx::Int + start_time::UInt64 + root_start_excl_time::UInt64 + root_stop_excl_time::UInt64 +end +function Invocation(start_idx, start_root_excl_time) + # Start at the current time. + return Invocation(start_idx, 0, time_ns(), start_root_excl_time, 0) +end + +""" +Global (locked) vector tracking running snoopi calls, and when they started. +- When one finishes, we lock(inference), export results, clear the inference profiles up to +the next oldest snoopi call, then unlock(inference). + +Imagine this is an ongoing inference profile, where each letter is another inference profile +result, and we start two profiles, 1 and 2, at the times indicated below: + ABCDEFGHIJKLMNOPQRSTUVWX + 1> 2> <1 <2 + + - invocations: [(1,A), (2,D)] + - 1 ends: + copy out ABCDEFGHIJKLMNOPQRSTU + pop (1,A) from invocations + read oldest invocation: (2,D) + delete up to D. + - New profile: + DEFGHIJKLMNOPQRSTUVWX + 2> <2 + + - 2 ends: + copy out DEFGHIJKLMNOPQRSTUVWX + pop (2,D) from invocations + no active invocations, so ... + ... delete up to X (end of this profile). +""" +const invocations = Invocation[] + +function _current_profile_stats_locked() + ccall(:jl_typeinf_lock_begin, Cvoid, ()) + try + inference_root_timing = Core.Compiler.Timings._timings[1] + children = inference_root_timing.children + # Since we were able to grab the lock, we must not be in an inference profile, + # meaning we are in ROOT(). So to get an accurate ROOT timing, we have to add the + # accumulated time since the ROOT was last updated: + accum_root_time = time_ns() - inference_root_timing.cur_start_time + current_root_time = inference_root_timing.time + accum_root_time + + return length(children), current_root_time + finally + ccall(:jl_typeinf_lock_end, Cvoid, ()) + end +end + +function _fetch_profile_buffer_locked(start_idx, stop_idx) + ccall(:jl_typeinf_lock_begin, Cvoid, ()) + try + inference_root_timing = Core.Compiler.Timings._timings[1] + children = inference_root_timing.children + return children[start_idx:stop_idx] + finally + ccall(:jl_typeinf_lock_end, Cvoid, ()) + end +end + +function start_timing_invocation() + # Locking respects mutex ordering. + Base.@lock MUTEX begin + current_profile_length, current_root_time = _current_profile_stats_locked() + profile_start_idx = current_profile_length + 1 + invocation = Invocation(profile_start_idx, current_root_time) + push!(invocations, invocation) + return invocation + end +end + +function stop_timing_invocation!(invocation) + invocation.stop_idx, invocation.root_stop_excl_time = _current_profile_stats_locked() +end + +function finish_timing_invocation_and_clear_profile(invocation) + # Locking respects mutex ordering. + Base.@lock MUTEX begin + # Check if this invocation was the oldest. If so, we'll want to clear the parts of + # the profile only it was using. + if invocations[1] !== invocation + idx = findfirst(==(invocation), invocations) + @assert idx !== nothing "invocation wasn't found in invocations: $invocation." + deleteat!(invocations, idx) + return + end + + # Clear this invocation from the invocations vector. + popfirst!(invocations) + + # Now clear the global inference profile up to the start of the next invocation. + # If no next invocations, clear them all. + if isempty(invocations) + ccall(:jl_typeinf_lock_begin, Cvoid, ()) + try + Core.Compiler.Timings.reset_timings() + finally + ccall(:jl_typeinf_lock_end, Cvoid, ()) + end + return + end + + # Else, we stop at the next oldest invocation. + next_oldest = invocations[1] + start_idx = next_oldest.start_idx + to_delete = start_idx - 1 + if to_delete == 0 + return + end + # Shift back the indices for all the running invocations + for running_invocation in invocations + running_invocation.start_idx -= to_delete + running_invocation.stop_idx -= to_delete + end + # Clear the profile up to the start of the new oldest invocation. + ccall(:jl_typeinf_lock_begin, Cvoid, ()) + try + inference_root_timing = Core.Compiler.Timings._timings[1] + children = inference_root_timing.children + deleteat!(children, 1:to_delete) + finally + ccall(:jl_typeinf_lock_end, Cvoid, ()) + end + end +end + +end # module + function start_deep_timing() - Core.Compiler.Timings.reset_timings() + invocation = SnoopiDeepParallelism.start_timing_invocation() Core.Compiler.__set_measure_typeinf(true) + return invocation end -function stop_deep_timing() +function stop_deep_timing!(invocation) Core.Compiler.__set_measure_typeinf(false) - Core.Compiler.Timings.close_current_timer() + return SnoopiDeepParallelism.stop_timing_invocation!(invocation) end -function finish_snoopi_deep() - return InferenceTimingNode(Core.Compiler.Timings._timings[1]) +function finish_snoopi_deep(invocation) + buffer = SnoopiDeepParallelism._fetch_profile_buffer_locked(invocation.start_idx, invocation.stop_idx) + + # Clean up the profile buffer, so that we don't leak memory. + SnoopiDeepParallelism.finish_timing_invocation_and_clear_profile(invocation) + + root_node = _create_finished_ROOT_Timing(invocation, buffer) + return InferenceTimingNode(root_node) end +# The MethodInstance for ROOT(), and default empty values for other fields. +# Copied from julia typeinf +root_inference_frame_info() = + Core.Compiler.Timings.InferenceFrameInfo(Core.Compiler.Timings.ROOTmi, 0x0, Any[], Any[Core.Const(Core.Compiler.Timings.ROOT)], 1) + +function _create_finished_ROOT_Timing(invocation, buffer) + total_time = time_ns() - invocation.start_time + + # Create a new ROOT() node, specific to this profiling invocation, which wraps the + # current profile buffer, and contains the total time for the profile. + return Core.Compiler.Timings.Timing( + root_inference_frame_info(), + invocation.start_time, + 0, + # Total exclusive time spent in ROOT during the lifetime of this node. + invocation.root_stop_excl_time - invocation.root_start_excl_time, + # Use the copied-out section of the profile buffer as the children of ROOT() + buffer, + ) +end + + + function _snoopi_deep(cmd::Expr) return quote - start_deep_timing() + invocation = start_deep_timing() try $(esc(cmd)) finally - stop_deep_timing() + stop_deep_timing!(invocation) end - finish_snoopi_deep() + # return the timing result: + finish_snoopi_deep(invocation) end end @@ -134,5 +307,5 @@ end # These are okay to come at the top-level because we're only measuring inference, and # inference results will be cached in a `.ji` file. precompile(start_deep_timing, ()) -precompile(stop_deep_timing, ()) -precompile(finish_snoopi_deep, ()) +precompile(stop_deep_timing!, (SnoopiDeepParallelism.Invocation,)) +precompile(finish_snoopi_deep, (SnoopiDeepParallelism.Invocation,)) diff --git a/test/snoopi_deep.jl b/test/snoopi_deep.jl index c07598ba..d5d6a8db 100644 --- a/test/snoopi_deep.jl +++ b/test/snoopi_deep.jl @@ -863,6 +863,7 @@ end # pgdsgui(axs[2], rit; bystr="Inclusive", consts=true, interactive=false) end + @testset "Stale" begin cproj = Base.active_project() cd(joinpath("testmodules", "Stale")) do @@ -944,6 +945,163 @@ end Pkg.activate(cproj) end +_name(frame::SnoopCompileCore.InferenceTiming) = frame.mi_info.mi.def.name + +@testset "reentrant concurrent profiles 1 - overlap" begin + # Warmup + @eval foo1(x) = x+2 + @eval foo1(2) + + # Test: + t1 = SnoopCompileCore.start_deep_timing() + + @eval foo1(x) = x+2 + @eval foo1(2) + + t2 = SnoopCompileCore.start_deep_timing() + + @eval foo2(x) = x+2 + @eval foo2(2) + + SnoopCompileCore.stop_deep_timing!(t1) + SnoopCompileCore.stop_deep_timing!(t2) + + prof1 = SnoopCompileCore.finish_snoopi_deep(t1) + prof2 = SnoopCompileCore.finish_snoopi_deep(t2) + + @test Set(_name.(SnoopCompile.flatten(prof1))) == Set([:ROOT, :foo1, :foo2]) + @test Set(_name.(SnoopCompile.flatten(prof2))) == Set([:ROOT, :foo2]) + + # Test Cleanup + @test isempty(SnoopCompileCore.SnoopiDeepParallelism.invocations) + @test isempty(Core.Compiler.Timings._timings[1].children) +end + +@testset "reentrant concurrent profiles 2 - interleaved" begin + # Warmup + @eval foo1(x) = x+2 + @eval foo1(2) + + # Test: + t1 = SnoopCompileCore.start_deep_timing() + + @eval foo1(x) = x+2 + @eval foo1(2) + + t2 = SnoopCompileCore.start_deep_timing() + + @eval foo2(x) = x+2 + @eval foo2(2) + + SnoopCompileCore.stop_deep_timing!(t1) + + @eval foo3(x) = x+2 + @eval foo3(2) + + SnoopCompileCore.stop_deep_timing!(t2) + + @eval foo4(x) = x+2 + @eval foo4(2) + + prof1 = SnoopCompileCore.finish_snoopi_deep(t1) + + @eval foo5(x) = x+2 + @eval foo5(2) + + prof2 = SnoopCompileCore.finish_snoopi_deep(t2) + + @test Set(_name.(SnoopCompile.flatten(prof1))) == Set([:ROOT, :foo1, :foo2]) + @test Set(_name.(SnoopCompile.flatten(prof2))) == Set([:ROOT, :foo2, :foo3]) + + # Test Cleanup + @test isempty(SnoopCompileCore.SnoopiDeepParallelism.invocations) + @test isempty(Core.Compiler.Timings._timings[1].children) +end + +@testset "reentrant concurrent profiles 3 - nested" begin + # Warmup + @eval foo1(x) = x+2 + @eval foo1(2) + + # Test: + local prof1, prof2, prof3 + prof1 = SnoopCompileCore.@snoopi_deep begin + @eval foo1(x) = x+2 + @eval foo1(2) + prof2 = SnoopCompileCore.@snoopi_deep begin + @eval foo2(x) = x+2 + @eval foo2(2) + prof3 = SnoopCompileCore.@snoopi_deep begin + @eval foo3(x) = x+2 + @eval foo3(2) + end + @eval foo4(x) = x+2 + @eval foo4(2) + end + @eval foo5(x) = x+2 + @eval foo5(2) + end + + @test Set(_name.(SnoopCompile.flatten(prof1))) == Set([:ROOT, :foo1, :foo2, :foo3, :foo4, :foo5]) + @test Set(_name.(SnoopCompile.flatten(prof2))) == Set([:ROOT, :foo2, :foo3, :foo4]) + @test Set(_name.(SnoopCompile.flatten(prof3))) == Set([:ROOT, :foo3]) + + # Test Cleanup + @test isempty(SnoopCompileCore.SnoopiDeepParallelism.invocations) + @test isempty(Core.Compiler.Timings._timings[1].children) +end + +@testset "reentrant concurrent profiles 3 - parallelism + accurate timing" begin + # Warmup + @eval foo1(x) = x+2 + @eval foo1(2) + + # Test: + local ts + snoop_times = Float64[0.0, 0.0, 0.0, 0.0] + # Run it twice to ensure we warmup the eval block + for _ in 1:2 + @sync begin + ts = [ + Threads.@spawn begin + sleep((i-1) / 10) # (Divide by 10 so the test isn't too slow) + snoop_time = @timed SnoopCompile.@snoopi_deep @eval begin + $(Symbol("foo$i"))(x) = x + 1 + sleep(1.5 / 10) + $(Symbol("foo$i"))(2) + end + snoop_times[i] = snoop_time.time + return snoop_time.value + end + for i in 1:4 + ] + end + end + profs = fetch.(ts) + + @test Set(_name.(SnoopCompile.flatten(profs[1]))) == Set([:ROOT, :foo1]) + @test Set(_name.(SnoopCompile.flatten(profs[2]))) == Set([:ROOT, :foo1, :foo2]) + @test Set(_name.(SnoopCompile.flatten(profs[3]))) == Set([:ROOT, :foo2, :foo3]) + @test Set(_name.(SnoopCompile.flatten(profs[4]))) == Set([:ROOT, :foo3, :foo4]) + + # Test the sanity of the reported Timings + @testset for i in eachindex(profs) + prof = profs[i] + # Test that the time for the inference is accounted for + @test 0.15 < prof.mi_timing.exclusive_time + @test prof.mi_timing.exclusive_time < prof.mi_timing.inclusive_time + # Test that the inclusive time (the total time reported by snoopi_deep) matches + # the actual time to do the snoopi_deep, as measured by `@time`. + # These should both be approximately ~0.15 seconds. + @info prof.mi_timing.inclusive_time + @test prof.mi_timing.inclusive_time <= snoop_times[i] + end + + # Test Cleanup + @test isempty(SnoopCompileCore.SnoopiDeepParallelism.invocations) + @test isempty(Core.Compiler.Timings._timings[1].children) +end + if Base.VERSION >= v"1.7" @testset "JET integration" begin f(c) = sum(c[1])