Skip to content

Commit

Permalink
Refactor heartbeat to shutdown cleanly (#1135)
Browse files Browse the repository at this point in the history
* Refactor heartbeat to shutdown cleanly

From ZMQ docs: "zmq_proxy() runs in the current thread and returns only
if/when the current context is closed."

The heartbeat socket doesn't need to be global, as nothing else touches
it. BUT, if we create the heartbeat socket in a `Context` that has a global ref,
we can close the context, which will cause zmq_proxy to return and then
that thread to end/finish.

Doing that before shutting down helps avoid a segfault on shutdown.

* Update src/heartbeat.jl

Co-authored-by: Steven G. Johnson <stevenj@mit.edu>

* Create heartbeat and context in `init.jl`

---------

Co-authored-by: Steven G. Johnson <stevenj@mit.edu>
  • Loading branch information
halleysfifthinc and stevengj authored Jan 29, 2025
1 parent eac46ab commit 2164bfc
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 12 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,5 +26,5 @@ Conda = "1"
JSON = "0.18,0.19,0.20,0.21,1"
MbedTLS = "0.5,0.6,0.7,1"
SoftGlobalScope = "1"
ZMQ = "1"
ZMQ = "1.3"
julia = "1.6"
3 changes: 3 additions & 0 deletions src/handlers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,9 @@ function connect_request(socket, msg)
end

function shutdown_request(socket, msg)
# stop heartbeat thread by closing the context
close(heartbeat_context[])

send_ipython(requests[], msg_reply(msg, "shutdown_reply",
msg.content))
sleep(0.1) # short delay (like in ipykernel), to hopefully ensure shutdown_reply is sent
Expand Down
19 changes: 10 additions & 9 deletions src/heartbeat.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,9 @@
import Libdl

const threadid = zeros(Int, 128) # sizeof(uv_thread_t) <= 8 on Linux, OSX, Win
const zmq_proxy = Ref(C_NULL)

# entry point for new thread
function heartbeat_thread(sock::Ptr{Cvoid})
function heartbeat_thread(heartbeat::Ptr{Cvoid})
@static if VERSION v"1.9.0-DEV.1588" # julia#46609
# julia automatically "adopts" this thread because
# we entered a Julia cfunction. We then have to enable
Expand All @@ -19,14 +18,16 @@ function heartbeat_thread(sock::Ptr{Cvoid})
# (see julia#47196)
ccall(:jl_gc_safe_enter, Int8, ())
end
ccall(zmq_proxy[], Cint, (Ptr{Cvoid}, Ptr{Cvoid}, Ptr{Cvoid}),
sock, sock, C_NULL)
nothing
ret = ZMQ.lib.zmq_proxy(heartbeat, heartbeat, C_NULL)
@static if VERSION v"1.9.0-DEV.1588" # julia#46609
# leave safe region if zmq_proxy returns (when context is closed)
ccall(:jl_gc_safe_leave, Int8, ())
end
return ret
end

function start_heartbeat(sock)
zmq_proxy[] = Libdl.dlsym(Libdl.dlopen(ZMQ.libzmq), :zmq_proxy)
heartbeat_c = @cfunction(heartbeat_thread, Cvoid, (Ptr{Cvoid},))
function start_heartbeat(heartbeat)
heartbeat_c = @cfunction(heartbeat_thread, Cint, (Ptr{Cvoid},))
ccall(:uv_thread_create, Cint, (Ptr{Int}, Ptr{Cvoid}, Ptr{Cvoid}),
threadid, heartbeat_c, sock)
threadid, heartbeat_c, heartbeat)
end
6 changes: 4 additions & 2 deletions src/init.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ const raw_input = Ref{Socket}()
const requests = Ref{Socket}()
const control = Ref{Socket}()
const heartbeat = Ref{Socket}()
const heartbeat_context = Ref{Context}()
const profile = Dict{String,Any}()
const read_stdout = Ref{Base.PipeEndpoint}()
const read_stderr = Ref{Base.PipeEndpoint}()
Expand Down Expand Up @@ -87,7 +88,8 @@ function init(args)
raw_input[] = Socket(ROUTER)
requests[] = Socket(ROUTER)
control[] = Socket(ROUTER)
heartbeat[] = Socket(ROUTER)
heartbeat_context[] = Context()
heartbeat = Socket(heartbeat_context[], ROUTER)
sep = profile["transport"]=="ipc" ? "-" : ":"
bind(publish[], "$(profile["transport"])://$(profile["ip"])$(sep)$(profile["iopub_port"])")
bind(requests[], "$(profile["transport"])://$(profile["ip"])$(sep)$(profile["shell_port"])")
Expand All @@ -97,7 +99,7 @@ function init(args)

# associate a lock with each socket so that multi-part messages
# on a given socket don't get inter-mingled between tasks.
for s in (publish[], raw_input[], requests[], control[], heartbeat[])
for s in (publish[], raw_input[], requests[], control[])
socket_locks[s] = ReentrantLock()
end

Expand Down

0 comments on commit 2164bfc

Please sign in to comment.