Skip to content

Commit 19b75eb

Browse files
Add tests
Signed-off-by: Saurabh <saurabhkoshatwar1996@gmail.com>
1 parent 412541c commit 19b75eb

File tree

1 file changed

+32
-0
lines changed

1 file changed

+32
-0
lines changed

tests/unit/launcher/test_user_args.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,18 @@
66
import pytest
77
import subprocess
88

9+
from types import SimpleNamespace
10+
911
from deepspeed.accelerator import get_accelerator
12+
from deepspeed.launcher.multinode_runner import MultiNodeRunner
13+
14+
class DummyRunner(MultiNodeRunner):
15+
def backend_exists(self):
16+
return True
1017

18+
def get_cmd(self, environment, active_resources):
19+
return []
20+
1121
if not get_accelerator().is_available():
1222
pytest.skip("only supported in accelerator environments.", allow_module_level=True)
1323

@@ -37,6 +47,11 @@ def cmd(user_script_fp, prompt, multi_node):
3747
cmd = ("deepspeed", "--num_nodes", "1", "--num_gpus", "1", user_script_fp, "--prompt", prompt)
3848
return cmd
3949

50+
@pytest.fixture
51+
def dummy_runner():
52+
args = SimpleNamespace(user_args=[], user_script="dummy_script.py")
53+
return DummyRunner(args, "dummy_world_info")
54+
4055

4156
@pytest.mark.parametrize("prompt", [
4257
'''"I am 6' tall"''', """'I am 72" tall'""", """'"translate English to Romanian: "'""",
@@ -64,3 +79,20 @@ def test_bash_string_args(tmpdir, user_script_fp):
6479
p = subprocess.Popen(["bash", bash_fp], stdout=subprocess.PIPE, stderr=subprocess.PIPE)
6580
out, err = p.communicate()
6681
assert "ARG PARSE SUCCESS" in out.decode("utf-8"), f"User args not parsed correctly: {err.decode('utf-8')}"
82+
83+
84+
def test_add_export_with_special_characters(dummy_runner):
85+
"""
86+
Values with special characters (e.g., 64(x2)) must be quoted to avoid bash syntax errors.
87+
"""
88+
dummy_runner.add_export("SLURM_JOB_CPUS_PER_NODE", "64(x2)")
89+
assert dummy_runner.exports["SLURM_JOB_CPUS_PER_NODE"] == "\"64(x2)\""
90+
91+
92+
def test_add_export_no_special_characters(dummy_runner):
93+
"""
94+
Values without special characters should remain unquoted (e.g., PYTHONPATH).
95+
This avoids issues where unnecessary quotes break module imports.
96+
"""
97+
dummy_runner.add_export("PYTHONPATH", "/usr/local/lib/python3.9/site-packages")
98+
assert dummy_runner.exports["PYTHONPATH"] == "/usr/local/lib/python3.9/site-packages"

0 commit comments

Comments
 (0)