From f8aa70e9998abbfd55ef09250d58bb5ab63377ae Mon Sep 17 00:00:00 2001 From: Antoine Pitrou Date: Mon, 20 Feb 2017 19:35:04 +0100 Subject: [PATCH 01/10] Add distributed.comm.inproc draft --- distributed/comm/core.py | 3 - distributed/comm/inproc.py | 230 +++++++++++++++++++++++++++ distributed/comm/tests/test_comms.py | 149 ++++++++++++++++- 3 files changed, 375 insertions(+), 7 deletions(-) create mode 100644 distributed/comm/inproc.py diff --git a/distributed/comm/core.py b/distributed/comm/core.py index d0e127e3321..89be2d7ef58 100644 --- a/distributed/comm/core.py +++ b/distributed/comm/core.py @@ -112,9 +112,6 @@ def stop(self): Stop listening. This does not shutdown already established communications, but prevents accepting new ones. """ - tcp_server, self.tcp_server = self.tcp_server, None - if tcp_server is not None: - tcp_server.stop() @abstractproperty def listen_address(self): diff --git a/distributed/comm/inproc.py b/distributed/comm/inproc.py new file mode 100644 index 00000000000..ab3d2891ecd --- /dev/null +++ b/distributed/comm/inproc.py @@ -0,0 +1,230 @@ +from __future__ import print_function, division, absolute_import + +from collections import namedtuple +import itertools +import os +import sys +import threading +import weakref + +from tornado import gen, locks +from tornado.ioloop import IOLoop +from tornado.queues import Queue + +from ..compatibility import finalize +from ..utils import get_ip +from .core import (connectors, listeners, Comm, Listener, CommClosedError, + ) + + +ConnectionRequest = namedtuple('ConnectionRequest', + ('c2s_q', 's2c_q', 'c_loop', 'c_addr', + 'conn_event', 'close_request')) + +_Close = object() + + +class Manager(object): + + def __init__(self): + self.listeners = weakref.WeakValueDictionary() + self.addr_suffixes = itertools.count(1) + self.ip = get_ip() + self.lock = threading.Lock() + + def add_listener(self, addr, listener): + with self.lock: + if addr in self.listeners: + raise RuntimeError("already listening on %r" % (addr,)) + self.listeners[addr] = listener + + def remove_listener(self, addr): + with self.lock: + del self.listeners[addr] + + def get_listener_for(self, addr): + with self.lock: + self.validate_address(addr) + return self.listeners.get(addr) + + def new_address(self): + return "%s/%d/%s" % (self.ip, os.getpid(), next(self.addr_suffixes)) + + def validate_address(self, addr): + """ + Validate the address' IP and pid. + """ + ip, pid, suffix = addr.split('/') + if ip != self.ip or int(pid) != os.getpid(): + raise ValueError("inproc address %r does not match host (%r) or pid (%r)" + % (addr, self.ip, os.getpid())) + + +global_manager = Manager() + +def new_address(): + """ + Generate a new address. + """ + return 'inproc://' + global_manager.new_address() + + +class InProc(Comm): + """ + An established communication based on a pair of in-process queues. + """ + + def __init__(self, peer_addr, read_q, write_q, write_loop, + close_request, deserialize=True): + self._peer_addr = peer_addr + self.deserialize = deserialize + self._read_q = read_q + self._write_q = write_q + self._write_loop = write_loop + self._closed = False + # A "close request" event shared between both comms + self._close_request = close_request + + self._finalizer = finalize(self, self._get_finalizer()) + self._finalizer.atexit = False + + def _get_finalizer(self): + def finalize(write_q=self._write_q, write_loop=self._write_loop, + close_request=self._close_request): + if not close_request.is_set(): + logger.warn("Closing dangling queue in %s" % (r,)) + close_request.set() + write_loop.add_callback(write_q.put_nowait, _Close) + + return finalize + + def __repr__(self): + return "" % (self._peer_addr,) + + @property + def peer_address(self): + return self._peer_addr + + @gen.coroutine + def read(self, deserialize=None): + if self._closed: + raise CommClosedError + + msg = yield self._read_q.get() + if msg is _Close: + assert self._close_request.is_set() + self._closed = True + self._finalizer.detach() + raise CommClosedError + + # XXX does deserialize matter? + raise gen.Return(msg) + + @gen.coroutine + def write(self, msg): + if self._close_request.is_set(): + self._closed = True + self._finalizer.detach() + raise CommClosedError + + self._write_loop.add_callback(self._write_q.put_nowait, msg) + + raise gen.Return(1) + + @gen.coroutine + def close(self): + self.abort() + + def abort(self): + if not self._closed: + self._close_request.set() + self._write_loop.add_callback(self._write_q.put_nowait, _Close) + self._write_q = self._read_q = None + self._closed = True + self._finalizer.detach() + + def closed(self): + return self._closed + + +class InProcListener(Listener): + + def __init__(self, address, comm_handler, deserialize=True): + self.manager = global_manager + self.address = address or self.manager.new_address() + self.comm_handler = comm_handler + self.deserialize = deserialize + self.listen_q = Queue() + + @gen.coroutine + def _listen(self): + while True: + conn_req = yield self.listen_q.get() + if conn_req is None: + break + comm = InProc(peer_addr='inproc://' + conn_req.c_addr, + read_q=conn_req.c2s_q, + write_q=conn_req.s2c_q, + write_loop=conn_req.c_loop, + close_request=conn_req.close_request, + deserialize=self.deserialize) + # Notify connector + conn_req.c_loop.add_callback(conn_req.conn_event.set) + self.comm_handler(comm) + + def connect_threadsafe(self, conn_req): + self.loop.add_callback(self.listen_q.put_nowait, conn_req) + + def start(self): + self.loop = IOLoop.current() + self.loop.add_callback(self._listen) + self.manager.add_listener(self.address, self) + + def stop(self): + self.listen_q.put_nowait(None) + self.manager.remove_listener(self.address) + + @property + def listen_address(self): + return 'inproc://' + self.address + + @property + def contact_address(self): + return 'inproc://' + self.address + + +class InProcConnector(object): + + def __init__(self, manager): + self.manager = manager + + @gen.coroutine + def connect(self, address, deserialize=True): + listener = self.manager.get_listener_for(address) + if listener is None: + raise IOError("no endpoint for inproc address %r") + + conn_req = ConnectionRequest(c2s_q=Queue(), + s2c_q=Queue(), + c_loop=IOLoop.current(), + c_addr=global_manager.new_address(), + conn_event=locks.Event(), + close_request=threading.Event(), + ) + listener.connect_threadsafe(conn_req) + # Wait for connection acknowledgement + # (do not pretend we're connected if the other comm never gets + # created, for example if the listener was stopped in the meantime) + yield conn_req.conn_event.wait() + + comm = InProc(peer_addr='inproc://' + address, + read_q=conn_req.s2c_q, + write_q=conn_req.c2s_q, + write_loop=listener.loop, + close_request=conn_req.close_request, + deserialize=deserialize) + raise gen.Return(comm) + + +connectors['inproc'] = InProcConnector(global_manager) +listeners['inproc'] = InProcListener diff --git a/distributed/comm/tests/test_comms.py b/distributed/comm/tests/test_comms.py index 7d1c19f865e..2b6ef5593b3 100644 --- a/distributed/comm/tests/test_comms.py +++ b/distributed/comm/tests/test_comms.py @@ -1,6 +1,7 @@ from __future__ import print_function, division, absolute_import from functools import partial +import os import pytest @@ -12,7 +13,7 @@ from distributed.utils_test import (slow, loop, gen_test, gen_cluster, requires_ipv6, has_ipv6) -from distributed.comm import (tcp, connect, listen, CommClosedError, +from distributed.comm import (tcp, inproc, connect, listen, CommClosedError, is_zmq_enabled) from distributed.comm.core import (parse_address, parse_host_port, unparse_host_port, resolve_address) @@ -191,6 +192,56 @@ def client_communicate(key, delay=0): assert set(l) == {1234} | set(range(20)) +@gen_test() +def test_inproc_specific(): + """ + Test concrete InProc API. + """ + listener_addr = inproc.global_manager.new_address() + addr_head = listener_addr.rpartition('/')[0] + + client_addresses = set() + + @gen.coroutine + def handle_comm(comm): + assert comm.peer_address.startswith('inproc://' + addr_head) + client_addresses.add(comm.peer_address) + msg = yield comm.read() + msg['op'] = 'pong' + yield comm.write(msg) + yield comm.close() + + listener = inproc.InProcListener(listener_addr, handle_comm) + listener.start() + assert listener.listen_address == listener.contact_address == 'inproc://' + listener_addr + + connector = inproc.InProcConnector(inproc.global_manager) + l = [] + + @gen.coroutine + def client_communicate(key, delay=0): + comm = yield connector.connect(listener_addr) + assert comm.peer_address == 'inproc://' + listener_addr + yield comm.write({'op': 'ping', 'data': key}) + if delay: + yield gen.sleep(delay) + msg = yield comm.read() + assert msg == {'op': 'pong', 'data': key} + l.append(key) + yield comm.close() + + yield client_communicate(key=1234) + + # Many clients at once + N = 200 + futures = [client_communicate(key=i, delay=0.05) for i in range(N)] + yield futures + assert set(l) == {1234} | set(range(N)) + + assert len(client_addresses) == N + 1 + assert listener.contact_address not in client_addresses + + @gen.coroutine def check_client_server(addr, check_listen_addr=None, check_contact_addr=None): """ @@ -217,7 +268,7 @@ def handle_comm(comm): # Check listener properties bound_addr = listener.listen_address bound_scheme, bound_loc = parse_address(bound_addr) - assert bound_scheme in ('tcp', 'zmq') + assert bound_scheme in ('inproc', 'tcp', 'zmq') assert bound_scheme == parse_address(addr)[0] if check_listen_addr is not None: @@ -256,6 +307,8 @@ def client_communicate(key, delay=0): yield futures assert set(l) == {1234} | set(range(20)) + listener.stop() + def tcp_eq(expected_host, expected_port=None): def checker(loc): @@ -270,6 +323,17 @@ def checker(loc): zmq_eq = tcp_eq +def inproc_check(): + expected_ip = get_ip() + expected_pid = os.getpid() + + def checker(loc): + ip, pid, suffix = loc.split('/') + assert ip == expected_ip + assert int(pid) == expected_pid + + return checker + @gen_test() def test_default_client_server_ipv4(): @@ -343,8 +407,14 @@ def test_zmq_client_server_ipv6(): zmq_eq('::', 3252), zmq_eq(EXTERNAL_IP6, 3252)) +@gen_test() +def test_inproc_client_server(): + yield check_client_server('inproc://', inproc_check()) + yield check_client_server(inproc.new_address(), inproc_check()) + + @gen.coroutine -def check_comm_closed_implicit(addr): +def check_comm_closed_implicit(addr, delay=None): @gen.coroutine def handle_comm(comm): yield comm.close() @@ -371,6 +441,10 @@ def test_tcp_comm_closed_implicit(): #def test_zmq_comm_closed(): #yield check_comm_closed('zmq://127.0.0.1') +@gen_test() +def test_inproc_comm_closed_implicit(): + yield check_comm_closed_implicit(inproc.new_address()) + @gen.coroutine def check_comm_closed_explicit(addr): @@ -396,6 +470,9 @@ def handle_comm(comm): with pytest.raises(CommClosedError): yield comm.read() + yield gen.moment + + @gen_test() def test_tcp_comm_closed_explicit(): yield check_comm_closed_explicit('tcp://127.0.0.1') @@ -405,6 +482,57 @@ def test_tcp_comm_closed_explicit(): def test_zmq_comm_closed_explicit(): yield check_comm_closed_explicit('zmq://127.0.0.1') +@gen_test() +def test_inproc_comm_closed_explicit(): + yield check_comm_closed_explicit(inproc.new_address()) + +@gen_test() +def test_inproc_comm_closed_explicit_2(): + listener_errors = [] + + @gen.coroutine + def handle_comm(comm): + # Wait + try: + yield comm.read() + except CommClosedError: + assert comm.closed() + listener_errors.append(True) + else: + comm.close() + + listener = listen('inproc://', handle_comm) + listener.start() + contact_addr = listener.contact_address + + comm = yield connect(contact_addr) + comm.close() + assert comm.closed() + yield gen.sleep(0.01) + assert len(listener_errors) == 1 + + with pytest.raises(CommClosedError): + yield comm.read() + with pytest.raises(CommClosedError): + yield comm.write("foo") + + comm = yield connect(contact_addr) + comm.write("foo") + with pytest.raises(CommClosedError): + yield comm.read() + with pytest.raises(CommClosedError): + yield comm.write("foo") + assert comm.closed() + + comm = yield connect(contact_addr) + comm.write("foo") + yield gen.sleep(0.01) + # XXX comm.closed() is only true after the first time read() raises CommClosedError + #assert comm.closed() + + comm.close() + comm.close() + @gen.coroutine def check_connect_timeout(addr): @@ -419,6 +547,10 @@ def check_connect_timeout(addr): def test_tcp_connect_timeout(): yield check_connect_timeout('tcp://127.0.0.1:44444') +@gen_test() +def test_inproc_connect_timeout(): + yield check_connect_timeout(inproc.new_address()) + def check_many_listeners(addr): @gen.coroutine @@ -426,11 +558,16 @@ def handle_comm(comm): pass listeners = [] - for i in range(100): + N = 100 + + for i in range(N): listener = listen(addr, handle_comm) listener.start() listeners.append(listener) + assert len(set(l.listen_address for l in listeners)) == N + assert len(set(l.contact_address for l in listeners)) == N + for listener in listeners: listener.stop() @@ -440,3 +577,7 @@ def test_tcp_many_listeners(): check_many_listeners('tcp://127.0.0.1') check_many_listeners('tcp://0.0.0.0') check_many_listeners('tcp://') + +@gen_test() +def test_inproc_many_listeners(): + check_many_listeners('inproc://') From 1f5f2ca7dbbbd67e6a34c908194095b681ad607a Mon Sep 17 00:00:00 2001 From: Antoine Pitrou Date: Tue, 21 Feb 2017 10:27:49 +0100 Subject: [PATCH 02/10] Replace close_request hack with a custom peekable queue --- distributed/comm/inproc.py | 115 +++++++++++++++++++++------ distributed/comm/tests/test_comms.py | 3 +- 2 files changed, 91 insertions(+), 27 deletions(-) diff --git a/distributed/comm/inproc.py b/distributed/comm/inproc.py index ab3d2891ecd..fa729d7a09c 100644 --- a/distributed/comm/inproc.py +++ b/distributed/comm/inproc.py @@ -1,15 +1,16 @@ from __future__ import print_function, division, absolute_import -from collections import namedtuple +from collections import deque, namedtuple import itertools +import logging import os import sys import threading import weakref from tornado import gen, locks +from tornado.concurrent import Future from tornado.ioloop import IOLoop -from tornado.queues import Queue from ..compatibility import finalize from ..utils import get_ip @@ -17,14 +18,17 @@ ) +logger = logging.getLogger(__name__) + ConnectionRequest = namedtuple('ConnectionRequest', ('c2s_q', 's2c_q', 'c_loop', 'c_addr', - 'conn_event', 'close_request')) - -_Close = object() + 'conn_event')) class Manager(object): + """ + An object coordinating listeners and their addresses. + """ def __init__(self): self.listeners = weakref.WeakValueDictionary() @@ -69,32 +73,86 @@ def new_address(): return 'inproc://' + global_manager.new_address() +class QueueEmpty(Exception): + pass + + +class Queue(object): + """ + A single-reader, single-writer, non-threadsafe, peekable queue. + """ + + def __init__(self): + self._q = deque() + self._read_future = None + + def get_nowait(self): + q = self._q + if not q: + raise QueueEmpty + return q.popleft() + + def get(self): + assert not self._read_future, "Only one reader allowed" + fut = Future() + q = self._q + if q: + fut.set_result(q.popleft()) + else: + self._read_future = fut + return fut + + def put_nowait(self, value): + q = self._q + fut = self._read_future + if fut is not None: + assert len(q) == 0 + self._read_future = None + fut.set_result(value) + else: + q.append(value) + + put = put_nowait + + _omitted = object() + + def peek(self, default=_omitted): + """ + Get the next object in the queue without removing it from the queue. + """ + q = self._q + if q: + return q[0] + elif default is not self._omitted: + return default + else: + raise QueueEmpty + + +_EOF = object() + class InProc(Comm): """ An established communication based on a pair of in-process queues. """ def __init__(self, peer_addr, read_q, write_q, write_loop, - close_request, deserialize=True): + deserialize=True): self._peer_addr = peer_addr self.deserialize = deserialize self._read_q = read_q self._write_q = write_q self._write_loop = write_loop self._closed = False - # A "close request" event shared between both comms - self._close_request = close_request self._finalizer = finalize(self, self._get_finalizer()) self._finalizer.atexit = False def _get_finalizer(self): def finalize(write_q=self._write_q, write_loop=self._write_loop, - close_request=self._close_request): - if not close_request.is_set(): - logger.warn("Closing dangling queue in %s" % (r,)) - close_request.set() - write_loop.add_callback(write_q.put_nowait, _Close) + r=repr(self)): + logger.warn("Closing dangling queue in %s" % (r,)) + write_loop.add_callback(write_q.put_nowait, _EOF) return finalize @@ -111,8 +169,7 @@ def read(self, deserialize=None): raise CommClosedError msg = yield self._read_q.get() - if msg is _Close: - assert self._close_request.is_set() + if msg is _EOF: self._closed = True self._finalizer.detach() raise CommClosedError @@ -122,9 +179,7 @@ def read(self, deserialize=None): @gen.coroutine def write(self, msg): - if self._close_request.is_set(): - self._closed = True - self._finalizer.detach() + if self.closed(): raise CommClosedError self._write_loop.add_callback(self._write_q.put_nowait, msg) @@ -136,15 +191,28 @@ def close(self): self.abort() def abort(self): - if not self._closed: - self._close_request.set() - self._write_loop.add_callback(self._write_q.put_nowait, _Close) + if not self.closed(): + # Putting EOF is cheap enough that we do it on abort() too + self._write_loop.add_callback(self._write_q.put_nowait, _EOF) self._write_q = self._read_q = None self._closed = True self._finalizer.detach() def closed(self): - return self._closed + """ + Whether this InProc comm is closed. It is closed iff: + 1) close() or abort() was called on this comm + 2) close() or abort() was called on the other end and the + read queue is empty + """ + if self._closed: + return True + if self._read_q.peek(None) is _EOF: + self._closed = True + self._finalizer.detach() + return True + else: + return False class InProcListener(Listener): @@ -166,7 +234,6 @@ def _listen(self): read_q=conn_req.c2s_q, write_q=conn_req.s2c_q, write_loop=conn_req.c_loop, - close_request=conn_req.close_request, deserialize=self.deserialize) # Notify connector conn_req.c_loop.add_callback(conn_req.conn_event.set) @@ -209,7 +276,6 @@ def connect(self, address, deserialize=True): c_loop=IOLoop.current(), c_addr=global_manager.new_address(), conn_event=locks.Event(), - close_request=threading.Event(), ) listener.connect_threadsafe(conn_req) # Wait for connection acknowledgement @@ -221,7 +287,6 @@ def connect(self, address, deserialize=True): read_q=conn_req.s2c_q, write_q=conn_req.c2s_q, write_loop=listener.loop, - close_request=conn_req.close_request, deserialize=deserialize) raise gen.Return(comm) diff --git a/distributed/comm/tests/test_comms.py b/distributed/comm/tests/test_comms.py index 2b6ef5593b3..49e209e3648 100644 --- a/distributed/comm/tests/test_comms.py +++ b/distributed/comm/tests/test_comms.py @@ -527,8 +527,7 @@ def handle_comm(comm): comm = yield connect(contact_addr) comm.write("foo") yield gen.sleep(0.01) - # XXX comm.closed() is only true after the first time read() raises CommClosedError - #assert comm.closed() + assert comm.closed() comm.close() comm.close() From 0615b3c74f70e18bffd46bca5a256386c8175e8d Mon Sep 17 00:00:00 2001 From: Antoine Pitrou Date: Tue, 21 Feb 2017 11:45:26 +0100 Subject: [PATCH 03/10] Add deserialization tests --- distributed/comm/core.py | 7 + distributed/comm/inproc.py | 38 ++++- distributed/comm/tests/test_comms.py | 146 +++++++++++++++++++ distributed/protocol/serialize.py | 15 ++ distributed/protocol/tests/test_serialize.py | 20 ++- 5 files changed, 223 insertions(+), 3 deletions(-) diff --git a/distributed/comm/core.py b/distributed/comm/core.py index 89be2d7ef58..61452cd9a4d 100644 --- a/distributed/comm/core.py +++ b/distributed/comm/core.py @@ -127,6 +127,13 @@ def contact_address(self): address such as 'tcp://0.0.0.0:123'. """ + def __enter__(self): + self.start() + return self + + def __exit__(self, *exc): + self.stop() + def parse_address(addr): """ diff --git a/distributed/comm/inproc.py b/distributed/comm/inproc.py index fa729d7a09c..c4553a18c22 100644 --- a/distributed/comm/inproc.py +++ b/distributed/comm/inproc.py @@ -13,6 +13,7 @@ from tornado.ioloop import IOLoop from ..compatibility import finalize +from ..protocol import deserialize, Serialize, Serialized from ..utils import get_ip from .core import (connectors, listeners, Comm, Listener, CommClosedError, ) @@ -129,8 +130,41 @@ def peek(self, default=_omitted): raise QueueEmpty +def _maybe_deserialize(msg): + """ + Replace all nested Serialize and Serialized values in *msg* + with their original object. Returns a copy of *msg*. + """ + def replace_inner(x): + if type(x) is dict: + x = x.copy() + for k, v in x.items(): + typ = type(v) + if typ is dict or typ is list: + x[k] = replace_inner(v) + elif typ is Serialize: + x[k] = v.data + elif typ is Serialized: + x[k] = deserialize(v.header, v.frames) + + elif type(x) is list: + x = list(x) + for k, v in enumerate(x): + typ = type(v) + if typ is dict or typ is list: + x[k] = replace_inner(v) + elif typ is Serialize: + x[k] = v.data + elif typ is Serialized: + x[k] = deserialize(v.header, v.frames) + + return x + + return replace_inner(msg) + _EOF = object() + class InProc(Comm): """ An established communication based on a pair of in-process queues. @@ -174,7 +208,9 @@ def read(self, deserialize=None): self._finalizer.detach() raise CommClosedError - # XXX does deserialize matter? + deserialize = deserialize if deserialize is not None else self.deserialize + if deserialize: + msg = _maybe_deserialize(msg) raise gen.Return(msg) @gen.coroutine diff --git a/distributed/comm/tests/test_comms.py b/distributed/comm/tests/test_comms.py index 49e209e3648..b4ec9cc2258 100644 --- a/distributed/comm/tests/test_comms.py +++ b/distributed/comm/tests/test_comms.py @@ -13,6 +13,9 @@ from distributed.utils_test import (slow, loop, gen_test, gen_cluster, requires_ipv6, has_ipv6) +from distributed.protocol import (loads, dumps, + to_serialize, Serialized, serialize, deserialize) + from distributed.comm import (tcp, inproc, connect, listen, CommClosedError, is_zmq_enabled) from distributed.comm.core import (parse_address, parse_host_port, @@ -44,6 +47,10 @@ def debug_loop(): yield gen.sleep(0.50) +# +# Test utility functions +# + def test_parse_host_port(): f = parse_host_port @@ -105,6 +112,10 @@ def test_resolve_address(): assert f('zmq://localhost:789') == 'zmq://127.0.0.1:789' +# +# Test concrete transport APIs +# + @gen_test() def test_tcp_specific(): """ @@ -242,6 +253,10 @@ def client_communicate(key, delay=0): assert listener.contact_address not in client_addresses +# +# Test communications through the abstract API +# + @gen.coroutine def check_client_server(addr, check_listen_addr=None, check_contact_addr=None): """ @@ -413,6 +428,10 @@ def test_inproc_client_server(): yield check_client_server(inproc.new_address(), inproc_check()) +# +# Test communication closing +# + @gen.coroutine def check_comm_closed_implicit(addr, delay=None): @gen.coroutine @@ -533,6 +552,10 @@ def handle_comm(comm): comm.close() +# +# Various stress tests +# + @gen.coroutine def check_connect_timeout(addr): t1 = time() @@ -580,3 +603,126 @@ def test_tcp_many_listeners(): @gen_test() def test_inproc_many_listeners(): check_many_listeners('inproc://') + + +# +# Test deserialization +# + +@gen.coroutine +def check_listener_deserialize(addr, deserialize, in_value, check_out): + q = queues.Queue() + + @gen.coroutine + def handle_comm(comm): + msg = yield comm.read() + q.put_nowait(msg) + yield comm.close() + + with listen(addr, handle_comm, deserialize=deserialize) as listener: + comm = yield connect(listener.contact_address) + + yield comm.write(in_value) + yield comm.close() + + out_value = yield q.get() + check_out(out_value) + +@gen.coroutine +def check_connector_deserialize(addr, deserialize, in_value, check_out): + q = queues.Queue() + + @gen.coroutine + def handle_comm(comm): + msg = yield q.get() + yield comm.write(msg) + yield comm.close() + + with listen(addr, handle_comm) as listener: + comm = yield connect(listener.contact_address, deserialize=deserialize) + + q.put_nowait(in_value) + out_value = yield comm.read() + yield comm.close() + check_out(out_value) + +@gen.coroutine +def check_deserialize(addr): + # Create a valid Serialized object + # (if using serialize(), it will lack a compression header) + ser = loads(dumps({'x': to_serialize(456)}), deserialize=False)['x'] + assert isinstance(ser, Serialized) + + # Test with Serialize and Serialized objects + + msg = {'op': 'update', + 'x': b'abc', + 'to_ser': [to_serialize(123)], + 'ser': ser, + } + msg_orig = msg.copy() + + def check_out_false(out_value): + # Check output with deserialize=False + out_value = out_value.copy() # in case transport passed the object as-is + to_ser = out_value.pop('to_ser') + ser = out_value.pop('ser') + expected_msg = msg_orig.copy() + del expected_msg['ser'] + del expected_msg['to_ser'] + assert out_value == expected_msg + + assert isinstance(ser, Serialized) + assert deserialize(ser.header, ser.frames) == 456 + + assert isinstance(to_ser, list) + to_ser, = to_ser + # The to_serialize() value could have been actually serialized + # or not (it's a transport-specific optimization) + if isinstance(to_ser, Serialized): + assert deserialize(to_ser.header, to_ser.frames) == 123 + else: + assert to_ser == to_serialize(123) + + def check_out_true(out_value): + # Check output with deserialize=True + expected_msg = msg.copy() + expected_msg['ser'] = 456 + expected_msg['to_ser'] = [123] + assert out_value == expected_msg + + yield check_listener_deserialize(addr, False, msg, check_out_false) + yield check_connector_deserialize(addr, False, msg, check_out_false) + + yield check_listener_deserialize(addr, True, msg, check_out_true) + yield check_connector_deserialize(addr, True, msg, check_out_true) + + # Test with a long bytestring + + msg = {'op': 'update', + 'x': b'abc', + 'y': b'def\n' * (2 ** 20), + } + msg_orig = msg.copy() + + def check_out(out_value): + assert out_value == msg_orig + + yield check_listener_deserialize(addr, False, msg, check_out) + yield check_connector_deserialize(addr, False, msg, check_out) + + yield check_listener_deserialize(addr, True, msg, check_out) + yield check_connector_deserialize(addr, True, msg, check_out) + + +@gen_test() +def test_tcp_deserialize(): + yield check_deserialize('tcp://') + +@gen_test() +def test_zmq_deserialize(): + yield check_deserialize('zmq://0.0.0.0') + +@gen_test() +def test_inproc_deserialize(): + yield check_deserialize('inproc://') diff --git a/distributed/protocol/serialize.py b/distributed/protocol/serialize.py index cec077ac1d9..3b1a6844063 100644 --- a/distributed/protocol/serialize.py +++ b/distributed/protocol/serialize.py @@ -174,6 +174,13 @@ def __str__(self): __repr__ = __str__ + def __eq__(self, other): + return (isinstance(other, Serialize) and + other.data == self.data) + + def __ne__(self, other): + return not (self == other) + to_serialize = Serialize @@ -195,6 +202,14 @@ def deserialize(self): frames = decompress(self.header, self.frames) return deserialize(self.header, frames) + def __eq__(self, other): + return (isinstance(other, Serialized) and + other.header == self.header and + other.frames == self.frames) + + def __ne__(self, other): + return not (self == other) + def container_copy(c): typ = type(c) diff --git a/distributed/protocol/tests/test_serialize.py b/distributed/protocol/tests/test_serialize.py index c7bf24d491c..e0e5e5a736e 100644 --- a/distributed/protocol/tests/test_serialize.py +++ b/distributed/protocol/tests/test_serialize.py @@ -58,8 +58,24 @@ def test_Serialize(): assert '123' in str(s) assert s.data == 123 - s = Serialize((1, 2)) - assert str(s) + t = Serialize((1, 2)) + assert str(t) + + u = Serialize(123) + assert s == u + assert not (s != u) + assert s != t + assert not (s == t) + + +def test_Serialized(): + s = Serialized(*serialize(123)) + t = Serialized(*serialize((1, 2))) + u = Serialized(*serialize(123)) + assert s == u + assert not (s != u) + assert s != t + assert not (s == t) from distributed.utils_test import gen_cluster From aba7147bb5c21b6cd9ae6db314a4df87eef80eb6 Mon Sep 17 00:00:00 2001 From: Antoine Pitrou Date: Tue, 21 Feb 2017 16:07:21 +0100 Subject: [PATCH 04/10] Add multi-thread test for inproc queues --- distributed/comm/inproc.py | 6 ++- distributed/comm/tests/test_comms.py | 72 ++++++++++++++++++++++------ 2 files changed, 62 insertions(+), 16 deletions(-) diff --git a/distributed/comm/inproc.py b/distributed/comm/inproc.py index c4553a18c22..6c03d360996 100644 --- a/distributed/comm/inproc.py +++ b/distributed/comm/inproc.py @@ -168,6 +168,9 @@ def replace_inner(x): class InProc(Comm): """ An established communication based on a pair of in-process queues. + + Reminder: a Comm must always be used from a single thread. + Its peer Comm can be running in any thread. """ def __init__(self, peer_addr, read_q, write_q, write_loop, @@ -218,6 +221,7 @@ def write(self, msg): if self.closed(): raise CommClosedError + # Ensure we feed the queue in the same thread it is read from. self._write_loop.add_callback(self._write_q.put_nowait, msg) raise gen.Return(1) @@ -236,7 +240,7 @@ def abort(self): def closed(self): """ - Whether this InProc comm is closed. It is closed iff: + Whether this comm is closed. An InProc comm is closed if: 1) close() or abort() was called on this comm 2) close() or abort() was called on the other end and the read queue is empty diff --git a/distributed/comm/tests/test_comms.py b/distributed/comm/tests/test_comms.py index b4ec9cc2258..8ff8c238672 100644 --- a/distributed/comm/tests/test_comms.py +++ b/distributed/comm/tests/test_comms.py @@ -2,10 +2,13 @@ from functools import partial import os +import sys +import threading import pytest from tornado import gen, ioloop, queues +from tornado.concurrent import Future from distributed.core import pingpong from distributed.metrics import time @@ -154,9 +157,10 @@ def client_communicate(key, delay=0): yield client_communicate(key=1234) # Many clients at once - futures = [client_communicate(key=i, delay=0.05) for i in range(20)] + N = 100 + futures = [client_communicate(key=i, delay=0.05) for i in range(N)] yield futures - assert set(l) == {1234} | set(range(20)) + assert set(l) == {1234} | set(range(N)) @requires_zmq @@ -198,13 +202,14 @@ def client_communicate(key, delay=0): yield client_communicate(key=1234) # Many clients at once - futures = [client_communicate(key=i, delay=0.05) for i in range(20)] + N = 20 + futures = [client_communicate(key=i, delay=0.05) for i in range(N)] yield futures - assert set(l) == {1234} | set(range(20)) + assert set(l) == {1234} | set(range(N)) -@gen_test() -def test_inproc_specific(): +@gen.coroutine +def check_inproc_specific(run_client): """ Test concrete InProc API. """ @@ -213,13 +218,16 @@ def test_inproc_specific(): client_addresses = set() + N_MSGS = 3 + @gen.coroutine def handle_comm(comm): assert comm.peer_address.startswith('inproc://' + addr_head) client_addresses.add(comm.peer_address) - msg = yield comm.read() - msg['op'] = 'pong' - yield comm.write(msg) + for i in range(N_MSGS): + msg = yield comm.read() + msg['op'] = 'pong' + yield comm.write(msg) yield comm.close() listener = inproc.InProcListener(listener_addr, handle_comm) @@ -233,19 +241,22 @@ def handle_comm(comm): def client_communicate(key, delay=0): comm = yield connector.connect(listener_addr) assert comm.peer_address == 'inproc://' + listener_addr - yield comm.write({'op': 'ping', 'data': key}) - if delay: - yield gen.sleep(delay) - msg = yield comm.read() + for i in range(N_MSGS): + yield comm.write({'op': 'ping', 'data': key}) + if delay: + yield gen.sleep(delay) + msg = yield comm.read() assert msg == {'op': 'pong', 'data': key} l.append(key) yield comm.close() + client_communicate = partial(run_client, client_communicate) + yield client_communicate(key=1234) # Many clients at once - N = 200 - futures = [client_communicate(key=i, delay=0.05) for i in range(N)] + N = 20 + futures = [client_communicate(key=i, delay=0.001) for i in range(N)] yield futures assert set(l) == {1234} | set(range(N)) @@ -253,6 +264,37 @@ def client_communicate(key, delay=0): assert listener.contact_address not in client_addresses +def run_coro(func, *args, **kwargs): + return func(*args, **kwargs) + +def run_coro_in_thread(func, *args, **kwargs): + fut = Future() + main_loop = ioloop.IOLoop.current() + + def run(): + thread_loop = ioloop.IOLoop() # need fresh IO loop for run_sync() + try: + res = thread_loop.run_sync(partial(func, *args, **kwargs), + timeout=10) + except: + main_loop.add_callback(fut.set_exc_info, sys.exc_info()) + else: + main_loop.add_callback(fut.set_result, res) + + t = threading.Thread(target=run) + t.start() + return fut + + +@gen_test() +def test_inproc_specific_same_thread(): + yield check_inproc_specific(run_coro) + +@gen_test() +def test_inproc_specific_different_threads(): + yield check_inproc_specific(run_coro_in_thread) + + # # Test communications through the abstract API # From b8dc6681a8eebbd4fef9fff246773c70e71ac398 Mon Sep 17 00:00:00 2001 From: Antoine Pitrou Date: Tue, 21 Feb 2017 16:34:14 +0100 Subject: [PATCH 05/10] Add inproc tests to test_core --- distributed/comm/__init__.py | 1 + distributed/comm/inproc.py | 2 +- distributed/tests/test_core.py | 162 +++++++++++++++++++++------------ 3 files changed, 106 insertions(+), 59 deletions(-) diff --git a/distributed/comm/__init__.py b/distributed/comm/__init__.py index ae79f6a9f95..b38297d074f 100644 --- a/distributed/comm/__init__.py +++ b/distributed/comm/__init__.py @@ -13,6 +13,7 @@ def is_zmq_enabled(): def _register_transports(): + from . import inproc from . import tcp if is_zmq_enabled(): diff --git a/distributed/comm/inproc.py b/distributed/comm/inproc.py index 6c03d360996..7e8854f584f 100644 --- a/distributed/comm/inproc.py +++ b/distributed/comm/inproc.py @@ -309,7 +309,7 @@ def __init__(self, manager): def connect(self, address, deserialize=True): listener = self.manager.get_listener_for(address) if listener is None: - raise IOError("no endpoint for inproc address %r") + raise IOError("no endpoint for inproc address %r" % (address,)) conn_req = ConnectionRequest(c2s_q=Queue(), s2c_q=Queue(), diff --git a/distributed/tests/test_core.py b/distributed/tests/test_core.py index 7ab200e56de..f64b8e8c6b1 100644 --- a/distributed/tests/test_core.py +++ b/distributed/tests/test_core.py @@ -2,6 +2,7 @@ from contextlib import contextmanager from functools import partial +import os import socket from tornado import gen, ioloop @@ -14,7 +15,8 @@ from distributed.utils import get_ip, get_ipv6 from distributed.utils_test import ( slow, loop, gen_test, gen_cluster, has_ipv6, - assert_can_connect, assert_can_connect_from_everywhere_4, + assert_can_connect, assert_cannot_connect, + assert_can_connect_from_everywhere_4, assert_can_connect_from_everywhere_4_6, assert_can_connect_from_everywhere_6, assert_can_connect_locally_4, assert_can_connect_locally_6) @@ -150,29 +152,58 @@ def listen_on(cls, *args): yield assert_can_connect(server.address) yield assert_can_connect_locally_6(server.port) + # InProc -def test_rpc(loop): - @gen.coroutine - def f(): - server = Server({'ping': pingpong}) - server.listen(8883) + with listen_on(Server, 'inproc://') as server: + inproc_addr1 = server.address + assert inproc_addr1.startswith('inproc://%s/%d/' % (get_ip(), os.getpid())) + yield assert_can_connect(inproc_addr1) - with rpc('127.0.0.1:8883') as remote: - response = yield remote.ping() - assert response == b'pong' + with listen_on(Server, 'inproc://') as server2: + inproc_addr2 = server2.address + assert inproc_addr2.startswith('inproc://%s/%d/' % (get_ip(), os.getpid())) + yield assert_can_connect(inproc_addr2) - assert remote.comms - assert remote.address == 'tcp://127.0.0.1:8883' + yield assert_can_connect(inproc_addr1) + yield assert_cannot_connect(inproc_addr2) - response = yield remote.ping(close=True) - assert response == b'pong' - assert not remote.comms - assert remote.status == 'closed' +@gen.coroutine +def check_rpc(listen_arg, rpc_arg=None): + server = Server({'ping': pingpong}) + server.listen(listen_arg) + if rpc_arg is None: + rpc_arg = server.address - server.stop() + with rpc(rpc_arg) as remote: + response = yield remote.ping() + assert response == b'pong' + assert remote.comms - loop.run_sync(f) + response = yield remote.ping(close=True) + assert response == b'pong' + response = yield remote.ping() + assert response == b'pong' + + assert not remote.comms + assert remote.status == 'closed' + + server.stop() + + +@gen_test() +def test_rpc_default(): + yield check_rpc(8883, '127.0.0.1:8883') + yield check_rpc(8883) + +@gen_test() +def test_rpc_tcp(): + yield check_rpc('tcp://:8883', 'tcp://127.0.0.1:8883') + yield check_rpc('tcp://') + +@gen_test() +def test_rpc_inproc(): + yield check_rpc('inproc://', None) def test_rpc_inputs(): @@ -187,70 +218,85 @@ def test_rpc_inputs(): r.close_rpc() -def test_rpc_with_many_connections(loop): - remote = rpc(('127.0.0.1', 8885)) - +@gen.coroutine +def check_rpc_with_many_connections(listen_arg): @gen.coroutine def g(): for i in range(10): yield remote.ping() - @gen.coroutine - def f(): - server = Server({'ping': pingpong}) - server.listen(8885) + server = Server({'ping': pingpong}) + server.listen(listen_arg) - yield [g() for i in range(10)] + remote = rpc(server.address) + yield [g() for i in range(10)] - server.stop() + server.stop() - remote.close_comms() - assert all(comm.closed() for comm in remote.comms) + remote.close_comms() + assert all(comm.closed() for comm in remote.comms) - loop.run_sync(f) +@gen_test() +def test_rpc_with_many_connections_tcp(): + yield check_rpc_with_many_connections('tcp://') + +@gen_test() +def test_rpc_with_many_connections_inproc(): + yield check_rpc_with_many_connections('inproc://') -def echo(stream, x): +def echo(comm, x): return x -@slow -def test_large_packets(loop): +@gen.coroutine +def check_large_packets(listen_arg): """ tornado has a 100MB cap by default """ - @gen.coroutine - def f(): - server = Server({'echo': echo}) - server.listen(8886) + server = Server({'echo': echo}) + server.listen(listen_arg) - data = b'0' * int(200e6) # slightly more than 100MB - conn = rpc('127.0.0.1:8886') - result = yield conn.echo(x=data) - assert result == data + data = b'0' * int(200e6) # slightly more than 100MB + conn = rpc(server.address) + result = yield conn.echo(x=data) + assert result == data - d = {'x': data} - result = yield conn.echo(x=d) - assert result == d + d = {'x': data} + result = yield conn.echo(x=d) + assert result == d - conn.close_comms() - server.stop() + conn.close_comms() + server.stop() - loop.run_sync(f) +@slow +@gen_test() +def test_large_packets_tcp(): + yield check_large_packets('tcp://') -def test_identity(loop): - @gen.coroutine - def f(): - server = Server({}) - server.listen(8887) +@gen_test() +def test_large_packets_inproc(): + yield check_large_packets('inproc://') - with rpc(('127.0.0.1', 8887)) as remote: - a = yield remote.identity() - b = yield remote.identity() - assert a['type'] == 'Server' - assert a['id'] == b['id'] - server.stop() +@gen.coroutine +def check_identity(listen_arg): + server = Server({}) + server.listen(listen_arg) - loop.run_sync(f) + with rpc(server.address) as remote: + a = yield remote.identity() + b = yield remote.identity() + assert a['type'] == 'Server' + assert a['id'] == b['id'] + + server.stop() + +@gen_test() +def test_identity_tcp(): + yield check_identity('tcp://') + +@gen_test() +def test_identity_inproc(): + yield check_identity('inproc://') def test_ports(loop): From 95c41487c5c6b276b0f319fc89e5012048c0bb43 Mon Sep 17 00:00:00 2001 From: Antoine Pitrou Date: Tue, 21 Feb 2017 17:02:09 +0100 Subject: [PATCH 06/10] Refactor address manipulation functions in a separate d.comm.addressing module --- distributed/cli/utils.py | 4 +- distributed/comm/__init__.py | 4 + distributed/comm/addressing.py | 115 +++++++++++++++++++++++++++ distributed/comm/core.py | 112 +------------------------- distributed/comm/tcp.py | 4 +- distributed/comm/tests/test_comms.py | 5 +- distributed/comm/utils.py | 1 + distributed/comm/zmq.py | 4 +- distributed/core.py | 7 +- distributed/nanny.py | 2 +- distributed/scheduler.py | 10 ++- distributed/utils.py | 2 +- distributed/worker.py | 9 ++- 13 files changed, 150 insertions(+), 129 deletions(-) create mode 100644 distributed/comm/addressing.py diff --git a/distributed/cli/utils.py b/distributed/cli/utils.py index 501092b06dd..336b42f8d79 100644 --- a/distributed/cli/utils.py +++ b/distributed/cli/utils.py @@ -12,8 +12,8 @@ """.strip() -from distributed.comm.core import (parse_address, unparse_address, - parse_host_port, unparse_host_port) +from distributed.comm import (parse_address, unparse_address, + parse_host_port, unparse_host_port) from ..utils import get_ip, ensure_ip diff --git a/distributed/comm/__init__.py b/distributed/comm/__init__.py index b38297d074f..ad5fe3618be 100644 --- a/distributed/comm/__init__.py +++ b/distributed/comm/__init__.py @@ -1,5 +1,9 @@ from __future__ import print_function, division, absolute_import +from .addressing import (parse_address, unparse_address, + normalize_address, parse_host_port, + unparse_host_port, get_address_host_port, + resolve_address) from .core import connect, listen, Comm, CommClosedError diff --git a/distributed/comm/addressing.py b/distributed/comm/addressing.py new file mode 100644 index 00000000000..c72d9d504d7 --- /dev/null +++ b/distributed/comm/addressing.py @@ -0,0 +1,115 @@ +from __future__ import print_function, division, absolute_import + +import six + +from ..config import config +from ..utils import ensure_ip + + +DEFAULT_SCHEME = config.get('default-scheme', 'tcp') + + +def parse_address(addr): + """ + Split address into its scheme and scheme-dependent location string. + """ + if not isinstance(addr, six.string_types): + raise TypeError("expected str, got %r" % addr.__class__.__name__) + scheme, sep, loc = addr.rpartition('://') + if not sep: + scheme = DEFAULT_SCHEME + return scheme, loc + + +def unparse_address(scheme, loc): + """ + Undo parse_address(). + """ + return '%s://%s' % (scheme, loc) + + +def normalize_address(addr): + """ + Canonicalize address, adding a default scheme if necessary. + """ + return unparse_address(*parse_address(addr)) + + +def parse_host_port(address, default_port=None): + """ + Parse an endpoint address given in the form "host:port". + """ + if isinstance(address, tuple): + return address + if address.startswith('tcp:'): + address = address[4:] + + def _fail(): + raise ValueError("invalid address %r" % (address,)) + + def _default(): + if default_port is None: + raise ValueError("missing port number in address %r" % (address,)) + return default_port + + if address.startswith('['): + host, sep, tail = address[1:].partition(']') + if not sep: + _fail() + if not tail: + port = _default() + else: + if not tail.startswith(':'): + _fail() + port = tail[1:] + else: + host, sep, port = address.partition(':') + if not sep: + port = _default() + elif ':' in host: + _fail() + + return host, int(port) + + +def unparse_host_port(host, port=None): + """ + Undo parse_host_port(). + """ + if ':' in host and not host.startswith('['): + host = '[%s]' % host + if port: + return '%s:%s' % (host, port) + else: + return host + + +def get_address_host_port(addr): + """ + Get a (host, port) tuple out of the given address. + """ + scheme, loc = parse_address(addr) + if scheme not in ('tcp', 'zmq'): + raise ValueError("don't know how to extract host and port " + "for address %r" % (addr,)) + return parse_host_port(loc) + + +def resolve_address(addr): + """ + Apply scheme-specific address resolution to *addr*, ensuring + all symbolic references are replaced with concrete location + specifiers. + + In practice, this means hostnames are resolved to IP addresses. + """ + # XXX circular import; reorganize APIs into a distributed.comms.addressing module? + #from ..utils import ensure_ip + scheme, loc = parse_address(addr) + if scheme not in ('tcp', 'zmq'): + return addr + + host, port = parse_host_port(loc) + loc = unparse_host_port(ensure_ip(host), port) + addr = unparse_address(scheme, loc) + return addr diff --git a/distributed/comm/core.py b/distributed/comm/core.py index 61452cd9a4d..23f1ca47ec6 100644 --- a/distributed/comm/core.py +++ b/distributed/comm/core.py @@ -4,13 +4,14 @@ from datetime import timedelta import logging -from six import string_types, with_metaclass +from six import with_metaclass from tornado import gen from tornado.ioloop import IOLoop from ..config import config from ..metrics import time +from .addressing import parse_address logger = logging.getLogger(__name__) @@ -31,9 +32,6 @@ } -DEFAULT_SCHEME = config.get('default-scheme', 'tcp') - - class CommClosedError(IOError): pass @@ -135,112 +133,6 @@ def __exit__(self, *exc): self.stop() -def parse_address(addr): - """ - Split address into its scheme and scheme-dependent location string. - """ - if not isinstance(addr, string_types): - raise TypeError("expected str, got %r" % addr.__class__.__name__) - scheme, sep, loc = addr.rpartition('://') - if not sep: - scheme = DEFAULT_SCHEME - return scheme, loc - - -def unparse_address(scheme, loc): - """ - Undo parse_address(). - """ - return '%s://%s' % (scheme, loc) - - -def parse_host_port(address, default_port=None): - """ - Parse an endpoint address given in the form "host:port". - """ - if isinstance(address, tuple): - return address - if address.startswith('tcp:'): - address = address[4:] - - def _fail(): - raise ValueError("invalid address %r" % (address,)) - - def _default(): - if default_port is None: - raise ValueError("missing port number in address %r" % (address,)) - return default_port - - if address.startswith('['): - host, sep, tail = address[1:].partition(']') - if not sep: - _fail() - if not tail: - port = _default() - else: - if not tail.startswith(':'): - _fail() - port = tail[1:] - else: - host, sep, port = address.partition(':') - if not sep: - port = _default() - elif ':' in host: - _fail() - - return host, int(port) - - -def unparse_host_port(host, port=None): - """ - Undo parse_host_port(). - """ - if ':' in host and not host.startswith('['): - host = '[%s]' % host - if port: - return '%s:%s' % (host, port) - else: - return host - - -def get_address_host_port(addr): - """ - Get a (host, port) tuple out of the given address. - """ - scheme, loc = parse_address(addr) - if scheme not in ('tcp', 'zmq'): - raise ValueError("don't know how to extract host and port " - "for address %r" % (addr,)) - return parse_host_port(loc) - - -def normalize_address(addr): - """ - Canonicalize address, adding a default scheme if necessary. - """ - return unparse_address(*parse_address(addr)) - - -def resolve_address(addr): - """ - Apply scheme-specific address resolution to *addr*, ensuring - all symbolic references are replaced with concrete location - specifiers. - - In practice, this means hostnames are resolved to IP addresses. - """ - # XXX circular import; reorganize APIs into a distributed.comms.addressing module? - from ..utils import ensure_ip - scheme, loc = parse_address(addr) - if scheme not in ('tcp', 'zmq'): - raise ValueError("don't know how to extract host and port " - "for address %r" % (addr,)) - host, port = parse_host_port(loc) - loc = unparse_host_port(ensure_ip(host), port) - addr = unparse_address(scheme, loc) - return addr - - @gen.coroutine def connect(addr, timeout=3, deserialize=True): """ diff --git a/distributed/comm/tcp.py b/distributed/comm/tcp.py index 9b16d772ddb..8d282151c31 100644 --- a/distributed/comm/tcp.py +++ b/distributed/comm/tcp.py @@ -15,8 +15,8 @@ from .. import config from ..compatibility import finalize from ..utils import ensure_bytes -from .core import (connectors, listeners, Comm, Listener, CommClosedError, - parse_host_port, unparse_host_port) +from .addressing import parse_host_port, unparse_host_port +from .core import connectors, listeners, Comm, Listener, CommClosedError from .utils import (to_frames, from_frames, get_tcp_server_address, ensure_concrete_host) diff --git a/distributed/comm/tests/test_comms.py b/distributed/comm/tests/test_comms.py index 8ff8c238672..a3b3f52da29 100644 --- a/distributed/comm/tests/test_comms.py +++ b/distributed/comm/tests/test_comms.py @@ -20,9 +20,8 @@ to_serialize, Serialized, serialize, deserialize) from distributed.comm import (tcp, inproc, connect, listen, CommClosedError, - is_zmq_enabled) -from distributed.comm.core import (parse_address, parse_host_port, - unparse_host_port, resolve_address) + is_zmq_enabled, parse_address, parse_host_port, + unparse_host_port, resolve_address) if is_zmq_enabled(): from distributed.comm import zmq diff --git a/distributed/comm/utils.py b/distributed/comm/utils.py index c15a0660323..c1856ff73e2 100644 --- a/distributed/comm/utils.py +++ b/distributed/comm/utils.py @@ -1,3 +1,4 @@ +from __future__ import print_function, division, absolute_import import logging import socket diff --git a/distributed/comm/zmq.py b/distributed/comm/zmq.py index 6b71869a5a1..e8399df9d08 100644 --- a/distributed/comm/zmq.py +++ b/distributed/comm/zmq.py @@ -13,8 +13,8 @@ from .. import config from ..utils import PY3 -from .core import (connectors, listeners, Comm, CommClosedError, Listener, - parse_host_port, unparse_host_port) +from .addressing import parse_host_port, unparse_host_port +from .core import connectors, listeners, Comm, CommClosedError, Listener from .utils import to_frames, from_frames, ensure_concrete_host from . import zmqimpl diff --git a/distributed/core.py b/distributed/core.py index 15470a8e679..d4be88ce0a0 100644 --- a/distributed/core.py +++ b/distributed/core.py @@ -24,9 +24,10 @@ from tornado.ioloop import IOLoop, PeriodicCallback from tornado.locks import Event -from .comm import connect, listen, CommClosedError -from .comm.core import (parse_address, normalize_address, Comm, - unparse_host_port, get_address_host_port) +from .comm import (connect, listen, CommClosedError, + parse_address, normalize_address, + unparse_host_port, get_address_host_port) +from .comm.core import Comm from .compatibility import PY3, unicode, WINDOWS from .config import config from .metrics import time diff --git a/distributed/nanny.py b/distributed/nanny.py index 4b89431b9cb..ffbb2d3f21e 100644 --- a/distributed/nanny.py +++ b/distributed/nanny.py @@ -14,7 +14,7 @@ from tornado.ioloop import IOLoop from tornado import gen -from .comm.core import get_address_host_port +from .comm import get_address_host_port from .compatibility import JSONDecodeError from .core import Server, rpc, RPCClosed, CommClosedError, coerce_to_address from .metrics import disk_io_counters, net_io_counters, time diff --git a/distributed/scheduler.py b/distributed/scheduler.py index e906d1ca587..119aaa697c6 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -28,8 +28,8 @@ from dask.order import order from .batched import BatchedSend -from .comm.core import (normalize_address, resolve_address, - get_address_host_port, unparse_host_port) +from .comm import (normalize_address, resolve_address, + get_address_host_port, unparse_host_port) from .config import config from .core import (rpc, connect, Server, send_recv, error_message, clean_exception, CommClosedError) @@ -412,7 +412,11 @@ def start(self, addr_or_port=8786, start_queues=True): self.listen(('', addr_or_port)) else: self.listen(addr_or_port) - self.ip, _ = get_address_host_port(self.listen_address) + try: + self.ip, _ = get_address_host_port(self.listen_address) + except ValueError: + # Address scheme does not have a notion of host and port + self.ip = get_ip() # Services listen on all addresses self.start_services() diff --git a/distributed/utils.py b/distributed/utils.py index 2bdd96a0d0f..12d48388deb 100644 --- a/distributed/utils.py +++ b/distributed/utils.py @@ -28,7 +28,6 @@ from .compatibility import Queue, PY3, PY2, get_thread_identity from .config import config -from .comm import CommClosedError logger = logging.getLogger(__name__) @@ -301,6 +300,7 @@ def key_split(s): @contextmanager def log_errors(pdb=False): + from .comm import CommClosedError try: yield except (CommClosedError, gen.Return): diff --git a/distributed/worker.py b/distributed/worker.py index d48c13e43d7..98b1678eac4 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -25,7 +25,7 @@ from tornado.locks import Event from .batched import BatchedSend -from .comm.core import get_address_host_port +from .comm import get_address_host_port from .config import config from .compatibility import reload, unicode from .core import (connect, send_recv, error_message, CommClosedError, @@ -212,13 +212,18 @@ def _start(self, addr_or_port=0): # XXX Factor this out if isinstance(addr_or_port, int): # Default ip is the required one to reach the scheduler + # XXX get local listening address self.ip = get_ip( get_address_host_port(self.scheduler.address)[0] ) self.listen((self.ip, addr_or_port)) else: self.listen(addr_or_port) - self.ip = get_address_host_port(self.address)[0] + try: + self.ip = get_address_host_port(self.address)[0] + except ValueError: + # Address scheme does not have a notion of host and port + self.ip = get_ip() self.name = self.name or self.address # Services listen on all addresses From e1e2c9258f0f5d4cbf2cf3fc527ec235b51a2af9 Mon Sep 17 00:00:00 2001 From: Antoine Pitrou Date: Tue, 21 Feb 2017 17:26:54 +0100 Subject: [PATCH 07/10] Disable zmq test when zmq isn't explicitly enabled --- distributed/comm/tests/test_comms.py | 1 + 1 file changed, 1 insertion(+) diff --git a/distributed/comm/tests/test_comms.py b/distributed/comm/tests/test_comms.py index a3b3f52da29..20464d785c1 100644 --- a/distributed/comm/tests/test_comms.py +++ b/distributed/comm/tests/test_comms.py @@ -760,6 +760,7 @@ def check_out(out_value): def test_tcp_deserialize(): yield check_deserialize('tcp://') +@requires_zmq @gen_test() def test_zmq_deserialize(): yield check_deserialize('zmq://0.0.0.0') From 978a161ee789629c9da0185d10c1bd46d0f70352 Mon Sep 17 00:00:00 2001 From: Antoine Pitrou Date: Tue, 21 Feb 2017 17:57:13 +0100 Subject: [PATCH 08/10] Fix loads() bug with hand-created Serialized object --- distributed/comm/tests/test_comms.py | 7 +------ distributed/protocol/core.py | 7 ++++--- distributed/protocol/tests/test_protocol.py | 22 ++++++++++++++++++++- 3 files changed, 26 insertions(+), 10 deletions(-) diff --git a/distributed/comm/tests/test_comms.py b/distributed/comm/tests/test_comms.py index 20464d785c1..54d9ba2eb5b 100644 --- a/distributed/comm/tests/test_comms.py +++ b/distributed/comm/tests/test_comms.py @@ -689,17 +689,12 @@ def handle_comm(comm): @gen.coroutine def check_deserialize(addr): - # Create a valid Serialized object - # (if using serialize(), it will lack a compression header) - ser = loads(dumps({'x': to_serialize(456)}), deserialize=False)['x'] - assert isinstance(ser, Serialized) - # Test with Serialize and Serialized objects msg = {'op': 'update', 'x': b'abc', 'to_ser': [to_serialize(123)], - 'ser': ser, + 'ser': Serialized(*serialize(456)), } msg_orig = msg.copy() diff --git a/distributed/protocol/core.py b/distributed/protocol/core.py index 77b053542ea..4c11feb56b9 100644 --- a/distributed/protocol/core.py +++ b/distributed/protocol/core.py @@ -55,7 +55,7 @@ def dumps(msg): for key, (head, frames) in data.items(): if 'lengths' not in head: - head['lengths'] = list(map(len, frames)) + head['lengths'] = tuple(map(len, frames)) if 'compression' not in head: frames = frame_split_size(frames) if frames: @@ -70,7 +70,7 @@ def dumps(msg): for key, (head, frames) in pre.items(): if 'lengths' not in head: - head['lengths'] = list(map(len, frames)) + head['lengths'] = tuple(map(len, frames)) head['count'] = len(frames) header['headers'][key] = head header['keys'].append(key) @@ -112,7 +112,8 @@ def loads(frames, deserialize=True): fs = [] if deserialize or key in bytestrings: - fs = decompress(head, fs) + if 'compression' in head: + fs = decompress(head, fs) fs = merge_frames(head, fs) value = _deserialize(head, fs) else: diff --git a/distributed/protocol/tests/test_protocol.py b/distributed/protocol/tests/test_protocol.py index 0a1212cd825..3e31029ba65 100644 --- a/distributed/protocol/tests/test_protocol.py +++ b/distributed/protocol/tests/test_protocol.py @@ -6,7 +6,8 @@ from distributed.protocol import (loads, dumps, msgpack, maybe_compress, to_serialize) -from distributed.protocol.serialize import Serialize, Serialized, deserialize +from distributed.protocol.serialize import (Serialize, Serialized, + serialize, deserialize) from distributed.utils_test import slow @@ -171,3 +172,22 @@ def test_dumps_loads_Serialize(): result3 = loads(frames2) assert result == result3 + + +def test_dumps_loads_Serialized(): + msg = {'x': 1, + 'data': Serialized(*serialize(123)), + } + frames = dumps(msg) + assert len(frames) > 2 + result = loads(frames) + assert result == {'x': 1, 'data': 123} + + result2 = loads(frames, deserialize=False) + assert result2 == msg + + frames2 = dumps(result2) + assert all(map(eq_frames, frames, frames2)) + + result3 = loads(frames2) + assert result == result3 From c55656a81070a90f561c51cfa1b4f1255a42f28c Mon Sep 17 00:00:00 2001 From: Antoine Pitrou Date: Tue, 21 Feb 2017 19:00:24 +0100 Subject: [PATCH 09/10] Move deserialization helper into distributed.protocol --- distributed/comm/inproc.py | 36 ++----------------- distributed/protocol/__init__.py | 3 +- distributed/protocol/serialize.py | 37 ++++++++++++++++++++ distributed/protocol/tests/test_serialize.py | 19 +++++++++- 4 files changed, 59 insertions(+), 36 deletions(-) diff --git a/distributed/comm/inproc.py b/distributed/comm/inproc.py index 7e8854f584f..e22b68d6813 100644 --- a/distributed/comm/inproc.py +++ b/distributed/comm/inproc.py @@ -13,7 +13,7 @@ from tornado.ioloop import IOLoop from ..compatibility import finalize -from ..protocol import deserialize, Serialize, Serialized +from ..protocol import nested_deserialize, Serialize, Serialized from ..utils import get_ip from .core import (connectors, listeners, Comm, Listener, CommClosedError, ) @@ -130,38 +130,6 @@ def peek(self, default=_omitted): raise QueueEmpty -def _maybe_deserialize(msg): - """ - Replace all nested Serialize and Serialized values in *msg* - with their original object. Returns a copy of *msg*. - """ - def replace_inner(x): - if type(x) is dict: - x = x.copy() - for k, v in x.items(): - typ = type(v) - if typ is dict or typ is list: - x[k] = replace_inner(v) - elif typ is Serialize: - x[k] = v.data - elif typ is Serialized: - x[k] = deserialize(v.header, v.frames) - - elif type(x) is list: - x = list(x) - for k, v in enumerate(x): - typ = type(v) - if typ is dict or typ is list: - x[k] = replace_inner(v) - elif typ is Serialize: - x[k] = v.data - elif typ is Serialized: - x[k] = deserialize(v.header, v.frames) - - return x - - return replace_inner(msg) - _EOF = object() @@ -213,7 +181,7 @@ def read(self, deserialize=None): deserialize = deserialize if deserialize is not None else self.deserialize if deserialize: - msg = _maybe_deserialize(msg) + msg = nested_deserialize(msg) raise gen.Return(msg) @gen.coroutine diff --git a/distributed/protocol/__init__.py b/distributed/protocol/__init__.py index a6d152535f5..1239b0d3d4d 100644 --- a/distributed/protocol/__init__.py +++ b/distributed/protocol/__init__.py @@ -4,7 +4,8 @@ from .compression import compressions, default_compression from .core import dumps, loads, maybe_compress, decompress, msgpack -from .serialize import (serialize, deserialize, Serialize, Serialized, +from .serialize import ( + serialize, deserialize, nested_deserialize, Serialize, Serialized, to_serialize, register_serialization, register_serialization_lazy) from ..utils import ignoring diff --git a/distributed/protocol/serialize.py b/distributed/protocol/serialize.py index 3b1a6844063..275dcd3b189 100644 --- a/distributed/protocol/serialize.py +++ b/distributed/protocol/serialize.py @@ -271,6 +271,43 @@ def _extract_serialize(x, ser, path=()): ser[path + (k,)] = v +def nested_deserialize(x): + """ + Replace all Serialize and Serialized values nested in *x* + with the original values. Returns a copy of *x*. + + >>> msg = {'op': 'update', 'data': to_serialize(123)} + >>> nested_deserialize(msg) + {'op': 'update', 'data': 123} + """ + def replace_inner(x): + if type(x) is dict: + x = x.copy() + for k, v in x.items(): + typ = type(v) + if typ is dict or typ is list: + x[k] = replace_inner(v) + elif typ is Serialize: + x[k] = v.data + elif typ is Serialized: + x[k] = deserialize(v.header, v.frames) + + elif type(x) is list: + x = list(x) + for k, v in enumerate(x): + typ = type(v) + if typ is dict or typ is list: + x[k] = replace_inner(v) + elif typ is Serialize: + x[k] = v.data + elif typ is Serialized: + x[k] = deserialize(v.header, v.frames) + + return x + + return replace_inner(x) + + @partial(normalize_token.register, Serialized) def normalize_Serialized(o): diff --git a/distributed/protocol/tests/test_serialize.py b/distributed/protocol/tests/test_serialize.py index e0e5e5a736e..6bdd6ad0f41 100644 --- a/distributed/protocol/tests/test_serialize.py +++ b/distributed/protocol/tests/test_serialize.py @@ -1,5 +1,6 @@ from __future__ import print_function, division, absolute_import +import copy import pickle import numpy as np @@ -7,7 +8,7 @@ from toolz import identity from distributed.protocol import (register_serialization, serialize, - deserialize, Serialize, Serialized, to_serialize) + deserialize, nested_deserialize, Serialize, Serialized, to_serialize) from distributed.protocol import decompress @@ -78,6 +79,22 @@ def test_Serialized(): assert not (s == t) +def test_nested_deserialize(): + x = {'op': 'update', + 'x': [to_serialize(123), to_serialize(456), 789], + 'y': {'a': ['abc', Serialized(*serialize('def'))], + 'b': 'ghi'} + } + x_orig = copy.deepcopy(x) + + assert nested_deserialize(x) == {'op': 'update', + 'x': [123, 456, 789], + 'y': {'a': ['abc', 'def'], + 'b': 'ghi'} + } + assert x == x_orig # x wasn't mutated + + from distributed.utils_test import gen_cluster from dask import delayed From fdda02fa4ce7228bd7ff4d391045f94f7a2f6598 Mon Sep 17 00:00:00 2001 From: Antoine Pitrou Date: Mon, 27 Feb 2017 10:28:42 +0100 Subject: [PATCH 10/10] Add a __hash__ to Serialize --- distributed/protocol/serialize.py | 3 +++ distributed/protocol/tests/test_serialize.py | 2 ++ 2 files changed, 5 insertions(+) diff --git a/distributed/protocol/serialize.py b/distributed/protocol/serialize.py index 275dcd3b189..40662287be2 100644 --- a/distributed/protocol/serialize.py +++ b/distributed/protocol/serialize.py @@ -181,6 +181,9 @@ def __eq__(self, other): def __ne__(self, other): return not (self == other) + def __hash__(self): + return hash(self.data) + to_serialize = Serialize diff --git a/distributed/protocol/tests/test_serialize.py b/distributed/protocol/tests/test_serialize.py index 6bdd6ad0f41..986e29ff7fa 100644 --- a/distributed/protocol/tests/test_serialize.py +++ b/distributed/protocol/tests/test_serialize.py @@ -67,6 +67,8 @@ def test_Serialize(): assert not (s != u) assert s != t assert not (s == t) + assert hash(s) == hash(u) + assert hash(s) != hash(t) # most probably def test_Serialized():