Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pytorch Fully Sharded Data Parallel (FSDP) Integration #147

Open
wants to merge 23 commits into
base: master
Choose a base branch
from

Conversation

parthraut
Copy link
Collaborator

Integrates Pytorch Fully Sharded Data Parallel (FSDP) into Zeus. Now, GlobalPowerLimitOptimizer performs distributed operations (all-reduce) to ensure it makes the correct power limit decision.

  • Implemented generic zeus.framework.all_reduce, which currently invokes torch.distributed.all_reduce if torch is the framework.
  • Added train_fsdp.py example and relevant documentation
  • Added relevant use cases to GlobalPowerLimitOptimizer.__init__
  • Implemented zeus.framework.is_distributed and warn the user if multiple GPUs are being monitored in a distributed context

@jaywonchung
Copy link
Member

Please rebase to the current master and push. It's impossible to review with all the changes from past commits.

Copy link
Member

@jaywonchung jaywonchung left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @parthraut! I'm requesting some changes. Please let me know if anything is unclear.

examples/power_limit_optimizer/README.md Outdated Show resolved Hide resolved
examples/power_limit_optimizer/README.md Outdated Show resolved Hide resolved
examples/power_limit_optimizer/README.md Outdated Show resolved Hide resolved
examples/power_limit_optimizer/train_dp.py Outdated Show resolved Hide resolved
zeus/monitor/energy.py Outdated Show resolved Hide resolved
zeus/utils/framework.py Outdated Show resolved Hide resolved
Comment on lines 140 to 154
if jax_is_available():
# JAX cross-device all-reduce not yet implemente
return sum(object) if operation == "sum" else max(object)

raise RuntimeError("No framework is available.")


def is_distributed() -> bool:
"""Check if the current execution is distributed across multiple devices."""
if torch_is_available(ensure_cuda=False):
torch = MODULE_CACHE["torch"]
return torch.distributed.is_available() and torch.distributed.is_initialized()
if jax_is_available():
return False # JAX not yet implemented
return False
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. If this was going to be left unimplemented, it should have raised a NotImplementedError instead of silently doing the wrong thing.
  2. This PR will be merged after JAX counterparts are implemented. No need to have a full JAX training script; I'm fine with it being tested manually with a quick script that imports and uses all_reduce and is_distributed with JAX.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will include JAX impl in this PR

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Line 265 is now broken, because the previous implementation assumed that len(zeus_monitor.gpu_indices) gives the current world size. Let's just switch the default optimum_selector to MaxSlowdownConstraint(factor=1.1).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or we could use torch.distributed.get_world_size (and something analogous for jax) by defining a generic framework function zeus.framework.get_world_size. What do you think?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nah, I wouldn't bother for this one. Now I think MaxSlowdownConstraint is a better default; the original one if from 2022.

zeus/utils/framework.py Outdated Show resolved Hide resolved
zeus/optimizer/power_limit.py Outdated Show resolved Hide resolved
return sum(object) if operation == "sum" else max(object)
# Check if not distributed
jax = MODULE_CACHE["jax"]
if jax.process_count() == 1:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator Author

@parthraut parthraut Dec 21, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, it should be. Fixed that, thanks

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants