Skip to content

Commit

Permalink
Forward access of ort_model._modules to torch_model._modules
Browse files Browse the repository at this point in the history
  • Loading branch information
guyang3532 committed Feb 3, 2023
1 parent c6c1103 commit 5775009
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,8 @@ def __init__(self, module, debug_options=None):
# ORTModule.
_utils.check_for_name_collisions_and_bind_methods_to_ortmodule(self, module)

self._modules = module._modules

except ORTModuleFallbackException as e:
# Although backend is switched to PyTorch here,
# it is up to _FallbackManager to actually terminate execution or fallback
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4164,7 +4164,8 @@ def forward(self, x):
x = torch.randn(N, D_in, device=device)
_ = wrapper_module(x)

state_dict1 = wrapper_module.state_dict()
# Must copy the state_dict or else they are sharing the same memory
state_dict1 = copy.deepcopy(wrapper_module.state_dict())
list(next(iter(state_dict1.items())))[1] += 10
wrapper_module.load_state_dict(state_dict1)
state_dict2 = wrapper_module.state_dict()
Expand Down

0 comments on commit 5775009

Please sign in to comment.