-
Notifications
You must be signed in to change notification settings - Fork 532
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Loading status checks…
Revise the implementation of set_activation_checkpointing (#866)
- v0.6.0-rc1
- v0.5.0
- v0.5.0-rc2
- v0.5.0-rc1
- v0.4.0
- v0.4.0-rc4
- v0.4.0-rc3
- v0.4.0-rc2
- v0.4.0-rc1
- v0.3.1
- v0.3.1-rc3
- v0.3.1-rc2
- v0.3.1-rc1
- v0.3.0
- v0.3.0-rc7
- v0.3.0-rc6
- v0.3.0-rc5
- v0.3.0-rc4
- v0.3.0-rc3
- v0.3.0-rc2
- v0.3.0-rc1
- v0.2.1
- v0.2.1-rc3
- v0.2.1-rc2
- v0.2.1-rc1
- v0.2.0
- v0.2.0-rc5
- v0.2.0-rc4
- v0.2.0-rc3
- v0.2.0-rc2
- v0.2.0-rc1
1 parent
1ae3bfd
commit ea3d4ea
Showing
2 changed files
with
57 additions
and
7 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
import pytest | ||
import torch.nn as nn | ||
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( | ||
CheckpointWrapper, | ||
) | ||
from torchtune.utils import set_activation_checkpointing | ||
|
||
|
||
class TestSetActivationCheckpointing: | ||
@pytest.fixture | ||
def model(self) -> int: | ||
return nn.Sequential( | ||
nn.Linear(10, 10), | ||
nn.Linear(10, 10), | ||
nn.Linear(10, 10), | ||
nn.ReLU(), | ||
nn.Dropout(0.5), | ||
) | ||
|
||
def _verify(self, model): | ||
for submodule in model.modules(): | ||
if isinstance(submodule, CheckpointWrapper): | ||
assert isinstance(submodule._checkpoint_wrapped_module, nn.Linear) | ||
|
||
def test_activation_checkpoint_set_policy(self, model): | ||
set_activation_checkpointing(model=model, auto_wrap_policy={nn.Linear}) | ||
self._verify(model) | ||
|
||
def test_activation_checkpoint_custom_policy(self, model): | ||
def custom_policy(module: nn.Module, recurse: bool, **kwargs) -> bool: | ||
if recurse: | ||
return True | ||
return isinstance(module, nn.Linear) | ||
|
||
set_activation_checkpointing(model=model, auto_wrap_policy=custom_policy) | ||
self._verify(model) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters