diff --git a/src/accelerate/commands/launch.py b/src/accelerate/commands/launch.py index 96700389b3d..54c8a80d25a 100644 --- a/src/accelerate/commands/launch.py +++ b/src/accelerate/commands/launch.py @@ -36,6 +36,7 @@ PrepareForLaunch, _filter_args, check_cuda_p2p_ib_support, + convert_dict_to_env_variables, is_bf16_available, is_deepspeed_available, is_mlu_available, @@ -738,10 +739,9 @@ def deepspeed_launcher(args): if args.num_machines > 1 and args.deepspeed_multinode_launcher != DEEPSPEED_MULTINODE_LAUNCHERS[1]: with open(".deepspeed_env", "a") as f: - for key, value in current_env.items(): - if ";" in value or " " in value: - continue - f.write(f"{key}={value}\n") + valid_env_items = convert_dict_to_env_variables(current_env) + if len(valid_env_items) > 1: + f.writelines(valid_env_items) process = subprocess.Popen(cmd, env=current_env) process.wait() diff --git a/src/accelerate/utils/__init__.py b/src/accelerate/utils/__init__.py index e8a8ca67dac..0ef022fe425 100644 --- a/src/accelerate/utils/__init__.py +++ b/src/accelerate/utils/__init__.py @@ -55,6 +55,7 @@ are_libraries_initialized, check_cuda_p2p_ib_support, check_fp8_capability, + convert_dict_to_env_variables, get_int_from_env, parse_choice_from_env, parse_flag_from_env, diff --git a/src/accelerate/utils/environment.py b/src/accelerate/utils/environment.py index ff99207e09b..133c330d42b 100644 --- a/src/accelerate/utils/environment.py +++ b/src/accelerate/utils/environment.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import logging import os import platform import subprocess @@ -23,6 +24,34 @@ from packaging.version import parse +logger = logging.getLogger(__name__) + + +def convert_dict_to_env_variables(current_env: dict): + """ + Verifies that all keys and values in `current_env` do not contain illegal keys or values, and returns a list of + strings as the result. + + Example: + ```python + >>> from accelerate.utils.environment import verify_env + + >>> env = {"ACCELERATE_DEBUG_MODE": "1", "BAD_ENV_NAME": ">> valid_env_items = verify_env(env) + >>> print(valid_env_items) + ["ACCELERATE_DEBUG_MODE=1\n", "OTHER_ENV=2\n"] + ``` + """ + forbidden_chars = [";", "\n", "<", ">", " "] + valid_env_items = [] + for key, value in current_env.items(): + if all(char not in (key + value) for char in forbidden_chars) and len(key) >= 1 and len(value) >= 1: + valid_env_items.append(f"{key}={value}\n") + else: + logger.warning(f"WARNING: Skipping {key}={value} as it contains forbidden characters or missing values.") + return valid_env_items + + def str_to_bool(value) -> int: """ Converts a string representation of truth to `True` (1) or `False` (0). diff --git a/tests/test_utils.py b/tests/test_utils.py index 92805a30456..063bf262036 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -31,6 +31,7 @@ CannotPadNestedTensorWarning, check_os_kernel, clear_environment, + convert_dict_to_env_variables, convert_outputs_to_fp32, convert_to_fp32, extract_model_from_parallel, @@ -355,3 +356,9 @@ class Second(QuantTensorBase): self.assertFalse(is_namedtuple((1, 2))) self.assertFalse(is_namedtuple("hey")) self.assertFalse(is_namedtuple(object())) + + def test_convert_dict_to_env_variables(self): + env = {"ACCELERATE_DEBUG_MODE": "1", "BAD_ENV_NAME": "