-
Notifications
You must be signed in to change notification settings - Fork 8
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
4 changed files
with
311 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
import sys | ||
import os | ||
import select | ||
import subprocess as sp | ||
from enum import Enum | ||
|
||
|
||
class RootSched(Enum): | ||
SLURM = 1 | ||
LSF = 2 | ||
|
||
|
||
def _run_daemon(cmd, shell=False): | ||
print(f"Going to run {cmd}") | ||
proc = sp.Popen(cmd, shell=shell, stdout=sp.PIPE, stderr=sp.PIPE, bufsize=1, text=True, universal_newlines=True) | ||
return proc | ||
|
||
|
||
def _read_flux_uri(proc, timeout=5): | ||
""" | ||
Reads the first line from the flux start command's stdout and puts it into a queue. | ||
:param timeout: The maximum of time we wait for writting to stdout | ||
:param proc: The process from which to read stdout. | ||
""" | ||
|
||
# Time to wait for I/O plus the time already waited | ||
total_wait_time = 0 | ||
poll_interval = 0.5 # Poll interval in seconds | ||
|
||
while total_wait_time < timeout: | ||
# Check if there is data to read from stdout | ||
ready_to_read = select.select([proc.stdout], [], [], poll_interval)[0] | ||
if ready_to_read: | ||
first_line = proc.stdout.readline() | ||
print("Frist line is", first_line) | ||
if "ssh" in first_line: | ||
return first_line | ||
total_wait_time += poll_interval | ||
print(f"Waited for {total_wait_time}") | ||
return None | ||
|
||
|
||
def spawn_rmq_broker(flux_uri): | ||
# TODO We need to implement this, my current specification is limited | ||
# We probably need to access to flux, to spawn a daemon inside the flux allocation | ||
raise NotImplementedError("spawn_rmq_broker is not implemented, spawn it manually and provide the credentials") | ||
return None, None | ||
|
||
|
||
def start_flux(scheduler, nnodes=None): | ||
def bootstrap_with_slurm(nnodes): | ||
if nnodes is None: | ||
nnodes = os.environ.get("SLURM_NNODES", None) | ||
|
||
bootstrap_cmd = f"srun -N {nnodes} -n {nnodes} --pty --mpi=none --mpibind=off flux start" | ||
flux_get_uri_cmd = "flux uri --remote \\$FLUX_URI; sleep inf" | ||
|
||
daemon = _run_daemon(f'{bootstrap_cmd} "{flux_get_uri_cmd}"', shell=True) | ||
flux_uri = _read_flux_uri(daemon, timeout=10) | ||
print("Got flux uri: ", flux_uri) | ||
if flux_uri is None: | ||
print("Fatal Error, Cannot read flux") | ||
daemon.terminate() | ||
raise RuntimeError("Cannot Get FLUX URI") | ||
|
||
return daemon, flux_uri | ||
|
||
if scheduler == RootSched.SLURM: | ||
return bootstrap_with_slurm(nnodes) | ||
|
||
raise NotImplementedError("We are only supporting bootstrap through SLURM") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,163 @@ | ||
from dataclasses import dataclass | ||
from pathlib import Path | ||
import os | ||
import sys | ||
import shutil | ||
from warnings import warn | ||
from typing import List, Dict, Optional, ClassVar | ||
from flux.job import JobspecV1 | ||
import flux.job as fjob | ||
|
||
from ams.loader import load_class | ||
|
||
|
||
@dataclass(kw_only=True) | ||
class BaseJob: | ||
""" | ||
Class Modeling a Job scheduled by AMS. There can be five types of JOBs (Physics, Stagers, Training, RMQServer and TrainingDispatcher) | ||
""" | ||
|
||
name: str | ||
executable: str | ||
nodes: int | ||
tasks_per_node: int | ||
args: List[str] = list() | ||
exclusive: bool = True | ||
cores_per_task: int = 1 | ||
environ: Dict[str, str] = dict() | ||
orderId: ClassVar[int] = 0 | ||
gpus_per_task: Optional[int] = None | ||
stdout: Optional[str] = None | ||
stderr: Optional[str] = None | ||
|
||
def _construct_command(self): | ||
command = [self.executable] + self.args | ||
return command | ||
|
||
def _construct_environ(self, forward_environ): | ||
environ = self.environ | ||
if forward_environ is not None: | ||
if not isinstance(forward_environ, type(os.environ)) and not isinstance(forward_environ, dict): | ||
raise TypeError(f"Unsupported forward_environ type ({type(forward_environ)})") | ||
for k, v in forward_environ: | ||
if k in environ: | ||
warn(f"Key {k} already exists in environment ({environ[k]}), prioritizing existing one ({v})") | ||
else: | ||
environ[k] = forward_environ[k] | ||
return environ | ||
|
||
def _construct_redirect_paths(self, redirectDir): | ||
stdDir = Path.cwd() | ||
if redirectDir is not None: | ||
stdDir = Path(redirectDir) | ||
|
||
if self.stdout is None: | ||
stdout = f"{stdDir}/{self.name}_{BaseJob.orderId}.out" | ||
else: | ||
stdout = f"{stdDir}/{self.stdout}_{BaseJob.orderId}.out" | ||
|
||
if self.stderr is None: | ||
stderr = f"{stdDir}/{self.name}_{BaseJob.orderId}.err" | ||
else: | ||
stderr = f"{stdDir}/{self.stderr}_{BaseJob.orderId}.err" | ||
|
||
BaseJob.orderId += 1 | ||
|
||
return stdout, stderr | ||
|
||
def schedule(self, flux_handle, forward_environ=None, redirectDir=None, pre_signed=False, waitable=True): | ||
jobspec = JobspecV1.from_command( | ||
command=self._construct_command(), | ||
num_tasks=self.tasks_per_node * self.nodes, | ||
num_nodes=self.nodes, | ||
cores_per_task=self.cores_per_task, | ||
gpus_per_task=self.gpus_per_task, | ||
exclusive=self.exclusive, | ||
) | ||
|
||
stdout, stderr = self._construct_redirect_paths(redirectDir) | ||
environ = self._construct_environ(forward_environ) | ||
jobspec.environment = environ | ||
jobspec.stdout = stdout | ||
jobspec.stderr = stderr | ||
|
||
return jobspec, fjob.submit(flux_handle, jobspec, pre_signed=pre_signed, waitable=waitable) | ||
|
||
|
||
@dataclass(kw_only=True) | ||
class PhysicsJob(BaseJob): | ||
def _verify(self): | ||
is_executable = shutil.which(self.executable) is not None | ||
is_path = Path(self.executable).is_file() | ||
return is_executable or is_path | ||
|
||
def __post_init__(self): | ||
if not self._verify(): | ||
raise RuntimeError( | ||
f"[PhysicsJob] executable is neither a executable nor a system command {self.executable}" | ||
) | ||
|
||
|
||
@dataclass(kw_only=True, init=False) | ||
class Stager(BaseJob): | ||
def _get_stager_default_cores(self): | ||
""" | ||
We need the following cores: | ||
1 RMQ Client to receive messages | ||
1 Process to store to filesystem | ||
1 Process to make public to kosh | ||
""" | ||
return 3 | ||
|
||
def _verify(self, pruner_path, pruner_cls): | ||
assert Path(pruner_path).is_file(), "Path to Pruner class should exist" | ||
user_class = load_class(pruner_path, pruner_cls) | ||
print(f"Loaded Pruner Class {user_class.__name__}") | ||
|
||
def __init__( | ||
self, | ||
name: str, | ||
num_cores: int, | ||
db_path: str, | ||
pruner_cls: str, | ||
pruner_path: str, | ||
pruner_args: List[str], | ||
num_gpus: Optional[int], | ||
**kwargs, | ||
): | ||
executable = sys.executable | ||
|
||
self._verify(pruner_path, pruner_cls) | ||
|
||
# TODO: Here we are accessing both the stager arguments and the pruner_arguments. Is is an oppotunity to emit | ||
# an early error message. But, this would require extending argparse or something else. Noting for future reference | ||
cli_arguments = [ | ||
"-m", | ||
"ams_wf.AMSDBStage", | ||
"-db", | ||
db_path, | ||
"--policy", | ||
"process", | ||
"--dest", | ||
str(Path(db_path) / Path("candidates")), | ||
"--db-type", | ||
"dhdf5", | ||
"--store", | ||
"-m", | ||
"fs", | ||
"--class", | ||
pruner_cls, | ||
] | ||
cli_arguments += pruner_args | ||
|
||
num_cores = self._get_stager_default_cores() + num_cores | ||
super().__init__( | ||
name=name, | ||
executable=executable, | ||
nodes=1, | ||
tasks_per_node=1, | ||
cores_per_task=num_cores, | ||
args=cli_arguments, | ||
gpus_per_task=num_gpus, | ||
**kwargs, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
import argparse | ||
import logging | ||
import sys | ||
import os | ||
import json | ||
from urllib import parse | ||
|
||
from ams.deploy_tools import spawn_rmq_broker | ||
from ams.deploy_tools import RootSched | ||
from ams.deploy_tools import start_flux | ||
from ams.rmq_async import broker_status | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
def get_rmq_credentials(flux_uri, rmq_creds, rmq_cert): | ||
if rmq_creds is None: | ||
# TODO Overhere we need to spawn our own server | ||
rmq_creds, rmq_cert = spawn_rmq_broker(flux_uri) | ||
with open(rmq_creds, "r") as fd: | ||
rmq_creds = json.load(fd) | ||
|
||
return rmq_creds, rmq_cert | ||
|
||
|
||
def main(): | ||
parser = argparse.ArgumentParser(description="AMS workflow deployment") | ||
|
||
parser.add_argument("--rmq-creds", help="Credentials file (JSON)") | ||
parser.add_argument("--rmq-cert", help="TLS certificate file") | ||
parser.add_argument("--flux-uri", help="Flux uri of an already existing allocation") | ||
parser.add_argument("--nnodes", help="Number of nnodes to use for this AMS Deployment") | ||
parser.add_argument( | ||
"--root-scheduler", | ||
dest="scheduler", | ||
choices=[e.name for e in RootSched], | ||
help="The provided scheduler of the cluster", | ||
) | ||
|
||
args = parser.parse_args() | ||
|
||
""" | ||
Verify System is on a "Valid" Status | ||
""" | ||
|
||
if args.flux_uri is None and args.scheduler is None: | ||
print("Please provide either a flux URI handle to connect to or provide the base job scheduler") | ||
sys.exit() | ||
|
||
flux_process = None | ||
flux_uri = args.flux_uri | ||
if flux_uri is None: | ||
flux_process, flux_uri = start_flux(RootSched[args.scheduler], args.nnodes) | ||
|
||
rmq_creds, rmq_cert = get_rmq_credentials(flux_uri, args.rmq_creds, args.rmq_cert) | ||
|
||
if not broker_status(rmq_creds, rmq_cert): | ||
# If we created a subprocess in the background to run flux, we should terminate it | ||
if flux_process is not None: | ||
flux_process.terminate() | ||
print("RMQ Broker is not connected, exiting ...") | ||
sys.exit() | ||
|
||
""" | ||
We Have FLUX URI and here we know that rmq_creds, and rmq_cert are valid and we can start | ||
scheduling jobs | ||
""" | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |