Skip to content

Commit

Permalink
Enable config for fsdp activation checkpointing (huggingface#2779)
Browse files Browse the repository at this point in the history
* Enable config for fsdp activation checkpointing

* Fix ruff errors
  • Loading branch information
helloworld1 authored and yhna940 committed May 16, 2024
1 parent 83f440c commit f5a5944
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 0 deletions.
6 changes: 6 additions & 0 deletions src/accelerate/commands/config/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,6 +446,12 @@ def get_cluster_input():
default=True,
error_message="Please enter yes or no.",
)
fsdp_config["fsdp_activation_checkpointing"] = _ask_field(
"Do you want to enable FSDP activation checkpointing? [yes/NO]: ",
_convert_yes_no_to_bool,
default=False,
error_message="Please enter yes or no.",
)

megatron_lm_config = {}
if distributed_type in [DistributedType.MULTI_GPU]:
Expand Down
1 change: 1 addition & 0 deletions src/accelerate/utils/launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,7 @@ def prepare_multi_gpu_env(args: argparse.Namespace) -> Dict[str, str]:
current_env["FSDP_USE_ORIG_PARAMS"] = str(args.fsdp_use_orig_params).lower()
current_env["FSDP_CPU_RAM_EFFICIENT_LOADING"] = str(args.fsdp_cpu_ram_efficient_loading).lower()
current_env["FSDP_SYNC_MODULE_STATES"] = str(args.fsdp_sync_module_states).lower()
current_env["FSDP_ACTIVATION_CHECKPOINTING"] = str(args.fsdp_activation_checkpointing).lower()

if args.use_megatron_lm:
prefix = "MEGATRON_LM_"
Expand Down

0 comments on commit f5a5944

Please sign in to comment.