diff --git a/parlai/agents/transformer/modules/decoder.py b/parlai/agents/transformer/modules/decoder.py index 52ee3a80cc1..04e0589f039 100644 --- a/parlai/agents/transformer/modules/decoder.py +++ b/parlai/agents/transformer/modules/decoder.py @@ -26,6 +26,7 @@ from parlai.utils.misc import warn_once from parlai.utils.torch import PipelineHelper from parlai.utils.fsdp import fsdp_wrap +from parlai.nn.checkpoint import checkpoint_wrapper @swappable( @@ -286,6 +287,8 @@ def build_layers(self) -> nn.ModuleList: activation=self.activation, variant=self.variant, ) + if self.opt.get('checkpoint_activations'): + layer = checkpoint_wrapper(layer) layers.append(fsdp_wrap(layer)) # type: ignore return layers diff --git a/parlai/agents/transformer/modules/encoder.py b/parlai/agents/transformer/modules/encoder.py index 441d13112f9..5f0fc2a3f0e 100644 --- a/parlai/agents/transformer/modules/encoder.py +++ b/parlai/agents/transformer/modules/encoder.py @@ -26,6 +26,7 @@ from parlai.utils.misc import warn_once from parlai.utils.torch import PipelineHelper from parlai.utils.fsdp import fsdp_wrap +from parlai.nn.checkpoint import checkpoint_wrapper @swappable(self_attention=MultiHeadAttention, feedforward=TransformerFFN) @@ -236,6 +237,8 @@ def build_layers(self) -> nn.ModuleList: variant=self.variant, activation=self.activation, ) + if self.opt.get('checkpoint_activations'): + layer = checkpoint_wrapper(layer) layers.append(fsdp_wrap(layer)) return layers diff --git a/parlai/agents/transformer/transformer.py b/parlai/agents/transformer/transformer.py index 643b1209896..25cb65c2fd7 100644 --- a/parlai/agents/transformer/transformer.py +++ b/parlai/agents/transformer/transformer.py @@ -141,6 +141,12 @@ def add_common_cmdline_args(parser): default=False, help='Shard the layers across multiple GPUs.', ) + parser.add_argument( + '--checkpoint-activations', + type='bool', + default=False, + help='Recompute activations on backward pass to conserve memory.', + ) class Transformer(Agent): diff --git a/parlai/nn/checkpoint.py b/parlai/nn/checkpoint.py new file mode 100644 index 00000000000..ff4bc58c23e --- /dev/null +++ b/parlai/nn/checkpoint.py @@ -0,0 +1,18 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +try: + from fairscale.nn import checkpoint_wrapper +except ImportError: + + def checkpoint_wrapper(module): + """ + Dummy checkpoint wrapper that raises an error. + """ + raise ImportError( + 'Please install fairscale with `pip install fairscale` to use ' + '--checkpoint-activations true.' + ) diff --git a/tests/test_transformers.py b/tests/test_transformers.py index 0829ab32292..893b77af60c 100644 --- a/tests/test_transformers.py +++ b/tests/test_transformers.py @@ -284,6 +284,23 @@ def _overfit_train(self, **args): args.update(args) return testing_utils.train_model(args) + def test_checkpoint(self): + """ + Checks --checkpoint-activations true + """ + valid, test = testing_utils.train_model( + dict( + task='integration_tests:overfit', + model='transformer/generator', + dict_file='zoo:unittest/transformer_generator2/model.dict', + batchsize=4, + skip_generation=True, + validation_metric='ppl', + max_train_steps=10, + checkpoint_activations=True, + ) + ) + def test_greedysearch(self): """ Test greedy search.