-
Notifications
You must be signed in to change notification settings - Fork 389
[Feature] Support torch ZeroRedundancyOptimizer #551
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Co-authored-by: Junhwa Song <ethan9867@gmail.com> Signed-off-by: Junhwa Song <ethan9867@gmail.com> Signed-off-by: Hakjin Lee <nijkah@gmail.com>
Codecov ReportBase: 78.07% // Head: 78.04% // Decreases project coverage by
Additional details and impacted files@@ Coverage Diff @@
## main #551 +/- ##
==========================================
- Coverage 78.07% 78.04% -0.04%
==========================================
Files 125 126 +1
Lines 8991 9009 +18
Branches 1845 1846 +1
==========================================
+ Hits 7020 7031 +11
- Misses 1659 1666 +7
Partials 312 312
Flags with carried forward coverage won't be shown. Click here to find out more.
Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here. ☔ View full report at Codecov. |
I found that it has a problem when saving the optimizer's state_dict. |
Solved the problem to call |
Thank you for your contributions! It could be better to add a unit test to use the |
This reverts commit dd64538.
if ZeroRedundancyOptimizer is None: | ||
self.skipTest('ZeroRedundancyOptimizer is not available.') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if ZeroRedundancyOptimizer is None: | |
self.skipTest('ZeroRedundancyOptimizer is not available.') | |
if ZeroRedundancyOptimizer is None: | |
self.skipTest('ZeroRedundancyOptimizer is not available.') |
Is this line duplicated with
@unittest.skipIf(
digit_version(TORCH_VERSION) < digit_version('1.8.0'),
reason='ZeRO needs Pytorch 1.8 or higher')
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
https://github.com/open-mmlab/mmengine/actions/runs/3134972777/jobs/5090129146#step:8:132
I found that importing ZeroRedundancyOptimizer
failed in the Windows CPU CI with & torch1.8.1
.
(The importing failure made _ZeroRedundancyOptimizer
as object
.)
So I added duplicated skip code.
I'll check it again.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I found that it has another condition.
torch.distributed.rpc
should be available. I removed duplicated lines, and clarified this condition.
I found another bug. Currently, saving the This constraint makes refactor the |
Hi, @HAOCHENYE @C1rN09 I changed the Please notify me if the potential problem is expected. |
#553 also remove the |
'`torch.distributed.optim.ZeroReundancyOptimizer` is only ' | ||
'available when pytorch version >= 1.8.') | ||
assert is_available(), 'torch.distributed.rpc is not available.' | ||
optimizer_class = getattr(torch.optim, optimizer_type) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can it support custom Optimizer classes?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm still figuring it out now. Until now, it does not seem to have a specific dependency on torch's optimizers. It may be possible to custom Optimizer classes.
Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>
Hi @nijkah , the lint failed. |
Hi @C1rN09 @zhouzaida , I'm concerning about the option The public API docs demonstrate:
So if someone wants to use Will it be better to fix it as |
def state_dict(self): | ||
"""Consolidate `state_dict`s from ranks to save the `state_dict`.""" | ||
self.consolidate_state_dict() | ||
state_dict = super().state_dict() if is_main_process() else dict() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
state_dict['loss_scaler'] = self.loss_scaler.state_dict() |
Due to this line, using
ZeroRedundancyOptimizer
with AmpOptimWrapper
gave the error like
TypeError: 'NoneType' object does not support item assignment in <mmengine.hooks.checkpoint_hook.CheckpointHook object at XXXXXX>
So I modified it to return dict()
instead of None
when it is not the main process.
Hi! If there is no easy solution, I think it's acceptable to fix |
@C1rN09 In some versions (e.g. 1.8.0), |
Hi, @nijkah I tested your branch on my cluster and got some different results, show as below:
What I'm confused about is the memory consumption. Since this model is only ~250MB, it seems that the maximum memory reduction that My cluster & configurations in case I miss something important:
optimizer = dict(type='ZeroRedundancyOptimizer', optimizer_type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0005)
optim_wrapper = dict(type='OptimWrapper', optimizer=optimizer, clip_grad=None) |
Hi, @C1rN09. I apologize for confusing this. 😞 I was confused because I just tried to compare my result with the provided log in github. After running the experiments again, I also could get the same result as you. I'll fix the description. |
nit: Hi @C1rN09. Since the SGD optimizer with momentum only stores the model params from the previous step, |
Yes, I think you are right! From the experiment result I guess it only shards optimizer states, instead of optimizer states + grads, which I used to think it might do. |
* [Feature] Support torch ZeRORedundancyOptimizer Co-authored-by: Junhwa Song <ethan9867@gmail.com> Signed-off-by: Junhwa Song <ethan9867@gmail.com> Signed-off-by: Hakjin Lee <nijkah@gmail.com> * lint * Fix saving optimizer state_dict * Fix handling import error * Add test case * fix UT * Revert "fix UT" This reverts commit dd64538. * fix handling import in UT * Fix saving zero checkpoint and delete redundant master_only * lint * test unittest * Fix handling impor error * Fix UT condition * Edit docstrings * Fix typo * Skip redundant procudure in checkpoint hook * fix typo again * Update mmengine/optim/optimizer/zero_optimizer.py Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * Add api info * lint * Fix lint * Handling AmpOptimWrapper case * handling overlap_with_ddp * Fix error Signed-off-by: Junhwa Song <ethan9867@gmail.com> Signed-off-by: Hakjin Lee <nijkah@gmail.com> Co-authored-by: Junhwa Song <ethan9867@gmail.com> Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>
how can i use ZeroRedundancyOptimizer with AmpOptimizer in the training config? |
I didn't test it yet. Is there any specific reason to use |
This is what i want
Did you test that? |
@twmht Yes, it should work! 😄 |
upernet_r50_4xb4-80k_ade20k-512x512
Time may be dependent to the environment.
Co-authored-by: Junhwa Song ethan9867@gmail.com @KKIEEK
Signed-off-by: Junhwa Song ethan9867@gmail.com
Signed-off-by: Hakjin Lee nijkah@gmail.com