Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Resolve issue with wrapped ORTModule load_state_dict #7847

Merged
merged 4 commits into from
May 27, 2021

Conversation

baijumeswani
Copy link
Contributor

@baijumeswani baijumeswani commented May 26, 2021

Since ORTModule had 2 submodules, _original_module and _flattened_module, all operations that traversed over the _modules would result in iterating over the same module multiple times (since _original_module is a part of _flattened_module).

This is troublesome for operations such as load_state_dict for a model which has an ORTModule encapsulated within it. This is because load_state_dict does not recursively call load_state_dict on its children, but instead it defines its own function load (inside load_state_dict) which does this task. As a result, some keys are not found in the state_dict (those belonging to _flattened_module and _original_module).

To resolve this problem, ORTModule _modules is now an empty OrderedDict and all the sub modules are a part of an object called ModuleAccessor. All other functions that depend on children modules must be overwritten to reference modules inside ModuleAccessor.

…t erroneuos iteration over children while loading the state dictionary
ytaous
ytaous previously approved these changes May 27, 2021
Copy link
Contributor

@ytaous ytaous left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM from local test.

mrry
mrry previously approved these changes May 27, 2021
@baijumeswani baijumeswani merged commit ddf4aaa into master May 27, 2021
@baijumeswani baijumeswani deleted the bmeswani/load-state-dict_wrapped-module branch May 27, 2021 23:11
xzhu1900 pushed a commit that referenced this pull request May 28, 2021
* Encapsulate children modules inside a ModuleAccessor object to prevent erroneuos iteration over children while loading the state dictionary

* Add named_models, models, apply methods, change ModuleAccessor to ModuleMetadata and modify unit tests

* Change ModuleMetadata module getter logic, raise NotImplementedError for add_modules

* Add comment explaining why overriding _load_from_state_dict method is needed
xzhu1900 added a commit that referenced this pull request May 28, 2021
* Fix bug in Transpose CUDA kernel (#7329)

* Fix permission error for ORTModule lock file (#7814)

* fix topo sort in quant tool (#7833)

* fix topo sort in quant tool

* add unit test and make the topo sort stable

* Relax tol for Conv1D fp16 test (#7844)

* Relax tol for Conv1D fp16 test

Co-authored-by: Sherlock Huang <bahuang@OrtTrainingDev3.af05slrtruoetgaxwwjv5nsq5e.px.internal.cloudapp.net>

* Resolve issue with wrapped ORTModule load_state_dict (#7847)

* Encapsulate children modules inside a ModuleAccessor object to prevent erroneuos iteration over children while loading the state dictionary

* Add named_models, models, apply methods, change ModuleAccessor to ModuleMetadata and modify unit tests

* Change ModuleMetadata module getter logic, raise NotImplementedError for add_modules

* Add comment explaining why overriding _load_from_state_dict method is needed

* fixed bugs in packed mode and enable pack mode tests in ci (#7848)

* fixed bugs in packed mode and enable pack mode tests in ci

* removed unnecessary space

* pr comments

* pr comments

* disable an average pool test

* try disabling another avg pool

* disable more avg pool tests

* disable maxpool tests

* add environment variable to control default training package's local version (#7849)

* [js] update documents (#7852)

* [js] update documents

* escape double quotes

* update operators.md

* resolve comments

* Support bool type for Pad CPU (#7856)

* Initial commit

* update

* nit

* Include ORT C/C++ API headers in the ORT Mobile AAR package (#7858)

* Add header files of ort c/c++ api to aar package

* Move header file selection to cmake based on EP choice

* fix duplicated node name (#7865)

* Clean up CPU kernel definition for opset 13 Pad (#7867)

Co-authored-by: Hariharan Seshadri <shariharan91@gmail.com>
Co-authored-by: Thiago Crepaldi <thiago.crepaldi@microsoft.com>
Co-authored-by: Yufeng Li <liyufeng1987@gmail.com>
Co-authored-by: Sherlock <baihan.huang@gmail.com>
Co-authored-by: Sherlock Huang <bahuang@OrtTrainingDev3.af05slrtruoetgaxwwjv5nsq5e.px.internal.cloudapp.net>
Co-authored-by: baijumeswani <bmeswani@microsoft.com>
Co-authored-by: Tixxx <tix@microsoft.com>
Co-authored-by: liqunfu <liqfu@microsoft.com>
Co-authored-by: Yulong Wang <yulongw@microsoft.com>
Co-authored-by: Guoyu Wang <62914304+gwang-msft@users.noreply.github.com>
Co-authored-by: Tianlei Wu <tlwu@microsoft.com>
raviskolli added a commit to microsoft/huggingface-transformers that referenced this pull request May 28, 2021
Update to import ORTModule from torch_ort as torch-ort is now public
Don't need to explicitly use _original_module for ORT with this [PR](microsoft/onnxruntime#7847)
guyang3532 added a commit that referenced this pull request Mar 3, 2023
#14563)

Missing '_modules' attribute in ORTModule will cause load_state_dict for
wrapped_ortmodule fail.

reference:#7847
mszhanyi pushed a commit that referenced this pull request Mar 9, 2023
#14563)

Missing '_modules' attribute in ORTModule will cause load_state_dict for
wrapped_ortmodule fail.

reference:#7847
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants