Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Consolidate ZeRO state before checkpoint saving #2623

Closed
danieltudosiu opened this issue Jul 20, 2022 · 4 comments · Fixed by #2642
Closed

Consolidate ZeRO state before checkpoint saving #2623

danieltudosiu opened this issue Jul 20, 2022 · 4 comments · Fixed by #2642

Comments

@danieltudosiu
Copy link

Is your feature request related to a problem? Please describe.
When checkpoint saving occurs there should be a check if the object that state_dict() is called on is a ZeroRedundancyOptimizer instance

Describe the solution you'd like
Prior to the call state_dict() call a consolidate_state_dict() call should be issued. This call needs to be issued on all ranks and point toward the same consolidating rank.

Here there are two design solutions, you instantiate a Checkpoint handler on all ranks and only on the designated rank does it save the checkpoint or do you create a Handler that needs to run before the Checkpoint handler in order to consolidate.

Special care must be addressed to the PyTorch version as the naming scheme of the ZeRO's method arguments has changed between recent versions.

Describe alternatives you've considered
The only alternative is for the users to write themselves Handlers to do that, segregating the checkpoint-saving logic. And it can be written as:

class ConsolidateZeROHandler:
    """
    Handler that consolidated the Zero Redundancy Optimizer prior to the checkpoint Saving.

    Args:
        zero_optimizer (ZeroRedundancyOptimizer): The optimizer to be consolidated.
        recipient_rank (int): The rank on which the consolidation will happen. Defaults to 0.
        epoch_level (bool): Call every N epochs or every N iterations. `True` is epoch level, `False` is iteration
            level. Defaults to True.
    """

    def __init__(
        self,
        zero_optimizer: ZeroRedundancyOptimizer,
        call_every: int,
        recipient_rank: int = 0,
        epoch_level=True,
    ):
        self.zero_optimizer = zero_optimizer
        self.recipient_rank = recipient_rank
        self.epoch_level = epoch_level
        self.call_every = call_every

    def __call__(self, engine: Engine):
        self.zero_optimizer.consolidate_state_dict(to=self.recipient_rank)

    def attach(self, engine: Engine) -> None:
        if self.epoch_level:
            engine.add_event_handler(Events.EPOCH_COMPLETED(every=self.call_every), self)
        else:
            engine.add_event_handler(Events.ITERATION_COMPLETED(every=self.call_every), self)
@vfdev-5
Copy link
Collaborator

vfdev-5 commented Jul 20, 2022

Thanks for the feature request @danieltudosiu !

Yes, from user side, as you suggested one possible alternative solution could be to attach a specific handler for that:

to_save = {"zero": zero_optimizer}
checkpoint = Checkpoint(to_save, ...)

recipient_rank = 0
trainer.add_event_handler(Events.EPOCH_COMPLETED, lambda _: zero_optimizer.consolidate_state_dict(to=recipient_rank))
trainer.add_event_handler(Events.EPOCH_COMPLETED, checkpoint)

Describe the solution you'd like
Prior to the call state_dict() call a consolidate_state_dict() call should be issued. This call needs to be issued on all ranks and point toward the same consolidating rank.

Do you think we should add this specific hard-coded into Checkpoint ? For example, we can assume that consolidate_state_dict is an API method and if object has it we can call it obj.consolidate_state_dict(to=self.consolidate_recipient_rank) before setting up the state dict:

def _setup_checkpoint(self) -> Dict[str, Dict[Any, Any]]:
checkpoint = {}
if self.to_save is not None:
for k, obj in self.to_save.items():
if isinstance(obj, (nn.DataParallel, nn.parallel.DistributedDataParallel)):
obj = obj.module
checkpoint[k] = obj.state_dict()
return checkpoint

Checkpoint class will have new consolidate_recipient_rank argument :

class Checkpoint:
    def __init__(self, ..., consolidate_recipient_rank=0 ...):
        self.consolidate_recipient_rank = consolidate_recipient_rank

@danieltudosiu
Copy link
Author

Hi @vfdev-5,

The whole decision boils down to a design decision. Do you want the user to take care of this specific checkpointing logic or do you want it to be seemingly integrated into Ignite?

Do you think we should add this specific hard-coded into Checkpoint ? For example, we can assume that consolidate_state_dict is an API method and if object has it we can call it obj.consolidate_state_dict(to=self.consolidate_recipient_rank) before setting up the state dict:
Checkpoint class will have new consolidate_recipient_rank argument :

class Checkpoint:
    def __init__(self, ..., consolidate_recipient_rank=0 ...):
        self.consolidate_recipient_rank = consolidate_recipient_rank

Be careful cuz this is not the complete solution, consolidate_state_dict MUST be called on all ranks, which means that all ranks need to go through the checkpoint-saving logic so the checkpoint saver must be aware of the rank it is being called on so we do not save on all rank the checkpoints.

Besides that, yes I would incline on integrating the ZeRO consolidation logic in the checkpoint class such that the user does not need to fuss around with it.

@vfdev-5
Copy link
Collaborator

vfdev-5 commented Jul 20, 2022

Hi @danieltudosiu thanks for your pointers !

Today, Checkpoint is intended to be called by all ranks (due to xla which requires that)

.. code-block:: python
# Wrong:
# if idist.get_rank() == 0:
# handler = Checkpoint(...)
# trainer.add_event_handler(Events.ITERATION_COMPLETED(every=1000), handler)
# Correct:
handler = Checkpoint(...)
trainer.add_event_handler(Events.ITERATION_COMPLETED(every=1000), handler)

and internally DiskSaver uses torch.save only on rank zero.

Besides that, yes I would incline on integrating the ZeRO consolidation logic in the checkpoint class such that the user does not need to fuss around with it.

Sounds good to me as well. I'll send a PR and will ask to check it. Hope it works for you :)

@danieltudosiu
Copy link
Author

Hi @vfdev-5,

Thanks for moving so quickly.

For my personal use, I already coded the Handler so I am ok ;) I just wanted Ignite to be even more user-friendly <3

Cheers,

Dan

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
2 participants