Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Check all param groups for flat parameters in FSDP #17914

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/lightning/fabric/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed an issue with hpu imports leading to performance degradation ([#17788](https://github.com/Lightning-AI/lightning/pull/17788))


- Fixed check for FSDP's flat parameters in all parameter groups ([#17914](https://github.com/Lightning-AI/lightning/pull/17914))


## [2.0.3] - 2023-06-07

- Added support for `Callback` registration through entry points ([#17756](https://github.com/Lightning-AI/lightning/pull/17756))
Expand Down
6 changes: 4 additions & 2 deletions src/lightning/fabric/strategies/fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -637,11 +637,13 @@ def _init_cpu_offload(cpu_offload: Optional[Union[bool, "CPUOffload"]]) -> "CPUO
def _optimizer_has_flat_params(optimizer: Optimizer) -> bool:
_FSDP_FLATTENED = "_fsdp_flattened"
if _TORCH_GREATER_EQUAL_1_13:
return any(getattr(param, _FSDP_FLATTENED, False) for param in optimizer.param_groups[0]["params"])
return any(
getattr(param, _FSDP_FLATTENED, False) for group in optimizer.param_group for param in group["params"]
)

from torch.distributed.fsdp import FlatParameter

return any(isinstance(param, FlatParameter) for param in optimizer.param_groups[0]["params"])
return any(isinstance(param, FlatParameter) for group in optimizer.param_groups for param in group["params"])


def _get_sharded_state_dict_context(module: "FullyShardedDataParallel") -> _GeneratorContextManager:
Expand Down
2 changes: 1 addition & 1 deletion tests/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ To setup a local development environment, install both local and test dependenci
git clone https://github.com/Lightning-AI/lightning.git
cd lightning

# install required depedencies
# install required dependencies
export PACKAGE_NAME=pytorch
python -m pip install ".[dev, examples]"
# install pre-commit (optional)
Expand Down