Skip to content

Commit

Permalink
Merge pull request #9276 from OpenMined/aziz/syncronous_handler
Browse files Browse the repository at this point in the history
add synchronous message handler
  • Loading branch information
koenvanderveen authored Sep 9, 2024
2 parents 22f3aec + 3c2114a commit 56fd893
Show file tree
Hide file tree
Showing 5 changed files with 86 additions and 8 deletions.
5 changes: 5 additions & 0 deletions packages/syft/src/syft/orchestra.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from .server.enclave import Enclave
from .server.gateway import Gateway
from .server.uvicorn import serve_server
from .service.queue.queue import ConsumerType
from .service.response import SyftInfo
from .types.errors import SyftException
from .util.util import get_random_available_port
Expand Down Expand Up @@ -182,6 +183,7 @@ def deploy_to_python(
log_level: str | int | None = None,
debug: bool = False,
migrate: bool = False,
consumer_type: ConsumerType | None = None,
) -> ServerHandle:
worker_classes = {
ServerType.DATASITE: Datasite,
Expand Down Expand Up @@ -213,6 +215,7 @@ def deploy_to_python(
"debug": debug,
"migrate": migrate,
"deployment_type": deployment_type_enum,
"consumer_type": consumer_type,
}

if port:
Expand Down Expand Up @@ -325,6 +328,7 @@ def launch(
debug: bool = False,
migrate: bool = False,
from_state_folder: str | Path | None = None,
consumer_type: ConsumerType | None = None,
) -> ServerHandle:
if from_state_folder is not None:
with open(f"{from_state_folder}/config.json") as f:
Expand Down Expand Up @@ -373,6 +377,7 @@ def launch(
background_tasks=background_tasks,
debug=debug,
migrate=migrate,
consumer_type=consumer_type,
)
display(
SyftInfo(
Expand Down
16 changes: 13 additions & 3 deletions packages/syft/src/syft/server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
from ..service.queue.base_queue import QueueConsumer
from ..service.queue.base_queue import QueueProducer
from ..service.queue.queue import APICallMessageHandler
from ..service.queue.queue import ConsumerType
from ..service.queue.queue import QueueManager
from ..service.queue.queue_stash import APIEndpointQueueItem
from ..service.queue.queue_stash import ActionQueueItem
Expand Down Expand Up @@ -338,6 +339,7 @@ def __init__(
smtp_host: str | None = None,
association_request_auto_approval: bool = False,
background_tasks: bool = False,
consumer_type: ConsumerType | None = None,
):
# 🟡 TODO 22: change our ENV variable format and default init args to make this
# less horrible or add some convenience functions
Expand Down Expand Up @@ -381,10 +383,15 @@ def __init__(

self.association_request_auto_approval = association_request_auto_approval

consumer_type = (
consumer_type or ConsumerType.Thread
if thread_workers
else ConsumerType.Process
)
self.queue_config = self.create_queue_config(
n_consumers=n_consumers,
create_producer=create_producer,
thread_workers=thread_workers,
consumer_type=consumer_type,
queue_port=queue_port,
queue_config=queue_config,
)
Expand Down Expand Up @@ -578,7 +585,7 @@ def create_queue_config(
self,
n_consumers: int,
create_producer: bool,
thread_workers: bool,
consumer_type: ConsumerType,
queue_port: int | None,
queue_config: QueueConfig | None,
) -> QueueConfig:
Expand All @@ -587,13 +594,14 @@ def create_queue_config(
elif queue_port is not None or n_consumers > 0 or create_producer:
if not create_producer and queue_port is None:
logger.warn("No queue port defined to bind consumers.")

queue_config_ = ZMQQueueConfig(
client_config=ZMQClientConfig(
create_producer=create_producer,
queue_port=queue_port,
n_consumers=n_consumers,
),
thread_workers=thread_workers,
consumer_type=consumer_type,
)
else:
queue_config_ = ZMQQueueConfig()
Expand Down Expand Up @@ -727,6 +735,7 @@ def named(
in_memory_workers: bool = True,
association_request_auto_approval: bool = False,
background_tasks: bool = False,
consumer_type: ConsumerType | None = None,
) -> Server:
uid = get_named_server_uid(name)
name_hash = hashlib.sha256(name.encode("utf8")).digest()
Expand Down Expand Up @@ -757,6 +766,7 @@ def named(
reset=reset,
association_request_auto_approval=association_request_auto_approval,
background_tasks=background_tasks,
consumer_type=consumer_type,
)

def is_root(self, credentials: SyftVerifyKey) -> bool:
Expand Down
16 changes: 13 additions & 3 deletions packages/syft/src/syft/service/queue/queue.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# stdlib
from enum import Enum
import logging
from multiprocessing import Process
import threading
Expand Down Expand Up @@ -35,6 +36,13 @@
logger = logging.getLogger(__name__)


@serializable(canonical_name="WorkerType", version=1)
class ConsumerType(str, Enum):
Thread = "thread"
Process = "process"
Synchronous = "synchronous"


class MonitorThread(threading.Thread):
def __init__(
self,
Expand Down Expand Up @@ -300,17 +308,17 @@ def handle_message(message: bytes, syft_worker_id: UID) -> None:
logger.info(
f"Handling queue item: id={queue_item.id}, method={queue_item.method} "
f"args={queue_item.args}, kwargs={queue_item.kwargs} "
f"service={queue_item.service}, as_thread={queue_config.thread_workers}"
f"service={queue_item.service}, as={queue_config.consumer_type}"
)

if queue_config.thread_workers:
if queue_config.consumer_type == ConsumerType.Thread:
thread = Thread(
target=handle_message_multiprocessing,
args=(worker_settings, queue_item, credentials),
)
thread.start()
thread.join()
else:
elif queue_config.consumer_type == ConsumerType.Process:
# if psutil.pid_exists(job_item.job_pid):
# psutil.Process(job_item.job_pid).terminate()
process = Process(
Expand All @@ -321,3 +329,5 @@ def handle_message(message: bytes, syft_worker_id: UID) -> None:
job_item.job_pid = process.pid
worker.job_stash.set_result(credentials, job_item).unwrap()
process.join()
elif queue_config.consumer_type == ConsumerType.Synchronous:
handle_message_multiprocessing(worker_settings, queue_item, credentials)
6 changes: 4 additions & 2 deletions packages/syft/src/syft/service/queue/zmq_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from .base_queue import QueueClient
from .base_queue import QueueClientConfig
from .base_queue import QueueConfig
from .queue import ConsumerType
from .queue_stash import QueueStash
from .zmq_consumer import ZMQConsumer
from .zmq_producer import ZMQProducer
Expand Down Expand Up @@ -76,6 +77,7 @@ def add_producer(
else:
port = self.config.queue_port

print(f"Adding producer for queue: {queue_name} on: {get_queue_address(port)}")
producer = ZMQProducer(
queue_name=queue_name,
queue_stash=queue_stash,
Expand Down Expand Up @@ -183,8 +185,8 @@ def __init__(
self,
client_type: type[ZMQClient] | None = None,
client_config: ZMQClientConfig | None = None,
thread_workers: bool = False,
consumer_type: ConsumerType = ConsumerType.Process,
):
self.client_type = client_type or ZMQClient
self.client_config: ZMQClientConfig = client_config or ZMQClientConfig()
self.thread_workers = thread_workers
self.consumer_type = consumer_type
51 changes: 51 additions & 0 deletions packages/syft/src/syft/service/queue/zmq_consumer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# stdlib
import logging
import subprocess # nosec
import threading
from threading import Event

Expand Down Expand Up @@ -28,6 +29,26 @@
logger = logging.getLogger(__name__)


def last_created_port() -> int:
command = (
"lsof -i -P -n | grep '*:[0-9]* (LISTEN)' | grep python | awk '{print $9, $1, $2}' | "
"sort -k2,2 -k3,3n | tail -n 1 | awk '{print $1}' | cut -d':' -f2"
)
# 1. Lists open files (including network connections) with lsof -i -P -n
# 2. Filters for listening ports with grep '*:[0-9]* (LISTEN)'
# 3. Further filters for Python processes with grep python
# 4. Sorts based on the 9th field (which is likely the port number) with sort -k9
# 5. Takes the last 10 entries with tail -n 10
# 6. Prints only the 9th field (port and address) with awk '{print $9}'
# 7. Extracts only the port number with cut -d':' -f2

process = subprocess.Popen( # nosec
command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True
)
out, err = process.communicate()
return int(out.decode("utf-8").strip())


@serializable(attrs=["_subscriber"], canonical_name="ZMQConsumer", version=1)
class ZMQConsumer(QueueConsumer):
def __init__(
Expand All @@ -54,6 +75,36 @@ def __init__(
self.worker_stash = worker_stash
self.post_init()

@classmethod
def default(cls, address: str | None = None, **kwargs: dict) -> "ZMQConsumer":
# relative
from ...types.uid import UID
from ..worker.utils import DEFAULT_WORKER_POOL_NAME
from .queue import APICallMessageHandler

if address is None:
try:
address = f"tcp://localhost:{last_created_port()}"
except Exception:
raise Exception(
"Could not auto-assign ZMQConsumer address. Please provide one."
)
print(f"Auto-assigning ZMQConsumer address: {address}. Please verify.")
default_kwargs = {
"message_handler": APICallMessageHandler,
"queue_name": APICallMessageHandler.queue_name,
"service_name": DEFAULT_WORKER_POOL_NAME,
"syft_worker_id": UID(),
"verbose": True,
"address": address,
}

for key, value in kwargs.items():
if key in default_kwargs:
default_kwargs[key] = value

return cls(**default_kwargs)

def reconnect_to_producer(self) -> None:
"""Connect or reconnect to producer"""
if self.socket:
Expand Down

0 comments on commit 56fd893

Please sign in to comment.