Skip to content

Commit

Permalink
Revise the implementation of set_activation_checkpointing (#866)
Browse files Browse the repository at this point in the history
rohan-varma authored Apr 25, 2024
1 parent 1ae3bfd commit ea3d4ea
Showing 2 changed files with 57 additions and 7 deletions.
42 changes: 42 additions & 0 deletions tests/torchtune/utils/test_memory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import pytest
import torch.nn as nn
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
CheckpointWrapper,
)
from torchtune.utils import set_activation_checkpointing


class TestSetActivationCheckpointing:
@pytest.fixture
def model(self) -> int:
return nn.Sequential(
nn.Linear(10, 10),
nn.Linear(10, 10),
nn.Linear(10, 10),
nn.ReLU(),
nn.Dropout(0.5),
)

def _verify(self, model):
for submodule in model.modules():
if isinstance(submodule, CheckpointWrapper):
assert isinstance(submodule._checkpoint_wrapped_module, nn.Linear)

def test_activation_checkpoint_set_policy(self, model):
set_activation_checkpointing(model=model, auto_wrap_policy={nn.Linear})
self._verify(model)

def test_activation_checkpoint_custom_policy(self, model):
def custom_policy(module: nn.Module, recurse: bool, **kwargs) -> bool:
if recurse:
return True
return isinstance(module, nn.Linear)

set_activation_checkpointing(model=model, auto_wrap_policy=custom_policy)
self._verify(model)
22 changes: 15 additions & 7 deletions torchtune/utils/memory.py
Original file line number Diff line number Diff line change
@@ -7,7 +7,7 @@
import gc
import logging

from typing import Any, Dict, Optional, Set
from typing import Any, Callable, Dict, Set, Type, Union

import torch

@@ -20,19 +20,27 @@

_log: logging.Logger = get_logger()

ACWrapPolicyType: Type = Union[Set[Type], Callable[[nn.Module, bool, int], bool]]


def set_activation_checkpointing(
model: nn.Module, auto_wrap_policy: Optional[Set[nn.Module]] = None, **kwargs
model: nn.Module, auto_wrap_policy: ACWrapPolicyType, **kwargs
) -> None:
"""Utility to setup activation checkpointing and wrap the model for checkpointing.
"""Utility to apply activation checkpointing to the passed in model.
Args:
model (nn.Module): Model to setup activation checkpointing.
auto_wrap_policy (Optional[Set[nn.Module]]): Policy to wrap module.
**kwargs: additional arguments to pass to torch.distributed activation checkpointing.
auto_wrap_policy (ACWrapPolicyType): Policy to wrap module.
This can either be a set of ``nn.Module`` types, in which case, modules of the specified type(s)
will be wrapped individually with activation checkpointing, or a ``callable`` policy describing
how to wrap the model with activation checkpointing. For more information on authoring custom
policies, please see this tutorial:
https://pytorch.org/tutorials/intermediate/FSDP_adavnced_tutorial.html#transformer-wrapping-policy.
**kwargs: additional arguments to pass to ``torch.distributed`` activation checkpointing.
"""
wrap_policy = ModuleWrapPolicy(auto_wrap_policy or set())
apply_activation_checkpointing(model, auto_wrap_policy=wrap_policy, **kwargs)
if isinstance(auto_wrap_policy, set):
auto_wrap_policy = ModuleWrapPolicy(auto_wrap_policy)
apply_activation_checkpointing(model, auto_wrap_policy=auto_wrap_policy, **kwargs)


def cleanup_before_training() -> None:

0 comments on commit ea3d4ea

Please sign in to comment.