Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Del ort_model._modules to foward its accessing to torch_model._modules #14563

Merged
merged 1 commit into from
Mar 3, 2023

Conversation

guyang3532
Copy link
Contributor

@guyang3532 guyang3532 commented Feb 3, 2023

General Description

Missing '_modules' attribute in ORTModule will cause load_state_dict for wrapped_ortmodule fail.
The ut of 'test_load_state_dict_for_wrapped_ortmodule' has not catch this problem is because it didn't copy the state_dict
and the two state_dicts shared the same memory.

Motivation and Context

reference:#7847

@guyang3532 guyang3532 changed the title Forward access of ort_model._modules to torch_model._modules [draft]Forward access of ort_model._modules to torch_model._modules Feb 3, 2023
@guyang3532 guyang3532 changed the title [draft]Forward access of ort_model._modules to torch_model._modules [draft]set ort_model._modules to torch_model._modules Feb 3, 2023
@guyang3532
Copy link
Contributor Author

guyang3532 commented Feb 3, 2023

I think a better solution should be forwarding the access of ORTModule._modules to TorchModule._modules to keep consistent rather than just copying it. But I have not figured out a good implementation. Do you have any suggestion? @baijumeswani @pengwa

@baijumeswani baijumeswani added the training issues related to ONNX Runtime training; typically submitted using template label Feb 3, 2023
@baijumeswani
Copy link
Contributor

baijumeswani commented Feb 3, 2023

ORTModule.load_state_dict already forwards the call to the underlying torch model. Does that not work?

@guyang3532
Copy link
Contributor Author

ORTModule.load_state_dict already forwards the call to the underlying torch model. Does that not work?

As you described in #7847, because load_state_dict does not recursively call load_state_dict on its children, but instead it defines its own function load (inside load_state_dict) which does this task.

@guyang3532 guyang3532 changed the title [draft]set ort_model._modules to torch_model._modules Del ort_model._modules to foward it to torch_model._modules Feb 7, 2023
@guyang3532 guyang3532 changed the title Del ort_model._modules to foward it to torch_model._modules Del ort_model._modules to foward its accessing to torch_model._modules Feb 7, 2023
baijumeswani
baijumeswani previously approved these changes Feb 7, 2023
baijumeswani
baijumeswani previously approved these changes Feb 11, 2023
@guyang3532 guyang3532 merged commit c49f250 into microsoft:main Mar 3, 2023
mszhanyi pushed a commit that referenced this pull request Mar 9, 2023
#14563)

Missing '_modules' attribute in ORTModule will cause load_state_dict for
wrapped_ortmodule fail.

reference:#7847
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
training issues related to ONNX Runtime training; typically submitted using template
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants