diff --git a/src/lightning/fabric/CHANGELOG.md b/src/lightning/fabric/CHANGELOG.md index 27ec226acc0b0..64ea5fc408344 100644 --- a/src/lightning/fabric/CHANGELOG.md +++ b/src/lightning/fabric/CHANGELOG.md @@ -135,6 +135,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed an issue with hpu imports leading to performance degradation ([#17788](https://github.com/Lightning-AI/lightning/pull/17788)) +- Fixed check for FSDP's flat parameters in all parameter groups ([#17914](https://github.com/Lightning-AI/lightning/pull/17914)) + + ## [2.0.3] - 2023-06-07 - Added support for `Callback` registration through entry points ([#17756](https://github.com/Lightning-AI/lightning/pull/17756)) diff --git a/src/lightning/fabric/strategies/fsdp.py b/src/lightning/fabric/strategies/fsdp.py index c09e91fa87f2a..09c24a444fcdf 100644 --- a/src/lightning/fabric/strategies/fsdp.py +++ b/src/lightning/fabric/strategies/fsdp.py @@ -637,11 +637,13 @@ def _init_cpu_offload(cpu_offload: Optional[Union[bool, "CPUOffload"]]) -> "CPUO def _optimizer_has_flat_params(optimizer: Optimizer) -> bool: _FSDP_FLATTENED = "_fsdp_flattened" if _TORCH_GREATER_EQUAL_1_13: - return any(getattr(param, _FSDP_FLATTENED, False) for param in optimizer.param_groups[0]["params"]) + return any( + getattr(param, _FSDP_FLATTENED, False) for group in optimizer.param_group for param in group["params"] + ) from torch.distributed.fsdp import FlatParameter - return any(isinstance(param, FlatParameter) for param in optimizer.param_groups[0]["params"]) + return any(isinstance(param, FlatParameter) for group in optimizer.param_groups for param in group["params"]) def _get_sharded_state_dict_context(module: "FullyShardedDataParallel") -> _GeneratorContextManager: diff --git a/tests/README.md b/tests/README.md index 637c7ccbc0834..b042044e12e09 100644 --- a/tests/README.md +++ b/tests/README.md @@ -12,7 +12,7 @@ To setup a local development environment, install both local and test dependenci git clone https://github.com/Lightning-AI/lightning.git cd lightning -# install required depedencies +# install required dependencies export PACKAGE_NAME=pytorch python -m pip install ".[dev, examples]" # install pre-commit (optional)