Skip to content

Commit

Permalink
Simplify CLI args validation and ensure CLI args take precedence over…
Browse files Browse the repository at this point in the history
… config file. (#2757)

* Remove unnecessary args.debug statement

* Add expected test failure for config sub-sections

* Remove redundancy in config file args parsing

* Make config file --cpu logic more explicit
  • Loading branch information
Iain-S authored May 9, 2024
1 parent afc2c99 commit 724824a
Show file tree
Hide file tree
Showing 8 changed files with 109 additions and 22 deletions.
32 changes: 10 additions & 22 deletions src/accelerate/commands/launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -910,7 +910,7 @@ def _validate_launch_command(args):
warned = []
mp_from_config_flag = False
# Get the default from the config file.
if args.config_file is not None or os.path.isfile(default_config_file) and not args.cpu:
if args.config_file is not None or (os.path.isfile(default_config_file) and not args.cpu):
defaults = load_config_from_file(args.config_file)
if (
not args.multi_gpu
Expand Down Expand Up @@ -954,31 +954,19 @@ def _validate_launch_command(args):
# Update args with the defaults
for name, attr in defaults.__dict__.items():
if isinstance(attr, dict):
for k in defaults.deepspeed_config:
setattr(args, k, defaults.deepspeed_config[k])
for k in defaults.fsdp_config:
arg_to_set = k
if "fsdp" not in arg_to_set:
arg_to_set = "fsdp_" + arg_to_set
setattr(args, arg_to_set, defaults.fsdp_config[k])
for k in defaults.megatron_lm_config:
setattr(args, k, defaults.megatron_lm_config[k])
for k in defaults.dynamo_config:
setattr(args, k, defaults.dynamo_config[k])
for k in defaults.ipex_config:
setattr(args, k, defaults.ipex_config[k])
for k in defaults.mpirun_config:
setattr(args, k, defaults.mpirun_config[k])
continue

# Those args are handled separately
if (
# Copy defaults.somedict.somearg to args.somearg and
# defaults.fsdp_config.x to args.fsdp_x
for key, value in attr.items():
if name == "fsdp_config" and not key.startswith("fsdp"):
key = "fsdp_" + key
if getattr(args, key, None) is None:
setattr(args, key, value)
elif (
name not in ["compute_environment", "mixed_precision", "distributed_type"]
and getattr(args, name, None) is None
):
# Those args are handled separately
setattr(args, name, attr)
if not args.debug:
args.debug = defaults.debug

if not args.mixed_precision:
if defaults.mixed_precision is None:
Expand Down
55 changes: 55 additions & 0 deletions tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,61 @@ def test_mpi_multicpu_config_cmd(self):
self.assertEqual(python_script_cmd[1], str(self.test_file_path))
self.assertEqual(python_script_cmd[2], test_file_arg)

def test_validate_launch_command(self):
"""Test that the validation function combines args and defaults."""
parser = launch_command_parser()

# If an option isn't specified on the CLI
# the config file should be used.
args = parser.parse_args(
[
"--config-file",
str(self.test_config_path / "debug.yaml"),
"test.py",
]
)
self.assertFalse(args.debug)
_validate_launch_command(args)
self.assertTrue(args.debug)

# If an option is specified on the CLI
# that should take precedence over the config file.
args = parser.parse_args(
[
"--num-processes",
"2",
"--config-file",
str(self.test_config_path / "one_proc.yaml"),
"test.py",
]
)
_validate_launch_command(args)
self.assertEqual(2, args.num_processes)

# Make sure fsdp sub-section entries are correctly parsed.
args = parser.parse_args(
[
"--config-file",
str(self.test_config_path / "fsdp.yaml"),
"test.py",
]
)
_validate_launch_command(args)
self.assertTrue(args.fsdp_sync_module_states)
self.assertEqual(1, args.fsdp_x)

args = parser.parse_args(
[
"--config-file",
str(self.test_config_path / "megatron.yaml"),
"--megatron_lm_recompute_activations",
"true",
"test.py",
]
)
_validate_launch_command(args)
self.assertEqual("true", args.megatron_lm_recompute_activations)


class LaunchArgTester(unittest.TestCase):
"""
Expand Down
4 changes: 4 additions & 0 deletions tests/test_configs/debug.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
compute_environment: LOCAL_MACHINE
debug: true
num_processes: 1
distributed_type: 'NO'
21 changes: 21 additions & 0 deletions tests/test_configs/downcast.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
compute_environment: LOCAL_MACHINE
deepspeed_config: {}
distributed_type: 'NO'
downcast_bf16: 'no'
fsdp_config: {}
gpu_ids: all
machine_rank: 0
main_process_ip: null
main_process_port: null
main_training_function: main
megatron_lm_config: {}
mixed_precision: 'no'
num_machines: 1
num_processes: 1
rdzv_backend: static
same_network: true
use_cpu: false
tpu_name: 'test-tpu'
tpu_zone: 'us-central1-a'
commands: null
command_file: tests/test_samples/test_command_file.sh
6 changes: 6 additions & 0 deletions tests/test_configs/fsdp.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
compute_environment: LOCAL_MACHINE
distributed_type: MEGATRON_LM
num_processes: 1
fsdp_config:
x: 1
fsdp_sync_module_states: true
4 changes: 4 additions & 0 deletions tests/test_configs/ipex.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
compute_environment: LOCAL_MACHINE
debug: True
num_processes: 1
distributed_type: 'NO'
5 changes: 5 additions & 0 deletions tests/test_configs/megatron.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
compute_environment: LOCAL_MACHINE
distributed_type: MEGATRON_LM
num_processes: 1
megatron_lm_config:
megatron_lm_recompute_activations: false
4 changes: 4 additions & 0 deletions tests/test_configs/one_proc.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
compute_environment: LOCAL_MACHINE
debug: true
num_processes: 1
distributed_type: 'NO'

0 comments on commit 724824a

Please sign in to comment.