Skip to content

Commit

Permalink
Make connection pool for each threads in client
Browse files Browse the repository at this point in the history
  • Loading branch information
Ananto30 committed Oct 23, 2023
1 parent 042009b commit b197811
Show file tree
Hide file tree
Showing 4 changed files with 188 additions and 96 deletions.
6 changes: 4 additions & 2 deletions benchmarks/local/zero/Dockerfile
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
FROM python:3.9-slim

COPY . .
RUN pip install -r requirements.txt
COPY requirements.txt requirements.txt
RUN pip install -r requirements.txt

COPY . .
60 changes: 60 additions & 0 deletions tests/unit/test_zero_mq_worker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import unittest
from unittest.mock import Mock, patch

import pytest
import zmq

from zero.zero_mq.queue_device.worker import ZeroMQWorker


class TestWorker(unittest.TestCase):
def test_create_zeromq_worker(self):
worker = ZeroMQWorker(1)
self.assertIsInstance(worker, ZeroMQWorker)
self.assertEqual(worker.worker_id, 1)
# self.assertEqual(worker.context, zmq.Context.instance())
# self.assertEqual(worker.socket, worker.context.socket(zmq.DEALER))
self.assertIsNotNone(worker.context)
self.assertIsNotNone(worker.socket)
self.assertEqual(worker.socket.getsockopt(zmq.LINGER), 0)
self.assertEqual(worker.socket.getsockopt(zmq.SNDTIMEO), 2000)
self.assertEqual(worker.socket.getsockopt(zmq.RCVTIMEO), -1)
worker.close()
self.assertEqual(worker.socket.closed, True)
self.assertEqual(worker.context.closed, True)

@pytest.mark.skip(reason="hard to test infinite loop")
def test_listen_timeout(self):
worker = ZeroMQWorker(1)
mock_msg_handler = Mock(return_value=b"response")
with (
patch.object(worker.socket, "connect", return_value=None),
patch.object(worker, "_recv_and_process", side_effect=zmq.error.Again),
):
worker.listen("tcp://example.com:5555", mock_msg_handler)

def test_recv_and_process(self):
worker = ZeroMQWorker(1)
mock_msg_handler = Mock(return_value=b"response")
with (
patch.object(
worker.socket, "recv_multipart", return_value=[b"ident", b"request"]
),
patch.object(worker.socket, "send_multipart", return_value=None),
):
worker._recv_and_process(mock_msg_handler)
mock_msg_handler.assert_called_once()
worker.socket.send_multipart.assert_called_once()
worker.socket.recv_multipart.assert_called_once()

def test_recv_and_process_invalid_message(self):
worker = ZeroMQWorker(1)
mock_msg_handler = Mock(return_value=b"response")
with (
patch.object(worker.socket, "recv_multipart", return_value=[b"ident"]),
patch.object(worker.socket, "send_multipart", return_value=None),
):
worker._recv_and_process(mock_msg_handler)
mock_msg_handler.assert_not_called()
worker.socket.send_multipart.assert_not_called()
worker.socket.recv_multipart.assert_called_once()
182 changes: 105 additions & 77 deletions zero/client_server/client.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
import asyncio
import logging
import threading
from typing import Any, Dict, Optional, Tuple, Type, TypeVar, Union

from zero import config
from zero.encoder import Encoder, get_encoder
from zero.error import MethodNotFoundException, RemoteException, TimeoutException
from zero.utils import util
from zero.zero_mq import AsyncZeroMQClient, ZeroMQClient, get_async_client, get_client
from zero.zero_mq.helpers import zpipe_async

T = TypeVar("T")

Expand All @@ -30,6 +30,8 @@ def __init__(
If the connection is dropped the client might timeout.
But in the next call the connection will be re-established.
For different threads/processes, different connections are created.
Parameters
----------
host: str
Expand All @@ -39,20 +41,23 @@ def __init__(
Port of the ZeroServer.
default_timeout: int
Default timeout for the ZeroClient for all calls.
Default is 2000 ms.
Default timeout for all calls. Default is 2000 ms.
encoder: Optional[Encoder]
Encoder to encode/decode messages from/to client.
Default is msgspec.
If any other encoder is used, the server should use the same encoder.
If any other encoder is used, make sure the server should use the same encoder.
Implement custom encoder by inheriting from `zero.encoder.Encoder`.
"""
self._address = f"tcp://{host}:{port}"
self._default_timeout = default_timeout
self._encoder = encoder or get_encoder(config.ENCODER)

self.zmqc: ZeroMQClient = None # type: ignore
self.client_pool = ZeroMQClientPool(
self._address,
self._default_timeout,
self._encoder,
)

def call(
self,
Expand Down Expand Up @@ -84,8 +89,9 @@ def call(
Returns
-------
Any
T
The return value of the rpc function.
If return_type is set, the response will be parsed to the return_type.
Raises
------
Expand All @@ -101,28 +107,26 @@ def call(
Or zeromq cannot receive the response from the server.
Mainly represents zmq.error.Again exception.
"""
self._ensure_connected()
zmqc = self.client_pool.get()

_timeout = self._default_timeout if timeout is None else timeout

def _poll_data():
if not self.zmqc.poll(_timeout):
if not zmqc.poll(_timeout):
raise TimeoutException(
f"Timeout while sending message at {self._address}"
)

resp_id, resp_data = (
self._encoder.decode(self.zmqc.recv())
self._encoder.decode(zmqc.recv())
if return_type is None
else self._encoder.decode_type(
self.zmqc.recv(), Tuple[str, return_type]
)
else self._encoder.decode_type(zmqc.recv(), Tuple[str, return_type])
)
return resp_id, resp_data

req_id = util.unique_id()
frames = [req_id, rpc_func_name, "" if msg is None else msg]
self.zmqc.send(self._encoder.encode(frames))
zmqc.send(self._encoder.encode(frames))

resp_id, resp_data = None, None
# as the client is synchronous, we know that the response will be available any next poll
Expand All @@ -137,26 +141,7 @@ def _poll_data():
return resp_data # type: ignore

def close(self):
if self.zmqc is not None:
self.zmqc.close()
self.zmqc = None

def _ensure_connected(self):
if self.zmqc is not None:
return

self._init()
self._try_connect()

def _init(self):
self.zmqc = get_client(config.ZEROMQ_PATTERN, self._default_timeout)
self.zmqc.connect(self._address)

def _try_connect(self):
frames = [util.unique_id(), "connect", ""]
self.zmqc.send(self._encoder.encode(frames))
self._encoder.decode(self.zmqc.recv())
logging.info("Connected to server at %s", self._address)
self.client_pool.close()


class AsyncZeroClient:
Expand All @@ -179,6 +164,7 @@ def __init__(
If the connection is dropped the client might timeout.
But in the next call the connection will be re-established.
For different threads/processes, different connections are created.
Parameters
----------
Expand All @@ -189,8 +175,7 @@ def __init__(
Port of the ZeroServer.
default_timeout: int
Default timeout for the AsyncZeroClient for all calls.
Default is 2000 ms.
Default timeout for all calls. Default is 2000 ms.
encoder: Optional[Encoder]
Encoder to encode/decode messages from/to client.
Expand All @@ -203,8 +188,11 @@ def __init__(
self._encoder = encoder or get_encoder(config.ENCODER)
self._resp_map: Dict[str, Any] = {}

self.zmqc: AsyncZeroMQClient = None # type: ignore
self.peer1 = self.peer2 = None
self.client_pool = AsyncZeroMQClientPool(
self._address,
self._default_timeout,
self._encoder,
)

async def call(
self,
Expand Down Expand Up @@ -254,7 +242,7 @@ async def call(
Or zeromq cannot receive the response from the server.
Mainly represents zmq.error.Again exception.
"""
await self._ensure_connected()
zmqc = await self.client_pool.get()

_timeout = self._default_timeout if timeout is None else timeout
expire_at = util.current_time_us() + (_timeout * 1000)
Expand All @@ -264,7 +252,7 @@ async def _poll_data():
# if not await self.zmq_client.poll(_timeout):
# raise TimeoutException(f"Timeout while sending message at {self._address}")

resp = await self.zmqc.recv()
resp = await zmqc.recv()
resp_id, resp_data = (
self._encoder.decode(resp)
if return_type is None
Expand All @@ -277,14 +265,14 @@ async def _poll_data():

req_id = util.unique_id()
frames = [req_id, rpc_func_name, "" if msg is None else msg]
await self.zmqc.send(self._encoder.encode(frames))
await zmqc.send(self._encoder.encode(frames))

# every request poll the data, so whenever a response comes, it will be stored in __resps
# dont need to poll again in the while loop
await _poll_data()

while req_id not in self._resp_map and util.current_time_us() <= expire_at:
# TODO the problem with the pipe is that we can miss some response
# TODO the problem with the zpipe is that we can miss some response
# when we come to this line
# await self.peer2.recv()
await asyncio.sleep(1e-6)
Expand All @@ -301,51 +289,91 @@ async def _poll_data():
return resp_data

def close(self):
if self.zmqc is not None:
self.zmqc.close()
self.zmqc = None
self._resp_map = {}
self.client_pool.close()
self._resp_map = {}

async def _ensure_connected(self):
if self.zmqc is not None:
return

self._init()
await self._try_connect()
def check_response(resp_data):
if isinstance(resp_data, dict):
if exc := resp_data.get("__zerror__function_not_found"):
raise MethodNotFoundException(exc)
if exc := resp_data.get("__zerror__server_exception"):
raise RemoteException(exc)

def _init(self):
self.zmqc = get_async_client(config.ZEROMQ_PATTERN, self._default_timeout)
self.zmqc.connect(self._address)

self._resp_map: Dict[str, Any] = {}
class ZeroMQClientPool:
"""
Connections are based on different threads and processes.
Each time a call is made it tries to get the connection from the pool,
based on the thread/process id.
If the connection is not available, it creates a new connection and stores it in the pool.
"""

self.peer1, self.peer2 = zpipe_async(self.zmqc.context, 10000)
# TODO try to use pipe instead of sleep
# asyncio.create_task(self._poll_data())
__slots__ = ["_pool", "_address", "_timeout", "_encoder"]

async def _try_connect(self):
def __init__(
self, address: str, timeout: int = 2000, encoder: Optional[Encoder] = None
):
self._pool: Dict[int, ZeroMQClient] = {}
self._address = address
self._timeout = timeout
self._encoder = encoder or get_encoder(config.ENCODER)

def get(self) -> ZeroMQClient:
thread_id = threading.get_ident()
if thread_id not in self._pool:
self._pool[thread_id] = get_client(config.ZEROMQ_PATTERN, self._timeout)
self._pool[thread_id].connect(self._address)
self._try_connect_ping(self._pool[thread_id])
return self._pool[thread_id]

def _try_connect_ping(self, client: ZeroMQClient):
frames = [util.unique_id(), "connect", ""]
await self.zmqc.send(self._encoder.encode(frames))
self._encoder.decode(await self.zmqc.recv())
client.send(self._encoder.encode(frames))
self._encoder.decode(client.recv())
logging.info("Connected to server at %s", self._address)

async def _poll_data(self): # pragma: no cover
while True:
try:
if not await self.zmqc.poll(self._default_timeout):
continue
def close(self):
for client in self._pool.values():
client.close()
self._pool = {}

frames = await self.zmqc.recv()
resp_id, data = self._encoder.decode(frames)
self._resp_map[resp_id] = data
await self.peer1.send(b"")
except Exception as exc: # pylint: disable=broad-except
logging.error("Error while polling data: %s", exc)

class AsyncZeroMQClientPool:
"""
Connections are based on different threads and processes.
Each time a call is made it tries to get the connection from the pool,
based on the thread/process id.
If the connection is not available, it creates a new connection and stores it in the pool.
"""

def check_response(resp_data):
if isinstance(resp_data, dict):
if exc := resp_data.get("__zerror__function_not_found"):
raise MethodNotFoundException(exc)
if exc := resp_data.get("__zerror__server_exception"):
raise RemoteException(exc)
__slots__ = ["_pool", "_address", "_timeout", "_encoder"]

def __init__(
self, address: str, timeout: int = 2000, encoder: Optional[Encoder] = None
):
self._pool: Dict[int, AsyncZeroMQClient] = {}
self._address = address
self._timeout = timeout
self._encoder = encoder or get_encoder(config.ENCODER)

async def get(self) -> AsyncZeroMQClient:
thread_id = threading.get_ident()
if thread_id not in self._pool:
self._pool[thread_id] = get_async_client(
config.ZEROMQ_PATTERN, self._timeout
)
self._pool[thread_id].connect(self._address)
await self._try_connect_ping(self._pool[thread_id])
return self._pool[thread_id]

async def _try_connect_ping(self, client: AsyncZeroMQClient):
frames = [util.unique_id(), "connect", ""]
await client.send(self._encoder.encode(frames))
self._encoder.decode(await client.recv())
logging.info("Connected to server at %s", self._address)

def close(self):
for client in self._pool.values():
client.close()
self._pool = {}
Loading

0 comments on commit b197811

Please sign in to comment.