Skip to content

Commit

Permalink
Fix saving of submit yaml env values to workflow.
Browse files Browse the repository at this point in the history
  • Loading branch information
MichelleGower committed Jan 10, 2025
1 parent 166633b commit d2ab51f
Show file tree
Hide file tree
Showing 5 changed files with 249 additions and 37 deletions.
1 change: 1 addition & 0 deletions doc/changes/DM-48245.feature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Updated GenericWorkflowJob creation code to handle variables and pre-existing environment variables in submit yaml's environment section.
57 changes: 39 additions & 18 deletions python/lsst/ctrl/bps/bps_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
import re
import string
from os.path import expandvars
from typing import Any

from lsst.daf.butler import Config
from lsst.resources import ResourcePath
Expand Down Expand Up @@ -372,28 +373,48 @@ def search(self, key, opt=None):
value = re.sub(r"\$(\S+)", r"<ENV:\1>", value)

if opt.get("replaceVars", True):
# default only applies to original search key
# Instead of doing deep copies of opt (especially with
# the recursive calls), temporarily remove default value
# and put it back.
default = opt.pop("default", _NO_SEARCH_DEFAULT_VALUE)

# Temporarily replace any env vars so formatter doesn't try to
# replace them.
value = re.sub(r"\${([^}]+)}", r"<BPSTMP:\1>", value)

value = self.formatter.format(value, self, opt)

# Replace any temporary env place holders.
value = re.sub(r"<BPSTMP:([^>]+)>", r"${\1}", value)

# if default was originally in opt
if default != _NO_SEARCH_DEFAULT_VALUE:
opt["default"] = default
value = self.replace_vars(value, opt)

_LOG.debug("after format=%s", value)

if found and isinstance(value, Config):
value = BpsConfig(value, search_order=[])

return found, value

def replace_vars(self, value: str, opt: dict[str, Any]) -> str:
"""Replace variables in string with values except those
in opt['skipNames'].
Parameters
----------
value : `str`
Value in which to replace variables.
opt : `dict` [`str`, Any]
Options to be used when searching and replacing values.
In particular "skipNames" lists variable names to
not replace.
"""
# default only applies to original search key
# Instead of doing deep copies of opt (especially with
# the recursive calls), temporarily remove default value
# and put it back.
default = opt.pop("default", _NO_SEARCH_DEFAULT_VALUE)

# Temporarily replace any env vars so formatter doesn't try to
# replace them.
value = re.sub(r"\${([^}]+)}", r"<BPSTMP:\1>", value)
for name in opt.get("skipNames", {}):
value = value.replace(f"{{{name}}}", f"<BPSTMP2:{name}>")

value = self.formatter.format(value, self, opt)

# Replace any temporary place holders.
value = re.sub(r"<BPSTMP:([^>]+)>", r"${\1}", value)
value = re.sub(r"<BPSTMP2:([^>]+)>", r"{\1}", value)

# if default was originally in opt
if default != _NO_SEARCH_DEFAULT_VALUE:
opt["default"] = default

return value
55 changes: 38 additions & 17 deletions python/lsst/ctrl/bps/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,6 +407,26 @@ def _get_job_values(config, search_opt, cmd_line_key):
else:
job_values[attr] = getattr(default_gwjob, attr)

# Need to replace all config variables in environment values
# While replacing variables, convert to plain dict
if job_values["environment"]:
old_searchobj = search_opt.get("searchobj", None)
old_replace_vars = search_opt.get("replaceVars", None)
job_env = job_values["environment"]
search_opt["searchobj"] = job_env
search_opt["replaceVars"] = True
job_values["environment"] = {}
for name in job_env:
job_values["environment"][name] = config.search(name, search_opt)[1]
if old_searchobj is None:
del search_opt["searchobj"]
else:
search_opt["searchobj"] = old_searchobj
if old_replace_vars is None:
del search_opt["replaceVars"]
else:
search_opt["replaceVars"] = old_replace_vars

# If the automatic memory scaling is enabled (i.e. the memory multiplier
# is set and it is a positive number greater than 1.0), adjust number
# of retries when necessary. If the memory multiplier is invalid, disable
Expand Down Expand Up @@ -676,7 +696,6 @@ def create_generic_workflow(
cached_pipetask_values[qnode.taskDef.label] = _get_job_values(
config, search_opt, "runQuantumCommand"
)

_handle_job_values(cached_pipetask_values[qnode.taskDef.label], gwjob, unset_attributes)

# Update job with workflow attribute and profile values.
Expand Down Expand Up @@ -841,9 +860,14 @@ def create_final_command(config: BpsConfig, prefix: str) -> tuple[GenericWorkflo
Executable object for the final script.
arguments : `str`
Command line needed to call the final script.
Raises
------
RuntimeError if no commands found.
"""
search_opt = {
"replaceVars": False,
"replaceVars": True,
"skipNames": ["butlerConfig", "qgraphFile"],
"replaceEnvVars": False,
"expandEnvVars": False,
"searchobj": config["finalJob"],
Expand All @@ -858,28 +882,25 @@ def create_final_command(config: BpsConfig, prefix: str) -> tuple[GenericWorkflo
print("qgraphFile=$1", file=fh)
print("butlerConfig=$2", file=fh)

command_len = 0 # Make sure at least write one actual command
i = 1
found, command = config.search(f"command{i}", opt=search_opt)
while found:
# Temporarily replace any env vars so formatter doesn't try to
# replace them.
command = re.sub(r"\${([^}]+)}", r"<BPSTMP:\1>", command)

# butlerConfig will be args to script and set to env vars
command = command.replace("{qgraphFile}", "<BPSTMP:qgraphFile>")
command = command.replace("{butlerConfig}", "<BPSTMP:butlerConfig>")

# Replace all other vars in command string
search_opt["replaceVars"] = True
command = config.formatter.format(command, config, search_opt)
search_opt["replaceVars"] = False

# Replace any temporary env placeholders.
command = re.sub(r"<BPSTMP:([^>]+)>", r"${\1}", command)
# The files will be args to script, so change to shell vars
command = command.replace("{qgraphFile}", "${qgraphFile}")
command = command.replace("{butlerConfig}", "${butlerConfig}")

print(command, file=fh)
command_len += len(command.strip())

# Search for next command
i += 1
found, command = config.search(f"command{i}", opt=search_opt)
command_len += len(command.strip())
if not command_len:
raise RuntimeError(
"No finalJob commands were found. Use NEVER for finalJob.whenRun to turn off finalJob"
)
os.chmod(script_file, 0o755)
executable = GenericWorkflowExec(os.path.basename(script_file), script_file, True)

Expand Down
15 changes: 15 additions & 0 deletions tests/test_bpsconfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,21 @@ def testDefault(self):
self.assertEqual(found, True)
self.assertEqual(value, 4)

def testReplaceVars(self):
"""Test replaceVars method."""
test_opt = {"default": 555}
orig_str = "<ENV:GARPLY>/waldo/{qux:03}/{notthere}"
value = self.config.replace_vars(orig_str, opt=test_opt)
self.assertEqual(value, "<ENV:GARPLY>/waldo/002/")
self.assertEqual(test_opt["default"], 555)

def testReplaceVarsSkipNames(self):
test_opt = {"default": 555, "skipNames": ["qgraphFile", "butlerConfig"]}
orig_str = "<ENV:GARPLY>/waldo/{qux:03} {qgraphFile} {butlerConfig}"
value = self.config.replace_vars(orig_str, opt=test_opt)
self.assertEqual(value, "<ENV:GARPLY>/waldo/002 {qgraphFile} {butlerConfig}")
self.assertEqual(test_opt["default"], 555)

def testVariables(self):
"""Test combinations of expandEnvVars, replaceEnvVars,
and replaceVars.
Expand Down
158 changes: 156 additions & 2 deletions tests/test_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,12 @@

from cqg_test_utils import make_test_clustered_quantum_graph
from lsst.ctrl.bps import BPS_SEARCH_ORDER, BpsConfig, GenericWorkflowJob
from lsst.ctrl.bps.transform import _get_job_values, create_generic_workflow, create_generic_workflow_config
from lsst.ctrl.bps.transform import (
_get_job_values,
create_final_command,
create_generic_workflow,
create_generic_workflow_config,
)

TESTDIR = os.path.abspath(os.path.dirname(__file__))

Expand Down Expand Up @@ -78,7 +83,7 @@ def setUp(self):
},
# Needed because transform assumes they exist
"whenSaveJobQgraph": "NEVER",
"finalJob": {"whenRun": "ALWAYS"},
"finalJob": {"whenRun": "ALWAYS", "command1": "/usr/bin/env"},
},
BPS_SEARCH_ORDER,
)
Expand Down Expand Up @@ -177,6 +182,155 @@ def testRetrievingCmdLine(self):
self.assertEqual(job_values["executable"].src_uri, "/path/to/foo")
self.assertEqual(job_values["arguments"], "bar.txt")

def testEnvironment(self):
config = BpsConfig(
{
"var1": "two",
"environment": {"TEST_INT": 1, "TEST_SPACES": "one {var1} three"},
}
)
job_values = _get_job_values(config, {}, None)
truth = BpsConfig({"TEST_INT": 1, "TEST_SPACES": "one two three"}, {}, None)
self.assertEqual(truth, job_values["environment"])

def testEnvironmentOptions(self):
config = BpsConfig(
{
"var1": "two",
"environment": {"TEST_INT": 1, "TEST_SPACES": "one {var1} three"},
"finalJob": {"requestMemory": 8096, "command1": "/usr/bin/env"},
}
)
search_obj = config["finalJob"]
search_opts = {"replaceVars": False, "searchobj": search_obj}
job_values = _get_job_values(config, search_opts, None)
truth = {"TEST_INT": 1, "TEST_SPACES": "one two three"}
self.assertEqual(truth, job_values["environment"])
self.assertEqual(search_opts["replaceVars"], False)
self.assertEqual(search_opts["searchobj"]["requestMemory"], 8096)
self.assertEqual(job_values["request_memory"], 8096)


class TestCreateFinalCommand(unittest.TestCase):
"""Tests for the create_final_command function."""

def setUp(self):
self.tmpdir = tempfile.TemporaryDirectory()
self.script_beginning = [
"#!/bin/bash\n",
"\n",
"set -e\n",
"set -x\n",
"qgraphFile=$1\n",
"butlerConfig=$2\n",
]

def tearDown(self):
self.tmpdir.cleanup()

def testSingleCommand(self):
"""Test with single final job command."""
config_butler = f"{self.tmpdir.name}/test_repo"
config = BpsConfig(
{
"var1": "42a",
"var2": "42b",
"var3": "42c",
"butlerConfig": config_butler,
"finalJob": {"command1": "/usr/bin/echo {var1} {qgraphFile} {var2} {butlerConfig} {var3}"},
}
)
gwf_exec, args = create_final_command(config, self.tmpdir.name)
self.assertEqual(args, f"<FILE:runQgraphFile> {config_butler}")
final_script = f"{self.tmpdir.name}/final_job.bash"
self.assertEqual(gwf_exec.src_uri, final_script)
with open(final_script) as infh:
lines = infh.readlines()
self.assertEqual(
lines, self.script_beginning + ["/usr/bin/echo 42a ${qgraphFile} 42b ${butlerConfig} 42c\n"]
)

def testMultipleCommands(self):
config_butler = f"{self.tmpdir.name}/test_repo"
config = BpsConfig(
{
"var1": "42a",
"var2": "42b",
"var3": "42c",
"butlerConfig": config_butler,
"finalJob": {
"command1": "/usr/bin/echo {var1} {qgraphFile} {var2} {butlerConfig} {var3}",
"command2": "/usr/bin/uptime",
},
}
)
gwf_exec, args = create_final_command(config, self.tmpdir.name)
self.assertEqual(args, f"<FILE:runQgraphFile> {config_butler}")
final_script = f"{self.tmpdir.name}/final_job.bash"
self.assertEqual(gwf_exec.src_uri, final_script)
with open(final_script) as infh:
lines = infh.readlines()
self.assertEqual(
lines,
self.script_beginning
+ ["/usr/bin/echo 42a ${qgraphFile} 42b ${butlerConfig} 42c\n", "/usr/bin/uptime\n"],
)

def testZeroCommands(self):
config_butler = f"{self.tmpdir.name}/test_repo"
config = BpsConfig(
{
"var1": "42a",
"var2": "42b",
"var3": "42c",
"butlerConfig": config_butler,
"finalJob": {
"cmd1": "/usr/bin/echo {var1} {qgraphFile} {var2} {butlerConfig} {var3}",
"cmd2": "/usr/bin/uptime",
},
}
)
with self.assertRaisesRegex(RuntimeError, "finalJob.whenRun"):
_, _ = create_final_command(config, self.tmpdir.name)

def testWhiteSpaceOnlyCommand(self):
config_butler = f"{self.tmpdir.name}/test_repo"
config = BpsConfig(
{
"butlerConfig": config_butler,
"finalJob": {"command1": "", "command2": "\t \n"},
}
)
with self.assertRaisesRegex(RuntimeError, "finalJob.whenRun"):
_, _ = create_final_command(config, self.tmpdir.name)

def testSkipCommandUsingWhiteSpace(self):
config_butler = f"{self.tmpdir.name}/test_repo"
config = BpsConfig(
{
"var1": "42a",
"var2": "42b",
"var3": "42c",
"butlerConfig": config_butler,
"finalJob": {
"command1": "/usr/bin/echo {var1} {qgraphFile} {var2} {butlerConfig} {var3}",
"command2": "", # test skipping a command (i.e., overriding a default)
"command3": "/usr/bin/uptime",
},
}
)
gwf_exec, args = create_final_command(config, self.tmpdir.name)
self.assertEqual(args, f"<FILE:runQgraphFile> {config_butler}")
final_script = f"{self.tmpdir.name}/final_job.bash"
self.assertEqual(gwf_exec.src_uri, final_script)
with open(final_script) as infh:
lines = infh.readlines()
self.assertEqual(
lines,
self.script_beginning
+ ["/usr/bin/echo 42a ${qgraphFile} 42b ${butlerConfig} 42c\n", "\n", "/usr/bin/uptime\n"],
)


if __name__ == "__main__":
unittest.main()

0 comments on commit d2ab51f

Please sign in to comment.