diff --git a/packages/syft/src/syft/orchestra.py b/packages/syft/src/syft/orchestra.py index 5921b53b434..0d295b81982 100644 --- a/packages/syft/src/syft/orchestra.py +++ b/packages/syft/src/syft/orchestra.py @@ -28,6 +28,7 @@ from .server.enclave import Enclave from .server.gateway import Gateway from .server.uvicorn import serve_server +from .service.queue.queue import ConsumerType from .service.response import SyftInfo from .types.errors import SyftException from .util.util import get_random_available_port @@ -182,6 +183,7 @@ def deploy_to_python( log_level: str | int | None = None, debug: bool = False, migrate: bool = False, + consumer_type: ConsumerType | None = None, ) -> ServerHandle: worker_classes = { ServerType.DATASITE: Datasite, @@ -213,6 +215,7 @@ def deploy_to_python( "debug": debug, "migrate": migrate, "deployment_type": deployment_type_enum, + "consumer_type": consumer_type, } if port: @@ -325,6 +328,7 @@ def launch( debug: bool = False, migrate: bool = False, from_state_folder: str | Path | None = None, + consumer_type: ConsumerType | None = None, ) -> ServerHandle: if from_state_folder is not None: with open(f"{from_state_folder}/config.json") as f: @@ -373,6 +377,7 @@ def launch( background_tasks=background_tasks, debug=debug, migrate=migrate, + consumer_type=consumer_type, ) display( SyftInfo( diff --git a/packages/syft/src/syft/server/server.py b/packages/syft/src/syft/server/server.py index 102b34ad99a..f9a05ca1279 100644 --- a/packages/syft/src/syft/server/server.py +++ b/packages/syft/src/syft/server/server.py @@ -61,6 +61,7 @@ from ..service.queue.base_queue import QueueConsumer from ..service.queue.base_queue import QueueProducer from ..service.queue.queue import APICallMessageHandler +from ..service.queue.queue import ConsumerType from ..service.queue.queue import QueueManager from ..service.queue.queue_stash import APIEndpointQueueItem from ..service.queue.queue_stash import ActionQueueItem @@ -338,6 +339,7 @@ def __init__( smtp_host: str | None = None, association_request_auto_approval: bool = False, background_tasks: bool = False, + consumer_type: ConsumerType | None = None, ): # 🟡 TODO 22: change our ENV variable format and default init args to make this # less horrible or add some convenience functions @@ -381,10 +383,15 @@ def __init__( self.association_request_auto_approval = association_request_auto_approval + consumer_type = ( + consumer_type or ConsumerType.Thread + if thread_workers + else ConsumerType.Process + ) self.queue_config = self.create_queue_config( n_consumers=n_consumers, create_producer=create_producer, - thread_workers=thread_workers, + consumer_type=consumer_type, queue_port=queue_port, queue_config=queue_config, ) @@ -578,7 +585,7 @@ def create_queue_config( self, n_consumers: int, create_producer: bool, - thread_workers: bool, + consumer_type: ConsumerType, queue_port: int | None, queue_config: QueueConfig | None, ) -> QueueConfig: @@ -587,13 +594,14 @@ def create_queue_config( elif queue_port is not None or n_consumers > 0 or create_producer: if not create_producer and queue_port is None: logger.warn("No queue port defined to bind consumers.") + queue_config_ = ZMQQueueConfig( client_config=ZMQClientConfig( create_producer=create_producer, queue_port=queue_port, n_consumers=n_consumers, ), - thread_workers=thread_workers, + consumer_type=consumer_type, ) else: queue_config_ = ZMQQueueConfig() @@ -727,6 +735,7 @@ def named( in_memory_workers: bool = True, association_request_auto_approval: bool = False, background_tasks: bool = False, + consumer_type: ConsumerType | None = None, ) -> Server: uid = get_named_server_uid(name) name_hash = hashlib.sha256(name.encode("utf8")).digest() @@ -757,6 +766,7 @@ def named( reset=reset, association_request_auto_approval=association_request_auto_approval, background_tasks=background_tasks, + consumer_type=consumer_type, ) def is_root(self, credentials: SyftVerifyKey) -> bool: diff --git a/packages/syft/src/syft/service/queue/queue.py b/packages/syft/src/syft/service/queue/queue.py index da1ded8bd70..aa2e99c6ba4 100644 --- a/packages/syft/src/syft/service/queue/queue.py +++ b/packages/syft/src/syft/service/queue/queue.py @@ -1,4 +1,5 @@ # stdlib +from enum import Enum import logging from multiprocessing import Process import threading @@ -35,6 +36,13 @@ logger = logging.getLogger(__name__) +@serializable(canonical_name="WorkerType", version=1) +class ConsumerType(str, Enum): + Thread = "thread" + Process = "process" + Synchronous = "synchronous" + + class MonitorThread(threading.Thread): def __init__( self, @@ -300,17 +308,17 @@ def handle_message(message: bytes, syft_worker_id: UID) -> None: logger.info( f"Handling queue item: id={queue_item.id}, method={queue_item.method} " f"args={queue_item.args}, kwargs={queue_item.kwargs} " - f"service={queue_item.service}, as_thread={queue_config.thread_workers}" + f"service={queue_item.service}, as={queue_config.consumer_type}" ) - if queue_config.thread_workers: + if queue_config.consumer_type == ConsumerType.Thread: thread = Thread( target=handle_message_multiprocessing, args=(worker_settings, queue_item, credentials), ) thread.start() thread.join() - else: + elif queue_config.consumer_type == ConsumerType.Process: # if psutil.pid_exists(job_item.job_pid): # psutil.Process(job_item.job_pid).terminate() process = Process( @@ -321,3 +329,5 @@ def handle_message(message: bytes, syft_worker_id: UID) -> None: job_item.job_pid = process.pid worker.job_stash.set_result(credentials, job_item).unwrap() process.join() + elif queue_config.consumer_type == ConsumerType.Synchronous: + handle_message_multiprocessing(worker_settings, queue_item, credentials) diff --git a/packages/syft/src/syft/service/queue/zmq_client.py b/packages/syft/src/syft/service/queue/zmq_client.py index deeeb97a32b..9265d9edd3d 100644 --- a/packages/syft/src/syft/service/queue/zmq_client.py +++ b/packages/syft/src/syft/service/queue/zmq_client.py @@ -16,6 +16,7 @@ from .base_queue import QueueClient from .base_queue import QueueClientConfig from .base_queue import QueueConfig +from .queue import ConsumerType from .queue_stash import QueueStash from .zmq_consumer import ZMQConsumer from .zmq_producer import ZMQProducer @@ -76,6 +77,7 @@ def add_producer( else: port = self.config.queue_port + print(f"Adding producer for queue: {queue_name} on: {get_queue_address(port)}") producer = ZMQProducer( queue_name=queue_name, queue_stash=queue_stash, @@ -183,8 +185,8 @@ def __init__( self, client_type: type[ZMQClient] | None = None, client_config: ZMQClientConfig | None = None, - thread_workers: bool = False, + consumer_type: ConsumerType = ConsumerType.Process, ): self.client_type = client_type or ZMQClient self.client_config: ZMQClientConfig = client_config or ZMQClientConfig() - self.thread_workers = thread_workers + self.consumer_type = consumer_type diff --git a/packages/syft/src/syft/service/queue/zmq_consumer.py b/packages/syft/src/syft/service/queue/zmq_consumer.py index 4de8da60494..f6993d6b032 100644 --- a/packages/syft/src/syft/service/queue/zmq_consumer.py +++ b/packages/syft/src/syft/service/queue/zmq_consumer.py @@ -1,5 +1,6 @@ # stdlib import logging +import subprocess # nosec import threading from threading import Event @@ -28,6 +29,26 @@ logger = logging.getLogger(__name__) +def last_created_port() -> int: + command = ( + "lsof -i -P -n | grep '*:[0-9]* (LISTEN)' | grep python | awk '{print $9, $1, $2}' | " + "sort -k2,2 -k3,3n | tail -n 1 | awk '{print $1}' | cut -d':' -f2" + ) + # 1. Lists open files (including network connections) with lsof -i -P -n + # 2. Filters for listening ports with grep '*:[0-9]* (LISTEN)' + # 3. Further filters for Python processes with grep python + # 4. Sorts based on the 9th field (which is likely the port number) with sort -k9 + # 5. Takes the last 10 entries with tail -n 10 + # 6. Prints only the 9th field (port and address) with awk '{print $9}' + # 7. Extracts only the port number with cut -d':' -f2 + + process = subprocess.Popen( # nosec + command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True + ) + out, err = process.communicate() + return int(out.decode("utf-8").strip()) + + @serializable(attrs=["_subscriber"], canonical_name="ZMQConsumer", version=1) class ZMQConsumer(QueueConsumer): def __init__( @@ -54,6 +75,36 @@ def __init__( self.worker_stash = worker_stash self.post_init() + @classmethod + def default(cls, address: str | None = None, **kwargs: dict) -> "ZMQConsumer": + # relative + from ...types.uid import UID + from ..worker.utils import DEFAULT_WORKER_POOL_NAME + from .queue import APICallMessageHandler + + if address is None: + try: + address = f"tcp://localhost:{last_created_port()}" + except Exception: + raise Exception( + "Could not auto-assign ZMQConsumer address. Please provide one." + ) + print(f"Auto-assigning ZMQConsumer address: {address}. Please verify.") + default_kwargs = { + "message_handler": APICallMessageHandler, + "queue_name": APICallMessageHandler.queue_name, + "service_name": DEFAULT_WORKER_POOL_NAME, + "syft_worker_id": UID(), + "verbose": True, + "address": address, + } + + for key, value in kwargs.items(): + if key in default_kwargs: + default_kwargs[key] = value + + return cls(**default_kwargs) + def reconnect_to_producer(self) -> None: """Connect or reconnect to producer""" if self.socket: