diff --git a/chimerapy/orchestrator/cli/__main__.py b/chimerapy/orchestrator/cli/__main__.py index 0378740..658cb28 100644 --- a/chimerapy/orchestrator/cli/__main__.py +++ b/chimerapy/orchestrator/cli/__main__.py @@ -1,43 +1,66 @@ +import asyncio import json import sys -import time from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser from pathlib import Path -from typing import Iterable +from typing import Dict, Iterable, List, Set -from chimerapy.engine import Manager -from chimerapy.engine.config import set +import tqdm + +from chimerapy.engine import Manager, Worker +from chimerapy.engine import config as cpe_config +from chimerapy.engine.utils import async_waiting_for from chimerapy.orchestrator.models.pipeline_config import ( ChimeraPyPipelineConfig, ) from chimerapy.orchestrator.orchestrator_config import OrchestratorConfig -def _wait_for_workers(manager: Manager, remote_workers: Iterable[str]): - while True: - if all( - [ - remote_worker in manager.workers - for remote_worker in remote_workers - ] - ): - print("All remote workers connected!") - break +def _check_remote_workers(manager: Manager, remote_workers: Iterable[str]): + return all( + [remote_worker in manager.workers for remote_worker in remote_workers] + ) -def orchestrate(config: ChimeraPyPipelineConfig): - manager, pipeline, mappings, remote_workers = config.pipeline_graph() +async def _connect_workers( + manager: Manager, config: ChimeraPyPipelineConfig +) -> Set[Worker]: + # Create Local Workers and Connect + remote_workers = set() + local_workers = set() + for wc in config.workers.instances: + if not wc.remote: + w = Worker(name=wc.name, id=wc.id, port=0, delete_temp=True) + await w.aserve() + await w.async_connect(method="zeroconf", timeout=20) + local_workers.add(w) + else: + remote_workers.add(wc.id) # Wait until workers connect - _wait_for_workers(manager, remote_workers) - - # Commit the graph - manager.commit_graph(graph=pipeline, mapping=mappings).result( - timeout=config.timeouts.commit_timeout + print("Waiting for workers to connect...") + await async_waiting_for( + lambda: _check_remote_workers(manager, remote_workers), ) + print("All remote workers connected!") + return local_workers - if config.mode == "preview": - manager.start().result(timeout=config.timeouts.preview_timeout) + +def _get_mappings( + config: ChimeraPyPipelineConfig, created_nodes: Dict +) -> Dict[str, List[str]]: + mp = {} + for worker_id in config.mappings: + if mp.get(worker_id) is None: + mp[worker_id] = [] + + for node_name in config.mappings[worker_id]: + mp[worker_id].append(created_nodes[node_name].id) + return mp + + +async def _pipeline_preview(manager: Manager) -> None: + await manager.async_start() # Wait until user stops while True: @@ -45,28 +68,57 @@ def orchestrate(config: ChimeraPyPipelineConfig): if q.lower() == "y": break - if config.mode == "record": - manager.start().result(timeout=config.timeouts.preview_timeout) + await manager.async_record() - manager.record().result(timeout=config.timeouts.record_timeout) - # Wait until user stops +async def _pipeline_record(manager: Manager) -> None: + while True: + q = input("Ready to start? (Y/n)") + if q.lower() == "y": + break + + await manager.async_start() + await manager.async_record() + + +async def aorchestrate(config: ChimeraPyPipelineConfig) -> None: + """Orchestrate the pipeline.""" + pipeline, created_nodes = config.get_cp_graph_map() + manager = config.instantiate_manager() + + await manager.aserve() + await manager.async_zeroconf(enable=True) + + local_workers = await _connect_workers(manager, config) + mappings = _get_mappings(config, created_nodes) + + # Commit the graph + await manager.async_commit(graph=pipeline, mapping=mappings) + + if config.mode == "preview": + await _pipeline_preview(manager) + else: + await _pipeline_record(manager) + if config.runtime is None: while True: q = input("Stop? (Y/n)") if q.lower() == "y": break - else: # Wait for runtime to elapse - start_time = time.time() - elapsed_time = time.time() - start_time - while elapsed_time < config.runtime: - elapsed_time = time.time() - start_time + else: + for _ in tqdm.tqdm(range(config.runtime), desc="Running..."): + await asyncio.sleep(1) - manager.stop().result(timeout=config.timeouts.stop_timeout) - manager.collect().result(timeout=config.timeouts.collect_timeout) + await manager.async_stop() + await manager.async_collect() + cpe_config.set( + "manager.timeout.worker-shutdown", config.timeouts.shutdown_timeout + ) - set("manager.timeout.worker-shutdown", config.timeouts.shutdown_timeout) - manager.shutdown(blocking=True) + await manager.async_shutdown() + print("Shutting down local workers...") + for worker in local_workers: + await worker.async_shutdown() def orchestrate_worker( @@ -222,7 +274,7 @@ def run(args=None): if args.subcommand == "orchestrate": if args.mode and cp_config.mode != args.mode: cp_config.mode = args.mode - orchestrate(cp_config) + asyncio.run(aorchestrate(cp_config)) elif args.subcommand == "orchestrate-worker": orchestrate_worker(cp_config, args.worker_id, args.timeout) diff --git a/chimerapy/orchestrator/models/pipeline_config.py b/chimerapy/orchestrator/models/pipeline_config.py index ce02f00..f7d8c4a 100644 --- a/chimerapy/orchestrator/models/pipeline_config.py +++ b/chimerapy/orchestrator/models/pipeline_config.py @@ -6,7 +6,6 @@ List, Literal, Optional, - Set, Tuple, Type, ) @@ -155,7 +154,6 @@ def instantiate_manager(self) -> cpe.Manager: mode="python", exclude={"zeroconf"} ) ) - m.zeroconf(enable=self.manager_config.zeroconf) return m def get_registered_node( @@ -164,9 +162,7 @@ def get_registered_node( wrapped_node = get_registered_node(name, package) return wrapped_node - def pipeline_graph( - self, - ) -> Tuple[cpe.Manager, cpe.Graph, Dict[str, List[str]], Set[str]]: + def get_cp_graph_map(self) -> Tuple[cpe.Graph, Dict[str, cpe.Node]]: created_nodes = {} for node_config in self.nodes: @@ -186,35 +182,7 @@ def pipeline_graph( for edge in edges: pipeline.add_edge(*edge) - workers = {} - remote_workers = set() - for wc in self.workers.instances: - if not wc.remote: - wo = cpe.Worker(name=wc.name, id=wc.id, port=0) - workers[wo.name] = wo - else: - remote_workers.add(wc.id) - - manager = self.instantiate_manager() - - [ - w.connect(host=manager.host, port=manager.port) - for w in workers.values() - ] - - mp = {} - for worker in self.mappings: - try: - mp[workers[worker].id] = [ - created_nodes[node_name].id - for node_name in self.mappings[worker] - ] - except KeyError: - mp[worker] = [ - created_nodes[node_name].id - for node_name in self.mappings[worker] - ] - return manager, pipeline, mp, remote_workers + return pipeline, created_nodes def instantiate_remote_worker(self, worker_id) -> cpe.Worker: for wc in self.workers.instances: