Skip to content

Commit

Permalink
Add support for init_meta_context, materialize_module (#9920)
Browse files Browse the repository at this point in the history
  • Loading branch information
tchaton authored Oct 21, 2021
1 parent 4ea72a9 commit 454e93b
Show file tree
Hide file tree
Showing 7 changed files with 412 additions and 2 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added `strategy` argument to Trainer ([#8597](https://github.com/PyTorchLightning/pytorch-lightning/pull/8597))


- Added `init_meta_context`, `materialize_module` utilities ([#9920](https://github.com/PyTorchLightning/pytorch-lightning/pull/9920))


- Added `TPUPrecisionPlugin` ([#10020](https://github.com/PyTorchLightning/pytorch-lightning/pull/#10020))


Expand All @@ -221,6 +224,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added `XLACheckpointIO` plugin ([#9972](https://github.com/PyTorchLightning/pytorch-lightning/pull/9972))



### Changed

- Setting `Trainer(accelerator="ddp_cpu")` now does not spawn a subprocess if `num_processes` is kept `1` along with `num_nodes > 1` ([#9603](https://github.com/PyTorchLightning/pytorch-lightning/pull/9603)).
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/plugins/training_type/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,7 +426,7 @@ def _setup_model_and_optimizer(
def init_deepspeed(self):
# check that `configure_gradient_clipping` hook isn't overriden since deepspeed handles
# gradient clipping internally
if is_overridden("configure_gradient_clipping", self.lightning_module):
if is_overridden("configure_gradient_clipping", self.lightning_module, pl.LightningModule):
rank_zero_warn(
"Since deepspeed handles gradient clipping internally, this hook will"
" be ignored. Consider setting `gradient_clip_val` and `gradient_clip_algorithm`"
Expand Down
2 changes: 2 additions & 0 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@
from pytorch_lightning.utilities.distributed import distributed_available
from pytorch_lightning.utilities.exceptions import ExitGracefullyException, MisconfigurationException
from pytorch_lightning.utilities.imports import _fault_tolerant_training
from pytorch_lightning.utilities.meta import materialize_module
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.seed import reset_seed
from pytorch_lightning.utilities.types import (
Expand Down Expand Up @@ -1349,6 +1350,7 @@ def _call_setup_hook(self) -> None:

def _call_configure_sharded_model(self) -> None:
with self.accelerator.model_sharded_context():
materialize_module(self.lightning_module)
self.call_hook("configure_sharded_model")
self.call_hook("on_configure_sharded_model")

Expand Down
1 change: 1 addition & 0 deletions pytorch_lightning/utilities/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ def _compare_version(package: str, op: Callable, version: str, use_base_version:
_OMEGACONF_AVAILABLE = _module_available("omegaconf")
_POPTORCH_AVAILABLE = _module_available("poptorch")
_RICH_AVAILABLE = _module_available("rich") and _compare_version("rich", operator.ge, "10.2.2")
_TORCH_META_AVAILABLE = _compare_version("torch", operator.ge, "1.10.0.dev20210922")
_TORCH_QUANTIZE_AVAILABLE = bool([eg for eg in torch.backends.quantized.supported_engines if eg != "none"])
_TORCHTEXT_AVAILABLE = _module_available("torchtext")
_TORCHVISION_AVAILABLE = _module_available("torchvision")
Expand Down
323 changes: 323 additions & 0 deletions pytorch_lightning/utilities/meta.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,323 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib
import inspect
import threading
from contextlib import contextmanager
from functools import partial
from itertools import chain
from types import ModuleType
from typing import Callable, Dict, Generator, Iterator, List, Optional, Set, Type

import torch
from torch import nn, Tensor
from torch.nn import Module
from torch.nn.modules.container import ModuleDict, ModuleList, Sequential

from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _TORCH_META_AVAILABLE

if _TORCH_META_AVAILABLE:
from torch._C import _DisableTorchDispatch # type: ignore[attr-defined]

####################################################################
# BELOW: TAKEN FROM https://github.com/pytorch/pytorch/pull/66317. #
# TODO: Removed once merged and released on PyTorch side #
####################################################################

@contextmanager
def enable_python_mode(cls) -> Iterator[None]:
if not hasattr(cls, "__torch_dispatch__"):
raise ValueError("The class passed to enable_python_mode " "must have a __torch_dispatch__ classmethod")
if not isinstance(cls, type) or not issubclass(cls, (torch.Tensor,)):
raise ValueError("The argument passed to enable_python_mode " "must be the type of a Tensor subclass")
torch._C._enter_python_mode(cls)
try:
yield
finally:
torch._C._exit_python_mode()

_tls = threading.local()
_tls.in_call = False

@contextmanager
def _no_dispatch() -> Iterator[None]:
"""Temporarily disables the Python dispatch mode."""
guard = _DisableTorchDispatch() # noqa F841
try:
yield
finally:
del guard

def _handle_arange(func, args, kwargs):
kwargs["device"] = torch.device("cpu")
return torch.empty_like(func(*args, **kwargs), device="meta")

def _handle_tril(func, args, kwargs):
if args and isinstance(args[0], Tensor):
return torch.empty_like(args[0], device="meta")

return NotImplemented

class _MetaContext(Tensor):
_op_handlers: Dict[Callable, Callable] = {}

@classmethod
def _ensure_handlers_initialized(cls) -> None:
if cls._op_handlers:
return

cls._op_handlers.update(
{
torch.ops.aten.arange: _handle_arange,
torch.ops.aten.tril: _handle_tril,
}
)

@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
cls._ensure_handlers_initialized()

op_handler: Optional[Callable]

try:
op_handler = cls._op_handlers[func]
except KeyError:
op_handler = None

with _no_dispatch():
if op_handler:
result = op_handler(func, args, kwargs)
if result is not NotImplemented:
return result

if "device" in kwargs:
kwargs["device"] = torch.device("meta")

return func(*args, **kwargs)

def init_meta(module_fn: Callable[..., Module], *args, **kwargs) -> Module:
def create_instance(module=None) -> Module:
if module:
module.__init__(*args, **kwargs)
return module
return module_fn(*args, **kwargs)

if _tls.in_call:
module = create_instance()
else:
_tls.in_call = True
try:
with enable_python_mode(_MetaContext):
module = create_instance()
finally:
_tls.in_call = False

module.materialize = partial(create_instance, module=module) # type: ignore[assignment]

return module

def is_meta_init() -> bool:
"""Indicates whether the module is being instantiated by ``init_meta()``."""
return _tls.in_call

####################################################################
# ABOVE: TAKEN FROM https://github.com/pytorch/pytorch/pull/66317. #
# TODO: Removed once merged and released on PyTorch side #
####################################################################

else:

def init_meta(*_, **__):
if not _TORCH_META_AVAILABLE:
return MisconfigurationException("`init_meta` is supported from PyTorch 1.10.0")


# https://stackoverflow.com/a/63851681/9201239
def get_all_subclasses(cls: Type[nn.Module]) -> Set[nn.Module]:
subclass_list = []

def recurse(cl):
for subclass in cl.__subclasses__():
subclass_list.append(subclass)
recurse(subclass)

recurse(cls)

return set(subclass_list)


def recursively_setattr(root_module: nn.Module, prefix: str, materialized_module: nn.Module) -> None:
*path, name = prefix.split(".")
for p in path:
root_module = getattr(root_module, p)

try:
index = int(name)
root_module[index] = materialized_module
except ValueError:
setattr(root_module, name, materialized_module)


def materialize_module(root_module: nn.Module) -> nn.Module:
"""This utility performs an in-place operation by materialize a module and its children."""
if not _TORCH_META_AVAILABLE:
return root_module

materialize_fn = getattr(root_module, "materialize", None)
if materialize_fn and not isinstance(root_module, (Sequential, ModuleList, ModuleDict)):
return materialize_fn()

for name, child in root_module.named_children():
materialize_fn = getattr(child, "materialize", None)
if not materialize_fn or isinstance(child, (Sequential, ModuleList, ModuleDict)):
materialize_module(child)
else:
setattr(child, name, materialize_fn())
return root_module


# cache subclasses to optimize the search when resetting the meta device later on.
__STORAGE_META__ = {}

__CREATED_MODULES__ = set()


def _unset_meta_device(from_created: bool = False) -> None:
"""Replace all meta module by their original version."""
if not _TORCH_META_AVAILABLE:
raise MisconfigurationException("`init_meta` is supported from PyTorch 1.10.0")

if from_created:
values = [__STORAGE_META__[key] for key in __CREATED_MODULES__]
else:
values = __STORAGE_META__.values()

for mods, subclass, _ in values:
for mod in mods:
setattr(mod, subclass.__name__, subclass)


def _set_meta_device_populated(from_created: bool = False) -> None:
"""Replace all meta module by their original version."""
if not _TORCH_META_AVAILABLE:
raise MisconfigurationException("`init_meta` is supported from PyTorch 1.10.0")

if from_created:
values = [__STORAGE_META__[key] for key in __CREATED_MODULES__]
else:
values = __STORAGE_META__.values()

for mods, subclass, meta_class in values:
for mod in mods:
setattr(mod, subclass.__name__, meta_class)


def _set_meta_device() -> None:
"""Replace all torch.nn.Module by their meta replacement."""

if not _TORCH_META_AVAILABLE:
raise MisconfigurationException("`init_meta` is supported from PyTorch 1.10.0")

# Author note: This can be optimized further by searching all subclasses at once.
# Its time complexity is O(n*m) where n is the number of all subclasses if there's no multiple inheritance
# and m the number of all subclasses belonging to its subclass module.

for subclass in get_all_subclasses(torch.nn.modules.module.Module):

if isinstance(subclass, (Sequential, ModuleList, ModuleDict)):
continue

# if a subclass has already been stored, we should use the cache
if str(subclass) in __STORAGE_META__:
# reset the class import package to its rightfull state.
mods, subclass, meta_class = __STORAGE_META__[subclass]
for mod in mods:
setattr(mod, subclass.__name__, meta_class)
continue

# Create a class subclassing current `subclass` overriding its new method.
# this will enable use to use `torch.distributed.nn.utils.init_meta` to create a `meta`
# version of the current subclass module
class _MetaClass(subclass):
@classmethod
@contextmanager
def instantiation_context(cls, materialize: bool):
_unset_meta_device(from_created=True)
yield
_set_meta_device_populated(from_created=True)

@classmethod
def materialize(cls, materialize_fn: Callable):
with cls.instantiation_context(materialize=True):
obj = materialize_fn()
return obj

@staticmethod
def add_subclasses(subclass):
"""This is used to unrol the instantion tree while creating the modules."""
__CREATED_MODULES__.add(subclass)
if subclass.__bases__[0] != torch.nn.modules.module.Module:
_MetaClass.add_subclasses(subclass.__bases__[0])

def __new__(cls, *args, **kwargs):
subclass = cls.__bases__[0]
cls.add_subclasses(subclass)
with cls.instantiation_context(materialize=False):
obj = init_meta(subclass, *args, **kwargs)

obj.materialize = partial(cls.materialize, materialize_fn=obj.materialize)
return obj

def search(mod: ModuleType) -> List[ModuleType]:
out = []
for _, obj in inspect.getmembers(mod):
if obj == subclass:
out.append(mod)
return out

submodules = subclass.__module__.split(".")
mod = importlib.import_module(submodules[0])

# nn.Module class can be imported at different level and they all need to be mocked.
# Example: torch.nn.Linear is actually torch.nn.modules.linear.Linear
# Therefore, torch.nn.Linear, torch.nn.modules.Linear, torch.nn.modules.linear.Linear
# needs to be replaced by the torch.nn.linear.modules.Linear _MetaClass
out = []
out.append(search(mod))
for name in submodules[1:]:
mod = getattr(mod, name)
out.append(search(mod))

# drop empty module
mods = [mod for mod in chain(*out) if mod]

# store the modules search so it doesn't have to be performed again for this class
__STORAGE_META__[subclass] = (mods, subclass, _MetaClass)

# replace all subclass by its meta form
for mod in mods:
setattr(mod, subclass.__name__, _MetaClass)


@contextmanager
def init_meta_context() -> Generator:
rank_zero_warn(
"Be aware this feature is highly experimental and there are a number of weird edge cases "
"where it can internal assert and/or crash. A more stable version is to be expected from PyTorch 1.11."
)
_set_meta_device()
yield
_unset_meta_device()
Loading

0 comments on commit 454e93b

Please sign in to comment.