From caae1948d7638c0a84cdb544c661ef65620b07bb Mon Sep 17 00:00:00 2001 From: guyang3532 Date: Fri, 3 Feb 2023 22:26:48 +0800 Subject: [PATCH] Forward access of ort_model._modules to torch_model._modules --- .../orttraining/python/training/ortmodule/ortmodule.py | 4 ++++ .../orttraining/test/python/orttraining_test_ortmodule_api.py | 3 ++- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/orttraining/orttraining/python/training/ortmodule/ortmodule.py b/orttraining/orttraining/python/training/ortmodule/ortmodule.py index 18000e0462d00..90f88459fc077 100644 --- a/orttraining/orttraining/python/training/ortmodule/ortmodule.py +++ b/orttraining/orttraining/python/training/ortmodule/ortmodule.py @@ -104,6 +104,10 @@ def __init__(self, module, debug_options=None): # else, they will be assigned to self._torch_module.original_module instead. self._is_initialized = True + # del the ort._modules so that all reference to ort._modules will be forward to the underlying torch_model + # through '__getattr__' + del self._modules + # IMPORTANT: DO NOT add code here # This declaration is for automatic document generation purposes only # The actual forward implementation is bound during ORTModule initialization diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py index f668776d0221f..f39d29a17882a 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py @@ -4192,7 +4192,8 @@ def forward(self, x): x = torch.randn(N, D_in, device=device) _ = wrapper_module(x) - state_dict1 = wrapper_module.state_dict() + # Must copy the state_dict or else they are sharing the same memory + state_dict1 = copy.deepcopy(wrapper_module.state_dict()) list(next(iter(state_dict1.items())))[1] += 10 wrapper_module.load_state_dict(state_dict1) state_dict2 = wrapper_module.state_dict()