Skip to content

Commit

Permalink
support async sockets and callbacks in zmqstream
Browse files Browse the repository at this point in the history
  • Loading branch information
minrk committed Sep 27, 2022
1 parent 6ee341d commit f3f1d4c
Show file tree
Hide file tree
Showing 2 changed files with 92 additions and 8 deletions.
55 changes: 50 additions & 5 deletions zmq/eventloop/zmqstream.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,21 @@
"""

import asyncio
import pickle
import warnings
from queue import Queue
from typing import Any, Callable, List, Optional, Sequence, Union, cast, overload
from typing import (
Any,
Awaitable,
Callable,
List,
Optional,
Sequence,
Union,
cast,
overload,
)

import zmq
from zmq import POLLIN, POLLOUT
Expand Down Expand Up @@ -64,7 +75,7 @@ class ZMQStream:
register a callback to be run every time the socket has something to receive
* **on_send(callback):**
register a callback to be run every time you call send
* **send(self, msg, flags=0, copy=False, callback=None):**
* **send_multipart(self, msg, flags=0, copy=False, callback=None):**
perform a send that will trigger the callback
if callback is passed, on_send is also called.
Expand All @@ -86,6 +97,17 @@ class ZMQStream:
>>> stream.bind is stream.socket.bind
True
.. versionadded:: 25
send/recv callbacks can be coroutines.
.. versionadded:: 25
ZMQStreams can be created from async Sockets.
Previously, using async sockets (or any zmq.Socket subclass) would result in undefined behavior for the
arguments passed to callback functions.
Now, the callback functions reliably get the return value of the base `zmq.Socket` send/recv_multipart methods.
"""

socket: zmq.Socket
Expand All @@ -103,7 +125,16 @@ class ZMQStream:
def __init__(
self, socket: "zmq.Socket", io_loop: Optional["tornado.ioloop.IOLoop"] = None
):
if type(socket) is not zmq.Socket:
# shadow back to base zmq.Socket,
# otherwise callbacks like `on_recv` will get the wrong types.
# We know async sockets don't work,
# but other socket subclasses _may_.
# should we allow that?
# TODO: warn here?
socket = zmq.Socket(shadow=socket)
self.socket = socket

self.io_loop = io_loop or IOLoop.current()
self.poller = zmq.Poller()
self._fd = cast(int, self.socket.FD)
Expand Down Expand Up @@ -552,15 +583,29 @@ def _run_callback(self, callback, *args, **kwargs):
"""Wrap running callbacks in try/except to allow us to
close our socket."""
try:
# Use a NullContext to ensure that all StackContexts are run
# inside our blanket exception handler rather than outside.
callback(*args, **kwargs)
f = callback(*args, **kwargs)
if isinstance(f, Awaitable):
f = asyncio.ensure_future(f)
else:
f = None
except Exception:
gen_log.error("Uncaught exception in ZMQStream callback", exc_info=True)
# Re-raise the exception so that IOLoop.handle_callback_exception
# can see it and log the error
raise

if f is not None:
# handle async callbacks
def _log_error(f):
try:
f.result()
except Exception:
gen_log.error(
"Uncaught exception in ZMQStream callback", exc_info=True
)

f.add_done_callback(_log_error)

def _handle_events(self, fd, events):
"""This method is the actual handler for IOLoop, that gets called whenever
an event on my socket is posted. It dispatches to _handle_recv, etc."""
Expand Down
45 changes: 42 additions & 3 deletions zmq/tests/test_zmqstream.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@


import asyncio
import logging
from functools import partial
from unittest import TestCase

Expand All @@ -18,11 +19,20 @@
except ImportError:
tornado = None # type: ignore

caplog = None


@pytest.fixture
def get_caplog(caplog):
globals()["caplog"] = caplog


@pytest.mark.usefixtures("get_caplog")
class TestZMQStream(TestCase):
def setUp(self):
if tornado is None:
pytest.skip()
self._timeout_task = None
self.context = zmq.Context()
self.loop = ioloop.IOLoop(make_current=False)
if tornado and tornado.version_info < (5,):
Expand All @@ -36,7 +46,6 @@ async def _make_sockets():
port = self.push.bind_to_random_port('tcp://127.0.0.1')
self.pull.connect('tcp://127.0.0.1:%i' % port)
self.stream = self.push
self._timeout_task = None

def tearDown(self):
if self._timeout_task:
Expand Down Expand Up @@ -87,7 +96,6 @@ def callback(msg):
self.loop.stop()

self.loop.run_sync(partial(self.push.send_multipart, sent))
self.loop.call_later(1, lambda: self.pull.on_recv(callback))
self.pull.on_recv(callback)
self.run_until_timeout()

Expand All @@ -99,5 +107,36 @@ def callback(msg):
self.loop.stop()

self.pull.on_recv(callback)
self.loop.call_later(1, lambda: self.push.send_multipart(sent))
self.loop.call_later(0.5, lambda: self.push.send_multipart(sent))
self.run_until_timeout()

def test_on_recv_async(self):
sent = [b'wake']

async def callback(msg):
assert msg == sent
self.loop.stop()

self.pull.on_recv(callback)
self.loop.call_later(0.5, lambda: self.push.send_multipart(sent))
self.run_until_timeout()

def test_on_recv_async_error(self):
sent = [b'wake']

async def callback(msg):
ioloop.IOLoop.current().call_later(0.5, lambda: self.loop.stop())
assert msg == sent
1 / 0

self.pull.on_recv(callback)
self.loop.call_later(0.5, lambda: self.push.send_multipart(sent))
with caplog.at_level(logging.ERROR, logger=zmqstream.gen_log.name):
self.run_until_timeout()

messages = [
x.message
for x in caplog.get_records("call")
if x.name == zmqstream.gen_log.name
]
assert "Uncaught exception in ZMQStream callback" in "\n".join(messages)

0 comments on commit f3f1d4c

Please sign in to comment.