Skip to content

Commit 3acf2ce

Browse files
saurabhkoshatwarloadams
authored andcommitted
Conditionally quote env vars (deepspeedai#7071)
Resolves deepspeedai#6997 This PR conditionally quotes environment variable values—only wrapping those containing special characters (like parentheses) that could trigger bash errors. Safe values remain unquoted. --------- Signed-off-by: Saurabh <saurabhkoshatwar1996@gmail.com> Signed-off-by: Saurabh Koshatwar <saurabhkoshatwar1996@gmail.com> Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com> Signed-off-by: yisheng <yi.sheng@intel.com>
1 parent ffbc40b commit 3acf2ce

File tree

2 files changed

+41
-1
lines changed

2 files changed

+41
-1
lines changed

deepspeed/launcher/multinode_runner.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import shutil
99
import subprocess
1010
import warnings
11+
import re
1112
from shlex import split
1213
from abc import ABC, abstractmethod
1314
from deepspeed.accelerator import get_accelerator
@@ -34,7 +35,10 @@ def get_cmd(self, environment, active_resources):
3435
"""Return the command to execute on node"""
3536

3637
def add_export(self, key, var):
37-
self.exports[key.strip()] = f"\"{var.strip()}\""
38+
var = var.strip()
39+
if re.search(r'[^\w@%+=:,./-]', var):
40+
var = f"\"{var}\""
41+
self.exports[key.strip()] = var
3842

3943
def parse_user_args(self):
4044
return self.args.user_args

tests/unit/launcher/test_user_args.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,20 @@
66
import pytest
77
import subprocess
88

9+
from types import SimpleNamespace
10+
911
from deepspeed.accelerator import get_accelerator
12+
from deepspeed.launcher.multinode_runner import MultiNodeRunner
13+
14+
15+
class DummyRunner(MultiNodeRunner):
16+
17+
def backend_exists(self):
18+
return True
19+
20+
def get_cmd(self, environment, active_resources):
21+
return []
22+
1023

1124
if not get_accelerator().is_available():
1225
pytest.skip("only supported in accelerator environments.", allow_module_level=True)
@@ -38,6 +51,12 @@ def cmd(user_script_fp, prompt, multi_node):
3851
return cmd
3952

4053

54+
@pytest.fixture
55+
def dummy_runner():
56+
args = SimpleNamespace(user_args=[], user_script="dummy_script.py")
57+
return DummyRunner(args, "dummy_world_info")
58+
59+
4160
@pytest.mark.parametrize("prompt", [
4261
'''"I am 6' tall"''', """'I am 72" tall'""", """'"translate English to Romanian: "'""",
4362
'''I'm going to tell them "DeepSpeed is the best"'''
@@ -64,3 +83,20 @@ def test_bash_string_args(tmpdir, user_script_fp):
6483
p = subprocess.Popen(["bash", bash_fp], stdout=subprocess.PIPE, stderr=subprocess.PIPE)
6584
out, err = p.communicate()
6685
assert "ARG PARSE SUCCESS" in out.decode("utf-8"), f"User args not parsed correctly: {err.decode('utf-8')}"
86+
87+
88+
def test_add_export_with_special_characters(dummy_runner):
89+
"""
90+
Values with special characters (e.g., 64(x2)) must be quoted to avoid bash syntax errors.
91+
"""
92+
dummy_runner.add_export("SLURM_JOB_CPUS_PER_NODE", "64(x2)")
93+
assert dummy_runner.exports["SLURM_JOB_CPUS_PER_NODE"] == "\"64(x2)\""
94+
95+
96+
def test_add_export_no_special_characters(dummy_runner):
97+
"""
98+
Values without special characters should remain unquoted (e.g., PYTHONPATH).
99+
This avoids issues where unnecessary quotes break module imports.
100+
"""
101+
dummy_runner.add_export("PYTHONPATH", "/usr/local/lib/python3.9/site-packages")
102+
assert dummy_runner.exports["PYTHONPATH"] == "/usr/local/lib/python3.9/site-packages"

0 commit comments

Comments
 (0)