diff --git a/beeflow/common/worker/flux_worker.py b/beeflow/common/worker/flux_worker.py index 353f38214..29f4ff41b 100644 --- a/beeflow/common/worker/flux_worker.py +++ b/beeflow/common/worker/flux_worker.py @@ -1,5 +1,6 @@ """Flux worker interface.""" +import io import os from beeflow.common import log as bee_logging from beeflow.common.worker.worker import Worker @@ -59,6 +60,19 @@ def build_jobspec(self, task): # TODO: What to do with the MPI version? # mpi_version = task.get_requirement('beeflow:MPIRequirement', 'mpiVersion', # default='pmi2') + scripts_enabled = task.get_requirement('beeflow:ScriptRequirement', 'enabled', + default=False) + if scripts_enabled: + # We use StringIO here to properly break the script up into lines with readlines + pre_script = io.StringIO(task.get_requirement('beeflow:ScriptRequirement', + 'pre_script')).readlines() + post_script = io.StringIO(task.get_requirement('beeflow:ScriptRequirement', + 'post_script')).readlines() + + # Pre commands + if scripts_enabled: + for cmd in pre_script: + script.append(cmd) for cmd in crt_res.pre_commands: if cmd.type == 'one-per-node': @@ -67,7 +81,7 @@ def build_jobspec(self, task): cmd_args = ['flux', 'run', ' '.join(cmd.args)] script.append(' '.join(cmd_args)) - # Set up the main command + # Main command args = ['flux', 'run', '-N', str(nodes), '-n', str(ntasks)] if task.stdout is not None: args.extend(['--output', task.stdout]) @@ -77,6 +91,7 @@ def build_jobspec(self, task): log.info(args) script.append(' '.join(args)) + # Post commands for cmd in crt_res.post_commands: if cmd.type == 'one-per-node': cmd_args = ['flux', 'run', '-N', str(nodes), '-n', str(nodes), ' '.join(cmd.args)] @@ -84,6 +99,10 @@ def build_jobspec(self, task): cmd_args = ['flux', 'run', ' '.join(cmd.args)] script.append(' '.join(cmd_args)) + if scripts_enabled: + for cmd in post_script: + script.append(cmd) + script = '\n'.join(script) jobspec = self.job.JobspecV1.from_batch_command(script, task.name, num_slots=ntasks, @@ -93,8 +112,8 @@ def build_jobspec(self, task): jobspec.stderr = f'{task_save_path}/{task.name}-{task.id}.err' jobspec.environment = dict(os.environ) # Save the script for later reference - with open(f'{task_save_path}/{task.name}-{task.id}.sh', 'w', encoding='utf-8') as fp: - fp.write(script) + with open(f'{task_save_path}/{task.name}-{task.id}.sh', 'w', encoding='utf-8') as f_path: + f_path.write(script) return jobspec def submit_task(self, task):