diff --git a/benchmarks/local/zero/Dockerfile b/benchmarks/local/zero/Dockerfile index 521bd59..dcd64fc 100755 --- a/benchmarks/local/zero/Dockerfile +++ b/benchmarks/local/zero/Dockerfile @@ -1,4 +1,6 @@ FROM python:3.9-slim -COPY . . -RUN pip install -r requirements.txt \ No newline at end of file +COPY requirements.txt requirements.txt +RUN pip install -r requirements.txt + +COPY . . \ No newline at end of file diff --git a/zero/client_server/client.py b/zero/client_server/client.py index ff58ace..476e9de 100644 --- a/zero/client_server/client.py +++ b/zero/client_server/client.py @@ -1,5 +1,6 @@ import asyncio import logging +import threading from typing import Any, Dict, Optional, Tuple, Type, TypeVar, Union from zero import config @@ -7,7 +8,6 @@ from zero.error import MethodNotFoundException, RemoteException, TimeoutException from zero.utils import util from zero.zero_mq import AsyncZeroMQClient, ZeroMQClient, get_async_client, get_client -from zero.zero_mq.helpers import zpipe_async T = TypeVar("T") @@ -30,6 +30,8 @@ def __init__( If the connection is dropped the client might timeout. But in the next call the connection will be re-established. + For different threads/processes, different connections are created. + Parameters ---------- host: str @@ -39,20 +41,23 @@ def __init__( Port of the ZeroServer. default_timeout: int - Default timeout for the ZeroClient for all calls. - Default is 2000 ms. + Default timeout for all calls. Default is 2000 ms. encoder: Optional[Encoder] Encoder to encode/decode messages from/to client. Default is msgspec. - If any other encoder is used, the server should use the same encoder. + If any other encoder is used, make sure the server should use the same encoder. Implement custom encoder by inheriting from `zero.encoder.Encoder`. """ self._address = f"tcp://{host}:{port}" self._default_timeout = default_timeout self._encoder = encoder or get_encoder(config.ENCODER) - self.zmqc: ZeroMQClient = None # type: ignore + self.client_pool = ZeroMQClientPool( + self._address, + self._default_timeout, + self._encoder, + ) def call( self, @@ -84,8 +89,9 @@ def call( Returns ------- - Any + T The return value of the rpc function. + If return_type is set, the response will be parsed to the return_type. Raises ------ @@ -101,28 +107,26 @@ def call( Or zeromq cannot receive the response from the server. Mainly represents zmq.error.Again exception. """ - self._ensure_connected() + zmqc = self.client_pool.get() _timeout = self._default_timeout if timeout is None else timeout def _poll_data(): - if not self.zmqc.poll(_timeout): + if not zmqc.poll(_timeout): raise TimeoutException( f"Timeout while sending message at {self._address}" ) resp_id, resp_data = ( - self._encoder.decode(self.zmqc.recv()) + self._encoder.decode(zmqc.recv()) if return_type is None - else self._encoder.decode_type( - self.zmqc.recv(), Tuple[str, return_type] - ) + else self._encoder.decode_type(zmqc.recv(), Tuple[str, return_type]) ) return resp_id, resp_data req_id = util.unique_id() frames = [req_id, rpc_func_name, "" if msg is None else msg] - self.zmqc.send(self._encoder.encode(frames)) + zmqc.send(self._encoder.encode(frames)) resp_id, resp_data = None, None # as the client is synchronous, we know that the response will be available any next poll @@ -137,26 +141,7 @@ def _poll_data(): return resp_data # type: ignore def close(self): - if self.zmqc is not None: - self.zmqc.close() - self.zmqc = None - - def _ensure_connected(self): - if self.zmqc is not None: - return - - self._init() - self._try_connect() - - def _init(self): - self.zmqc = get_client(config.ZEROMQ_PATTERN, self._default_timeout) - self.zmqc.connect(self._address) - - def _try_connect(self): - frames = [util.unique_id(), "connect", ""] - self.zmqc.send(self._encoder.encode(frames)) - self._encoder.decode(self.zmqc.recv()) - logging.info("Connected to server at %s", self._address) + self.client_pool.close() class AsyncZeroClient: @@ -179,6 +164,7 @@ def __init__( If the connection is dropped the client might timeout. But in the next call the connection will be re-established. + For different threads/processes, different connections are created. Parameters ---------- @@ -189,8 +175,7 @@ def __init__( Port of the ZeroServer. default_timeout: int - Default timeout for the AsyncZeroClient for all calls. - Default is 2000 ms. + Default timeout for all calls. Default is 2000 ms. encoder: Optional[Encoder] Encoder to encode/decode messages from/to client. @@ -203,8 +188,11 @@ def __init__( self._encoder = encoder or get_encoder(config.ENCODER) self._resp_map: Dict[str, Any] = {} - self.zmqc: AsyncZeroMQClient = None # type: ignore - self.peer1 = self.peer2 = None + self.client_pool = AsyncZeroMQClientPool( + self._address, + self._default_timeout, + self._encoder, + ) async def call( self, @@ -254,7 +242,7 @@ async def call( Or zeromq cannot receive the response from the server. Mainly represents zmq.error.Again exception. """ - await self._ensure_connected() + zmqc = await self.client_pool.get() _timeout = self._default_timeout if timeout is None else timeout expire_at = util.current_time_us() + (_timeout * 1000) @@ -264,7 +252,7 @@ async def _poll_data(): # if not await self.zmq_client.poll(_timeout): # raise TimeoutException(f"Timeout while sending message at {self._address}") - resp = await self.zmqc.recv() + resp = await zmqc.recv() resp_id, resp_data = ( self._encoder.decode(resp) if return_type is None @@ -277,14 +265,14 @@ async def _poll_data(): req_id = util.unique_id() frames = [req_id, rpc_func_name, "" if msg is None else msg] - await self.zmqc.send(self._encoder.encode(frames)) + await zmqc.send(self._encoder.encode(frames)) # every request poll the data, so whenever a response comes, it will be stored in __resps # dont need to poll again in the while loop await _poll_data() while req_id not in self._resp_map and util.current_time_us() <= expire_at: - # TODO the problem with the pipe is that we can miss some response + # TODO the problem with the zpipe is that we can miss some response # when we come to this line # await self.peer2.recv() await asyncio.sleep(1e-6) @@ -301,51 +289,91 @@ async def _poll_data(): return resp_data def close(self): - if self.zmqc is not None: - self.zmqc.close() - self.zmqc = None - self._resp_map = {} + self.client_pool.close() + self._resp_map = {} - async def _ensure_connected(self): - if self.zmqc is not None: - return - self._init() - await self._try_connect() +def check_response(resp_data): + if isinstance(resp_data, dict): + if exc := resp_data.get("__zerror__function_not_found"): + raise MethodNotFoundException(exc) + if exc := resp_data.get("__zerror__server_exception"): + raise RemoteException(exc) - def _init(self): - self.zmqc = get_async_client(config.ZEROMQ_PATTERN, self._default_timeout) - self.zmqc.connect(self._address) - self._resp_map: Dict[str, Any] = {} +class ZeroMQClientPool: + """ + Connections are based on different threads and processes. + Each time a call is made it tries to get the connection from the pool, + based on the thread/process id. + If the connection is not available, it creates a new connection and stores it in the pool. + """ - self.peer1, self.peer2 = zpipe_async(self.zmqc.context, 10000) - # TODO try to use pipe instead of sleep - # asyncio.create_task(self._poll_data()) + __slots__ = ["_pool", "_address", "_timeout", "_encoder"] - async def _try_connect(self): + def __init__( + self, address: str, timeout: int = 2000, encoder: Optional[Encoder] = None + ): + self._pool: Dict[int, ZeroMQClient] = {} + self._address = address + self._timeout = timeout + self._encoder = encoder or get_encoder(config.ENCODER) + + def get(self) -> ZeroMQClient: + thread_id = threading.get_ident() + if thread_id not in self._pool: + self._pool[thread_id] = get_client(config.ZEROMQ_PATTERN, self._timeout) + self._pool[thread_id].connect(self._address) + self._try_connect_ping(self._pool[thread_id]) + return self._pool[thread_id] + + def _try_connect_ping(self, client: ZeroMQClient): frames = [util.unique_id(), "connect", ""] - await self.zmqc.send(self._encoder.encode(frames)) - self._encoder.decode(await self.zmqc.recv()) + client.send(self._encoder.encode(frames)) + self._encoder.decode(client.recv()) logging.info("Connected to server at %s", self._address) - async def _poll_data(self): # pragma: no cover - while True: - try: - if not await self.zmqc.poll(self._default_timeout): - continue + def close(self): + for client in self._pool.values(): + client.close() + self._pool = {} - frames = await self.zmqc.recv() - resp_id, data = self._encoder.decode(frames) - self._resp_map[resp_id] = data - await self.peer1.send(b"") - except Exception as exc: # pylint: disable=broad-except - logging.error("Error while polling data: %s", exc) +class AsyncZeroMQClientPool: + """ + Connections are based on different threads and processes. + Each time a call is made it tries to get the connection from the pool, + based on the thread/process id. + If the connection is not available, it creates a new connection and stores it in the pool. + """ -def check_response(resp_data): - if isinstance(resp_data, dict): - if exc := resp_data.get("__zerror__function_not_found"): - raise MethodNotFoundException(exc) - if exc := resp_data.get("__zerror__server_exception"): - raise RemoteException(exc) + __slots__ = ["_pool", "_address", "_timeout", "_encoder"] + + def __init__( + self, address: str, timeout: int = 2000, encoder: Optional[Encoder] = None + ): + self._pool: Dict[int, AsyncZeroMQClient] = {} + self._address = address + self._timeout = timeout + self._encoder = encoder or get_encoder(config.ENCODER) + + async def get(self) -> AsyncZeroMQClient: + thread_id = threading.get_ident() + if thread_id not in self._pool: + self._pool[thread_id] = get_async_client( + config.ZEROMQ_PATTERN, self._timeout + ) + self._pool[thread_id].connect(self._address) + await self._try_connect_ping(self._pool[thread_id]) + return self._pool[thread_id] + + async def _try_connect_ping(self, client: AsyncZeroMQClient): + frames = [util.unique_id(), "connect", ""] + await client.send(self._encoder.encode(frames)) + self._encoder.decode(await client.recv()) + logging.info("Connected to server at %s", self._address) + + def close(self): + for client in self._pool.values(): + client.close() + self._pool = {} diff --git a/zero/zero_mq/queue_device/worker.py b/zero/zero_mq/queue_device/worker.py index 97d69c7..526a4d9 100644 --- a/zero/zero_mq/queue_device/worker.py +++ b/zero/zero_mq/queue_device/worker.py @@ -1,9 +1,10 @@ import logging from typing import Callable, Optional -# import zmq.green as zmq import zmq +# import zmq.green as zmq + class ZeroMQWorker: def __init__(self, worker_id: int): @@ -12,12 +13,9 @@ def __init__(self, worker_id: int): self.socket: zmq.Socket = self.context.socket(zmq.DEALER) self.socket.setsockopt(zmq.LINGER, 0) # dont buffer messages - self.socket.setsockopt(zmq.RCVTIMEO, 2000) + # self.socket.setsockopt(zmq.RCVTIMEO, 2000) self.socket.setsockopt(zmq.SNDTIMEO, 2000) - self.poller = zmq.Poller() - self.poller.register(self.socket, zmq.POLLIN) - def listen( self, address: str, msg_handler: Callable[[bytes], Optional[bytes]] ) -> None: @@ -25,16 +23,18 @@ def listen( logging.info("Starting worker %d", self.worker_id) while True: - socks = dict(self.poller.poll(100)) - if self.socket in socks: - frames = self.socket.recv_multipart(flags=zmq.NOBLOCK) - if len(frames) != 2: - logging.error("invalid message received: %s", frames) - continue - - ident, message = frames - response = msg_handler(message) - self.socket.send_multipart([ident, response], zmq.NOBLOCK) + try: + frames = self.socket.recv_multipart() + except zmq.error.Again: + continue + + if len(frames) != 2: + logging.error("invalid message received: %s", frames) + continue + + ident, message = frames + response = msg_handler(message) + self.socket.send_multipart([ident, response], zmq.NOBLOCK) def close(self) -> None: self.socket.close()