Skip to content

Commit

Permalink
Use async utilities to orchestrate pipeline from the CLI. Closes #162 (
Browse files Browse the repository at this point in the history
…#163)

* Use async utilities to orchestrate pipeline from the CLI. Closes #162

* Async/Await Error.

* Shutdown local workers on pipeline completion

* loop correction
  • Loading branch information
umesh-timalsina authored Oct 13, 2023
1 parent cb1f879 commit 1bcafce
Show file tree
Hide file tree
Showing 2 changed files with 91 additions and 71 deletions.
126 changes: 89 additions & 37 deletions chimerapy/orchestrator/cli/__main__.py
Original file line number Diff line number Diff line change
@@ -1,72 +1,124 @@
import asyncio
import json
import sys
import time
from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser
from pathlib import Path
from typing import Iterable
from typing import Dict, Iterable, List, Set

from chimerapy.engine import Manager
from chimerapy.engine.config import set
import tqdm

from chimerapy.engine import Manager, Worker
from chimerapy.engine import config as cpe_config
from chimerapy.engine.utils import async_waiting_for
from chimerapy.orchestrator.models.pipeline_config import (
ChimeraPyPipelineConfig,
)
from chimerapy.orchestrator.orchestrator_config import OrchestratorConfig


def _wait_for_workers(manager: Manager, remote_workers: Iterable[str]):
while True:
if all(
[
remote_worker in manager.workers
for remote_worker in remote_workers
]
):
print("All remote workers connected!")
break
def _check_remote_workers(manager: Manager, remote_workers: Iterable[str]):
return all(
[remote_worker in manager.workers for remote_worker in remote_workers]
)


def orchestrate(config: ChimeraPyPipelineConfig):
manager, pipeline, mappings, remote_workers = config.pipeline_graph()
async def _connect_workers(
manager: Manager, config: ChimeraPyPipelineConfig
) -> Set[Worker]:
# Create Local Workers and Connect
remote_workers = set()
local_workers = set()
for wc in config.workers.instances:
if not wc.remote:
w = Worker(name=wc.name, id=wc.id, port=0, delete_temp=True)
await w.aserve()
await w.async_connect(method="zeroconf", timeout=20)
local_workers.add(w)
else:
remote_workers.add(wc.id)

# Wait until workers connect
_wait_for_workers(manager, remote_workers)

# Commit the graph
manager.commit_graph(graph=pipeline, mapping=mappings).result(
timeout=config.timeouts.commit_timeout
print("Waiting for workers to connect...")
await async_waiting_for(
lambda: _check_remote_workers(manager, remote_workers),
)
print("All remote workers connected!")
return local_workers

if config.mode == "preview":
manager.start().result(timeout=config.timeouts.preview_timeout)

def _get_mappings(
config: ChimeraPyPipelineConfig, created_nodes: Dict
) -> Dict[str, List[str]]:
mp = {}
for worker_id in config.mappings:
if mp.get(worker_id) is None:
mp[worker_id] = []

for node_name in config.mappings[worker_id]:
mp[worker_id].append(created_nodes[node_name].id)
return mp


async def _pipeline_preview(manager: Manager) -> None:
await manager.async_start()

# Wait until user stops
while True:
q = input("Ready to start? (Y/n)")
if q.lower() == "y":
break

if config.mode == "record":
manager.start().result(timeout=config.timeouts.preview_timeout)
await manager.async_record()

manager.record().result(timeout=config.timeouts.record_timeout)

# Wait until user stops
async def _pipeline_record(manager: Manager) -> None:
while True:
q = input("Ready to start? (Y/n)")
if q.lower() == "y":
break

await manager.async_start()
await manager.async_record()


async def aorchestrate(config: ChimeraPyPipelineConfig) -> None:
"""Orchestrate the pipeline."""
pipeline, created_nodes = config.get_cp_graph_map()
manager = config.instantiate_manager()

await manager.aserve()
await manager.async_zeroconf(enable=True)

local_workers = await _connect_workers(manager, config)
mappings = _get_mappings(config, created_nodes)

# Commit the graph
await manager.async_commit(graph=pipeline, mapping=mappings)

if config.mode == "preview":
await _pipeline_preview(manager)
else:
await _pipeline_record(manager)

if config.runtime is None:
while True:
q = input("Stop? (Y/n)")
if q.lower() == "y":
break
else: # Wait for runtime to elapse
start_time = time.time()
elapsed_time = time.time() - start_time
while elapsed_time < config.runtime:
elapsed_time = time.time() - start_time
else:
for _ in tqdm.tqdm(range(config.runtime), desc="Running..."):
await asyncio.sleep(1)

manager.stop().result(timeout=config.timeouts.stop_timeout)
manager.collect().result(timeout=config.timeouts.collect_timeout)
await manager.async_stop()
await manager.async_collect()
cpe_config.set(
"manager.timeout.worker-shutdown", config.timeouts.shutdown_timeout
)

set("manager.timeout.worker-shutdown", config.timeouts.shutdown_timeout)
manager.shutdown(blocking=True)
await manager.async_shutdown()
print("Shutting down local workers...")
for worker in local_workers:
await worker.async_shutdown()


def orchestrate_worker(
Expand Down Expand Up @@ -222,7 +274,7 @@ def run(args=None):
if args.subcommand == "orchestrate":
if args.mode and cp_config.mode != args.mode:
cp_config.mode = args.mode
orchestrate(cp_config)
asyncio.run(aorchestrate(cp_config))

elif args.subcommand == "orchestrate-worker":
orchestrate_worker(cp_config, args.worker_id, args.timeout)
Expand Down
36 changes: 2 additions & 34 deletions chimerapy/orchestrator/models/pipeline_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
List,
Literal,
Optional,
Set,
Tuple,
Type,
)
Expand Down Expand Up @@ -155,7 +154,6 @@ def instantiate_manager(self) -> cpe.Manager:
mode="python", exclude={"zeroconf"}
)
)
m.zeroconf(enable=self.manager_config.zeroconf)
return m

def get_registered_node(
Expand All @@ -164,9 +162,7 @@ def get_registered_node(
wrapped_node = get_registered_node(name, package)
return wrapped_node

def pipeline_graph(
self,
) -> Tuple[cpe.Manager, cpe.Graph, Dict[str, List[str]], Set[str]]:
def get_cp_graph_map(self) -> Tuple[cpe.Graph, Dict[str, cpe.Node]]:
created_nodes = {}

for node_config in self.nodes:
Expand All @@ -186,35 +182,7 @@ def pipeline_graph(
for edge in edges:
pipeline.add_edge(*edge)

workers = {}
remote_workers = set()
for wc in self.workers.instances:
if not wc.remote:
wo = cpe.Worker(name=wc.name, id=wc.id, port=0)
workers[wo.name] = wo
else:
remote_workers.add(wc.id)

manager = self.instantiate_manager()

[
w.connect(host=manager.host, port=manager.port)
for w in workers.values()
]

mp = {}
for worker in self.mappings:
try:
mp[workers[worker].id] = [
created_nodes[node_name].id
for node_name in self.mappings[worker]
]
except KeyError:
mp[worker] = [
created_nodes[node_name].id
for node_name in self.mappings[worker]
]
return manager, pipeline, mp, remote_workers
return pipeline, created_nodes

def instantiate_remote_worker(self, worker_id) -> cpe.Worker:
for wc in self.workers.instances:
Expand Down

0 comments on commit 1bcafce

Please sign in to comment.