From 754f098ae483e2ae124df6df70fcc0a4a3af2183 Mon Sep 17 00:00:00 2001 From: Baiju Meswani Date: Wed, 26 May 2021 22:03:52 +0000 Subject: [PATCH 1/4] Encapsulate children modules inside a ModuleAccessor object to prevent erroneuos iteration over children while loading the state dictionary --- .../python/training/ortmodule/_utils.py | 17 +++++ .../python/training/ortmodule/ortmodule.py | 76 +++++++++++++++---- .../python/orttraining_test_ortmodule_api.py | 34 +++++++-- 3 files changed, 105 insertions(+), 22 deletions(-) diff --git a/orttraining/orttraining/python/training/ortmodule/_utils.py b/orttraining/orttraining/python/training/ortmodule/_utils.py index 98d553bcac0c..5c1cfd5bf373 100644 --- a/orttraining/orttraining/python/training/ortmodule/_utils.py +++ b/orttraining/orttraining/python/training/ortmodule/_utils.py @@ -97,3 +97,20 @@ def _create_iobinding(io_binding, inputs, model, device): for value_info in model.graph.output: io_binding.bind_output(value_info.name, device.type, device_id=get_device_index(device)) + +class ModuleAccessor(): + """Encapsulates modules and allows easy access as required""" + + def __init__(self, original_module, flattened_module): + self._original_module = original_module + self._flattened_module = flattened_module + + def original_module(self): + """Returns the original PyTorch module""" + + return self._original_module + + def flattened_module(self): + """Returns the flattened input and output PyTorch module""" + + return self._flattened_module diff --git a/orttraining/orttraining/python/training/ortmodule/ortmodule.py b/orttraining/orttraining/python/training/ortmodule/ortmodule.py index 62d1c7ee4627..18c20032ba54 100644 --- a/orttraining/orttraining/python/training/ortmodule/ortmodule.py +++ b/orttraining/orttraining/python/training/ortmodule/ortmodule.py @@ -5,6 +5,7 @@ from . import _io from ._graph_execution_manager_factory import GraphExecutionManagerFactory +from ._utils import ModuleAccessor from onnxruntime.training import register_custom_ops_pytorch_exporter @@ -51,10 +52,9 @@ def _forward(self, *inputs, **kwargs): register_custom_ops_pytorch_exporter.register_custom_op(is_ortmodule=True) # User module is wrapped to use its initializers and save computed gradients - self._original_module = module - - # Get the module that flattens both input and output - self._flattened_module = _io._FlattenedModule(self._original_module) + # along with the module that flattens both input and output of the user module + # inside ModuleAccessor + self._module = ModuleAccessor(module, _io._FlattenedModule(module)) self._execution_manager = GraphExecutionManagerFactory(self._flattened_module) @@ -65,15 +65,32 @@ def forward(self, *inputs, **kwargs): '''Dummy documentation for forward method''' ... + def _apply(self, fn): + """Override original method to delegate execution to the base module""" + + # Delegation must happen to _flattened_module since methods depend on + # _apply to recursively apply the internal setting changes + self._flattened_module._apply(fn) + return self + def _is_training(self): - return self._flattened_module.training and torch.is_grad_enabled() + return self.training and torch.is_grad_enabled() + + def train(self: T, mode: bool = True) -> T: + """Override original method to delegate execution to the base module""" + + # Since _modules is empty, the task needs to be delegated to _flattened_module.train + # which will recursively update the original_module + self.training = mode + self._flattened_module.train(mode) + return self def state_dict(self, destination=None, prefix='', keep_vars=False): """Override original method to delegate execution to the base module""" # Override the state_dict() method so that the state dict key names # do not contain the _flattened_module._original_module prefix - return self._original_module.state_dict( + return self.original_module.state_dict( destination=destination, prefix=prefix, keep_vars=keep_vars) def load_state_dict(self, state_dict: 'OrderedDict[str, Tensor]', @@ -82,40 +99,40 @@ def load_state_dict(self, state_dict: 'OrderedDict[str, Tensor]', # Override the load_state_dict() method so that the loaded state dict # key names does not need to contain the _flattened_module._original_module prefix - return self._original_module.load_state_dict( + return self.original_module.load_state_dict( state_dict, strict=strict) def register_buffer(self, name: str, tensor: Optional[torch.Tensor], persistent: bool = True) -> None: """Override original method to delegate execution to the base module""" - self._original_module.register_buffer(name, tensor, persistent=persistent) + self.original_module.register_buffer(name, tensor, persistent=persistent) def register_parameter(self, name: str, param: Optional[torch.nn.Parameter]) -> None: """Override original method to delegate execution to the base module""" - self._original_module.register_parameter(name, param) + self.original_module.register_parameter(name, param) def get_parameter(self, target: str) -> torch.nn.Parameter: """Override original method to delegate execution to the base module""" - return self._original_module.get_parameter(target) + return self.original_module.get_parameter(target) def get_buffer(self, target: str) -> torch.Tensor: """Override original method to delegate execution to the base module""" - return self._original_module.get_buffer(target) + return self.original_module.get_buffer(target) def parameters(self, recurse: bool = True) -> Iterator[torch.nn.Parameter]: """Override original method to delegate execution to the base module""" - yield from self._original_module.parameters(recurse=recurse) + yield from self.original_module.parameters(recurse=recurse) def named_parameters(self, prefix: str = '', recurse: bool = True) -> Iterator[Tuple[str, torch.nn.Parameter]]: """Override original method to delegate execution to the base module""" - yield from self._original_module.named_parameters(prefix=prefix, recurse=recurse) + yield from self.original_module.named_parameters(prefix=prefix, recurse=recurse) def buffers(self, recurse: bool = True) -> Iterator[torch.Tensor]: """Override original method to delegate execution to the base module""" - yield from self._original_module.buffers(recurse=recurse) + yield from self.original_module.buffers(recurse=recurse) def named_buffers(self, prefix: str = '', recurse: bool = True) -> Iterator[Tuple[str, torch.Tensor]]: """Override original method to delegate execution to the base module""" - yield from self._original_module.named_buffers(prefix=prefix, recurse=recurse) + yield from self.original_module.named_buffers(prefix=prefix, recurse=recurse) def _replicate_for_data_parallel(self): """Raises a NotImplementedError exception since ORTModule is not compatible with torch.nn.DataParallel @@ -135,3 +152,32 @@ def _replicate_for_data_parallel(self): raise NotImplementedError("ORTModule is not compatible with torch.nn.DataParallel. " "Please use torch.nn.parallel.DistributedDataParallel instead.") + + def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, + missing_keys, unexpected_keys, error_msgs): + """Override original method to delegate execution to the base module""" + + self.original_module._load_from_state_dict(state_dict, prefix, local_metadata, strict, + missing_keys, unexpected_keys, error_msgs) + + def named_children(self): + """Override original method to delegate execution to the base module""" + + yield from self.original_module.named_children() + + @property + def original_module(self): + """Accessor for the original PyTorch module + + Users can retrieve the original module by using this property + ort_model = ORTModule(model) + original_model = ort_model.original_module + """ + + return self._module.original_module() + + @property + def _flattened_module(self): + """Accessor for the flattened module""" + + return self._module.flattened_module() diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py index 3a6607895d4c..ff6c9ce13a76 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py @@ -1666,26 +1666,26 @@ def test_model_initializer_requires_grad_changes_from_one_forward_to_next(): model.fc1.requires_grad_(True) model = ORTModule(model) x = torch.randn(N, D_in, device=device) - assert model._original_module.fc1.weight.grad is None - assert model._original_module.fc1.bias.grad is None + assert model.original_module.fc1.weight.grad is None + assert model.original_module.fc1.bias.grad is None # Make sure no exception is raised output = model(x) loss = torch.sum(output) loss.backward() training_session1 = model._execution_manager(model._is_training())._execution_agent - weight_grad_2 = model._original_module.fc1.weight.grad - bias_grad_2 = model._original_module.fc1.bias.grad + weight_grad_2 = model.original_module.fc1.weight.grad + bias_grad_2 = model.original_module.fc1.bias.grad assert weight_grad_2 is not None assert bias_grad_2 is not None - model._original_module.fc1.requires_grad_(False) + model.original_module.fc1.requires_grad_(False) output = model(x) loss = torch.sum(output) loss.backward() training_session2 = model._execution_manager(model._is_training())._execution_agent - weight_grad_3 = model._original_module.fc1.weight.grad - bias_grad_3 = model._original_module.fc1.bias.grad + weight_grad_3 = model.original_module.fc1.weight.grad + bias_grad_3 = model.original_module.fc1.bias.grad assert training_session1 != training_session2 assert torch.equal(weight_grad_2, weight_grad_3) @@ -2619,3 +2619,23 @@ def test_unused_parameters_does_not_unnecssarily_reinitilize(model): {}) assert not training_manager._reinitialize_graph_builder(input_info) + +def test_load_state_dict_for_wrapped_ortmodule(): + class WrapperModule(torch.nn.Module): + def __init__(self, ortmodule): + super(WrapperModule, self).__init__() + self._ortmodule = ortmodule + + def forward(self, x): + return self._ortmodule(x) + + device = 'cuda' + N, D_in, H, D_out = 64, 784, 500, 10 + model = NeuralNetSinglePositionalArgument(D_in, H, D_out).to(device) + model = ORTModule(copy.deepcopy(model)) + wrapper_module = WrapperModule(model) + x = torch.randn(N, D_in, device=device) + _ = wrapper_module(x) + + state_dict = wrapper_module.state_dict() + wrapper_module.load_state_dict(state_dict) From 63f07073cfdab39744b5a920e571c1b6f9a09be2 Mon Sep 17 00:00:00 2001 From: Baiju Meswani Date: Thu, 27 May 2021 00:37:12 +0000 Subject: [PATCH 2/4] Add named_models, models, apply methods, change ModuleAccessor to ModuleMetadata and modify unit tests --- .../python/training/ortmodule/_utils.py | 2 +- .../python/training/ortmodule/ortmodule.py | 104 +++++++++--------- .../python/orttraining_test_ortmodule_api.py | 26 +++-- 3 files changed, 73 insertions(+), 59 deletions(-) diff --git a/orttraining/orttraining/python/training/ortmodule/_utils.py b/orttraining/orttraining/python/training/ortmodule/_utils.py index 5c1cfd5bf373..543d8fe3c8d3 100644 --- a/orttraining/orttraining/python/training/ortmodule/_utils.py +++ b/orttraining/orttraining/python/training/ortmodule/_utils.py @@ -98,7 +98,7 @@ def _create_iobinding(io_binding, inputs, model, device): for value_info in model.graph.output: io_binding.bind_output(value_info.name, device.type, device_id=get_device_index(device)) -class ModuleAccessor(): +class ModuleMetadata(): """Encapsulates modules and allows easy access as required""" def __init__(self, original_module, flattened_module): diff --git a/orttraining/orttraining/python/training/ortmodule/ortmodule.py b/orttraining/orttraining/python/training/ortmodule/ortmodule.py index 18c20032ba54..97ca7ee77191 100644 --- a/orttraining/orttraining/python/training/ortmodule/ortmodule.py +++ b/orttraining/orttraining/python/training/ortmodule/ortmodule.py @@ -5,13 +5,13 @@ from . import _io from ._graph_execution_manager_factory import GraphExecutionManagerFactory -from ._utils import ModuleAccessor +from ._utils import ModuleMetadata from onnxruntime.training import register_custom_ops_pytorch_exporter import functools import torch -from typing import Iterator, Optional, Tuple, TypeVar +from typing import Iterator, Optional, Tuple, TypeVar, Set, Callable # Needed to override PyTorch methods T = TypeVar('T', bound='Module') @@ -53,10 +53,10 @@ def _forward(self, *inputs, **kwargs): # User module is wrapped to use its initializers and save computed gradients # along with the module that flattens both input and output of the user module - # inside ModuleAccessor - self._module = ModuleAccessor(module, _io._FlattenedModule(module)) + # inside ModuleMetadata + self._module = ModuleMetadata(module, _io._FlattenedModule(module)) - self._execution_manager = GraphExecutionManagerFactory(self._flattened_module) + self._execution_manager = GraphExecutionManagerFactory(self._module.flattened_module()) # IMPORTANT: DO NOT add code here # This declaration is for automatic document generation purposes only @@ -66,73 +66,81 @@ def forward(self, *inputs, **kwargs): ... def _apply(self, fn): - """Override original method to delegate execution to the base module""" + """Override original method to delegate execution to the flattened PyTorch user module""" # Delegation must happen to _flattened_module since methods depend on # _apply to recursively apply the internal setting changes - self._flattened_module._apply(fn) + self._module.flattened_module()._apply(fn) + return self + + def apply(self: T, fn: Callable[['Module'], None]) -> T: + """Override original method to delegate execution to the flattened PyTorch user module""" + + # Delegation must happen to _flattened_module since methods depend on + # apply to recursively apply the internal setting changes + self._module.flattened_module().apply(fn) return self def _is_training(self): return self.training and torch.is_grad_enabled() def train(self: T, mode: bool = True) -> T: - """Override original method to delegate execution to the base module""" + """Override original method to delegate execution to the flattened PyTorch user module""" - # Since _modules is empty, the task needs to be delegated to _flattened_module.train + # Since _modules is empty, the task needs to be delegated to _module.flattened_module().train # which will recursively update the original_module self.training = mode - self._flattened_module.train(mode) + self._module.flattened_module().train(mode) return self def state_dict(self, destination=None, prefix='', keep_vars=False): - """Override original method to delegate execution to the base module""" + """Override original method to delegate execution to the original PyTorch user module""" # Override the state_dict() method so that the state dict key names - # do not contain the _flattened_module._original_module prefix - return self.original_module.state_dict( + # do not contain the flattened_module._original_module prefix + return self._module.original_module().state_dict( destination=destination, prefix=prefix, keep_vars=keep_vars) def load_state_dict(self, state_dict: 'OrderedDict[str, Tensor]', strict: bool = True): - """Override original method to delegate execution to the base module""" + """Override original method to delegate execution to the original PyTorch user module""" # Override the load_state_dict() method so that the loaded state dict - # key names does not need to contain the _flattened_module._original_module prefix - return self.original_module.load_state_dict( + # key names does not need to contain the _module.flattened_module()._original_module prefix + return self._module.original_module().load_state_dict( state_dict, strict=strict) def register_buffer(self, name: str, tensor: Optional[torch.Tensor], persistent: bool = True) -> None: - """Override original method to delegate execution to the base module""" - self.original_module.register_buffer(name, tensor, persistent=persistent) + """Override original method to delegate execution to the original PyTorch user module""" + self._module.original_module().register_buffer(name, tensor, persistent=persistent) def register_parameter(self, name: str, param: Optional[torch.nn.Parameter]) -> None: - """Override original method to delegate execution to the base module""" - self.original_module.register_parameter(name, param) + """Override original method to delegate execution to the original PyTorch user module""" + self._module.original_module().register_parameter(name, param) def get_parameter(self, target: str) -> torch.nn.Parameter: - """Override original method to delegate execution to the base module""" - return self.original_module.get_parameter(target) + """Override original method to delegate execution to the original PyTorch user module""" + return self._module.original_module().get_parameter(target) def get_buffer(self, target: str) -> torch.Tensor: - """Override original method to delegate execution to the base module""" - return self.original_module.get_buffer(target) + """Override original method to delegate execution to the original PyTorch user module""" + return self._module.original_module().get_buffer(target) def parameters(self, recurse: bool = True) -> Iterator[torch.nn.Parameter]: - """Override original method to delegate execution to the base module""" - yield from self.original_module.parameters(recurse=recurse) + """Override original method to delegate execution to the original PyTorch user module""" + yield from self._module.original_module().parameters(recurse=recurse) def named_parameters(self, prefix: str = '', recurse: bool = True) -> Iterator[Tuple[str, torch.nn.Parameter]]: - """Override original method to delegate execution to the base module""" - yield from self.original_module.named_parameters(prefix=prefix, recurse=recurse) + """Override original method to delegate execution to the original PyTorch user module""" + yield from self._module.original_module().named_parameters(prefix=prefix, recurse=recurse) def buffers(self, recurse: bool = True) -> Iterator[torch.Tensor]: - """Override original method to delegate execution to the base module""" - yield from self.original_module.buffers(recurse=recurse) + """Override original method to delegate execution to the original PyTorch user module""" + yield from self._module.original_module().buffers(recurse=recurse) def named_buffers(self, prefix: str = '', recurse: bool = True) -> Iterator[Tuple[str, torch.Tensor]]: - """Override original method to delegate execution to the base module""" - yield from self.original_module.named_buffers(prefix=prefix, recurse=recurse) + """Override original method to delegate execution to the original PyTorch user module""" + yield from self._module.original_module().named_buffers(prefix=prefix, recurse=recurse) def _replicate_for_data_parallel(self): """Raises a NotImplementedError exception since ORTModule is not compatible with torch.nn.DataParallel @@ -155,29 +163,27 @@ def _replicate_for_data_parallel(self): def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): - """Override original method to delegate execution to the base module""" + """Override original method to delegate execution to the original PyTorch user module""" - self.original_module._load_from_state_dict(state_dict, prefix, local_metadata, strict, + self._module.original_module()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) - def named_children(self): - """Override original method to delegate execution to the base module""" + def named_children(self) -> Iterator[Tuple[str, 'Module']]: + """Override original method to delegate execution to the original PyTorch user module""" - yield from self.original_module.named_children() + yield from self._module.original_module().named_children() - @property - def original_module(self): - """Accessor for the original PyTorch module + def modules(self) -> Iterator['Module']: + """Override original method to delegate execution to the original PyTorch user module""" - Users can retrieve the original module by using this property - ort_model = ORTModule(model) - original_model = ort_model.original_module - """ + yield from self._module.original_module().modules() + + def named_modules(self, memo: Optional[Set['Module']] = None, prefix: str = ''): + """Override original method to delegate execution to the original PyTorch user module""" - return self._module.original_module() + yield from self._module.original_module().named_modules(memo, prefix) - @property - def _flattened_module(self): - """Accessor for the flattened module""" + def add_module(self, name: str, module: Optional['Module']) -> None: + """Override original method to delegate execution to the original PyTorch user module""" - return self._module.flattened_module() + self._module.original_module().add_module(name, module) diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py index ff6c9ce13a76..a95c712356d0 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py @@ -1666,26 +1666,26 @@ def test_model_initializer_requires_grad_changes_from_one_forward_to_next(): model.fc1.requires_grad_(True) model = ORTModule(model) x = torch.randn(N, D_in, device=device) - assert model.original_module.fc1.weight.grad is None - assert model.original_module.fc1.bias.grad is None + assert model._module.original_module().fc1.weight.grad is None + assert model._module.original_module().fc1.bias.grad is None # Make sure no exception is raised output = model(x) loss = torch.sum(output) loss.backward() training_session1 = model._execution_manager(model._is_training())._execution_agent - weight_grad_2 = model.original_module.fc1.weight.grad - bias_grad_2 = model.original_module.fc1.bias.grad + weight_grad_2 = model._module.original_module().fc1.weight.grad + bias_grad_2 = model._module.original_module().fc1.bias.grad assert weight_grad_2 is not None assert bias_grad_2 is not None - model.original_module.fc1.requires_grad_(False) + model._module.original_module().fc1.requires_grad_(False) output = model(x) loss = torch.sum(output) loss.backward() training_session2 = model._execution_manager(model._is_training())._execution_agent - weight_grad_3 = model.original_module.fc1.weight.grad - bias_grad_3 = model.original_module.fc1.bias.grad + weight_grad_3 = model._module.original_module().fc1.weight.grad + bias_grad_3 = model._module.original_module().fc1.bias.grad assert training_session1 != training_session2 assert torch.equal(weight_grad_2, weight_grad_3) @@ -2637,5 +2637,13 @@ def forward(self, x): x = torch.randn(N, D_in, device=device) _ = wrapper_module(x) - state_dict = wrapper_module.state_dict() - wrapper_module.load_state_dict(state_dict) + state_dict1 = wrapper_module.state_dict() + list(next(iter(state_dict1.items())))[1] += 10 + wrapper_module.load_state_dict(state_dict1) + state_dict2 = wrapper_module.state_dict() + + assert state_dict1 + assert len(state_dict1.keys()) == len(state_dict2.keys()) + for param_name, param_value in state_dict1.items(): + assert param_name in state_dict2 + assert torch.equal(param_value, state_dict2[param_name]) From 741f1022711fa7a636cc435278039dbf7224a36b Mon Sep 17 00:00:00 2001 From: Baiju Meswani Date: Thu, 27 May 2021 17:23:41 +0000 Subject: [PATCH 3/4] Change ModuleMetadata module getter logic, raise NotImplementedError for add_modules --- .../python/training/ortmodule/_utils.py | 14 +----- .../python/training/ortmodule/ortmodule.py | 44 +++++++++---------- .../python/orttraining_test_ortmodule_api.py | 14 +++--- 3 files changed, 31 insertions(+), 41 deletions(-) diff --git a/orttraining/orttraining/python/training/ortmodule/_utils.py b/orttraining/orttraining/python/training/ortmodule/_utils.py index 543d8fe3c8d3..eb3679530b99 100644 --- a/orttraining/orttraining/python/training/ortmodule/_utils.py +++ b/orttraining/orttraining/python/training/ortmodule/_utils.py @@ -102,15 +102,5 @@ class ModuleMetadata(): """Encapsulates modules and allows easy access as required""" def __init__(self, original_module, flattened_module): - self._original_module = original_module - self._flattened_module = flattened_module - - def original_module(self): - """Returns the original PyTorch module""" - - return self._original_module - - def flattened_module(self): - """Returns the flattened input and output PyTorch module""" - - return self._flattened_module + self.original_module = original_module + self.flattened_module = flattened_module diff --git a/orttraining/orttraining/python/training/ortmodule/ortmodule.py b/orttraining/orttraining/python/training/ortmodule/ortmodule.py index 97ca7ee77191..9070a4fc5b1a 100644 --- a/orttraining/orttraining/python/training/ortmodule/ortmodule.py +++ b/orttraining/orttraining/python/training/ortmodule/ortmodule.py @@ -56,7 +56,7 @@ def _forward(self, *inputs, **kwargs): # inside ModuleMetadata self._module = ModuleMetadata(module, _io._FlattenedModule(module)) - self._execution_manager = GraphExecutionManagerFactory(self._module.flattened_module()) + self._execution_manager = GraphExecutionManagerFactory(self._module.flattened_module) # IMPORTANT: DO NOT add code here # This declaration is for automatic document generation purposes only @@ -70,7 +70,7 @@ def _apply(self, fn): # Delegation must happen to _flattened_module since methods depend on # _apply to recursively apply the internal setting changes - self._module.flattened_module()._apply(fn) + self._module.flattened_module._apply(fn) return self def apply(self: T, fn: Callable[['Module'], None]) -> T: @@ -78,7 +78,7 @@ def apply(self: T, fn: Callable[['Module'], None]) -> T: # Delegation must happen to _flattened_module since methods depend on # apply to recursively apply the internal setting changes - self._module.flattened_module().apply(fn) + self._module.flattened_module.apply(fn) return self def _is_training(self): @@ -87,10 +87,10 @@ def _is_training(self): def train(self: T, mode: bool = True) -> T: """Override original method to delegate execution to the flattened PyTorch user module""" - # Since _modules is empty, the task needs to be delegated to _module.flattened_module().train + # Since _modules is empty, the task needs to be delegated to _module.flattened_module.train # which will recursively update the original_module self.training = mode - self._module.flattened_module().train(mode) + self._module.flattened_module.train(mode) return self def state_dict(self, destination=None, prefix='', keep_vars=False): @@ -98,7 +98,7 @@ def state_dict(self, destination=None, prefix='', keep_vars=False): # Override the state_dict() method so that the state dict key names # do not contain the flattened_module._original_module prefix - return self._module.original_module().state_dict( + return self._module.original_module.state_dict( destination=destination, prefix=prefix, keep_vars=keep_vars) def load_state_dict(self, state_dict: 'OrderedDict[str, Tensor]', @@ -106,41 +106,41 @@ def load_state_dict(self, state_dict: 'OrderedDict[str, Tensor]', """Override original method to delegate execution to the original PyTorch user module""" # Override the load_state_dict() method so that the loaded state dict - # key names does not need to contain the _module.flattened_module()._original_module prefix - return self._module.original_module().load_state_dict( + # key names does not need to contain the _module.flattened_module._original_module prefix + return self._module.original_module.load_state_dict( state_dict, strict=strict) def register_buffer(self, name: str, tensor: Optional[torch.Tensor], persistent: bool = True) -> None: """Override original method to delegate execution to the original PyTorch user module""" - self._module.original_module().register_buffer(name, tensor, persistent=persistent) + self._module.original_module.register_buffer(name, tensor, persistent=persistent) def register_parameter(self, name: str, param: Optional[torch.nn.Parameter]) -> None: """Override original method to delegate execution to the original PyTorch user module""" - self._module.original_module().register_parameter(name, param) + self._module.original_module.register_parameter(name, param) def get_parameter(self, target: str) -> torch.nn.Parameter: """Override original method to delegate execution to the original PyTorch user module""" - return self._module.original_module().get_parameter(target) + return self._module.original_module.get_parameter(target) def get_buffer(self, target: str) -> torch.Tensor: """Override original method to delegate execution to the original PyTorch user module""" - return self._module.original_module().get_buffer(target) + return self._module.original_module.get_buffer(target) def parameters(self, recurse: bool = True) -> Iterator[torch.nn.Parameter]: """Override original method to delegate execution to the original PyTorch user module""" - yield from self._module.original_module().parameters(recurse=recurse) + yield from self._module.original_module.parameters(recurse=recurse) def named_parameters(self, prefix: str = '', recurse: bool = True) -> Iterator[Tuple[str, torch.nn.Parameter]]: """Override original method to delegate execution to the original PyTorch user module""" - yield from self._module.original_module().named_parameters(prefix=prefix, recurse=recurse) + yield from self._module.original_module.named_parameters(prefix=prefix, recurse=recurse) def buffers(self, recurse: bool = True) -> Iterator[torch.Tensor]: """Override original method to delegate execution to the original PyTorch user module""" - yield from self._module.original_module().buffers(recurse=recurse) + yield from self._module.original_module.buffers(recurse=recurse) def named_buffers(self, prefix: str = '', recurse: bool = True) -> Iterator[Tuple[str, torch.Tensor]]: """Override original method to delegate execution to the original PyTorch user module""" - yield from self._module.original_module().named_buffers(prefix=prefix, recurse=recurse) + yield from self._module.original_module.named_buffers(prefix=prefix, recurse=recurse) def _replicate_for_data_parallel(self): """Raises a NotImplementedError exception since ORTModule is not compatible with torch.nn.DataParallel @@ -165,25 +165,25 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): """Override original method to delegate execution to the original PyTorch user module""" - self._module.original_module()._load_from_state_dict(state_dict, prefix, local_metadata, strict, + self._module.original_module._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) def named_children(self) -> Iterator[Tuple[str, 'Module']]: """Override original method to delegate execution to the original PyTorch user module""" - yield from self._module.original_module().named_children() + yield from self._module.original_module.named_children() def modules(self) -> Iterator['Module']: """Override original method to delegate execution to the original PyTorch user module""" - yield from self._module.original_module().modules() + yield from self._module.original_module.modules() def named_modules(self, memo: Optional[Set['Module']] = None, prefix: str = ''): """Override original method to delegate execution to the original PyTorch user module""" - yield from self._module.original_module().named_modules(memo, prefix) + yield from self._module.original_module.named_modules(memo, prefix) def add_module(self, name: str, module: Optional['Module']) -> None: - """Override original method to delegate execution to the original PyTorch user module""" + """Raises a NotImplementedError exception since ORTModule does not support adding modules to it""" - self._module.original_module().add_module(name, module) + raise NotImplementedError("ORTModule does not support adding modules to it.") diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py index a95c712356d0..a6a7661fa00b 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py @@ -1666,26 +1666,26 @@ def test_model_initializer_requires_grad_changes_from_one_forward_to_next(): model.fc1.requires_grad_(True) model = ORTModule(model) x = torch.randn(N, D_in, device=device) - assert model._module.original_module().fc1.weight.grad is None - assert model._module.original_module().fc1.bias.grad is None + assert model._module.original_module.fc1.weight.grad is None + assert model._module.original_module.fc1.bias.grad is None # Make sure no exception is raised output = model(x) loss = torch.sum(output) loss.backward() training_session1 = model._execution_manager(model._is_training())._execution_agent - weight_grad_2 = model._module.original_module().fc1.weight.grad - bias_grad_2 = model._module.original_module().fc1.bias.grad + weight_grad_2 = model._module.original_module.fc1.weight.grad + bias_grad_2 = model._module.original_module.fc1.bias.grad assert weight_grad_2 is not None assert bias_grad_2 is not None - model._module.original_module().fc1.requires_grad_(False) + model._module.original_module.fc1.requires_grad_(False) output = model(x) loss = torch.sum(output) loss.backward() training_session2 = model._execution_manager(model._is_training())._execution_agent - weight_grad_3 = model._module.original_module().fc1.weight.grad - bias_grad_3 = model._module.original_module().fc1.bias.grad + weight_grad_3 = model._module.original_module.fc1.weight.grad + bias_grad_3 = model._module.original_module.fc1.bias.grad assert training_session1 != training_session2 assert torch.equal(weight_grad_2, weight_grad_3) From 3291659c8f9cb8b51dfa9edd2ffa6b6457fca0dc Mon Sep 17 00:00:00 2001 From: Baiju Meswani Date: Thu, 27 May 2021 19:17:15 +0000 Subject: [PATCH 4/4] Add comment explaining why overriding _load_from_state_dict method is needed --- .../python/training/ortmodule/_utils.py | 2 +- .../python/training/ortmodule/ortmodule.py | 46 ++++++++++--------- .../python/orttraining_test_ortmodule_api.py | 14 +++--- 3 files changed, 33 insertions(+), 29 deletions(-) diff --git a/orttraining/orttraining/python/training/ortmodule/_utils.py b/orttraining/orttraining/python/training/ortmodule/_utils.py index eb3679530b99..751c5f1a46dd 100644 --- a/orttraining/orttraining/python/training/ortmodule/_utils.py +++ b/orttraining/orttraining/python/training/ortmodule/_utils.py @@ -98,7 +98,7 @@ def _create_iobinding(io_binding, inputs, model, device): for value_info in model.graph.output: io_binding.bind_output(value_info.name, device.type, device_id=get_device_index(device)) -class ModuleMetadata(): +class _PytorchModuleMetadata(): """Encapsulates modules and allows easy access as required""" def __init__(self, original_module, flattened_module): diff --git a/orttraining/orttraining/python/training/ortmodule/ortmodule.py b/orttraining/orttraining/python/training/ortmodule/ortmodule.py index 9070a4fc5b1a..bfdc1c5631a8 100644 --- a/orttraining/orttraining/python/training/ortmodule/ortmodule.py +++ b/orttraining/orttraining/python/training/ortmodule/ortmodule.py @@ -5,7 +5,7 @@ from . import _io from ._graph_execution_manager_factory import GraphExecutionManagerFactory -from ._utils import ModuleMetadata +from ._utils import _PytorchModuleMetadata from onnxruntime.training import register_custom_ops_pytorch_exporter @@ -53,10 +53,10 @@ def _forward(self, *inputs, **kwargs): # User module is wrapped to use its initializers and save computed gradients # along with the module that flattens both input and output of the user module - # inside ModuleMetadata - self._module = ModuleMetadata(module, _io._FlattenedModule(module)) + # inside _PytorchModuleMetadata + self._module_metadata = _PytorchModuleMetadata(module, _io._FlattenedModule(module)) - self._execution_manager = GraphExecutionManagerFactory(self._module.flattened_module) + self._execution_manager = GraphExecutionManagerFactory(self._module_metadata.flattened_module) # IMPORTANT: DO NOT add code here # This declaration is for automatic document generation purposes only @@ -70,7 +70,7 @@ def _apply(self, fn): # Delegation must happen to _flattened_module since methods depend on # _apply to recursively apply the internal setting changes - self._module.flattened_module._apply(fn) + self._module_metadata.flattened_module._apply(fn) return self def apply(self: T, fn: Callable[['Module'], None]) -> T: @@ -78,7 +78,7 @@ def apply(self: T, fn: Callable[['Module'], None]) -> T: # Delegation must happen to _flattened_module since methods depend on # apply to recursively apply the internal setting changes - self._module.flattened_module.apply(fn) + self._module_metadata.flattened_module.apply(fn) return self def _is_training(self): @@ -90,7 +90,7 @@ def train(self: T, mode: bool = True) -> T: # Since _modules is empty, the task needs to be delegated to _module.flattened_module.train # which will recursively update the original_module self.training = mode - self._module.flattened_module.train(mode) + self._module_metadata.flattened_module.train(mode) return self def state_dict(self, destination=None, prefix='', keep_vars=False): @@ -98,7 +98,7 @@ def state_dict(self, destination=None, prefix='', keep_vars=False): # Override the state_dict() method so that the state dict key names # do not contain the flattened_module._original_module prefix - return self._module.original_module.state_dict( + return self._module_metadata.original_module.state_dict( destination=destination, prefix=prefix, keep_vars=keep_vars) def load_state_dict(self, state_dict: 'OrderedDict[str, Tensor]', @@ -107,40 +107,40 @@ def load_state_dict(self, state_dict: 'OrderedDict[str, Tensor]', # Override the load_state_dict() method so that the loaded state dict # key names does not need to contain the _module.flattened_module._original_module prefix - return self._module.original_module.load_state_dict( + return self._module_metadata.original_module.load_state_dict( state_dict, strict=strict) def register_buffer(self, name: str, tensor: Optional[torch.Tensor], persistent: bool = True) -> None: """Override original method to delegate execution to the original PyTorch user module""" - self._module.original_module.register_buffer(name, tensor, persistent=persistent) + self._module_metadata.original_module.register_buffer(name, tensor, persistent=persistent) def register_parameter(self, name: str, param: Optional[torch.nn.Parameter]) -> None: """Override original method to delegate execution to the original PyTorch user module""" - self._module.original_module.register_parameter(name, param) + self._module_metadata.original_module.register_parameter(name, param) def get_parameter(self, target: str) -> torch.nn.Parameter: """Override original method to delegate execution to the original PyTorch user module""" - return self._module.original_module.get_parameter(target) + return self._module_metadata.original_module.get_parameter(target) def get_buffer(self, target: str) -> torch.Tensor: """Override original method to delegate execution to the original PyTorch user module""" - return self._module.original_module.get_buffer(target) + return self._module_metadata.original_module.get_buffer(target) def parameters(self, recurse: bool = True) -> Iterator[torch.nn.Parameter]: """Override original method to delegate execution to the original PyTorch user module""" - yield from self._module.original_module.parameters(recurse=recurse) + yield from self._module_metadata.original_module.parameters(recurse=recurse) def named_parameters(self, prefix: str = '', recurse: bool = True) -> Iterator[Tuple[str, torch.nn.Parameter]]: """Override original method to delegate execution to the original PyTorch user module""" - yield from self._module.original_module.named_parameters(prefix=prefix, recurse=recurse) + yield from self._module_metadata.original_module.named_parameters(prefix=prefix, recurse=recurse) def buffers(self, recurse: bool = True) -> Iterator[torch.Tensor]: """Override original method to delegate execution to the original PyTorch user module""" - yield from self._module.original_module.buffers(recurse=recurse) + yield from self._module_metadata.original_module.buffers(recurse=recurse) def named_buffers(self, prefix: str = '', recurse: bool = True) -> Iterator[Tuple[str, torch.Tensor]]: """Override original method to delegate execution to the original PyTorch user module""" - yield from self._module.original_module.named_buffers(prefix=prefix, recurse=recurse) + yield from self._module_metadata.original_module.named_buffers(prefix=prefix, recurse=recurse) def _replicate_for_data_parallel(self): """Raises a NotImplementedError exception since ORTModule is not compatible with torch.nn.DataParallel @@ -165,23 +165,27 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): """Override original method to delegate execution to the original PyTorch user module""" - self._module.original_module._load_from_state_dict(state_dict, prefix, local_metadata, strict, + # PyTorch load_state_dict implementation does not recursively call load_state_dict on its sub-modules. + # Instead, it creates a recursive function and invokes _load_from_state_dict on all child modules. + # For the scenario where an ORTModule is a sub-module of another module, loading of the state + # dictionary requires the _load_from_state_dict to be overridden to prevent an error. + self._module_metadata.original_module._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) def named_children(self) -> Iterator[Tuple[str, 'Module']]: """Override original method to delegate execution to the original PyTorch user module""" - yield from self._module.original_module.named_children() + yield from self._module_metadata.original_module.named_children() def modules(self) -> Iterator['Module']: """Override original method to delegate execution to the original PyTorch user module""" - yield from self._module.original_module.modules() + yield from self._module_metadata.original_module.modules() def named_modules(self, memo: Optional[Set['Module']] = None, prefix: str = ''): """Override original method to delegate execution to the original PyTorch user module""" - yield from self._module.original_module.named_modules(memo, prefix) + yield from self._module_metadata.original_module.named_modules(memo, prefix) def add_module(self, name: str, module: Optional['Module']) -> None: """Raises a NotImplementedError exception since ORTModule does not support adding modules to it""" diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py index a6a7661fa00b..a66d609809ba 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py @@ -1666,26 +1666,26 @@ def test_model_initializer_requires_grad_changes_from_one_forward_to_next(): model.fc1.requires_grad_(True) model = ORTModule(model) x = torch.randn(N, D_in, device=device) - assert model._module.original_module.fc1.weight.grad is None - assert model._module.original_module.fc1.bias.grad is None + assert model._module_metadata.original_module.fc1.weight.grad is None + assert model._module_metadata.original_module.fc1.bias.grad is None # Make sure no exception is raised output = model(x) loss = torch.sum(output) loss.backward() training_session1 = model._execution_manager(model._is_training())._execution_agent - weight_grad_2 = model._module.original_module.fc1.weight.grad - bias_grad_2 = model._module.original_module.fc1.bias.grad + weight_grad_2 = model._module_metadata.original_module.fc1.weight.grad + bias_grad_2 = model._module_metadata.original_module.fc1.bias.grad assert weight_grad_2 is not None assert bias_grad_2 is not None - model._module.original_module.fc1.requires_grad_(False) + model._module_metadata.original_module.fc1.requires_grad_(False) output = model(x) loss = torch.sum(output) loss.backward() training_session2 = model._execution_manager(model._is_training())._execution_agent - weight_grad_3 = model._module.original_module.fc1.weight.grad - bias_grad_3 = model._module.original_module.fc1.bias.grad + weight_grad_3 = model._module_metadata.original_module.fc1.weight.grad + bias_grad_3 = model._module_metadata.original_module.fc1.bias.grad assert training_session1 != training_session2 assert torch.equal(weight_grad_2, weight_grad_3)