Skip to content

Commit

Permalink
Introduce shard-merging util for FSDP (#2772)
Browse files Browse the repository at this point in the history
* Initial commit

* Now to test

* Store false

* Slight tweaks

* Fix naming

* Got it all working with tests

* Use not for safetensors arg

* rm change

* Add docs

* Adjust based on Marc's feedback

* Specify just weights

* Update tests to include CLI and swap namings

* Fin

* Rm unused

* Rm again
  • Loading branch information
muellerzr authored May 16, 2024
1 parent 91e8a3c commit 4ba436e
Show file tree
Hide file tree
Showing 9 changed files with 321 additions and 2 deletions.
2 changes: 2 additions & 0 deletions docs/source/package_reference/fsdp.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,6 @@ rendered properly in your Markdown viewer.

# Utilities for Fully Sharded Data Parallelism

[[autodoc]] utils.merge_fsdp_weights

[[autodoc]] utils.FullyShardedDataParallelPlugin
16 changes: 16 additions & 0 deletions docs/source/usage_guides/fsdp.md
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,22 @@ When using transformers `save_pretrained`, pass `state_dict=accelerator.get_stat

You can then pass `state` into the `save_pretrained` method. There are several modes for `StateDictType` and `FullStateDictConfig` that you can use to control the behavior of `state_dict`. For more information, see the [PyTorch documentation](https://pytorch.org/docs/stable/fsdp.html).

If you choose to use `StateDictType.SHARDED_STATE_DICT`, the weights of the model during `Accelerator.save_state` will be split into `n` files for each sub-split on the model. To merge them back into
a single dictionary to load back into the model later after training you can use the `merge_weights` utility:

```py
from accelerate.utils import merge_fsdp_weights

# Our weights are saved usually in a `pytorch_model_fsdp_{model_number}` folder
merge_fsdp_weights("pytorch_model_fsdp_0", "output_path", safe_serialization=True)
```
The final output will then either be saved to `model.safetensors` or `pytorch_model.bin` (if `safe_serialization=False` is passed).

This can also be called using the CLI:
```bash
accelerate merge-weights pytorch_model_fsdp_0/ output_path
```


## Mapping between FSDP sharding strategies and DeepSpeed ZeRO Stages
* `FULL_SHARD` maps to the DeepSpeed `ZeRO Stage-3`. Shards optimizer states, gradients and parameters.
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@
"accelerate-config=accelerate.commands.config:main",
"accelerate-estimate-memory=accelerate.commands.estimate:main",
"accelerate-launch=accelerate.commands.launch:main",
"accelerate-merge-weights=accelerate.commands.merge:main",
]
},
python_requires=">=3.8.0",
Expand Down
2 changes: 2 additions & 0 deletions src/accelerate/commands/accelerate_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from accelerate.commands.env import env_command_parser
from accelerate.commands.estimate import estimate_command_parser
from accelerate.commands.launch import launch_command_parser
from accelerate.commands.merge import merge_command_parser
from accelerate.commands.test import test_command_parser
from accelerate.commands.tpu import tpu_command_parser
from accelerate.commands.utils import CustomArgumentParser
Expand All @@ -32,6 +33,7 @@ def main():
estimate_command_parser(subparsers=subparsers)
env_command_parser(subparsers=subparsers)
launch_command_parser(subparsers=subparsers)
merge_command_parser(subparsers=subparsers)
tpu_command_parser(subparsers=subparsers)
test_command_parser(subparsers=subparsers)

Expand Down
69 changes: 69 additions & 0 deletions src/accelerate/commands/merge.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
#!/usr/bin/env python

# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from accelerate.commands.utils import CustomArgumentParser
from accelerate.utils import merge_fsdp_weights


description = """Utility to merge the weights from multiple FSDP checkpoints into a single combined checkpoint. Should be used if
`SHARDED_STATE_DICT` was used for the model. Weights will be saved to `{output_path}`.
This is a CPU-bound process and requires enough RAM to load the entire model state dict."""


def merge_command(args):
merge_fsdp_weights(
args.checkpoint_directory, args.output_path, not args.unsafe_serialization, args.remove_checkpoint_dir
)


def merge_command_parser(subparsers=None):
if subparsers is not None:
parser = subparsers.add_parser("merge-weights", description=description)
else:
parser = CustomArgumentParser(description=description)

parser.add_argument("checkpoint_directory", type=str, help="A directory containing sharded weights saved by FSDP.")
parser.add_argument(
"output_path",
type=str,
help="The path to save the merged weights. Defaults to the current directory. ",
)
parser.add_argument(
"--unsafe_serialization",
action="store_false",
default=True,
help="Whether to save the merged weights as `.bin` rather than `.safetensors` (not recommended).",
)
parser.add_argument(
"--remove_checkpoint_dir",
action="store_true",
help="Whether to remove the checkpoint directory after merging.",
default=False,
)

if subparsers is not None:
parser.set_defaults(func=merge_command)
return parser


def main():
parser = merge_command_parser()
args = parser.parse_args()
merge_command(args)


if __name__ == "__main__":
main()
160 changes: 160 additions & 0 deletions src/accelerate/test_utils/scripts/test_merge_weights.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import gc
import logging
import shutil
from pathlib import Path

import torch
from safetensors.torch import load_file
from torch.distributed.fsdp.fully_sharded_data_parallel import ShardingStrategy, StateDictType
from torch.utils.data import DataLoader

from accelerate import Accelerator, FullyShardedDataParallelPlugin
from accelerate.commands.merge import merge_command, merge_command_parser
from accelerate.state import AcceleratorState
from accelerate.test_utils.training import RegressionDataset
from accelerate.utils import merge_fsdp_weights, patch_environment, save_fsdp_model


logging.basicConfig(level=logging.INFO)

parser = merge_command_parser()


class TinyModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear1 = torch.nn.Linear(16, 16)
self.activation = torch.nn.ReLU()
self.linear2 = torch.nn.Linear(16, 16)
self.softmax = torch.nn.Softmax()

def forward(self, x):
return self.linear2(self.activation(self.linear1(x)))


def setup():
if AcceleratorState._shared_state != {}:
AcceleratorState()._reset_state()
plugin = FullyShardedDataParallelPlugin(
sharding_strategy=ShardingStrategy.FULL_SHARD, state_dict_type=StateDictType.SHARDED_STATE_DICT
)
model = TinyModel()
with patch_environment(fsdp_auto_wrap_policy="SIZE_BASED_WRAP"):
plugin.set_auto_wrap_policy(model)
accelerator = Accelerator(fsdp_plugin=plugin)
model = accelerator.prepare(model)
return model, plugin, accelerator


def mock_training(accelerator, model):
train_set = RegressionDataset(length=128, seed=42)
train_dl = DataLoader(train_set, batch_size=16, shuffle=False)
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)

train_dl, model, optimizer = accelerator.prepare(train_dl, model, optimizer)
for _ in range(3):
for batch in train_dl:
model.zero_grad()
output = model(batch["x"])
loss = torch.nn.functional.mse_loss(output, batch["y"])
accelerator.backward(loss)
optimizer.step()
return model


def check_weights(operation, state_1, state_2):
for weight_1, weight_2 in zip(state_1.values(), state_2.values()):
if str(weight_1.device) != "cuda":
weight_1 = weight_1.to("cuda")
if str(weight_2.device) != "cuda":
weight_2 = weight_2.to("cuda")
if operation == "same":
assert torch.allclose(weight_1, weight_2)
else:
assert not torch.allclose(weight_1, weight_2)


def check_safetensors_weights(path, model):
safe_state_dict = load_file(path / "model.safetensors")
safe_loaded_model = TinyModel()
check_weights("diff", model.state_dict(), safe_loaded_model.state_dict())
safe_loaded_model.load_state_dict(safe_state_dict)
check_weights("same", model.state_dict(), safe_loaded_model.state_dict())


def check_pytorch_weights(path, model):
nonsafe_state_dict = torch.load(path / "pytorch_model.bin")
nonsafe_loaded_model = TinyModel()
check_weights("diff", model.state_dict(), nonsafe_loaded_model.state_dict())
nonsafe_loaded_model.load_state_dict(nonsafe_state_dict)
check_weights("same", model.state_dict(), nonsafe_loaded_model.state_dict())


def test_merge_weights_safetensors(model, path):
# Should now be saved at `path/merged.safetensors`
merge_fsdp_weights(path / "pytorch_model_fsdp_0", path, safe_serialization=True)
check_safetensors_weights(path, model)


def test_merge_weights_command_safetensors(model, path):
args = parser.parse_args([str(path / "pytorch_model_fsdp_0"), str(path)])
merge_command(args)
check_safetensors_weights(path, model)


def test_merge_weights_pytorch(model, path):
# Should now be saved at `path/merged.bin`
merge_fsdp_weights(path / "pytorch_model_fsdp_0", path, safe_serialization=False)
check_pytorch_weights(path, model)


def test_merge_weights_command_pytorch(model, path):
args = parser.parse_args([str(path / "pytorch_model_fsdp_0"), str(path), "--unsafe_serialization"])
merge_command(args)
check_pytorch_weights(path, model)


if __name__ == "__main__":
# Note this test requires at least two accelerators!
model, plugin, accelerator = setup()
if accelerator.num_processes > 1:
try:
# Initial setup for things
out_path = Path("test_merge_weights_fsdp_weights")
if not out_path.exists():
out_path.mkdir(parents=True, exist_ok=True)

# Train briefly once weights aren't the baseline
model = mock_training(accelerator, model)
accelerator.wait_for_everyone()

gc.collect() # Needed for some lingering refs after training
save_fsdp_model(plugin, accelerator, model, out_path)
accelerator.wait_for_everyone()

# Finally we can test
test_merge_weights_safetensors(model, out_path)
test_merge_weights_command_safetensors(model, out_path)
test_merge_weights_pytorch(model, out_path)
test_merge_weights_command_pytorch(model, out_path)
except Exception:
raise
finally:
# Cleanup in case of any failures
if accelerator.is_main_process:
shutil.rmtree(out_path)
accelerator.wait_for_everyone()
2 changes: 1 addition & 1 deletion src/accelerate/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@
)

from .bnb import has_4bit_bnb_layers, load_and_quantize_model
from .fsdp_utils import load_fsdp_model, load_fsdp_optimizer, save_fsdp_model, save_fsdp_optimizer
from .fsdp_utils import load_fsdp_model, load_fsdp_optimizer, merge_fsdp_weights, save_fsdp_model, save_fsdp_optimizer
from .launch import (
PrepareForLaunch,
_filter_args,
Expand Down
63 changes: 62 additions & 1 deletion src/accelerate/utils/fsdp_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,22 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import shutil
from pathlib import Path

import torch

from ..logging import get_logger
from .constants import FSDP_MODEL_NAME, FSDP_PYTORCH_VERSION, OPTIMIZER_NAME
from .constants import FSDP_MODEL_NAME, FSDP_PYTORCH_VERSION, OPTIMIZER_NAME, SAFE_WEIGHTS_NAME, WEIGHTS_NAME
from .imports import is_torch_distributed_available
from .modeling import is_peft_model
from .other import save
from .versions import is_torch_version


if is_torch_version(">=", FSDP_PYTORCH_VERSION) and is_torch_distributed_available():
import torch.distributed.checkpoint as dist_cp
import torch.distributed.checkpoint.format_utils as dist_cp_format_utils
from torch.distributed.checkpoint.default_planner import DefaultLoadPlanner, DefaultSavePlanner
from torch.distributed.checkpoint.optimizer import load_sharded_optimizer_state_dict
from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP
Expand Down Expand Up @@ -207,3 +211,60 @@ def load_fsdp_optimizer(fsdp_plugin, accelerator, optimizer, model, input_dir, o
logger.info(f"Optimizer loaded from {ckpt_dir}")
flattened_osd = FSDP.optim_state_dict_to_load(model=model, optim=optimizer, optim_state_dict=optim_state)
optimizer.load_state_dict(flattened_osd)


def _distributed_checkpoint_to_merged_weights(checkpoint_dir: str, save_path: str, safe_serialization: bool = True):
"""
Passthrough to `torch.distributed.checkpoint.format_utils.dcp_to_torch_save`
Will save under `save_path` as either `model.safetensors` or `pytorch_model.bin`.
"""
state_dict = {}
save_path = Path(save_path)
dist_cp_format_utils._load_state_dict(
state_dict,
storage_reader=dist_cp.FileSystemReader(checkpoint_dir),
planner=dist_cp_format_utils._EmptyStateDictLoadPlanner(),
no_dist=True,
)
save_path = save_path / SAFE_WEIGHTS_NAME if safe_serialization else save_path / WEIGHTS_NAME

# To handle if state is a dict like {model: {...}}
if len(state_dict.keys()) == 1:
state_dict = state_dict[list(state_dict)[0]]
save(state_dict, save_path, safe_serialization=safe_serialization)
return save_path


def merge_fsdp_weights(
checkpoint_dir: str, output_path: str, safe_serialization: bool = True, remove_checkpoint_dir: bool = False
):
"""
Merge the weights from sharded FSDP model checkpoints into a single combined checkpoint. Should be used if
`SHARDED_STATE_DICT` was used for the model. Weights will be saved to `{output_path}/model.safetensors` if
`safe_serialization` else `pytorch_model.bin`.
Note: this is a CPU-bound process.
Args:
checkpoint_dir (`str`):
The directory containing the FSDP checkpoints (can be either the model or optimizer).
output_path (`str`):
The path to save the merged checkpoint.
safe_serialization (`bool`, *optional*, defaults to `True`):
Whether to save the merged weights with safetensors (recommended).
remove_checkpoint_dir (`bool`, *optional*, defaults to `False`):
Whether to remove the checkpoint directory after merging.
"""
from accelerate.state import PartialState

# To setup `save` to work
state = PartialState()
if state.is_main_process:
logger.info(f"Merging FSDP weights from {checkpoint_dir}")
save_path = _distributed_checkpoint_to_merged_weights(checkpoint_dir, output_path, safe_serialization)
logger.info(f"Successfully merged FSDP weights and saved to {save_path}")
if remove_checkpoint_dir:
logger.info(f"Removing old checkpoint directory {checkpoint_dir}")
shutil.rmtree(checkpoint_dir)
state.wait_for_everyone()
Loading

0 comments on commit 4ba436e

Please sign in to comment.