Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for init_meta_context, materialize_module #9920

Merged
merged 42 commits into from
Oct 21, 2021
Merged
Show file tree
Hide file tree
Changes from 27 commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
9a8954e
update
tchaton Oct 13, 2021
36bb238
update
tchaton Oct 13, 2021
f1890bc
remove credit
tchaton Oct 13, 2021
103c311
update
tchaton Oct 14, 2021
f346120
update
tchaton Oct 14, 2021
8f7fc11
update
tchaton Oct 14, 2021
3d7852f
add changelog
tchaton Oct 14, 2021
feb6c9c
update
tchaton Oct 14, 2021
0cdbec2
update on comments
tchaton Oct 14, 2021
8c0402b
update changelog
tchaton Oct 14, 2021
ad0f3ba
typo
tchaton Oct 14, 2021
73c0588
update
tchaton Oct 14, 2021
ff41479
update
tchaton Oct 14, 2021
402e6f6
update
tchaton Oct 15, 2021
e0d4c5b
update
tchaton Oct 15, 2021
57f4ec0
update changelog
tchaton Oct 15, 2021
1b5fb68
update
tchaton Oct 15, 2021
11a3eb9
add note
tchaton Oct 15, 2021
e116e78
update
tchaton Oct 15, 2021
0f8fb06
Merge branch 'set_meta_device' of https://github.com/PyTorchLightning…
tchaton Oct 15, 2021
ee15d11
update test name
tchaton Oct 15, 2021
f8d2e9e
wip
tchaton Oct 15, 2021
0318480
update
tchaton Oct 15, 2021
78744bc
add some typing
tchaton Oct 15, 2021
0bd6b72
update on comments
tchaton Oct 15, 2021
92b5a63
resolve bug
tchaton Oct 15, 2021
7661b1b
add layernorm
tchaton Oct 15, 2021
f78db68
update
tchaton Oct 15, 2021
5eeec6a
revert back
tchaton Oct 15, 2021
a03cd69
replace the in_place
tchaton Oct 15, 2021
f28673c
remove extra lines
tchaton Oct 15, 2021
43b62ee
update
tchaton Oct 15, 2021
0595843
remove list
tchaton Oct 15, 2021
8b27b15
update
tchaton Oct 15, 2021
0850f1e
update
tchaton Oct 15, 2021
e3f991b
update
tchaton Oct 16, 2021
cfb42a2
add a warning about unstability
tchaton Oct 16, 2021
50357b2
add a warning about unstability
tchaton Oct 16, 2021
df531aa
update test
tchaton Oct 16, 2021
50e9d65
Merge branch 'master' into set_meta_device
tchaton Oct 19, 2021
0afb695
revert on previous api based on can comments
tchaton Oct 20, 2021
2d8c0a1
Merge branch 'master' into set_meta_device
tchaton Oct 20, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added `strategy` argument to Trainer ([#8597](https://github.com/PyTorchLightning/pytorch-lightning/pull/8597))


- Added `init_meta_context, materialize_module` utilities ([#9920](https://github.com/PyTorchLightning/pytorch-lightning/pull/9920))
tchaton marked this conversation as resolved.
Show resolved Hide resolved


### Changed

- Module imports are now catching `ModuleNotFoundError` instead of `ImportError` ([#9867](https://github.com/PyTorchLightning/pytorch-lightning/pull/9867))
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/plugins/training_type/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,7 +380,7 @@ def pre_dispatch(self):
def init_deepspeed(self):
# check that `configure_gradient_clipping` hook isn't overriden since deepspeed handles
# gradient clipping internally
if is_overridden("configure_gradient_clipping", self.lightning_module):
if is_overridden("configure_gradient_clipping", self.lightning_module, pl.LightningModule):
tchaton marked this conversation as resolved.
Show resolved Hide resolved
rank_zero_warn(
"Since deepspeed handles gradient clipping internally, this hook will"
" be ignored. Consider setting `gradient_clip_val` and `gradient_clip_algorithm`"
Expand Down
2 changes: 2 additions & 0 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@
from pytorch_lightning.utilities.distributed import distributed_available
from pytorch_lightning.utilities.exceptions import ExitGracefullyException, MisconfigurationException
from pytorch_lightning.utilities.imports import _fault_tolerant_training
from pytorch_lightning.utilities.meta import materialize_module
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.seed import reset_seed
from pytorch_lightning.utilities.types import (
Expand Down Expand Up @@ -1349,6 +1350,7 @@ def _call_setup_hook(self) -> None:

def _call_configure_sharded_model(self) -> None:
with self.accelerator.model_sharded_context():
materialize_module(self.lightning_module)
self.call_hook("configure_sharded_model")
self.call_hook("on_configure_sharded_model")

Expand Down
1 change: 1 addition & 0 deletions pytorch_lightning/utilities/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ def _compare_version(package: str, op, version) -> bool:
_TORCH_BFLOAT_AVAILABLE = _compare_version(
"torch", operator.ge, "1.10.0.dev20210902"
) # todo: swap to 1.10.0 once released
_TORCH_META_AVAILABLE = _compare_version("torch", operator.ge, "1.10.0.dev20210922")
_TORCH_QUANTIZE_AVAILABLE = bool([eg for eg in torch.backends.quantized.supported_engines if eg != "none"])
_TORCH_SHARDED_TENSOR_AVAILABLE = _compare_version(
"torch", operator.ge, "1.10.0.dev20210809"
Expand Down
302 changes: 302 additions & 0 deletions pytorch_lightning/utilities/meta.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,302 @@
# Copyright The PyTorch Lightning team.
tchaton marked this conversation as resolved.
Show resolved Hide resolved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import contextlib
import importlib
import inspect
import threading
from contextlib import contextmanager
from functools import partial
from itertools import chain
from types import ModuleType
from typing import Callable, Dict, Generator, Iterator, List, Optional, Set, Type

import torch
from torch import nn, Tensor
from torch.nn import Module

from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _TORCH_META_AVAILABLE

if _TORCH_META_AVAILABLE:
from torch._C import _DisableTorchDispatch # type: ignore[attr-defined]

####################################################################
# BELOW: TAKEN FROM https://github.com/pytorch/pytorch/pull/66317. #
# TODO: Removed once merged and released on PyTorch side #
####################################################################

@contextlib.contextmanager
def enable_python_mode(cls) -> Iterator[None]:
if not hasattr(cls, "__torch_dispatch__"):
raise ValueError("The class passed to enable_python_mode " "must have a __torch_dispatch__ classmethod")
if not isinstance(cls, type) or not issubclass(cls, (torch.Tensor,)):
raise ValueError("The argument passed to enable_python_mode " "must be the type of a Tensor subclass")
torch._C._enter_python_mode(cls)
try:
yield
finally:
torch._C._exit_python_mode()
tchaton marked this conversation as resolved.
Show resolved Hide resolved

_tls = threading.local()
_tls.in_call = False

@contextmanager
def _no_dispatch() -> Iterator[None]:
"""Temporarily disables the Python dispatch mode."""
guard = _DisableTorchDispatch() # noqa F841
try:
yield
finally:
del guard

def _handle_arange(func, args, kwargs):
kwargs["device"] = torch.device("cpu")
return torch.empty_like(func(*args, **kwargs), device="meta")

def _handle_tril(func, args, kwargs):
if args and isinstance(args[0], Tensor):
return torch.empty_like(args[0], device="meta")

return NotImplemented

class _MetaContext(Tensor):

_op_handlers: Dict[Callable, Callable] = {}

@classmethod
def _ensure_handlers_initialized(cls) -> None:
if cls._op_handlers:
return

cls._op_handlers.update(
{
torch.ops.aten.arange: _handle_arange,
torch.ops.aten.tril: _handle_tril,
}
)

@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
cls._ensure_handlers_initialized()

op_handler: Optional[Callable]

try:
op_handler = cls._op_handlers[func]
except KeyError:
op_handler = None

with _no_dispatch():
if op_handler:
result = op_handler(func, args, kwargs)
if result is not NotImplemented:
return result

if "device" in kwargs:
kwargs["device"] = torch.device("meta")

return func(*args, **kwargs)

def init_meta(module_fn: Callable[..., Module], *args, **kwargs) -> Module:
def create_instance() -> Module:
return module_fn(*args, **kwargs)

if _tls.in_call:
module = create_instance()
else:
_tls.in_call = True
try:
with enable_python_mode(_MetaContext):
module = create_instance()
finally:
_tls.in_call = False

module.materialize = create_instance # type: ignore[assignment]

return module

def is_meta_init() -> bool:
"""Indicates whether the module is being instantiated by ``init_meta()``."""
return _tls.in_call

####################################################################
# ABOVE: TAKEN FROM https://github.com/pytorch/pytorch/pull/66317. #
# TODO: Removed once merged and released on PyTorch side #
####################################################################


# https://stackoverflow.com/a/63851681/9201239
def get_all_subclasses(cls: Type[nn.Module]) -> Set[nn.Module]:
subclass_list = []

def recurse(cl):
for subclass in cl.__subclasses__():
subclass_list.append(subclass)
recurse(subclass)

recurse(cls)

return set(subclass_list)


def recursively_setattr(root_module: nn.Module, prefix: str, materialized_module: nn.Module) -> None:
*path, name = prefix.split(".")
for p in path:
root_module = getattr(root_module, p)

try:
index = int(name)
root_module[index] = materialized_module
except ValueError:
setattr(root_module, name, materialized_module)


def materialize_module(root_module: nn.Module) -> nn.Module:
"""This utility performs an in-place operation by materialize a module and its children."""
if not _TORCH_META_AVAILABLE:
return root_module
tchaton marked this conversation as resolved.
Show resolved Hide resolved
memo = []
modules = list(root_module.named_modules())
for prefix, mod in modules:
materialize_fn = getattr(mod, "materialize", None)
if materialize_fn:
memo.append((prefix, materialize_fn()))
for prefix, materialized_module in memo:
recursively_setattr(root_module, prefix, materialized_module)
return root_module


# cache subclasses to optimize the search when resetting the meta device later on.
__STORAGE_META__ = {}

__CREATED_MODULES__ = set()


def _unset_meta_device(from_created: bool = False) -> None:
"""Replace all meta module by their original version."""
if not _TORCH_META_AVAILABLE:
raise MisconfigurationException("`init_meta` is supported from PyTorch 1.10.0")

if from_created:
values = [__STORAGE_META__[key] for key in __CREATED_MODULES__]
else:
values = __STORAGE_META__.values()

for mods, subclass, _ in values:
for mod in mods:
setattr(mod, subclass.__name__, subclass)


def _set_meta_device_populated(from_created: bool = False) -> None:
"""Replace all meta module by their original version."""
if not _TORCH_META_AVAILABLE:
raise MisconfigurationException("`init_meta` is supported from PyTorch 1.10.0")

if from_created:
values = [__STORAGE_META__[key] for key in __CREATED_MODULES__]
else:
values = __STORAGE_META__.values()

for mods, subclass, meta_class in values:
for mod in mods:
setattr(mod, subclass.__name__, meta_class)


def _set_meta_device() -> None:
"""Replace all torch.nn.Module by their meta replacement."""

if not _TORCH_META_AVAILABLE:
raise MisconfigurationException("`init_meta` is supported from PyTorch 1.10.0")

# Author note: This can be optimized further by searching all subclasses at once.
# Find all the nn.Module subclasses
tchaton marked this conversation as resolved.
Show resolved Hide resolved
for subclass in get_all_subclasses(torch.nn.modules.module.Module):

# if a subclass has already been stored, we should use the cache
if str(subclass) in __STORAGE_META__:
# reset the class import package to its rightfull state.
mods, subclass, meta_class = __STORAGE_META__[str(subclass)]
for mod in mods:
setattr(mod, subclass.__name__, meta_class)
continue

# Create a class subclassing current `subclass` overriding its new method.
# this will enable use to use `torch.distributed.nn.utils.init_meta` to create a `meta`
# version of the current subclass module
class _MetaClass(subclass):
@classmethod
@contextmanager
def instantiation_context(cls, materialize: bool):
_unset_meta_device(from_created=True)
yield
_set_meta_device_populated(from_created=True)

@classmethod
def materialize(cls, materialize_fn: Callable):
with cls.instantiation_context(materialize=True):
obj = materialize_fn()
return obj

@staticmethod
def add_subclasses(subclass):
"""This is used to unrol the instantion tree while creating the modules."""
__CREATED_MODULES__.add(str(subclass))
if subclass.__bases__[0] != torch.nn.modules.module.Module:
_MetaClass.add_subclasses(subclass.__bases__[0])

def __new__(cls, *args, **kwargs):
subclass = cls.__bases__[0]
cls.add_subclasses(subclass)
with cls.instantiation_context(materialize=False):
obj = init_meta(subclass, *args, **kwargs)

obj.materialize = partial(cls.materialize, materialize_fn=obj.materialize)
return obj

def search(mod: ModuleType) -> List[ModuleType]:
out = []
for _, obj in inspect.getmembers(mod):
if obj == subclass:
out.append(mod)
return out

submodules = subclass.__module__.split(".")
mod = importlib.import_module(submodules[0])

# nn.Module class can be imported at different level and they all need to be mocked.
# Example: torch.nn.Linear is actually torch.nn.modules.linear.Linear
# Therefore, torch.nn.Linear, torch.nn.modules.Linear, torch.nn.modules.linear.Linear
# needs to be replaced by the torch.nn.linear.modules.Linear _MetaClass
out = []
out.append(search(mod))
for name in submodules[1:]:
mod = getattr(mod, name)
out.append(search(mod))

# drop empty module
mods = [mod for mod in chain(*out) if mod]

# store the modules search so it doesn't have to be performed again for this class
__STORAGE_META__[str(subclass)] = (mods, subclass, _MetaClass)

# replace all subclass by its meta form
for mod in mods:
setattr(mod, subclass.__name__, _MetaClass)


@contextmanager
def init_meta_context() -> Generator:
tchaton marked this conversation as resolved.
Show resolved Hide resolved
_set_meta_device()
yield
_unset_meta_device()
16 changes: 15 additions & 1 deletion tests/plugins/test_deepspeed_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
from pytorch_lightning.plugins import DeepSpeedPlugin, DeepSpeedPrecisionPlugin
from pytorch_lightning.plugins.training_type.deepspeed import LightningDeepSpeedModule
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _DEEPSPEED_AVAILABLE
from pytorch_lightning.utilities.imports import _DEEPSPEED_AVAILABLE, _TORCH_META_AVAILABLE
from pytorch_lightning.utilities.meta import init_meta_context
from tests.helpers.boring_model import BoringModel, RandomDataset, RandomIterableDataset
from tests.helpers.datamodules import ClassifDataModule
from tests.helpers.runif import RunIf
Expand Down Expand Up @@ -1042,3 +1043,16 @@ def on_test_batch_start(
)
trainer.fit(model)
trainer.test(model)


@pytest.mark.skipif(not _TORCH_META_AVAILABLE, reason="the meta device context is supported from PyTorch 1.10.")
@RunIf(min_gpus=2, deepspeed=True, special=True)
def test_deepspeed_with_meta_device(tmpdir):
with init_meta_context():
model = BoringModel()
assert model.layer.weight.device.type == "meta"
trainer = Trainer(
default_root_dir=tmpdir, plugins=[DeepSpeedPlugin(stage=3)], gpus=2, fast_dev_run=True, precision=16
)
trainer.fit(model)
assert model.layer.weight.device.type == "cpu"
Loading