Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Resolve issue with wrapped ORTModule load_state_dict #7847

Merged
merged 4 commits into from
May 27, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions orttraining/orttraining/python/training/ortmodule/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,3 +97,10 @@ def _create_iobinding(io_binding, inputs, model, device):

for value_info in model.graph.output:
io_binding.bind_output(value_info.name, device.type, device_id=get_device_index(device))

class _PytorchModuleMetadata():
"""Encapsulates modules and allows easy access as required"""

def __init__(self, original_module, flattened_module):
self.original_module = original_module
self.flattened_module = flattened_module
114 changes: 85 additions & 29 deletions orttraining/orttraining/python/training/ortmodule/ortmodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,13 @@

from . import _io
from ._graph_execution_manager_factory import GraphExecutionManagerFactory
from ._utils import _PytorchModuleMetadata

from onnxruntime.training import register_custom_ops_pytorch_exporter

import functools
import torch
from typing import Iterator, Optional, Tuple, TypeVar
from typing import Iterator, Optional, Tuple, TypeVar, Set, Callable

# Needed to override PyTorch methods
T = TypeVar('T', bound='Module')
Expand Down Expand Up @@ -51,12 +52,11 @@ def _forward(self, *inputs, **kwargs):
register_custom_ops_pytorch_exporter.register_custom_op(is_ortmodule=True)

# User module is wrapped to use its initializers and save computed gradients
self._original_module = module
# along with the module that flattens both input and output of the user module
# inside _PytorchModuleMetadata
self._module_metadata = _PytorchModuleMetadata(module, _io._FlattenedModule(module))

# Get the module that flattens both input and output
self._flattened_module = _io._FlattenedModule(self._original_module)

self._execution_manager = GraphExecutionManagerFactory(self._flattened_module)
self._execution_manager = GraphExecutionManagerFactory(self._module_metadata.flattened_module)

# IMPORTANT: DO NOT add code here
# This declaration is for automatic document generation purposes only
Expand All @@ -65,57 +65,82 @@ def forward(self, *inputs, **kwargs):
'''Dummy documentation for forward method'''
...

def _apply(self, fn):
baijumeswani marked this conversation as resolved.
Show resolved Hide resolved
"""Override original method to delegate execution to the flattened PyTorch user module"""

# Delegation must happen to _flattened_module since methods depend on
# _apply to recursively apply the internal setting changes
self._module_metadata.flattened_module._apply(fn)
return self

def apply(self: T, fn: Callable[['Module'], None]) -> T:
"""Override original method to delegate execution to the flattened PyTorch user module"""

# Delegation must happen to _flattened_module since methods depend on
# apply to recursively apply the internal setting changes
self._module_metadata.flattened_module.apply(fn)
return self

def _is_training(self):
return self._flattened_module.training and torch.is_grad_enabled()
return self.training and torch.is_grad_enabled()

def train(self: T, mode: bool = True) -> T:
"""Override original method to delegate execution to the flattened PyTorch user module"""

# Since _modules is empty, the task needs to be delegated to _module.flattened_module.train
# which will recursively update the original_module
self.training = mode
self._module_metadata.flattened_module.train(mode)
return self

def state_dict(self, destination=None, prefix='', keep_vars=False):
"""Override original method to delegate execution to the base module"""
"""Override original method to delegate execution to the original PyTorch user module"""

# Override the state_dict() method so that the state dict key names
# do not contain the _flattened_module._original_module prefix
return self._original_module.state_dict(
# do not contain the flattened_module._original_module prefix
return self._module_metadata.original_module.state_dict(
destination=destination, prefix=prefix, keep_vars=keep_vars)

def load_state_dict(self, state_dict: 'OrderedDict[str, Tensor]',
strict: bool = True):
"""Override original method to delegate execution to the base module"""
"""Override original method to delegate execution to the original PyTorch user module"""

# Override the load_state_dict() method so that the loaded state dict
# key names does not need to contain the _flattened_module._original_module prefix
return self._original_module.load_state_dict(
# key names does not need to contain the _module.flattened_module._original_module prefix
return self._module_metadata.original_module.load_state_dict(
state_dict, strict=strict)

def register_buffer(self, name: str, tensor: Optional[torch.Tensor], persistent: bool = True) -> None:
"""Override original method to delegate execution to the base module"""
self._original_module.register_buffer(name, tensor, persistent=persistent)
"""Override original method to delegate execution to the original PyTorch user module"""
self._module_metadata.original_module.register_buffer(name, tensor, persistent=persistent)

def register_parameter(self, name: str, param: Optional[torch.nn.Parameter]) -> None:
"""Override original method to delegate execution to the base module"""
self._original_module.register_parameter(name, param)
"""Override original method to delegate execution to the original PyTorch user module"""
self._module_metadata.original_module.register_parameter(name, param)

def get_parameter(self, target: str) -> torch.nn.Parameter:
"""Override original method to delegate execution to the base module"""
return self._original_module.get_parameter(target)
"""Override original method to delegate execution to the original PyTorch user module"""
return self._module_metadata.original_module.get_parameter(target)

def get_buffer(self, target: str) -> torch.Tensor:
"""Override original method to delegate execution to the base module"""
return self._original_module.get_buffer(target)
"""Override original method to delegate execution to the original PyTorch user module"""
return self._module_metadata.original_module.get_buffer(target)

def parameters(self, recurse: bool = True) -> Iterator[torch.nn.Parameter]:
"""Override original method to delegate execution to the base module"""
yield from self._original_module.parameters(recurse=recurse)
"""Override original method to delegate execution to the original PyTorch user module"""
yield from self._module_metadata.original_module.parameters(recurse=recurse)

def named_parameters(self, prefix: str = '', recurse: bool = True) -> Iterator[Tuple[str, torch.nn.Parameter]]:
"""Override original method to delegate execution to the base module"""
yield from self._original_module.named_parameters(prefix=prefix, recurse=recurse)
"""Override original method to delegate execution to the original PyTorch user module"""
yield from self._module_metadata.original_module.named_parameters(prefix=prefix, recurse=recurse)

def buffers(self, recurse: bool = True) -> Iterator[torch.Tensor]:
"""Override original method to delegate execution to the base module"""
yield from self._original_module.buffers(recurse=recurse)
"""Override original method to delegate execution to the original PyTorch user module"""
yield from self._module_metadata.original_module.buffers(recurse=recurse)

def named_buffers(self, prefix: str = '', recurse: bool = True) -> Iterator[Tuple[str, torch.Tensor]]:
"""Override original method to delegate execution to the base module"""
yield from self._original_module.named_buffers(prefix=prefix, recurse=recurse)
"""Override original method to delegate execution to the original PyTorch user module"""
yield from self._module_metadata.original_module.named_buffers(prefix=prefix, recurse=recurse)

def _replicate_for_data_parallel(self):
"""Raises a NotImplementedError exception since ORTModule is not compatible with torch.nn.DataParallel
Expand All @@ -135,3 +160,34 @@ def _replicate_for_data_parallel(self):

raise NotImplementedError("ORTModule is not compatible with torch.nn.DataParallel. "
"Please use torch.nn.parallel.DistributedDataParallel instead.")

def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs):
"""Override original method to delegate execution to the original PyTorch user module"""

# PyTorch load_state_dict implementation does not recursively call load_state_dict on its sub-modules.
# Instead, it creates a recursive function and invokes _load_from_state_dict on all child modules.
# For the scenario where an ORTModule is a sub-module of another module, loading of the state
# dictionary requires the _load_from_state_dict to be overridden to prevent an error.
self._module_metadata.original_module._load_from_state_dict(state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs)

def named_children(self) -> Iterator[Tuple[str, 'Module']]:
"""Override original method to delegate execution to the original PyTorch user module"""

yield from self._module_metadata.original_module.named_children()

def modules(self) -> Iterator['Module']:
"""Override original method to delegate execution to the original PyTorch user module"""

yield from self._module_metadata.original_module.modules()

def named_modules(self, memo: Optional[Set['Module']] = None, prefix: str = ''):
"""Override original method to delegate execution to the original PyTorch user module"""

yield from self._module_metadata.original_module.named_modules(memo, prefix)

def add_module(self, name: str, module: Optional['Module']) -> None:
"""Raises a NotImplementedError exception since ORTModule does not support adding modules to it"""

raise NotImplementedError("ORTModule does not support adding modules to it.")
Original file line number Diff line number Diff line change
Expand Up @@ -1666,26 +1666,26 @@ def test_model_initializer_requires_grad_changes_from_one_forward_to_next():
model.fc1.requires_grad_(True)
model = ORTModule(model)
x = torch.randn(N, D_in, device=device)
assert model._original_module.fc1.weight.grad is None
assert model._original_module.fc1.bias.grad is None
assert model._module_metadata.original_module.fc1.weight.grad is None
assert model._module_metadata.original_module.fc1.bias.grad is None

# Make sure no exception is raised
output = model(x)
loss = torch.sum(output)
loss.backward()
training_session1 = model._execution_manager(model._is_training())._execution_agent
weight_grad_2 = model._original_module.fc1.weight.grad
bias_grad_2 = model._original_module.fc1.bias.grad
weight_grad_2 = model._module_metadata.original_module.fc1.weight.grad
bias_grad_2 = model._module_metadata.original_module.fc1.bias.grad
assert weight_grad_2 is not None
assert bias_grad_2 is not None

model._original_module.fc1.requires_grad_(False)
model._module_metadata.original_module.fc1.requires_grad_(False)
output = model(x)
loss = torch.sum(output)
loss.backward()
training_session2 = model._execution_manager(model._is_training())._execution_agent
weight_grad_3 = model._original_module.fc1.weight.grad
bias_grad_3 = model._original_module.fc1.bias.grad
weight_grad_3 = model._module_metadata.original_module.fc1.weight.grad
bias_grad_3 = model._module_metadata.original_module.fc1.bias.grad

assert training_session1 != training_session2
assert torch.equal(weight_grad_2, weight_grad_3)
Expand Down Expand Up @@ -2619,3 +2619,31 @@ def test_unused_parameters_does_not_unnecssarily_reinitilize(model):
{})

assert not training_manager._reinitialize_graph_builder(input_info)

def test_load_state_dict_for_wrapped_ortmodule():
class WrapperModule(torch.nn.Module):
def __init__(self, ortmodule):
super(WrapperModule, self).__init__()
self._ortmodule = ortmodule

def forward(self, x):
return self._ortmodule(x)

device = 'cuda'
N, D_in, H, D_out = 64, 784, 500, 10
model = NeuralNetSinglePositionalArgument(D_in, H, D_out).to(device)
model = ORTModule(copy.deepcopy(model))
wrapper_module = WrapperModule(model)
x = torch.randn(N, D_in, device=device)
_ = wrapper_module(x)

state_dict1 = wrapper_module.state_dict()
list(next(iter(state_dict1.items())))[1] += 10
wrapper_module.load_state_dict(state_dict1)
state_dict2 = wrapper_module.state_dict()

assert state_dict1
assert len(state_dict1.keys()) == len(state_dict2.keys())
for param_name, param_value in state_dict1.items():
assert param_name in state_dict2
assert torch.equal(param_value, state_dict2[param_name])