Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion deepspeed/launcher/multinode_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import shutil
import subprocess
import warnings
import re
from shlex import split
from abc import ABC, abstractmethod
from deepspeed.accelerator import get_accelerator
Expand All @@ -34,7 +35,10 @@ def get_cmd(self, environment, active_resources):
"""Return the command to execute on node"""

def add_export(self, key, var):
self.exports[key.strip()] = f"\"{var.strip()}\""
var = var.strip()
if re.search(r'[^\w@%+=:,./-]', var):
var = f"\"{var}\""
self.exports[key.strip()] = var

def parse_user_args(self):
return self.args.user_args
Expand Down
36 changes: 36 additions & 0 deletions tests/unit/launcher/test_user_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,20 @@
import pytest
import subprocess

from types import SimpleNamespace

from deepspeed.accelerator import get_accelerator
from deepspeed.launcher.multinode_runner import MultiNodeRunner


class DummyRunner(MultiNodeRunner):

def backend_exists(self):
return True

def get_cmd(self, environment, active_resources):
return []


if not get_accelerator().is_available():
pytest.skip("only supported in accelerator environments.", allow_module_level=True)
Expand Down Expand Up @@ -38,6 +51,12 @@ def cmd(user_script_fp, prompt, multi_node):
return cmd


@pytest.fixture
def dummy_runner():
args = SimpleNamespace(user_args=[], user_script="dummy_script.py")
return DummyRunner(args, "dummy_world_info")


@pytest.mark.parametrize("prompt", [
'''"I am 6' tall"''', """'I am 72" tall'""", """'"translate English to Romanian: "'""",
'''I'm going to tell them "DeepSpeed is the best"'''
Expand All @@ -64,3 +83,20 @@ def test_bash_string_args(tmpdir, user_script_fp):
p = subprocess.Popen(["bash", bash_fp], stdout=subprocess.PIPE, stderr=subprocess.PIPE)
out, err = p.communicate()
assert "ARG PARSE SUCCESS" in out.decode("utf-8"), f"User args not parsed correctly: {err.decode('utf-8')}"


def test_add_export_with_special_characters(dummy_runner):
"""
Values with special characters (e.g., 64(x2)) must be quoted to avoid bash syntax errors.
"""
dummy_runner.add_export("SLURM_JOB_CPUS_PER_NODE", "64(x2)")
assert dummy_runner.exports["SLURM_JOB_CPUS_PER_NODE"] == "\"64(x2)\""


def test_add_export_no_special_characters(dummy_runner):
"""
Values without special characters should remain unquoted (e.g., PYTHONPATH).
This avoids issues where unnecessary quotes break module imports.
"""
dummy_runner.add_export("PYTHONPATH", "/usr/local/lib/python3.9/site-packages")
assert dummy_runner.exports["PYTHONPATH"] == "/usr/local/lib/python3.9/site-packages"