Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Arde/fsdp activation checkpointing #25771

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions docs/source/en/main_classes/trainer.md
Original file line number Diff line number Diff line change
Expand Up @@ -456,6 +456,10 @@ as the model saving with FSDP activated is only available with recent fixes.
If `"True"`, FSDP explicitly prefetches the next upcoming all-gather while executing in the forward pass.
- `limit_all_gathers` can be specified in the config file.
If `"True"`, FSDP explicitly synchronizes the CPU thread to prevent too many in-flight all-gathers.
- `activation_checkpointing` can be specified in the config file.
If `"True"`, FSDP activation checkpointing is a technique to reduce memory usage by clearing activations of
certain layers and recomputing them during a backward pass. Effectively, this trades extra computation time
for reduced memory usage.

**Few caveats to be aware of**
- it is incompatible with `generate`, thus is incompatible with `--predict_with_generate`
Expand Down
9 changes: 9 additions & 0 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3896,6 +3896,15 @@ def create_accelerator_and_postprocess(self):
fsdp_plugin.limit_all_gathers = self.args.fsdp_config.get(
"limit_all_gathers", fsdp_plugin.limit_all_gathers
)
fsdp_plugin.activation_checkpointing = self.args.fsdp_config.get(
"activation_checkpointing", fsdp_plugin.activation_checkpointing
)
if fsdp_plugin.activation_checkpointing and self.args.gradient_checkpointing:
raise ValueError(
"The activation_checkpointing in FSDP config and the gradient_checkpointing in training arg "
"can't be set to True simultaneously. Please use FSDP's activation_checkpointing logic "
"when using FSDP."
)

if self.is_deepspeed_enabled:
if getattr(self.args, "hf_deepspeed_config", None) is None:
Expand Down
4 changes: 4 additions & 0 deletions src/transformers/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -482,6 +482,10 @@ class TrainingArguments:
Will use gradient checkpointing over each nested XLA FSDP wrapped layer. This setting can only be
used when the xla flag is set to true, and an auto wrapping policy is specified through
fsdp_min_num_params or fsdp_transformer_layer_cls_to_wrap.
- activation_checkpointing (`bool`, *optional*, defaults to `False`):
If True, activation checkpointing is a technique to reduce memory usage by clearing activations of
certain layers and recomputing them during a backward pass. Effectively, this trades extra
computation time for reduced memory usage.

deepspeed (`str` or `dict`, *optional*):
Use [Deepspeed](https://github.com/microsoft/deepspeed). This is an experimental feature and its API may
Expand Down