Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Check all param groups for flat parameters in FSDP #17914

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/lightning/fabric/CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -135,6 +135,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed an issue with hpu imports leading to performance degradation ([#17788](https://github.com/Lightning-AI/lightning/pull/17788))


- Fixed check for FSDP's flat parameters in all parameter groups ([#17914](https://github.com/Lightning-AI/lightning/pull/17914))


## [2.0.3] - 2023-06-07

- Added support for `Callback` registration through entry points ([#17756](https://github.com/Lightning-AI/lightning/pull/17756))
6 changes: 4 additions & 2 deletions src/lightning/fabric/strategies/fsdp.py
Original file line number Diff line number Diff line change
@@ -637,11 +637,13 @@ def _init_cpu_offload(cpu_offload: Optional[Union[bool, "CPUOffload"]]) -> "CPUO
def _optimizer_has_flat_params(optimizer: Optimizer) -> bool:
_FSDP_FLATTENED = "_fsdp_flattened"
if _TORCH_GREATER_EQUAL_1_13:
return any(getattr(param, _FSDP_FLATTENED, False) for param in optimizer.param_groups[0]["params"])
return any(
getattr(param, _FSDP_FLATTENED, False) for group in optimizer.param_group for param in group["params"]
)

from torch.distributed.fsdp import FlatParameter

return any(isinstance(param, FlatParameter) for param in optimizer.param_groups[0]["params"])
return any(isinstance(param, FlatParameter) for group in optimizer.param_groups for param in group["params"])


def _get_sharded_state_dict_context(module: "FullyShardedDataParallel") -> _GeneratorContextManager:
2 changes: 1 addition & 1 deletion tests/README.md
Original file line number Diff line number Diff line change
@@ -12,7 +12,7 @@ To setup a local development environment, install both local and test dependenci
git clone https://github.com/Lightning-AI/lightning.git
cd lightning

# install required depedencies
# install required dependencies
export PACKAGE_NAME=pytorch
python -m pip install ".[dev, examples]"
# install pre-commit (optional)