Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Lightning 2.1.0 no longer supports saving/loading FSDP checkpoints with PyTorch < 2.0 #18230

Closed
speediedan opened this issue Aug 4, 2023 · 7 comments · Fixed by #18296
Closed
Assignees
Labels
bug Something isn't working strategy: fsdp Fully Sharded Data Parallel ver: 2.1.x
Milestone

Comments

@speediedan
Copy link
Contributor

speediedan commented Aug 4, 2023

Bug description

With the latest dev commit as of this writing (0aeeb60), Lightning imports do not allow saving/loading of FSDP checkpoints with PyTorch < 2.0:

./tests/tests_pytorch/strategies/test_fsdp.py::test_fsdp_strategy_save_optimizer_states[2] Failed: [undefined]ModuleNotFoundError: No module named 'torch.distributed.fsdp.api'
tmpdir = local('/tmp/pytest-of-speediedan/pytest-807/test_fsdp_strategy_save_optimi0')
wrap_min_params = 2

    @RunIf(min_cuda_gpus=2, skip_windows=True, standalone=False, min_torch="1.12")
    @pytest.mark.parametrize("wrap_min_params", [2, 1024, 100000000])
    def test_fsdp_strategy_save_optimizer_states(tmpdir, wrap_min_params):
        """Test to ensure that the full state dict and optimizer states is saved when using FSDP strategy.
    
        Based on `wrap_min_params`, the model will be fully wrapped, half wrapped, and not wrapped at all. If the model can
        be restored to DDP, it means that the optimizer states were saved correctly.
        """
        model = TestFSDPModelAutoWrapped(wrap_min_params=wrap_min_params)
    
        strategy = FSDPStrategy(auto_wrap_policy=partial(size_based_auto_wrap_policy, min_num_params=wrap_min_params))
        trainer = Trainer(
            default_root_dir=tmpdir,
            accelerator="gpu",
            devices=2,
            strategy=strategy,
            precision="16-mixed",
            max_epochs=1,
            barebones=True,
        )
    
        trainer.fit(model)
        model_path = os.path.join(tmpdir, "last.ckpt")
        model_path = trainer.strategy.broadcast(model_path)
>       trainer.save_checkpoint(model_path)

tests/tests_pytorch/strategies/test_fsdp.py:577: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
src/lightning/pytorch/trainer/trainer.py:1360: in save_checkpoint
    checkpoint = self._checkpoint_connector.dump_checkpoint(weights_only)
src/lightning/pytorch/trainer/connectors/checkpoint_connector.py:433: in dump_checkpoint
    "state_dict": self._get_lightning_module_state_dict(),
src/lightning/pytorch/trainer/connectors/checkpoint_connector.py:491: in _get_lightning_module_state_dict
    return self.trainer.strategy.lightning_module_state_dict()

self = <lightning.pytorch.strategies.fsdp.FSDPStrategy object at 0x7fd315bcb850>

    def lightning_module_state_dict(self) -> Dict[str, Any]:
        """Gathers the full state dict by unsharding all the parameters.
    
        To avoid OOM, the returned parameters will only be returned on rank 0 and on CPU. All other ranks get an empty
        dict.
        """
        from torch.distributed.fsdp import FullyShardedDataParallel
>       from torch.distributed.fsdp.api import FullStateDictConfig, StateDictType
E       ModuleNotFoundError: No module named 'torch.distributed.fsdp.api'

src/lightning/pytorch/strategies/fsdp.py:177: ModuleNotFoundError

Also note that both save and load code-paths use the state_dict_type context manager and attempt to import from FSDP PyTorch 2.0 locations even with PyTorch < 2.0.
https://github.com/Lightning-AI/lightning/blob/0aeeb60566cc0375df3cf1a4458592651f143717/src/lightning/fabric/strategies/fsdp.py#L792-L819

Finally, I don't believe FullOptimStateDictConfig is defined in the FSDP 1.x API so that may need to be worked around if support for 1.x FSDP continues.

I imagine the above challenges could be surmounted to continue providing support for saving/loading FSDP checkpoints with PyTorch < 2.0 but I wanted to ensure that was the intention. If deprecation of this FSDP functionality for PyTorch 1.x is expected I'll go ahead and begin deprecating this functionality in finetuning-scheduler.

Thanks again for all your invaluable contributions to the open-source ML ecosystem!

What version are you seeing the problem on?

master

How to reproduce the bug

To reproduce, install `torch==1.13.1` and run the following existing test:
./tests/tests_pytorch/strategies/test_fsdp.py::test_fsdp_strategy_save_optimizer_states[2]

Error messages and logs

./tests/tests_pytorch/strategies/test_fsdp.py::test_fsdp_strategy_save_optimizer_states[2] Failed: [undefined]ModuleNotFoundError: No module named 'torch.distributed.fsdp.api'
tmpdir = local('/tmp/pytest-of-speediedan/pytest-807/test_fsdp_strategy_save_optimi0')
wrap_min_params = 2

    @RunIf(min_cuda_gpus=2, skip_windows=True, standalone=False, min_torch="1.12")
    @pytest.mark.parametrize("wrap_min_params", [2, 1024, 100000000])
    def test_fsdp_strategy_save_optimizer_states(tmpdir, wrap_min_params):
        """Test to ensure that the full state dict and optimizer states is saved when using FSDP strategy.
    
        Based on `wrap_min_params`, the model will be fully wrapped, half wrapped, and not wrapped at all. If the model can
        be restored to DDP, it means that the optimizer states were saved correctly.
        """
        model = TestFSDPModelAutoWrapped(wrap_min_params=wrap_min_params)
    
        strategy = FSDPStrategy(auto_wrap_policy=partial(size_based_auto_wrap_policy, min_num_params=wrap_min_params))
        trainer = Trainer(
            default_root_dir=tmpdir,
            accelerator="gpu",
            devices=2,
            strategy=strategy,
            precision="16-mixed",
            max_epochs=1,
            barebones=True,
        )
    
        trainer.fit(model)
        model_path = os.path.join(tmpdir, "last.ckpt")
        model_path = trainer.strategy.broadcast(model_path)
>       trainer.save_checkpoint(model_path)

tests/tests_pytorch/strategies/test_fsdp.py:577: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
src/lightning/pytorch/trainer/trainer.py:1360: in save_checkpoint
    checkpoint = self._checkpoint_connector.dump_checkpoint(weights_only)
src/lightning/pytorch/trainer/connectors/checkpoint_connector.py:433: in dump_checkpoint
    "state_dict": self._get_lightning_module_state_dict(),
src/lightning/pytorch/trainer/connectors/checkpoint_connector.py:491: in _get_lightning_module_state_dict
    return self.trainer.strategy.lightning_module_state_dict()

self = <lightning.pytorch.strategies.fsdp.FSDPStrategy object at 0x7fd315bcb850>

    def lightning_module_state_dict(self) -> Dict[str, Any]:
        """Gathers the full state dict by unsharding all the parameters.
    
        To avoid OOM, the returned parameters will only be returned on rank 0 and on CPU. All other ranks get an empty
        dict.
        """
        from torch.distributed.fsdp import FullyShardedDataParallel
>       from torch.distributed.fsdp.api import FullStateDictConfig, StateDictType
E       ModuleNotFoundError: No module named 'torch.distributed.fsdp.api'

src/lightning/pytorch/strategies/fsdp.py:177: ModuleNotFoundError

Environment

Current environment
  • CUDA:
    • GPU:
      • NVIDIA GeForce RTX 4090
      • NVIDIA GeForce RTX 2070 SUPER
    • available: True
    • version: 11.7
  • Lightning:
    • lightning: 2.1.0.dev0
    • lightning-api-access: 0.0.5
    • lightning-cloud: 0.5.37
    • lightning-fabric: 2.0.6
    • lightning-utilities: 0.9.0
    • pt-lightning-sphinx-theme: 0.0.31
    • pytorch-lightning: 2.0.6
    • torch: 1.13.1
    • torchmetrics: 1.0.2
    • torchvision: 0.14.1
  • Packages:
    • absl-py: 1.4.0
    • aiobotocore: 2.5.2
    • aiohttp: 3.8.5
    • aioitertools: 0.11.0
    • aiosignal: 1.3.1
    • alabaster: 0.7.13
    • altair: 5.0.1
    • annotated-types: 0.5.0
    • antlr4-python3-runtime: 4.9.3
    • anyio: 3.7.1
    • apeye: 1.4.0
    • apeye-core: 1.1.4
    • argon2-cffi: 21.3.0
    • argon2-cffi-bindings: 21.2.0
    • arrow: 1.2.3
    • asttokens: 2.2.1
    • async-generator: 1.10
    • async-lru: 2.0.4
    • async-timeout: 4.0.2
    • attrs: 23.1.0
    • autodocsumm: 0.2.11
    • babel: 2.12.1
    • backcall: 0.2.0
    • backoff: 2.2.1
    • beautifulsoup4: 4.12.2
    • bleach: 6.0.0
    • blessed: 1.20.0
    • blinker: 1.6.2
    • bokeh: 3.2.1
    • botocore: 1.29.161
    • bracex: 2.3.post1
    • brotlipy: 0.7.0
    • cachecontrol: 0.13.1
    • cachetools: 5.3.1
    • certifi: 2023.7.22
    • cffi: 1.15.1
    • charset-normalizer: 2.0.4
    • click: 8.1.6
    • cloudpickle: 2.2.1
    • colorama: 0.4.6
    • coloredlogs: 15.0.1
    • comm: 0.1.4
    • contourpy: 1.1.0
    • coverage: 7.2.7
    • croniter: 1.4.1
    • cryptography: 41.0.2
    • cssutils: 2.7.1
    • cycler: 0.11.0
    • dateutils: 0.6.12
    • debugpy: 1.6.7
    • decorator: 5.1.1
    • deepdiff: 6.3.1
    • defusedxml: 0.7.1
    • dict2css: 0.3.0
    • docker: 6.1.3
    • docstring-parser: 0.15
    • docutils: 0.17.1
    • domdf-python-tools: 3.6.1
    • exceptiongroup: 1.1.2
    • executing: 1.2.0
    • fastapi: 0.100.1
    • fastjsonschema: 2.18.0
    • filelock: 3.12.2
    • fire: 0.5.0
    • flatbuffers: 23.5.26
    • fonttools: 4.42.0
    • fqdn: 1.5.1
    • frozenlist: 1.4.0
    • fsspec: 2023.6.0
    • gitdb: 4.0.10
    • gitpython: 3.1.32
    • google-auth: 2.22.0
    • google-auth-oauthlib: 1.0.0
    • greenlet: 2.0.2
    • grpcio: 1.56.2
    • h11: 0.14.0
    • html5lib: 1.1
    • httpcore: 0.17.3
    • httpx: 0.24.1
    • humanfriendly: 10.0
    • hydra-core: 1.3.2
    • idna: 3.4
    • imagesize: 1.4.1
    • importlib-metadata: 6.8.0
    • importlib-resources: 6.0.0
    • iniconfig: 2.0.0
    • inquirer: 3.1.3
    • ipykernel: 6.25.0
    • ipython: 8.6.0
    • ipywidgets: 8.1.0
    • isoduration: 20.11.0
    • itsdangerous: 2.1.2
    • jedi: 0.19.0
    • jinja2: 3.0.3
    • jmespath: 1.0.1
    • joblib: 1.3.1
    • json5: 0.9.14
    • jsonargparse: 4.22.1
    • jsonpointer: 2.4
    • jsonschema: 4.18.6
    • jsonschema-specifications: 2023.7.1
    • jupyter-client: 8.3.0
    • jupyter-core: 5.3.1
    • jupyter-events: 0.7.0
    • jupyter-lsp: 2.2.0
    • jupyter-server: 2.7.0
    • jupyter-server-terminals: 0.4.4
    • jupyterlab: 4.0.4
    • jupyterlab-pygments: 0.2.2
    • jupyterlab-server: 2.24.0
    • jupyterlab-widgets: 3.0.8
    • kiwisolver: 1.4.4
    • lightning: 2.1.0.dev0
    • lightning-api-access: 0.0.5
    • lightning-cloud: 0.5.37
    • lightning-fabric: 2.0.6
    • lightning-utilities: 0.9.0
    • linkify-it-py: 2.0.2
    • livereload: 2.6.3
    • lockfile: 0.12.2
    • markdown: 3.4.4
    • markdown-it-py: 2.2.0
    • markupsafe: 2.1.3
    • matplotlib: 3.7.2
    • matplotlib-inline: 0.1.6
    • mdit-py-plugins: 0.3.5
    • mdurl: 0.1.2
    • mistune: 3.0.1
    • mkl-fft: 1.3.6
    • mkl-random: 1.2.2
    • mkl-service: 2.4.0
    • mpmath: 1.3.0
    • msgpack: 1.0.5
    • multidict: 6.0.4
    • myst-parser: 0.18.1
    • natsort: 8.4.0
    • nbclient: 0.8.0
    • nbconvert: 7.7.3
    • nbformat: 5.9.2
    • nbsphinx: 0.8.9
    • nest-asyncio: 1.5.7
    • notebook: 7.0.1
    • notebook-shim: 0.2.3
    • numpy: 1.25.0
    • oauthlib: 3.2.2
    • omegaconf: 2.3.0
    • onnx: 1.12.0
    • onnxruntime: 1.15.1
    • ordered-set: 4.1.0
    • outcome: 1.2.0
    • overrides: 7.3.1
    • packaging: 23.1
    • pandas: 2.0.3
    • pandoc: 2.3
    • pandocfilters: 1.5.0
    • panel: 1.2.1
    • param: 1.13.0
    • parso: 0.8.3
    • pexpect: 4.8.0
    • pickleshare: 0.7.5
    • pillow: 9.4.0
    • pip: 23.2.1
    • platformdirs: 3.10.0
    • playwright: 1.35.0
    • pluggy: 1.2.0
    • plumbum: 1.8.2
    • ply: 3.11
    • prometheus-client: 0.17.1
    • prompt-toolkit: 3.0.39
    • protobuf: 3.20.1
    • psutil: 5.9.5
    • pt-lightning-sphinx-theme: 0.0.31
    • ptyprocess: 0.7.0
    • pure-eval: 0.2.2
    • py: 1.11.0
    • pyarrow: 12.0.1
    • pyasn1: 0.5.0
    • pyasn1-modules: 0.3.0
    • pycparser: 2.21
    • pydantic: 2.0.3
    • pydantic-core: 2.3.0
    • pydeck: 0.8.0
    • pyee: 9.0.4
    • pygments: 2.15.1
    • pyjwt: 2.8.0
    • pympler: 1.0.1
    • pyopenssl: 23.2.0
    • pyparsing: 3.0.9
    • pysocks: 1.7.1
    • pytest: 7.4.0
    • pytest-asyncio: 0.21.1
    • pytest-cov: 4.1.0
    • pytest-doctestplus: 0.13.0
    • pytest-forked: 1.4.0
    • pytest-random-order: 1.1.0
    • pytest-rerunfailures: 10.3
    • pytest-timeout: 2.1.0
    • python-dateutil: 2.8.2
    • python-editor: 1.0.4
    • python-json-logger: 2.0.7
    • python-multipart: 0.0.6
    • pytorch-lightning: 2.0.6
    • pytz: 2023.3
    • pytz-deprecation-shim: 0.1.0.post0
    • pyviz-comms: 2.3.2
    • pyyaml: 6.0.1
    • pyzmq: 25.1.0
    • readchar: 4.0.5
    • redis: 4.6.0
    • referencing: 0.30.0
    • requests: 2.31.0
    • requests-mock: 1.11.0
    • requests-oauthlib: 1.3.1
    • rfc3339-validator: 0.1.4
    • rfc3986-validator: 0.1.1
    • rich: 13.5.2
    • rpds-py: 0.9.2
    • rsa: 4.9
    • ruamel.yaml: 0.17.32
    • ruamel.yaml.clib: 0.2.7
    • s3fs: 2023.6.0
    • scikit-learn: 1.3.0
    • scipy: 1.11.1
    • send2trash: 1.8.2
    • setuptools: 57.5.0
    • six: 1.16.0
    • smmap: 5.0.0
    • sniffio: 1.3.0
    • snowballstemmer: 2.2.0
    • sortedcontainers: 2.4.0
    • soupsieve: 2.4.1
    • sphinx: 4.5.0
    • sphinx-autobuild: 2021.3.14
    • sphinx-autodoc-typehints: 1.19.1
    • sphinx-copybutton: 0.5.2
    • sphinx-jinja2-compat: 0.2.0
    • sphinx-multiproject: 1.0.0rc1
    • sphinx-paramlinks: 0.5.4
    • sphinx-prompt: 1.5.0
    • sphinx-rtd-dark-mode: 1.2.4
    • sphinx-rtd-theme: 1.2.2
    • sphinx-tabs: 3.4.0
    • sphinx-togglebutton: 0.3.2
    • sphinx-toolbox: 3.4.0
    • sphinxcontrib-applehelp: 1.0.4
    • sphinxcontrib-devhelp: 1.0.2
    • sphinxcontrib-fulltoc: 1.2.0
    • sphinxcontrib-htmlhelp: 2.0.1
    • sphinxcontrib-jquery: 4.1
    • sphinxcontrib-jsmath: 1.0.1
    • sphinxcontrib-mockautodoc: 0.0.1.dev20130518
    • sphinxcontrib-qthelp: 1.0.3
    • sphinxcontrib-serializinghtml: 1.1.5
    • sphinxcontrib-video: 0.2.0
    • stack-data: 0.6.2
    • starlette: 0.27.0
    • starsessions: 1.3.0
    • streamlit: 1.25.0
    • sympy: 1.12
    • tabulate: 0.9.0
    • tenacity: 8.2.2
    • tensorboard: 2.13.0
    • tensorboard-data-server: 0.7.1
    • tensorboardx: 2.6.2
    • termcolor: 2.3.0
    • terminado: 0.17.1
    • threadpoolctl: 3.2.0
    • tinycss2: 1.2.1
    • toml: 0.10.2
    • tomli: 2.0.1
    • toolz: 0.12.0
    • torch: 1.13.1
    • torchmetrics: 1.0.2
    • torchvision: 0.14.1
    • tornado: 6.3.2
    • tqdm: 4.65.0
    • traitlets: 5.9.0
    • trio: 0.21.0
    • typeshed-client: 2.3.0
    • typing-extensions: 4.7.1
    • tzdata: 2023.3
    • tzlocal: 4.3.1
    • uc-micro-py: 1.0.2
    • uri-template: 1.3.0
    • urllib3: 1.26.16
    • uvicorn: 0.23.2
    • validators: 0.20.0
    • watchdog: 3.0.0
    • wcmatch: 8.4.1
    • wcwidth: 0.2.6
    • webcolors: 1.13
    • webencodings: 0.5.1
    • websocket-client: 1.6.1
    • websockets: 11.0.3
    • werkzeug: 2.3.6
    • wheel: 0.38.4
    • widgetsnbextension: 4.0.8
    • wrapt: 1.15.0
    • xyzservices: 2023.7.0
    • yarl: 1.9.2
    • zipp: 3.16.2
  • System:

More info

No response

cc @awaelchli @carmocca

@speediedan speediedan added bug Something isn't working needs triage Waiting to be triaged by maintainers labels Aug 4, 2023
@awaelchli
Copy link
Contributor

Hi @speediedan
Thanks for opening the issue.
The PR that changed this was here: #17819
If I recall correctly, we used the latest APIs and conditioned it on >= 2.0 there because we knew that the previous loading logic wasn't really working / was incorrect (see the linked issues in that PR).

We could however try to import from here https://github.com/pytorch/pytorch/blob/v1.13.1/torch/distributed/fsdp/__init__.py and maybe that's enough to make it torch 1.13 compatible.

@awaelchli awaelchli added strategy: fsdp Fully Sharded Data Parallel and removed needs triage Waiting to be triaged by maintainers labels Aug 4, 2023
@speediedan
Copy link
Contributor Author

We could however try to import from here https://github.com/pytorch/pytorch/blob/v1.13.1/torch/distributed/fsdp/__init__.py and maybe that's enough to make it torch 1.13 compatible.

Yeah, I tinkered with conditionally using the old FullStateDictConfig and StateDictType import locations for a few minutes before I opened this PR but when I noticed that nonextant (in 1.13.1) config classes (e.g. FullOptimStateDictConfig) were being used for the _get_*_state_dict_context context managers, I figured it was better to check with the team to see what you guys were planning before investigating further.

While I think we likely could backport this functionality to 1.x if we:

  1. Used the old import locations for FullStateDictConfig/StateDictType in a 1.x context
  2. Enhanced the _get_*_state_dict_context context managers to use a customized, backported version of the 2.x state_dict_type context manager that extends the 1.x state_dict_type context manager to set StateDictType appropriately for the optim state dicts of all descendant modules.

That would be a fair amount of custom code in the backport that could prove a fairly ugly/brittle solution though. As such, it may be worth considering the alternative of:

  1. Using the old import locations for FullStateDictConfig/StateDictType in a 1.x context only for lightning_module_state_dict imports
  2. In a 1.x context, for optimizer state dict save/load calls, provide a warning with a noop indicating optimizer state saving or loading is not supported for 1.x FSDP (after updating the Lightning documentation accordingly of course) (throwing an exception may be preferred rather than a warning and noop)

Open to other thoughts and suggestions of course (of which yours are so often awesome!). What do you think?

@awaelchli
Copy link
Contributor

Thanks for the suggestion. Keeping the loading for the model state compatible with 1.13 seems feasible, and warning/error for optimizer state is probably the easiest for now. Would that work for you and the finetuning-scheduler as well?

@speediedan
Copy link
Contributor Author

Absolutely, sounds great. finetuning-scheduler is overriding load_optimizer_state_dict and optimizer_state for other reasons already so I could just mirror the relevant Lightning warnings/errors in that context while continuing to rely upon Lightning's lightning_module_state_dict for the model state dict collection.

@carmocca
Copy link
Contributor

carmocca commented Aug 8, 2023

@speediedan Do you plan to work on this? We'd want to fix this before the next release to avoid breaking these checkpoints.

@carmocca carmocca added this to the 2.1 milestone Aug 8, 2023
@speediedan
Copy link
Contributor Author

Do you plan to work on this? We'd want to fix this before the next release to avoid breaking these checkpoints.

Not sure if I'll have the bandwidth in the next few days and wouldn't want to hold this up since I know it'll be important to ensure it's in 2.1. Certainly go ahead and implement. Thanks for checking!

@speediedan
Copy link
Contributor Author

@awaelchli I ran into #18277 today which is very close to this issue (#18230) in terms of modified code-path intersection so I figured it made sense to implement the discussed resolution to this issue in a PR that addresses both #18277 and #18230. Hope that's okay!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working strategy: fsdp Fully Sharded Data Parallel ver: 2.1.x
Projects
None yet
3 participants