Skip to content

Commit

Permalink
Enable activation checkpointing
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: facebookresearch#573

Enable Activation Checkpointing from Pytorch Distributed in d2go.

Reviewed By: rohan-varma

Differential Revision: D45681009

fbshipit-source-id: c03f27af61e0374b9e5991d82070edbe41edde6d
  • Loading branch information
Anthony Chen authored and facebook-github-bot committed Jun 14, 2023
1 parent 3fce52c commit 0389f4e
Show file tree
Hide file tree
Showing 6 changed files with 299 additions and 94 deletions.
3 changes: 3 additions & 0 deletions d2go/runner/config_defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from d2go.modeling.subclass import add_subclass_configs
from d2go.quantization.modeling import add_quantization_default_configs
from d2go.registry.builtin import CONFIG_UPDATER_REGISTRY
from d2go.trainer.activation_checkpointing import add_activation_checkpoint_configs
from d2go.trainer.fsdp import add_fsdp_configs
from d2go.utils.gpu_memory_profiler import add_memory_profiler_configs
from d2go.utils.visualization import add_tensorboard_default_configs
Expand Down Expand Up @@ -87,6 +88,8 @@ def _add_detectron2go_runner_default_cfg(_C: CN) -> None:
add_distillation_configs(_C)
# _C.FSDP
add_fsdp_configs(_C)
# _C.ACTIVATION_CHECKPOINT
add_activation_checkpoint_configs(_C)

# Set find_unused_parameters for DistributedDataParallel.
_C.MODEL.DDP_FIND_UNUSED_PARAMETERS = False
Expand Down
63 changes: 63 additions & 0 deletions d2go/trainer/activation_checkpointing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
#!/usr/bin/env python3
# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.
import logging
from functools import partial

import torch.nn as nn
from d2go.config import CfgNode as CN
from d2go.modeling import modeling_hook as mh
from d2go.registry.builtin import MODELING_HOOK_REGISTRY
from d2go.trainer.helper import D2GO_WRAP_POLICY_REGISTRY
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
apply_activation_checkpointing,
checkpoint_wrapper,
CheckpointImpl,
)


logger = logging.getLogger(__name__)


def add_activation_checkpoint_configs(_C: CN):
_C.ACTIVATION_CHECKPOINT = CN()
_C.ACTIVATION_CHECKPOINT.REENTRANT = False
# Find autowrap policy at D2GO_WRAP_POLICY_REGISTRY, or use '' to disable autowrap
_C.ACTIVATION_CHECKPOINT.AUTO_WRAP_POLICY = "always_wrap_policy"
# A list of layer cls names to wrap, case sensitive
_C.ACTIVATION_CHECKPOINT.AUTO_WRAP_LAYER_CLS = []


@MODELING_HOOK_REGISTRY.register()
class ActivationCheckpointModelingHook(mh.ModelingHook):
"""Modeling hook that wraps model in activation checkpoint based on config"""

def apply(self, model: nn.Module) -> nn.Module:
logger.info("Activation Checkpointing is used")
wrapper_fn = partial(
checkpoint_wrapper,
checkpoint_impl=CheckpointImpl.NO_REENTRANT
if not self.cfg.ACTIVATION_CHECKPOINT.REENTRANT
else CheckpointImpl.REENTRANT,
)
policy_name = self.cfg.ACTIVATION_CHECKPOINT.AUTO_WRAP_POLICY
assert (
policy_name != "size_based_auto_wrap_policy"
), "ActivationCheckpointing should always be wrapped at module boundary"
policy_kwargs = {
"layer_names": self.cfg.ACTIVATION_CHECKPOINT.AUTO_WRAP_LAYER_CLS,
}
auto_wrap_policy = (
D2GO_WRAP_POLICY_REGISTRY.get(policy_name)(model, **policy_kwargs)
if policy_name != ""
else lambda _: True
)

apply_activation_checkpointing(
model, checkpoint_wrapper_fn=wrapper_fn, auto_wrap_policy=auto_wrap_policy
)
return model

def unapply(self, model: nn.Module) -> nn.Module:
raise NotImplementedError(
"ActivationCheckpointModelingHook.unapply() not implemented: can't unwrap an activation checkpoint module"
)
97 changes: 4 additions & 93 deletions d2go/trainer/fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,14 @@
import contextlib
import logging
from enum import Enum
from functools import partial
from typing import Callable, Generator, Iterable, Optional
from typing import Generator, Optional

import torch
import torch.nn as nn
from d2go.config import CfgNode as CN
from d2go.modeling.modeling_hook import ModelingHook
from d2go.registry.builtin import MODELING_HOOK_REGISTRY
from d2go.trainer.helper import parse_precision_from_string
from detectron2.utils.registry import Registry
from d2go.trainer.helper import D2GO_WRAP_POLICY_REGISTRY, parse_precision_from_string
from torch.ao.pruning import fqn_to_module
from torch.cuda.amp import GradScaler
from torch.distributed.fsdp.fully_sharded_data_parallel import (
Expand All @@ -27,17 +25,10 @@
StateDictType,
)
from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
from torch.distributed.fsdp.wrap import (
always_wrap_policy as _always_wrap_policy,
size_based_auto_wrap_policy as _size_based_auto_wrap_policy,
transformer_auto_wrap_policy as _layer_based_auto_wrap_policy,
)


logger = logging.getLogger(__name__)

D2GO_FSDP_WRAP_POLICY_REGISTRY = Registry("D2GO_FSDP_WRAP_POLICY_REGISTRY")


def add_fsdp_configs(_C: CN):
_C.FSDP = CN()
Expand All @@ -49,7 +40,7 @@ def add_fsdp_configs(_C: CN):
# See docstring of CpuOffload and BackwardPrefetch in torch.distributed.fsdp.fully_sharded_data_parallel
_C.FSDP.CPU_OFFLOAD = False
_C.FSDP.BACKWARD_PREFETCH = True
# Find autowrap policy at D2GO_FSDP_WRAP_POLICY_REGISTRY, or use '' to disable autowrap
# Find autowrap policy at D2GO_WRAP_POLICY_REGISTRY, or use '' to disable autowrap
_C.FSDP.AUTO_WRAP_POLICY = "never_wrap_policy"
_C.FSDP.AUTO_WRAP_MIN_PARAMS = int(1e4)
# A list of layer cls names to wrap, case sensitive
Expand Down Expand Up @@ -210,7 +201,7 @@ def build_fsdp(
)

auto_wrap_policy = (
D2GO_FSDP_WRAP_POLICY_REGISTRY.get(auto_wrap_policy_name)(
D2GO_WRAP_POLICY_REGISTRY.get(auto_wrap_policy_name)(
model, **auto_wrap_policy_kwargs
)
if auto_wrap_policy_name != ""
Expand Down Expand Up @@ -321,83 +312,3 @@ def unapply(self, model: FSDPWrapper) -> nn.Module:
raise NotImplementedError(
"FSDPModelingHook.unapply() not implemented: can't unwrap a FSDP module"
)


def get_module_class_from_name(module, name):
"""
Gets a class from a module by its name. Code borrowed from HuggingFace
Args:
module (`torch.nn.Module`): The module to get the class from.
name (`str`): The name of the class.
"""
modules_children = list(module.children())
if module.__class__.__name__ == name:
return module.__class__
elif len(modules_children) == 0:
return
else:
for child_module in modules_children:
module_class = get_module_class_from_name(child_module, name)
if module_class is not None:
return module_class


@D2GO_FSDP_WRAP_POLICY_REGISTRY.register()
def never_wrap_policy(model, **kwargs) -> Optional[Callable]:
"""
Don't wrap any child module, only wrap the root
"""

def never_wrap(*args, **kwargs):
return False

return never_wrap


@D2GO_FSDP_WRAP_POLICY_REGISTRY.register()
def always_wrap_policy(model, **kwargs) -> Optional[Callable]:
"""
Wrapper for always_wrap_policy() from torch.distributed.fsdp.wrap
"""
return _always_wrap_policy


@D2GO_FSDP_WRAP_POLICY_REGISTRY.register()
def size_based_auto_wrap_policy(
model, min_num_params=1e4, **kwargs
) -> Optional[Callable]:
"""
Wrapper for size_based_auto_wrap_policy() from torch.distributed.fsdp.wrap
"""
# Note: be careful when using auto wrap with shared parameters.
# Errors will be thrown if shared parameters reside in different FSDP units
return partial(
_size_based_auto_wrap_policy,
min_num_params=min_num_params,
)


@D2GO_FSDP_WRAP_POLICY_REGISTRY.register()
def layer_based_auto_wrap_policy(
model, layer_names: Iterable[str], **kwargs
) -> Optional[Callable]:
"""
Wrapper for transformer_auto_wrap_policy() from torch.distributed.fsdp.wrap
Args:
layer_names: a list of layer names
"""
assert (
len(layer_names) > 0
), "FSDP.AUTO_WRAP_LAYER_CLS should be a nonempty list of layer names contained in the model"
layer_cls = []
for name in layer_names:
closure = get_module_class_from_name(model, name)
if closure is None:
raise Exception(
f"Could not find the layer class {name} to wrap in the model."
)
layer_cls.append(closure)
return partial(
_layer_based_auto_wrap_policy,
transformer_layer_cls=layer_cls,
)
104 changes: 103 additions & 1 deletion d2go/trainer/helper.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,18 @@
from typing import Union
from functools import partial
from typing import Any, Callable, Iterable, List, Optional, Union

import torch

from detectron2.utils.registry import Registry
from torch.distributed.fsdp.wrap import (
always_wrap_policy as _always_wrap_policy,
size_based_auto_wrap_policy as _size_based_auto_wrap_policy,
transformer_auto_wrap_policy as _layer_based_auto_wrap_policy,
)


D2GO_WRAP_POLICY_REGISTRY = Registry("D2GO_WRAP_POLICY_REGISTRY")


def parse_precision_from_string(
precision: str, lightning=False
Expand All @@ -19,3 +30,94 @@ def parse_precision_from_string(
return torch.bfloat16 if not lightning else "bf16"
else:
raise ValueError(f"Invalid precision dtype {precision}")


def get_module_class_from_name(module, name):
"""
Gets a class from a module by its name. Code borrowed from HuggingFace
Args:
module (`torch.nn.Module`): The module to get the class from.
name (`str`): The name of the class.
"""
modules_children = list(module.children())
if module.__class__.__name__ == name:
return module.__class__
elif len(modules_children) == 0:
return
else:
for child_module in modules_children:
module_class = get_module_class_from_name(child_module, name)
if module_class is not None:
return module_class


def get_layer_cls_from_names(
model: Any, layer_names: Iterable[str]
) -> List[torch.nn.Module]:
"""
Get a list of layers from a model that match a list of layer names.
"""
layer_cls = []
for name in layer_names:
closure = get_module_class_from_name(model, name)
if closure is None:
raise Exception(
f"Could not find the layer class {name} to wrap in the model."
)
layer_cls.append(closure)

return layer_cls


@D2GO_WRAP_POLICY_REGISTRY.register()
def never_wrap_policy(model, **kwargs) -> Optional[Callable]:
"""
Don't wrap any child module, only wrap the root
"""

def never_wrap(*args, **kwargs):
return False

return never_wrap


@D2GO_WRAP_POLICY_REGISTRY.register()
def always_wrap_policy(model, **kwargs) -> Optional[Callable]:
"""
Wrapper for always_wrap_policy() from torch.distributed.fsdp.wrap
"""
return _always_wrap_policy


@D2GO_WRAP_POLICY_REGISTRY.register()
def size_based_auto_wrap_policy(
model, min_num_params=1e4, **kwargs
) -> Optional[Callable]:
"""
Wrapper for size_based_auto_wrap_policy() from torch.distributed.fsdp.wrap
"""
# Note: be careful when using auto wrap with shared parameters.
# Errors will be thrown if shared parameters reside in different FSDP units
return partial(
_size_based_auto_wrap_policy,
min_num_params=min_num_params,
)


@D2GO_WRAP_POLICY_REGISTRY.register()
def layer_based_auto_wrap_policy(
model, layer_names: Iterable[str], **kwargs
) -> Optional[Callable]:
"""
Wrapper for transformer_auto_wrap_policy() from torch.distributed.fsdp.wrap
Args:
layer_names: a list of layer names
"""
assert (
len(layer_names) > 0
), "layer_names should be a nonempty list of layer names contained in the model"
layer_cls = get_layer_cls_from_names(model, layer_names)
return partial(
_layer_based_auto_wrap_policy,
transformer_layer_cls=layer_cls,
)
Empty file added tests/trainer/__init__.py
Empty file.
Loading

0 comments on commit 0389f4e

Please sign in to comment.