Skip to content

Commit

Permalink
Merge pull request #273 from juaml/feat/queue-shell
Browse files Browse the repository at this point in the history
[ENH]:  Specify which shell to use when using `junifer queue`
  • Loading branch information
synchon authored Apr 9, 2024
2 parents b0c0a61 + de69448 commit c36165b
Show file tree
Hide file tree
Showing 15 changed files with 176 additions and 50 deletions.
1 change: 1 addition & 0 deletions docs/changes/newsfragments/273.feature
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add support for choosing between ``bash`` and ``zsh`` shells when queueing by `Synchon Mandal`_
1 change: 0 additions & 1 deletion docs/changes/newsfragments/320.fix

This file was deleted.

1 change: 0 additions & 1 deletion docs/changes/newsfragments/321.fix

This file was deleted.

4 changes: 4 additions & 0 deletions docs/whats_new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,10 @@ Bugfixes
(:gh:`312`)
- Fix element access for :class:`.DMCC13Benchmark` DataGrabber by `Synchon
Mandal`_ (:gh:`314`)
- Add a validation step on the :func:`.run` function to validate the marker
collection by `Fede Raimondo`_ (:gh:`320`)
- Add the executable flag to the ants docker scripts, fsl docker scripts and
other running scripts by `Fede Raimondo`_ (:gh:`321`)
- Force ``str`` dtype when parsing elements from file by `Synchon Mandal`_
(:gh:`322`)

Expand Down
22 changes: 16 additions & 6 deletions junifer/api/queue_context/gnu_parallel_local_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@ class GnuParallelLocalAdapter(QueueContextAdapter):
Raises
------
ValueError
If``env`` is invalid.
If ``env.kind`` is invalid or
if ``env.shell`` is invalid.
See Also
--------
Expand Down Expand Up @@ -110,13 +111,22 @@ def _check_env(self, env: Optional[Dict[str, str]]) -> None:
f"must be one of {valid_env_kinds}"
)
else:
# Check shell
shell = env.get("shell", "bash")
valid_shells = ["bash", "zsh"]
if shell not in valid_shells:
raise_error(
f"Invalid value for `env.shell`: {shell}, "
f"must be one of {valid_shells}"
)
self._shell = shell
# Set variables
if env["kind"] == "local":
# No virtual environment
self._executable = "junifer"
self._arguments = ""
else:
self._executable = f"run_{env['kind']}.sh"
self._executable = f"run_{env['kind']}.{self._shell}"
self._arguments = f"{env['name']} junifer"
self._exec_path = self._job_dir / self._executable

Expand All @@ -135,7 +145,7 @@ def elements(self) -> str:
def pre_run(self) -> str:
"""Return pre-run commands."""
fixed = (
"#!/usr/bin/env bash\n\n"
f"#!/usr/bin/env {self._shell}\n\n"
"# This script is auto-generated by junifer.\n\n"
"# Force datalad to run in non-interactive mode\n"
"DATALAD_UI_INTERACTIVE=false\n"
Expand All @@ -146,7 +156,7 @@ def pre_run(self) -> str:
def run(self) -> str:
"""Return run commands."""
return (
"#!/usr/bin/env bash\n\n"
f"#!/usr/bin/env {self._shell}\n\n"
"# This script is auto-generated by junifer.\n\n"
"# Run pre_run.sh\n"
f"sh {self._pre_run_path.resolve()!s}\n\n"
Expand All @@ -166,7 +176,7 @@ def run(self) -> str:
def pre_collect(self) -> str:
"""Return pre-collect commands."""
fixed = (
"#!/usr/bin/env bash\n\n"
f"#!/usr/bin/env {self._shell}\n\n"
"# This script is auto-generated by junifer.\n"
)
var = self._pre_collect or ""
Expand All @@ -175,7 +185,7 @@ def pre_collect(self) -> str:
def collect(self) -> str:
"""Return collect commands."""
return (
"#!/usr/bin/env bash\n\n"
f"#!/usr/bin/env {self._shell}\n\n"
"# This script is auto-generated by junifer.\n\n"
"# Run pre_collect.sh\n"
f"sh {self._pre_collect_path.resolve()!s}\n\n"
Expand Down
21 changes: 16 additions & 5 deletions junifer/api/queue_context/htcondor_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,8 @@ def _check_env(self, env: Optional[Dict[str, str]]) -> None:
Raises
------
ValueError
If ``env.kind`` is invalid.
If ``env.kind`` is invalid or
if ``env.shell`` is invalid.
"""
# Set env related variables
Expand All @@ -140,13 +141,22 @@ def _check_env(self, env: Optional[Dict[str, str]]) -> None:
f"must be one of {valid_env_kinds}"
)
else:
# Check shell
shell = env.get("shell", "bash")
valid_shells = ["bash", "zsh"]
if shell not in valid_shells:
raise_error(
f"Invalid value for `env.shell`: {shell}, "
f"must be one of {valid_shells}"
)
self._shell = shell
# Set variables
if env["kind"] == "local":
# No virtual environment
self._executable = "junifer"
self._arguments = ""
else:
self._executable = f"run_{env['kind']}.sh"
self._executable = f"run_{env['kind']}.{self._shell}"
self._arguments = f"{env['name']} junifer"
self._exec_path = self._job_dir / self._executable

Expand Down Expand Up @@ -181,7 +191,7 @@ def _check_collect(self, collect: str) -> str:
def pre_run(self) -> str:
"""Return pre-run commands."""
fixed = (
"#!/bin/bash\n\n"
f"#!/usr/bin/env {self._shell}\n\n"
"# This script is auto-generated by junifer.\n\n"
"# Force datalad to run in non-interactive mode\n"
"DATALAD_UI_INTERACTIVE=false\n"
Expand Down Expand Up @@ -225,12 +235,13 @@ def run(self) -> str:
def pre_collect(self) -> str:
"""Return pre-collect commands."""
fixed = (
"#!/bin/bash\n\n" "# This script is auto-generated by junifer.\n"
f"#!/usr/bin/env {self._shell}\n\n"
"# This script is auto-generated by junifer.\n"
)
var = self._pre_collect or ""
# Add commands if collect="yes"
if self._collect == "yes":
var += 'if [ "${1}" == "4" ]; then\n' " exit 1\n" "fi\n"
var += 'if [ "${1}" == "4" ]; then\n exit 1\nfi\n'
return fixed + "\n" + var

def collect(self) -> str:
Expand Down
53 changes: 41 additions & 12 deletions junifer/api/queue_context/tests/test_gnu_parallel_local_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,30 @@
from junifer.api.queue_context import GnuParallelLocalAdapter


def test_GnuParallelLocalAdapter_env_error() -> None:
def test_GnuParallelLocalAdapter_env_kind_error() -> None:
"""Test error for invalid env kind."""
with pytest.raises(ValueError, match="Invalid value for `env.kind`"):
GnuParallelLocalAdapter(
job_name="check_env",
job_name="check_env_kind",
job_dir=Path("."),
yaml_config_path=Path("."),
elements=["sub01"],
env={"kind": "jambalaya"},
)


def test_GnuParallelLocalAdapter_env_shell_error() -> None:
"""Test error for invalid env shell."""
with pytest.raises(ValueError, match="Invalid value for `env.shell`"):
GnuParallelLocalAdapter(
job_name="check_env_shell",
job_dir=Path("."),
yaml_config_path=Path("."),
elements=["sub01"],
env={"kind": "conda", "shell": "fish"},
)


@pytest.mark.parametrize(
"elements, expected_text",
[
Expand Down Expand Up @@ -55,14 +67,18 @@ def test_GnuParallelLocalAdapter_elements(


@pytest.mark.parametrize(
"pre_run, expected_text",
"pre_run, expected_text, shell",
[
(None, "# Force datalad"),
("# Check this out\n", "# Check this out"),
(None, "# Force datalad", "bash"),
(None, "# Force datalad", "zsh"),
("# Check this out\n", "# Check this out", "bash"),
("# Check this out\n", "# Check this out", "zsh"),
],
)
def test_GnuParallelLocalAdapter_pre_run(
pre_run: Optional[str], expected_text: str
pre_run: Optional[str],
expected_text: str,
shell: str,
) -> None:
"""Test GnuParallelLocalAdapter pre_run().
Expand All @@ -72,28 +88,35 @@ def test_GnuParallelLocalAdapter_pre_run(
The parametrized pre run text.
expected_text : str
The parametrized expected text.
shell : str
The parametrized expected shell.
"""
adapter = GnuParallelLocalAdapter(
job_name="test_pre_run",
job_dir=Path("."),
yaml_config_path=Path("."),
elements=["sub01"],
env={"kind": "conda", "name": "junifer", "shell": shell},
pre_run=pre_run,
)
assert shell in adapter.pre_run()
assert expected_text in adapter.pre_run()


@pytest.mark.parametrize(
"pre_collect, expected_text",
"pre_collect, expected_text, shell",
[
(None, "# This script"),
("# Check this out\n", "# Check this out"),
(None, "# This script", "bash"),
(None, "# This script", "zsh"),
("# Check this out\n", "# Check this out", "bash"),
("# Check this out\n", "# Check this out", "zsh"),
],
)
def test_GnuParallelLocalAdapter_pre_collect(
pre_collect: Optional[str],
expected_text: str,
shell: str,
) -> None:
"""Test GnuParallelLocalAdapter pre_collect().
Expand All @@ -103,15 +126,19 @@ def test_GnuParallelLocalAdapter_pre_collect(
The parametrized pre collect text.
expected_text : str
The parametrized expected text.
shell : str
The parametrized expected shell.
"""
adapter = GnuParallelLocalAdapter(
job_name="test_pre_collect",
job_dir=Path("."),
yaml_config_path=Path("."),
elements=["sub01"],
env={"kind": "venv", "name": "junifer", "shell": shell},
pre_collect=pre_collect,
)
assert shell in adapter.pre_collect()
assert expected_text in adapter.pre_collect()


Expand Down Expand Up @@ -140,8 +167,10 @@ def test_GnuParallelLocalAdapter_collect() -> None:
@pytest.mark.parametrize(
"env",
[
{"kind": "conda", "name": "junifer"},
{"kind": "venv", "name": "./junifer"},
{"kind": "conda", "name": "junifer", "shell": "bash"},
{"kind": "conda", "name": "junifer", "shell": "zsh"},
{"kind": "venv", "name": "./junifer", "shell": "bash"},
{"kind": "venv", "name": "./junifer", "shell": "zsh"},
],
)
def test_GnuParallelLocalAdapter_prepare(
Expand Down Expand Up @@ -177,7 +206,7 @@ def test_GnuParallelLocalAdapter_prepare(
adapter.prepare()

assert "GNU parallel" in caplog.text
assert f"Copying run_{env['kind']}" in caplog.text
assert f"Copying run_{env['kind']}.{env['shell']}" in caplog.text
assert "Writing pre_run.sh" in caplog.text
assert "Writing run_test_prepare.sh" in caplog.text
assert "Writing pre_collect.sh" in caplog.text
Expand Down
Loading

0 comments on commit c36165b

Please sign in to comment.