Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve shadowing, support async sockets/callbacks in ZMQStream #1785

Merged
merged 7 commits into from
Sep 28, 2022
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion zmq/backend/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,14 @@ class Socket:

# specific option types
FD: int

def __init__(
self,
context: Optional[Context] = None,
socket_type: int = 0,
shadow: int = 0,
copy_threshold: Optional[int] = zmq.COPY_THRESHOLD,
) -> None: ...
def close(self, linger: Optional[int] = ...) -> None: ...
def get(self, option: int) -> Union[int, bytes, str]: ...
def set(self, option: int, value: Union[int, bytes, str]) -> None: ...
Expand Down Expand Up @@ -84,7 +92,7 @@ class Socket:

class Context:
underlying: int
def __init__(self, io_threads: int = 1, shadow: Any = None): ...
def __init__(self, io_threads: int = 1, shadow: int = 0): ...
def get(self, option: int) -> Union[int, bytes, str]: ...
def set(self, option: int, value: Union[int, bytes, str]) -> None: ...
def socket(self, socket_type: int) -> Socket: ...
Expand Down
10 changes: 6 additions & 4 deletions zmq/backend/cffi/socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,11 +87,13 @@ class Socket:
_shadow = False
copy_threshold = 0

def __init__(self, context=None, socket_type=None, shadow=None):
def __init__(self, context=None, socket_type=None, shadow=0, copy_threshold=None):
if copy_threshold is None:
copy_threshold = zmq.COPY_THRESHOLD
self.copy_threshold = copy_threshold

self.context = context
if shadow is not None:
if isinstance(shadow, Socket):
shadow = shadow.underlying
if shadow:
self._zmq_socket = ffi.cast("void *", shadow)
self._shadow = True
else:
Expand Down
8 changes: 2 additions & 6 deletions zmq/backend/cython/socket.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -311,20 +311,16 @@ cdef class Socket:
.Context.socket : method for creating a socket bound to a Context.
"""

def __init__(self, context=None, socket_type=-1, shadow=0, copy_threshold=None):
def __init__(self, context=None, socket_type=-1, size_t shadow=0, copy_threshold=None):
if copy_threshold is None:
copy_threshold = zmq.COPY_THRESHOLD
self.copy_threshold = copy_threshold

self.handle = NULL
self.context = context
cdef size_t c_shadow
if shadow:
if isinstance(shadow, Socket):
shadow = shadow.underlying
c_shadow = shadow
self._shadow = True
self.handle = <void *>c_shadow
self.handle = <void *>shadow
else:
if context is None:
raise TypeError("context must be specified")
Expand Down
55 changes: 50 additions & 5 deletions zmq/eventloop/zmqstream.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,21 @@

"""

import asyncio
import pickle
import warnings
from queue import Queue
from typing import Any, Callable, List, Optional, Sequence, Union, cast, overload
from typing import (
Any,
Awaitable,
Callable,
List,
Optional,
Sequence,
Union,
cast,
overload,
)

import zmq
from zmq import POLLIN, POLLOUT
Expand Down Expand Up @@ -64,7 +75,7 @@ class ZMQStream:
register a callback to be run every time the socket has something to receive
* **on_send(callback):**
register a callback to be run every time you call send
* **send(self, msg, flags=0, copy=False, callback=None):**
* **send_multipart(self, msg, flags=0, copy=False, callback=None):**
perform a send that will trigger the callback
if callback is passed, on_send is also called.

Expand All @@ -86,6 +97,17 @@ class ZMQStream:
>>> stream.bind is stream.socket.bind
True


.. versionadded:: 25

send/recv callbacks can be coroutines.

.. versionadded:: 25

ZMQStreams can be created from async Sockets.
Previously, using async sockets (or any zmq.Socket subclass) would result in undefined behavior for the
arguments passed to callback functions.
Now, the callback functions reliably get the return value of the base `zmq.Socket` send/recv_multipart methods.
"""

socket: zmq.Socket
Expand All @@ -103,7 +125,16 @@ class ZMQStream:
def __init__(
self, socket: "zmq.Socket", io_loop: Optional["tornado.ioloop.IOLoop"] = None
):
if type(socket) is not zmq.Socket:
# shadow back to base zmq.Socket,
# otherwise callbacks like `on_recv` will get the wrong types.
# We know async sockets don't work,
# but other socket subclasses _may_.
# should we allow that?
# TODO: warn here?
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe explicitly override if the socket subclasses the asyncio socket, and warn otherwise?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

WDYT about warning in the async case? This is where behavior is changing, so a warning makes sense. But also now that shadowing is added, it works fine if the new behavior is what you expect, so a warning might be annoying!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe instead error out if the original socket is async but the callback is not? That way you're exposing the previously silent bad behavior.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm, that's a little tricky since on_recv is called later, and by then the socket has been cast to sync already.

You also can't tell for sure whether a callback is truly async without calling it. (it could be a sync function e.g. via a decorator that schedules an async handler, which is fine, or calls future.add_done_callback).

Maybe the best is to always warn/cast, and make explicit: ZMQStream only accepts zmq.Socket. If you give it anything else, it'll cast to sync, but will ask you to do the casting for next time.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fair enough, that works

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Went with explicit warning for async sockets, so if it casts, it warns.

Decided to ignore random custom socket classes, since they are rare, and likely implement custom serialization rather than something async. If they are subclasses of the async sockets, it'll the warning path.

Not sure 100%

socket = zmq.Socket(shadow=socket)
self.socket = socket

self.io_loop = io_loop or IOLoop.current()
self.poller = zmq.Poller()
self._fd = cast(int, self.socket.FD)
Expand Down Expand Up @@ -552,15 +583,29 @@ def _run_callback(self, callback, *args, **kwargs):
"""Wrap running callbacks in try/except to allow us to
close our socket."""
try:
# Use a NullContext to ensure that all StackContexts are run
# inside our blanket exception handler rather than outside.
callback(*args, **kwargs)
f = callback(*args, **kwargs)
if isinstance(f, Awaitable):
f = asyncio.ensure_future(f)
else:
f = None
except Exception:
gen_log.error("Uncaught exception in ZMQStream callback", exc_info=True)
# Re-raise the exception so that IOLoop.handle_callback_exception
# can see it and log the error
raise

if f is not None:
# handle async callbacks
def _log_error(f):
try:
f.result()
except Exception:
gen_log.error(
"Uncaught exception in ZMQStream callback", exc_info=True
)

f.add_done_callback(_log_error)

def _handle_events(self, fd, events):
"""This method is the actual handler for IOLoop, that gets called whenever
an event on my socket is posted. It dispatches to _handle_recv, etc."""
Expand Down
97 changes: 85 additions & 12 deletions zmq/sugar/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,25 @@
import atexit
import os
from threading import Lock
from typing import Any, Dict, Generic, List, Optional, Type, TypeVar
from typing import (
Any,
Callable,
Dict,
Generic,
List,
Optional,
Type,
TypeVar,
Union,
overload,
)
from warnings import warn
from weakref import WeakSet

from zmq.backend import Context as ContextBase
from zmq.constants import ContextOption, Errno, SocketOption
from zmq.error import ZMQError
from zmq.utils.interop import cast_int_addr

from .attrsettr import AttributeSetter, OptValT
from .socket import Socket
Expand Down Expand Up @@ -49,24 +61,68 @@ class Context(ContextBase, AttributeSetter, Generic[ST]):
but means that unclean destruction of contexts
(with sockets left open) is not safe
if sockets are managed in other threads.

.. versionadded:: 25

Contexts can now be shadowed by passing another Context.
This helps in creating an async copy of a sync context or vice versa::

ctx = zmq.Context(async_ctx)

Which previously had to be::

ctx = zmq.Context.shadow(async_ctx.underlying)
"""

sockopts: Dict[int, Any]
_instance: Any = None
_instance_lock = Lock()
_instance_pid: Optional[int] = None
_shadow = False
_shadow_obj = None
_warn_destroy_close = False
_sockets: WeakSet
# mypy doesn't like a default value here
_socket_class: Type[ST] = Socket # type: ignore

def __init__(self: "Context[Socket]", io_threads: int = 1, **kwargs: Any) -> None:
super().__init__(io_threads=io_threads, **kwargs)
if kwargs.get('shadow', False):
@overload
def __init__(self: "Context[Socket]", io_threads: int = 1):
...

@overload
def __init__(self: "Context[Socket]", io_threads: "Context"):
# this should be positional-only, but that requires 3.8
...

@overload
def __init__(self: "Context[Socket]", *, shadow: Union["Context", int]):
...

def __init__(
self: "Context[Socket]",
io_threads: Union[int, "Context"] = 1,
shadow: Union["Context", int] = 0,
) -> None:
if isinstance(io_threads, Context):
# allow positional shadow `zmq.Context(zmq.asyncio.Context())`
# this s
shadow = io_threads
io_threads = 1

shadow_address: int = 0
if shadow:
self._shadow = True
# hold a reference to the shadow object
self._shadow_obj = shadow
if not isinstance(shadow, int):
try:
shadow = shadow.underlying
except AttributeError:
pass
shadow_address = cast_int_addr(shadow)
else:
self._shadow = False
super().__init__(io_threads=io_threads, shadow=shadow_address)
self.sockopts = {}
self._sockets = WeakSet()

Expand Down Expand Up @@ -127,17 +183,18 @@ def __copy__(self: T, memo: Any = None) -> T:
__deepcopy__ = __copy__

@classmethod
def shadow(cls: Type[T], address: int) -> T:
def shadow(cls: Type[T], address: Union[int, "Context"]) -> T:
"""Shadow an existing libzmq context

address is the integer address of the libzmq context
or an FFI pointer to it.
address is a zmq.Context or an integer (or FFI pointer)
representing the address of the libzmq context.

.. versionadded:: 14.1
"""
from zmq.utils.interop import cast_int_addr

address = cast_int_addr(address)
.. versionadded:: 25
Support for shadowing `zmq.Context` objects,
instead of just integer addresses.
"""
return cls(shadow=address)

@classmethod
Expand Down Expand Up @@ -274,7 +331,12 @@ def destroy(self, linger: Optional[int] = None) -> None:

self.term()

def socket(self: T, socket_type: int, **kwargs: Any) -> ST:
def socket(
self: T,
socket_type: int,
socket_class: Callable[[T, int], ST] = None,
**kwargs: Any,
) -> ST:
"""Create a Socket associated with this Context.

Parameters
Expand All @@ -283,12 +345,20 @@ def socket(self: T, socket_type: int, **kwargs: Any) -> ST:
The socket type, which can be any of the 0MQ socket types:
REQ, REP, PUB, SUB, PAIR, DEALER, ROUTER, PULL, PUSH, etc.

socket_class: zmq.Socket or a subclass
The socket class to instantiate, if different from the default for this Context.
e.g. for creating an asyncio socket attached to a default Context or vice versa.

.. versionadded:: 25

kwargs:
will be passed to the __init__ method of the socket class.
"""
if self.closed:
raise ZMQError(Errno.ENOTSUP)
s: ST = self._socket_class( # set PYTHONTRACEMALLOC=2 to get the calling frame
if socket_class is None:
socket_class = self._socket_class
s: ST = socket_class( # set PYTHONTRACEMALLOC=2 to get the calling frame
self, socket_type, **kwargs
)
for opt, value in self.sockopts.items():
Expand Down Expand Up @@ -337,6 +407,9 @@ def _get_attr_opt(self, name: str, opt: int) -> OptValT:

def __delattr__(self, key: str) -> None:
"""delete default sockopts as attributes"""
if key in self.__dict__:
self.__dict__.pop(key)
return
key = key.upper()
try:
opt = getattr(SocketOption, key)
Expand Down
Loading