diff --git a/docs/source/en/main_classes/trainer.md b/docs/source/en/main_classes/trainer.md index 4a767ee0766..f433b820af8 100644 --- a/docs/source/en/main_classes/trainer.md +++ b/docs/source/en/main_classes/trainer.md @@ -456,6 +456,10 @@ as the model saving with FSDP activated is only available with recent fixes. If `"True"`, FSDP explicitly prefetches the next upcoming all-gather while executing in the forward pass. - `limit_all_gathers` can be specified in the config file. If `"True"`, FSDP explicitly synchronizes the CPU thread to prevent too many in-flight all-gathers. + - `activation_checkpointing` can be specified in the config file. + If `"True"`, FSDP activation checkpointing is a technique to reduce memory usage by clearing activations of + certain layers and recomputing them during a backward pass. Effectively, this trades extra computation time + for reduced memory usage. **Few caveats to be aware of** - it is incompatible with `generate`, thus is incompatible with `--predict_with_generate` diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 2694cff70af..f11cfb9f0b7 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -3896,6 +3896,15 @@ def create_accelerator_and_postprocess(self): fsdp_plugin.limit_all_gathers = self.args.fsdp_config.get( "limit_all_gathers", fsdp_plugin.limit_all_gathers ) + fsdp_plugin.activation_checkpointing = self.args.fsdp_config.get( + "activation_checkpointing", fsdp_plugin.activation_checkpointing + ) + if fsdp_plugin.activation_checkpointing and self.args.gradient_checkpointing: + raise ValueError( + "The activation_checkpointing in FSDP config and the gradient_checkpointing in training arg " + "can't be set to True simultaneously. Please use FSDP's activation_checkpointing logic " + "when using FSDP." + ) if self.is_deepspeed_enabled: if getattr(self.args, "hf_deepspeed_config", None) is None: diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index 11f812eaf2f..8549739deb1 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -482,6 +482,10 @@ class TrainingArguments: Will use gradient checkpointing over each nested XLA FSDP wrapped layer. This setting can only be used when the xla flag is set to true, and an auto wrapping policy is specified through fsdp_min_num_params or fsdp_transformer_layer_cls_to_wrap. + - activation_checkpointing (`bool`, *optional*, defaults to `False`): + If True, activation checkpointing is a technique to reduce memory usage by clearing activations of + certain layers and recomputing them during a backward pass. Effectively, this trades extra + computation time for reduced memory usage. deepspeed (`str` or `dict`, *optional*): Use [Deepspeed](https://github.com/microsoft/deepspeed). This is an experimental feature and its API may