Skip to content

Commit

Permalink
Del ort_model._modules to foward its accessing to torch_model._modules (
Browse files Browse the repository at this point in the history
#14563)

Missing '_modules' attribute in ORTModule will cause load_state_dict for
wrapped_ortmodule fail.

reference:#7847
  • Loading branch information
guyang3532 authored Mar 3, 2023
1 parent 8d87fdc commit c49f250
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,10 @@ def __init__(self, module, debug_options=None):
# else, they will be assigned to self._torch_module.original_module instead.
self._is_initialized = True

# del the ort._modules so that all reference to ort._modules will be forward to the underlying torch_model
# through '__getattr__'
del self._modules

# IMPORTANT: DO NOT add code here
# This declaration is for automatic document generation purposes only
# The actual forward implementation is bound during ORTModule initialization
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4192,7 +4192,8 @@ def forward(self, x):
x = torch.randn(N, D_in, device=device)
_ = wrapper_module(x)

state_dict1 = wrapper_module.state_dict()
# Must copy the state_dict or else they are sharing the same memory
state_dict1 = copy.deepcopy(wrapper_module.state_dict())
list(next(iter(state_dict1.items())))[1] += 10
wrapper_module.load_state_dict(state_dict1)
state_dict2 = wrapper_module.state_dict()
Expand Down

0 comments on commit c49f250

Please sign in to comment.