diff --git a/src/AMSWorkflow/ams/deploy_tools.py b/src/AMSWorkflow/ams/deploy_tools.py new file mode 100644 index 00000000..f8190705 --- /dev/null +++ b/src/AMSWorkflow/ams/deploy_tools.py @@ -0,0 +1,71 @@ +import sys +import os +import select +import subprocess as sp +from enum import Enum + + +class RootSched(Enum): + SLURM = 1 + LSF = 2 + + +def _run_daemon(cmd, shell=False): + print(f"Going to run {cmd}") + proc = sp.Popen(cmd, shell=shell, stdout=sp.PIPE, stderr=sp.PIPE, bufsize=1, text=True, universal_newlines=True) + return proc + + +def _read_flux_uri(proc, timeout=5): + """ + Reads the first line from the flux start command's stdout and puts it into a queue. + :param timeout: The maximum of time we wait for writting to stdout + :param proc: The process from which to read stdout. + """ + + # Time to wait for I/O plus the time already waited + total_wait_time = 0 + poll_interval = 0.5 # Poll interval in seconds + + while total_wait_time < timeout: + # Check if there is data to read from stdout + ready_to_read = select.select([proc.stdout], [], [], poll_interval)[0] + if ready_to_read: + first_line = proc.stdout.readline() + print("Frist line is", first_line) + if "ssh" in first_line: + return first_line + total_wait_time += poll_interval + print(f"Waited for {total_wait_time}") + return None + + +def spawn_rmq_broker(flux_uri): + # TODO We need to implement this, my current specification is limited + # We probably need to access to flux, to spawn a daemon inside the flux allocation + raise NotImplementedError("spawn_rmq_broker is not implemented, spawn it manually and provide the credentials") + return None, None + + +def start_flux(scheduler, nnodes=None): + def bootstrap_with_slurm(nnodes): + if nnodes is None: + nnodes = os.environ.get("SLURM_NNODES", None) + + bootstrap_cmd = f"srun -N {nnodes} -n {nnodes} --pty --mpi=none --mpibind=off flux start" + flux_get_uri_cmd = "flux uri --remote \\$FLUX_URI; sleep inf" + + daemon = _run_daemon(f'{bootstrap_cmd} "{flux_get_uri_cmd}"', shell=True) + flux_uri = _read_flux_uri(daemon, timeout=10) + print("Got flux uri: ", flux_uri) + if flux_uri is None: + print("Fatal Error, Cannot read flux") + daemon.terminate() + raise RuntimeError("Cannot Get FLUX URI") + + return daemon, flux_uri + + if scheduler == RootSched.SLURM: + return bootstrap_with_slurm(nnodes) + + raise NotImplementedError("We are only supporting bootstrap through SLURM") diff --git a/src/AMSWorkflow/ams/job_types.py b/src/AMSWorkflow/ams/job_types.py new file mode 100644 index 00000000..1bcd8067 --- /dev/null +++ b/src/AMSWorkflow/ams/job_types.py @@ -0,0 +1,163 @@ +from dataclasses import dataclass +from pathlib import Path +import os +import sys +import shutil +from warnings import warn +from typing import List, Dict, Optional, ClassVar +from flux.job import JobspecV1 +import flux.job as fjob + +from ams.loader import load_class + + +@dataclass(kw_only=True) +class BaseJob: + """ + Class Modeling a Job scheduled by AMS. There can be five types of JOBs (Physics, Stagers, Training, RMQServer and TrainingDispatcher) + """ + + name: str + executable: str + nodes: int + tasks_per_node: int + args: List[str] = list() + exclusive: bool = True + cores_per_task: int = 1 + environ: Dict[str, str] = dict() + orderId: ClassVar[int] = 0 + gpus_per_task: Optional[int] = None + stdout: Optional[str] = None + stderr: Optional[str] = None + + def _construct_command(self): + command = [self.executable] + self.args + return command + + def _construct_environ(self, forward_environ): + environ = self.environ + if forward_environ is not None: + if not isinstance(forward_environ, type(os.environ)) and not isinstance(forward_environ, dict): + raise TypeError(f"Unsupported forward_environ type ({type(forward_environ)})") + for k, v in forward_environ: + if k in environ: + warn(f"Key {k} already exists in environment ({environ[k]}), prioritizing existing one ({v})") + else: + environ[k] = forward_environ[k] + return environ + + def _construct_redirect_paths(self, redirectDir): + stdDir = Path.cwd() + if redirectDir is not None: + stdDir = Path(redirectDir) + + if self.stdout is None: + stdout = f"{stdDir}/{self.name}_{BaseJob.orderId}.out" + else: + stdout = f"{stdDir}/{self.stdout}_{BaseJob.orderId}.out" + + if self.stderr is None: + stderr = f"{stdDir}/{self.name}_{BaseJob.orderId}.err" + else: + stderr = f"{stdDir}/{self.stderr}_{BaseJob.orderId}.err" + + BaseJob.orderId += 1 + + return stdout, stderr + + def schedule(self, flux_handle, forward_environ=None, redirectDir=None, pre_signed=False, waitable=True): + jobspec = JobspecV1.from_command( + command=self._construct_command(), + num_tasks=self.tasks_per_node * self.nodes, + num_nodes=self.nodes, + cores_per_task=self.cores_per_task, + gpus_per_task=self.gpus_per_task, + exclusive=self.exclusive, + ) + + stdout, stderr = self._construct_redirect_paths(redirectDir) + environ = self._construct_environ(forward_environ) + jobspec.environment = environ + jobspec.stdout = stdout + jobspec.stderr = stderr + + return jobspec, fjob.submit(flux_handle, jobspec, pre_signed=pre_signed, waitable=waitable) + + +@dataclass(kw_only=True) +class PhysicsJob(BaseJob): + def _verify(self): + is_executable = shutil.which(self.executable) is not None + is_path = Path(self.executable).is_file() + return is_executable or is_path + + def __post_init__(self): + if not self._verify(): + raise RuntimeError( + f"[PhysicsJob] executable is neither a executable nor a system command {self.executable}" + ) + + +@dataclass(kw_only=True, init=False) +class Stager(BaseJob): + def _get_stager_default_cores(self): + """ + We need the following cores: + 1 RMQ Client to receive messages + 1 Process to store to filesystem + 1 Process to make public to kosh + """ + return 3 + + def _verify(self, pruner_path, pruner_cls): + assert Path(pruner_path).is_file(), "Path to Pruner class should exist" + user_class = load_class(pruner_path, pruner_cls) + print(f"Loaded Pruner Class {user_class.__name__}") + + def __init__( + self, + name: str, + num_cores: int, + db_path: str, + pruner_cls: str, + pruner_path: str, + pruner_args: List[str], + num_gpus: Optional[int], + **kwargs, + ): + executable = sys.executable + + self._verify(pruner_path, pruner_cls) + + # TODO: Here we are accessing both the stager arguments and the pruner_arguments. Is is an oppotunity to emit + # an early error message. But, this would require extending argparse or something else. Noting for future reference + cli_arguments = [ + "-m", + "ams_wf.AMSDBStage", + "-db", + db_path, + "--policy", + "process", + "--dest", + str(Path(db_path) / Path("candidates")), + "--db-type", + "dhdf5", + "--store", + "-m", + "fs", + "--class", + pruner_cls, + ] + cli_arguments += pruner_args + + num_cores = self._get_stager_default_cores() + num_cores + super().__init__( + name=name, + executable=executable, + nodes=1, + tasks_per_node=1, + cores_per_task=num_cores, + args=cli_arguments, + gpus_per_task=num_gpus, + **kwargs, + ) diff --git a/src/AMSWorkflow/ams/rmq_async.py b/src/AMSWorkflow/ams/rmq_async.py index 54a76e56..36f24610 100644 --- a/src/AMSWorkflow/ams/rmq_async.py +++ b/src/AMSWorkflow/ams/rmq_async.py @@ -402,7 +402,9 @@ def stop(self): print("Already closed?") -def broker_running(credentials, cacert): +def broker_status(credentials, cacert): + print(credentials) + print(cacert) ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1_2) ssl_context.verify_mode = ssl.CERT_REQUIRED ssl_context.load_verify_locations(cacert) @@ -410,9 +412,9 @@ def broker_running(credentials, cacert): pika_credentials = pika.PlainCredentials(credentials["rabbitmq-user"], credentials["rabbitmq-password"]) parameters = pika.ConnectionParameters( - host=credentials["server"], - port=credentials["port"], - virtual_host=credentials["virtual_host"], + host=credentials["service-host"], + port=credentials["service-port"], + virtual_host=credentials["rabbitmq-vhost"], credentials=pika_credentials, ssl_options=pika.SSLOptions(ssl_context), ) diff --git a/src/AMSWorkflow/ams_wf/AMSDeploy.py b/src/AMSWorkflow/ams_wf/AMSDeploy.py new file mode 100644 index 00000000..a5a13f04 --- /dev/null +++ b/src/AMSWorkflow/ams_wf/AMSDeploy.py @@ -0,0 +1,71 @@ +import argparse +import logging +import sys +import os +import json +from urllib import parse + +from ams.deploy_tools import spawn_rmq_broker +from ams.deploy_tools import RootSched +from ams.deploy_tools import start_flux +from ams.rmq_async import broker_status + +logger = logging.getLogger(__name__) + + +def get_rmq_credentials(flux_uri, rmq_creds, rmq_cert): + if rmq_creds is None: + # TODO Overhere we need to spawn our own server + rmq_creds, rmq_cert = spawn_rmq_broker(flux_uri) + with open(rmq_creds, "r") as fd: + rmq_creds = json.load(fd) + + return rmq_creds, rmq_cert + + +def main(): + parser = argparse.ArgumentParser(description="AMS workflow deployment") + + parser.add_argument("--rmq-creds", help="Credentials file (JSON)") + parser.add_argument("--rmq-cert", help="TLS certificate file") + parser.add_argument("--flux-uri", help="Flux uri of an already existing allocation") + parser.add_argument("--nnodes", help="Number of nnodes to use for this AMS Deployment") + parser.add_argument( + "--root-scheduler", + dest="scheduler", + choices=[e.name for e in RootSched], + help="The provided scheduler of the cluster", + ) + + args = parser.parse_args() + + """ + Verify System is on a "Valid" Status + """ + + if args.flux_uri is None and args.scheduler is None: + print("Please provide either a flux URI handle to connect to or provide the base job scheduler") + sys.exit() + + flux_process = None + flux_uri = args.flux_uri + if flux_uri is None: + flux_process, flux_uri = start_flux(RootSched[args.scheduler], args.nnodes) + + rmq_creds, rmq_cert = get_rmq_credentials(flux_uri, args.rmq_creds, args.rmq_cert) + + if not broker_status(rmq_creds, rmq_cert): + # If we created a subprocess in the background to run flux, we should terminate it + if flux_process is not None: + flux_process.terminate() + print("RMQ Broker is not connected, exiting ...") + sys.exit() + + """ + We Have FLUX URI and here we know that rmq_creds, and rmq_cert are valid and we can start + scheduling jobs + """ + + +if __name__ == "__main__": + main()