Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generalize Optimizer validation to accommodate both FSDP 1.x and 2.x #16733

Merged
merged 14 commits into from
Mar 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/lightning/fabric/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Changed

- Generalized `Optimizer` validation to accommodate both FSDP 1.x and 2.x ([#16733](https://github.com/Lightning-AI/lightning/pull/16733))


-


Expand Down
8 changes: 6 additions & 2 deletions src/lightning/fabric/strategies/fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,10 @@ def _init_cpu_offload(cpu_offload: Optional[Union[bool, "CPUOffload"]]) -> "CPUO


def _optimizer_has_flat_params(optimizer: Optimizer) -> bool:
from torch.distributed.fsdp import FlatParameter
_FSDP_FLATTENED = "_fsdp_flattened"
if _TORCH_GREATER_EQUAL_1_13:
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
return any(getattr(param, _FSDP_FLATTENED, False) for param in optimizer.param_groups[0]["params"])
else:
from torch.distributed.fsdp import FlatParameter

return any(isinstance(param, FlatParameter) for param in optimizer.param_groups[0]["params"])
return any(isinstance(param, FlatParameter) for param in optimizer.param_groups[0]["params"])
3 changes: 3 additions & 0 deletions src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Changed

- Generalized `Optimizer` validation to accommodate both FSDP 1.x and 2.x ([#16733](https://github.com/Lightning-AI/lightning/pull/16733))


- Changes to the `NeptuneLogger` ([#16761](https://github.com/Lightning-AI/lightning/pull/16761)):
* It now supports neptune-client 0.16.16 and neptune >=1.0, and we have replaced the `log()` method with `append()` and `extend()`.
* It now accepts a namespace `Handler` as an alternative to `Run` for the `run` argument. This means that you can call it like `NeptuneLogger(run=run["some/namespace"])` to log everything to the `some/namespace/` location of the run.
Expand Down
74 changes: 59 additions & 15 deletions tests/tests_pytorch/strategies/test_fsdp.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
from typing import Any, Dict, Optional
from functools import partial
from typing import Any, Callable, Dict, Optional
from unittest import mock
from unittest.mock import ANY, Mock

Expand All @@ -18,7 +19,14 @@

if _TORCH_GREATER_EQUAL_1_12:
from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload, FullyShardedDataParallel, MixedPrecision
from torch.distributed.fsdp.wrap import wrap
from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy, wrap
else:
size_based_auto_wrap_policy = object

if _TORCH_GREATER_EQUAL_2_0:
from torch.distributed.fsdp.wrap import _FSDPPolicy
else:
_FSDPPolicy = object


class TestFSDPModel(BoringModel):
Expand Down Expand Up @@ -117,17 +125,18 @@ def _run_multiple_stages(trainer, model, model_path: Optional[str] = None):

_assert_save_equality(trainer, model_path, cls=model.__class__)

# Test entry point
trainer.test(model) # model is wrapped, will not call `configure_sharded_model`
with torch.inference_mode():
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
# Test entry point
trainer.test(model) # model is wrapped, will not call `configure_sharded_model`

# provide model path, will create a new unwrapped model and load and then call `configure_shared_model` to wrap
trainer.test(ckpt_path=model_path)
# provide model path, will create a new unwrapped model and load and then call `configure_shared_model` to wrap
trainer.test(ckpt_path=model_path)

# Predict entry point
trainer.predict(model) # model is wrapped, will not call `configure_sharded_model`
# Predict entry point
trainer.predict(model) # model is wrapped, will not call `configure_sharded_model`

# provide model path, will create a new unwrapped model and load and then call `configure_shared_model` to wrap
trainer.predict(ckpt_path=model_path)
# provide model path, will create a new unwrapped model and load and then call `configure_shared_model` to wrap
trainer.predict(ckpt_path=model_path)


def _assert_save_equality(trainer, ckpt_path, cls=TestFSDPModel):
Expand Down Expand Up @@ -200,6 +209,20 @@ def test_fsdp_strategy_checkpoint(tmpdir, precision):
_run_multiple_stages(trainer, model, os.path.join(tmpdir, "last.ckpt"))


class CustomWrapPolicy(_FSDPPolicy):
"""This is a wrapper around :func:`_module_wrap_policy`."""

def __init__(self, min_num_params: int):
self._policy: Callable = partial(size_based_auto_wrap_policy, min_num_params=min_num_params)

@property
def policy(self):
return self._policy


custom_fsdp_policy = CustomWrapPolicy(min_num_params=2)


if _TORCH_GREATER_EQUAL_2_0:

def custom_auto_wrap_policy(
Expand All @@ -221,19 +244,40 @@ def custom_auto_wrap_policy(

@RunIf(min_cuda_gpus=2, skip_windows=True, standalone=True, min_torch="1.12")
@pytest.mark.parametrize(
"model, strategy",
"model, strategy, strategy_cfg",
[
(TestFSDPModel(), "fsdp"),
(TestFSDPModelAutoWrapped(), FSDPStrategy),
pytest.param(TestFSDPModel(), "fsdp", None, id="manually_wrapped"),
pytest.param(
TestFSDPModelAutoWrapped(),
FSDPStrategy,
{"auto_wrap_policy": custom_auto_wrap_policy},
marks=RunIf(max_torch="2.0.0"),
id="autowrap_1x",
),
pytest.param(
TestFSDPModelAutoWrapped(),
FSDPStrategy,
{"auto_wrap_policy": custom_auto_wrap_policy},
marks=RunIf(min_torch="2.0.0"),
id="autowrap_2x",
),
pytest.param(
TestFSDPModelAutoWrapped(),
FSDPStrategy,
{"auto_wrap_policy": custom_fsdp_policy, "use_orig_params": True},
marks=RunIf(min_torch="2.0.0"),
id="autowrap_use_orig_params",
),
],
)
def test_fsdp_checkpoint_multi_gpus(tmpdir, model, strategy):
def test_fsdp_checkpoint_multi_gpus(tmpdir, model, strategy, strategy_cfg):
"""Test to ensure that checkpoint is saved correctly when using multiple GPUs, and all stages can be run."""

ck = ModelCheckpoint(save_last=True)

strategy_cfg = strategy_cfg or {}
if not isinstance(strategy, str):
strategy = strategy(auto_wrap_policy=custom_auto_wrap_policy)
strategy = strategy(**strategy_cfg)

trainer = Trainer(
default_root_dir=tmpdir,
Expand Down