Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add pre/post script support to flux worker #818

Merged
merged 7 commits into from
May 8, 2024
25 changes: 22 additions & 3 deletions beeflow/common/worker/flux_worker.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Flux worker interface."""

import io
import os
from beeflow.common import log as bee_logging
from beeflow.common.worker.worker import Worker
Expand Down Expand Up @@ -59,6 +60,19 @@ def build_jobspec(self, task):
# TODO: What to do with the MPI version?
# mpi_version = task.get_requirement('beeflow:MPIRequirement', 'mpiVersion',
# default='pmi2')
scripts_enabled = task.get_requirement('beeflow:ScriptRequirement', 'enabled',
default=False)
if scripts_enabled:
# We use StringIO here to properly break the script up into lines with readlines
pre_script = io.StringIO(task.get_requirement('beeflow:ScriptRequirement',
'pre_script')).readlines()
post_script = io.StringIO(task.get_requirement('beeflow:ScriptRequirement',
'post_script')).readlines()

# Pre commands
if scripts_enabled:
for cmd in pre_script:
script.append(cmd)

for cmd in crt_res.pre_commands:
if cmd.type == 'one-per-node':
Expand All @@ -67,7 +81,7 @@ def build_jobspec(self, task):
cmd_args = ['flux', 'run', ' '.join(cmd.args)]
script.append(' '.join(cmd_args))

# Set up the main command
# Main command
args = ['flux', 'run', '-N', str(nodes), '-n', str(ntasks)]
if task.stdout is not None:
args.extend(['--output', task.stdout])
Expand All @@ -77,13 +91,18 @@ def build_jobspec(self, task):
log.info(args)
script.append(' '.join(args))

# Post commands
for cmd in crt_res.post_commands:
if cmd.type == 'one-per-node':
cmd_args = ['flux', 'run', '-N', str(nodes), '-n', str(nodes), ' '.join(cmd.args)]
else:
cmd_args = ['flux', 'run', ' '.join(cmd.args)]
script.append(' '.join(cmd_args))

if scripts_enabled:
for cmd in post_script:
script.append(cmd)

script = '\n'.join(script)
jobspec = self.job.JobspecV1.from_batch_command(script, task.name,
num_slots=ntasks,
Expand All @@ -93,8 +112,8 @@ def build_jobspec(self, task):
jobspec.stderr = f'{task_save_path}/{task.name}-{task.id}.err'
jobspec.environment = dict(os.environ)
# Save the script for later reference
with open(f'{task_save_path}/{task.name}-{task.id}.sh', 'w', encoding='utf-8') as fp:
fp.write(script)
with open(f'{task_save_path}/{task.name}-{task.id}.sh', 'w', encoding='utf-8') as f_path:
f_path.write(script)
return jobspec

def submit_task(self, task):
Expand Down