diff --git a/orttraining/orttraining/python/training/ortmodule/_utils.py b/orttraining/orttraining/python/training/ortmodule/_utils.py index 98d553bcac0c..751c5f1a46dd 100644 --- a/orttraining/orttraining/python/training/ortmodule/_utils.py +++ b/orttraining/orttraining/python/training/ortmodule/_utils.py @@ -97,3 +97,10 @@ def _create_iobinding(io_binding, inputs, model, device): for value_info in model.graph.output: io_binding.bind_output(value_info.name, device.type, device_id=get_device_index(device)) + +class _PytorchModuleMetadata(): + """Encapsulates modules and allows easy access as required""" + + def __init__(self, original_module, flattened_module): + self.original_module = original_module + self.flattened_module = flattened_module diff --git a/orttraining/orttraining/python/training/ortmodule/ortmodule.py b/orttraining/orttraining/python/training/ortmodule/ortmodule.py index 62d1c7ee4627..bfdc1c5631a8 100644 --- a/orttraining/orttraining/python/training/ortmodule/ortmodule.py +++ b/orttraining/orttraining/python/training/ortmodule/ortmodule.py @@ -5,12 +5,13 @@ from . import _io from ._graph_execution_manager_factory import GraphExecutionManagerFactory +from ._utils import _PytorchModuleMetadata from onnxruntime.training import register_custom_ops_pytorch_exporter import functools import torch -from typing import Iterator, Optional, Tuple, TypeVar +from typing import Iterator, Optional, Tuple, TypeVar, Set, Callable # Needed to override PyTorch methods T = TypeVar('T', bound='Module') @@ -51,12 +52,11 @@ def _forward(self, *inputs, **kwargs): register_custom_ops_pytorch_exporter.register_custom_op(is_ortmodule=True) # User module is wrapped to use its initializers and save computed gradients - self._original_module = module + # along with the module that flattens both input and output of the user module + # inside _PytorchModuleMetadata + self._module_metadata = _PytorchModuleMetadata(module, _io._FlattenedModule(module)) - # Get the module that flattens both input and output - self._flattened_module = _io._FlattenedModule(self._original_module) - - self._execution_manager = GraphExecutionManagerFactory(self._flattened_module) + self._execution_manager = GraphExecutionManagerFactory(self._module_metadata.flattened_module) # IMPORTANT: DO NOT add code here # This declaration is for automatic document generation purposes only @@ -65,57 +65,82 @@ def forward(self, *inputs, **kwargs): '''Dummy documentation for forward method''' ... + def _apply(self, fn): + """Override original method to delegate execution to the flattened PyTorch user module""" + + # Delegation must happen to _flattened_module since methods depend on + # _apply to recursively apply the internal setting changes + self._module_metadata.flattened_module._apply(fn) + return self + + def apply(self: T, fn: Callable[['Module'], None]) -> T: + """Override original method to delegate execution to the flattened PyTorch user module""" + + # Delegation must happen to _flattened_module since methods depend on + # apply to recursively apply the internal setting changes + self._module_metadata.flattened_module.apply(fn) + return self + def _is_training(self): - return self._flattened_module.training and torch.is_grad_enabled() + return self.training and torch.is_grad_enabled() + + def train(self: T, mode: bool = True) -> T: + """Override original method to delegate execution to the flattened PyTorch user module""" + + # Since _modules is empty, the task needs to be delegated to _module.flattened_module.train + # which will recursively update the original_module + self.training = mode + self._module_metadata.flattened_module.train(mode) + return self def state_dict(self, destination=None, prefix='', keep_vars=False): - """Override original method to delegate execution to the base module""" + """Override original method to delegate execution to the original PyTorch user module""" # Override the state_dict() method so that the state dict key names - # do not contain the _flattened_module._original_module prefix - return self._original_module.state_dict( + # do not contain the flattened_module._original_module prefix + return self._module_metadata.original_module.state_dict( destination=destination, prefix=prefix, keep_vars=keep_vars) def load_state_dict(self, state_dict: 'OrderedDict[str, Tensor]', strict: bool = True): - """Override original method to delegate execution to the base module""" + """Override original method to delegate execution to the original PyTorch user module""" # Override the load_state_dict() method so that the loaded state dict - # key names does not need to contain the _flattened_module._original_module prefix - return self._original_module.load_state_dict( + # key names does not need to contain the _module.flattened_module._original_module prefix + return self._module_metadata.original_module.load_state_dict( state_dict, strict=strict) def register_buffer(self, name: str, tensor: Optional[torch.Tensor], persistent: bool = True) -> None: - """Override original method to delegate execution to the base module""" - self._original_module.register_buffer(name, tensor, persistent=persistent) + """Override original method to delegate execution to the original PyTorch user module""" + self._module_metadata.original_module.register_buffer(name, tensor, persistent=persistent) def register_parameter(self, name: str, param: Optional[torch.nn.Parameter]) -> None: - """Override original method to delegate execution to the base module""" - self._original_module.register_parameter(name, param) + """Override original method to delegate execution to the original PyTorch user module""" + self._module_metadata.original_module.register_parameter(name, param) def get_parameter(self, target: str) -> torch.nn.Parameter: - """Override original method to delegate execution to the base module""" - return self._original_module.get_parameter(target) + """Override original method to delegate execution to the original PyTorch user module""" + return self._module_metadata.original_module.get_parameter(target) def get_buffer(self, target: str) -> torch.Tensor: - """Override original method to delegate execution to the base module""" - return self._original_module.get_buffer(target) + """Override original method to delegate execution to the original PyTorch user module""" + return self._module_metadata.original_module.get_buffer(target) def parameters(self, recurse: bool = True) -> Iterator[torch.nn.Parameter]: - """Override original method to delegate execution to the base module""" - yield from self._original_module.parameters(recurse=recurse) + """Override original method to delegate execution to the original PyTorch user module""" + yield from self._module_metadata.original_module.parameters(recurse=recurse) def named_parameters(self, prefix: str = '', recurse: bool = True) -> Iterator[Tuple[str, torch.nn.Parameter]]: - """Override original method to delegate execution to the base module""" - yield from self._original_module.named_parameters(prefix=prefix, recurse=recurse) + """Override original method to delegate execution to the original PyTorch user module""" + yield from self._module_metadata.original_module.named_parameters(prefix=prefix, recurse=recurse) def buffers(self, recurse: bool = True) -> Iterator[torch.Tensor]: - """Override original method to delegate execution to the base module""" - yield from self._original_module.buffers(recurse=recurse) + """Override original method to delegate execution to the original PyTorch user module""" + yield from self._module_metadata.original_module.buffers(recurse=recurse) def named_buffers(self, prefix: str = '', recurse: bool = True) -> Iterator[Tuple[str, torch.Tensor]]: - """Override original method to delegate execution to the base module""" - yield from self._original_module.named_buffers(prefix=prefix, recurse=recurse) + """Override original method to delegate execution to the original PyTorch user module""" + yield from self._module_metadata.original_module.named_buffers(prefix=prefix, recurse=recurse) def _replicate_for_data_parallel(self): """Raises a NotImplementedError exception since ORTModule is not compatible with torch.nn.DataParallel @@ -135,3 +160,34 @@ def _replicate_for_data_parallel(self): raise NotImplementedError("ORTModule is not compatible with torch.nn.DataParallel. " "Please use torch.nn.parallel.DistributedDataParallel instead.") + + def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, + missing_keys, unexpected_keys, error_msgs): + """Override original method to delegate execution to the original PyTorch user module""" + + # PyTorch load_state_dict implementation does not recursively call load_state_dict on its sub-modules. + # Instead, it creates a recursive function and invokes _load_from_state_dict on all child modules. + # For the scenario where an ORTModule is a sub-module of another module, loading of the state + # dictionary requires the _load_from_state_dict to be overridden to prevent an error. + self._module_metadata.original_module._load_from_state_dict(state_dict, prefix, local_metadata, strict, + missing_keys, unexpected_keys, error_msgs) + + def named_children(self) -> Iterator[Tuple[str, 'Module']]: + """Override original method to delegate execution to the original PyTorch user module""" + + yield from self._module_metadata.original_module.named_children() + + def modules(self) -> Iterator['Module']: + """Override original method to delegate execution to the original PyTorch user module""" + + yield from self._module_metadata.original_module.modules() + + def named_modules(self, memo: Optional[Set['Module']] = None, prefix: str = ''): + """Override original method to delegate execution to the original PyTorch user module""" + + yield from self._module_metadata.original_module.named_modules(memo, prefix) + + def add_module(self, name: str, module: Optional['Module']) -> None: + """Raises a NotImplementedError exception since ORTModule does not support adding modules to it""" + + raise NotImplementedError("ORTModule does not support adding modules to it.") diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py index 3a6607895d4c..a66d609809ba 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py @@ -1666,26 +1666,26 @@ def test_model_initializer_requires_grad_changes_from_one_forward_to_next(): model.fc1.requires_grad_(True) model = ORTModule(model) x = torch.randn(N, D_in, device=device) - assert model._original_module.fc1.weight.grad is None - assert model._original_module.fc1.bias.grad is None + assert model._module_metadata.original_module.fc1.weight.grad is None + assert model._module_metadata.original_module.fc1.bias.grad is None # Make sure no exception is raised output = model(x) loss = torch.sum(output) loss.backward() training_session1 = model._execution_manager(model._is_training())._execution_agent - weight_grad_2 = model._original_module.fc1.weight.grad - bias_grad_2 = model._original_module.fc1.bias.grad + weight_grad_2 = model._module_metadata.original_module.fc1.weight.grad + bias_grad_2 = model._module_metadata.original_module.fc1.bias.grad assert weight_grad_2 is not None assert bias_grad_2 is not None - model._original_module.fc1.requires_grad_(False) + model._module_metadata.original_module.fc1.requires_grad_(False) output = model(x) loss = torch.sum(output) loss.backward() training_session2 = model._execution_manager(model._is_training())._execution_agent - weight_grad_3 = model._original_module.fc1.weight.grad - bias_grad_3 = model._original_module.fc1.bias.grad + weight_grad_3 = model._module_metadata.original_module.fc1.weight.grad + bias_grad_3 = model._module_metadata.original_module.fc1.bias.grad assert training_session1 != training_session2 assert torch.equal(weight_grad_2, weight_grad_3) @@ -2619,3 +2619,31 @@ def test_unused_parameters_does_not_unnecssarily_reinitilize(model): {}) assert not training_manager._reinitialize_graph_builder(input_info) + +def test_load_state_dict_for_wrapped_ortmodule(): + class WrapperModule(torch.nn.Module): + def __init__(self, ortmodule): + super(WrapperModule, self).__init__() + self._ortmodule = ortmodule + + def forward(self, x): + return self._ortmodule(x) + + device = 'cuda' + N, D_in, H, D_out = 64, 784, 500, 10 + model = NeuralNetSinglePositionalArgument(D_in, H, D_out).to(device) + model = ORTModule(copy.deepcopy(model)) + wrapper_module = WrapperModule(model) + x = torch.randn(N, D_in, device=device) + _ = wrapper_module(x) + + state_dict1 = wrapper_module.state_dict() + list(next(iter(state_dict1.items())))[1] += 10 + wrapper_module.load_state_dict(state_dict1) + state_dict2 = wrapper_module.state_dict() + + assert state_dict1 + assert len(state_dict1.keys()) == len(state_dict2.keys()) + for param_name, param_value in state_dict1.items(): + assert param_name in state_dict2 + assert torch.equal(param_value, state_dict2[param_name])