Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce shard-merging util for FSDP #2772

Merged
merged 15 commits into from
May 16, 2024
2 changes: 2 additions & 0 deletions docs/source/package_reference/fsdp.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,6 @@ rendered properly in your Markdown viewer.

# Utilities for Fully Sharded Data Parallelism

[[autodoc]] utils.merge_fsdp_weights

[[autodoc]] utils.FullyShardedDataParallelPlugin
16 changes: 16 additions & 0 deletions docs/source/usage_guides/fsdp.md
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,22 @@ When using transformers `save_pretrained`, pass `state_dict=accelerator.get_stat

You can then pass `state` into the `save_pretrained` method. There are several modes for `StateDictType` and `FullStateDictConfig` that you can use to control the behavior of `state_dict`. For more information, see the [PyTorch documentation](https://pytorch.org/docs/stable/fsdp.html).

If you choose to use `StateDictType.SHARDED_STATE_DICT`, the weights of the model during `Accelerator.save_state` will be split into `n` files for each sub-split on the model. To merge them back into
a single dictionary to load back into the model later after training you can use the `merge_weights` utility:

```py
from accelerate.utils import merge_fsdp_weights

# Our weights are saved usually in a `pytorch_model_fsdp_{model_number}` folder
merge_fsdp_weights("pytorch_model_fsdp_0", "output_path", safe_serialization=True)
```
The final output will then either be saved to `model.safetensors` or `pytorch_model.bin` (if `safe_serialization=False` is passed).

This can also be called using the CLI:
```bash
accelerate merge-weights pytorch_model_fsdp_0/ output_path
```


## Mapping between FSDP sharding strategies and DeepSpeed ZeRO Stages
* `FULL_SHARD` maps to the DeepSpeed `ZeRO Stage-3`. Shards optimizer states, gradients and parameters.
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@
"accelerate-config=accelerate.commands.config:main",
"accelerate-estimate-memory=accelerate.commands.estimate:main",
"accelerate-launch=accelerate.commands.launch:main",
"accelerate-merge-weights=accelerate.commands.merge:main",
]
},
python_requires=">=3.8.0",
Expand Down
2 changes: 2 additions & 0 deletions src/accelerate/commands/accelerate_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from accelerate.commands.env import env_command_parser
from accelerate.commands.estimate import estimate_command_parser
from accelerate.commands.launch import launch_command_parser
from accelerate.commands.merge import merge_command_parser
from accelerate.commands.test import test_command_parser
from accelerate.commands.tpu import tpu_command_parser
from accelerate.commands.utils import CustomArgumentParser
Expand All @@ -32,6 +33,7 @@ def main():
estimate_command_parser(subparsers=subparsers)
env_command_parser(subparsers=subparsers)
launch_command_parser(subparsers=subparsers)
merge_command_parser(subparsers=subparsers)
tpu_command_parser(subparsers=subparsers)
test_command_parser(subparsers=subparsers)

Expand Down
69 changes: 69 additions & 0 deletions src/accelerate/commands/merge.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
#!/usr/bin/env python

# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from accelerate.commands.utils import CustomArgumentParser
from accelerate.utils import merge_fsdp_weights


description = """Utility to merge the weights from multiple FSDP checkpoints into a single combined checkpoint. Should be used if
`SHARDED_STATE_DICT` was used for the model. Weights will be saved to `{output_path}`.

This is a CPU-bound process and requires enough RAM to load the entire model state dict."""


def merge_command(args):
merge_fsdp_weights(
args.checkpoint_directory, args.output_path, not args.unsafe_serialization, args.remove_checkpoint_dir
)


def merge_command_parser(subparsers=None):
if subparsers is not None:
parser = subparsers.add_parser("merge-weights", description=description)
else:
parser = CustomArgumentParser(description=description)

parser.add_argument("checkpoint_directory", type=str, help="A directory containing sharded weights saved by FSDP.")
parser.add_argument(
"output_path",
type=str,
help="The path to save the merged weights. Defaults to the current directory. ",
)
parser.add_argument(
"--unsafe_serialization",
action="store_false",
default=True,
help="Whether to save the merged weights as `.bin` rather than `.safetensors` (not recommended).",
)
parser.add_argument(
"--remove_checkpoint_dir",
action="store_true",
help="Whether to remove the checkpoint directory after merging.",
default=False,
)

if subparsers is not None:
parser.set_defaults(func=merge_command)
return parser


def main():
parser = merge_command_parser()
args = parser.parse_args()
merge_command(args)


if __name__ == "__main__":
main()
160 changes: 160 additions & 0 deletions src/accelerate/test_utils/scripts/test_merge_weights.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import gc
import logging
import shutil
from pathlib import Path

import torch
from safetensors.torch import load_file
from torch.distributed.fsdp.fully_sharded_data_parallel import ShardingStrategy, StateDictType
from torch.utils.data import DataLoader

from accelerate import Accelerator, FullyShardedDataParallelPlugin
from accelerate.commands.merge import merge_command, merge_command_parser
from accelerate.state import AcceleratorState
from accelerate.test_utils.training import RegressionDataset
from accelerate.utils import merge_fsdp_weights, patch_environment, save_fsdp_model


logging.basicConfig(level=logging.INFO)

parser = merge_command_parser()


class TinyModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear1 = torch.nn.Linear(16, 16)
self.activation = torch.nn.ReLU()
self.linear2 = torch.nn.Linear(16, 16)
self.softmax = torch.nn.Softmax()

def forward(self, x):
return self.linear2(self.activation(self.linear1(x)))


def setup():
if AcceleratorState._shared_state != {}:
AcceleratorState()._reset_state()
plugin = FullyShardedDataParallelPlugin(
sharding_strategy=ShardingStrategy.FULL_SHARD, state_dict_type=StateDictType.SHARDED_STATE_DICT
)
model = TinyModel()
with patch_environment(fsdp_auto_wrap_policy="SIZE_BASED_WRAP"):
plugin.set_auto_wrap_policy(model)
accelerator = Accelerator(fsdp_plugin=plugin)
model = accelerator.prepare(model)
return model, plugin, accelerator


def mock_training(accelerator, model):
train_set = RegressionDataset(length=128, seed=42)
train_dl = DataLoader(train_set, batch_size=16, shuffle=False)
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)

train_dl, model, optimizer = accelerator.prepare(train_dl, model, optimizer)
for _ in range(3):
for batch in train_dl:
model.zero_grad()
output = model(batch["x"])
loss = torch.nn.functional.mse_loss(output, batch["y"])
accelerator.backward(loss)
optimizer.step()
return model


def check_weights(operation, state_1, state_2):
for weight_1, weight_2 in zip(state_1.values(), state_2.values()):
if str(weight_1.device) != "cuda":
weight_1 = weight_1.to("cuda")
if str(weight_2.device) != "cuda":
weight_2 = weight_2.to("cuda")
if operation == "same":
assert torch.allclose(weight_1, weight_2)
else:
assert not torch.allclose(weight_1, weight_2)


def check_safetensors_weights(path, model):
safe_state_dict = load_file(path / "model.safetensors")
safe_loaded_model = TinyModel()
check_weights("diff", model.state_dict(), safe_loaded_model.state_dict())
safe_loaded_model.load_state_dict(safe_state_dict)
check_weights("same", model.state_dict(), safe_loaded_model.state_dict())


def check_pytorch_weights(path, model):
nonsafe_state_dict = torch.load(path / "pytorch_model.bin")
nonsafe_loaded_model = TinyModel()
check_weights("diff", model.state_dict(), nonsafe_loaded_model.state_dict())
nonsafe_loaded_model.load_state_dict(nonsafe_state_dict)
check_weights("same", model.state_dict(), nonsafe_loaded_model.state_dict())


def test_merge_weights_safetensors(model, path):
# Should now be saved at `path/merged.safetensors`
merge_fsdp_weights(path / "pytorch_model_fsdp_0", path, safe_serialization=True)
check_safetensors_weights(path, model)


def test_merge_weights_command_safetensors(model, path):
args = parser.parse_args([str(path / "pytorch_model_fsdp_0"), str(path)])
merge_command(args)
check_safetensors_weights(path, model)


def test_merge_weights_pytorch(model, path):
# Should now be saved at `path/merged.bin`
merge_fsdp_weights(path / "pytorch_model_fsdp_0", path, safe_serialization=False)
check_pytorch_weights(path, model)


def test_merge_weights_command_pytorch(model, path):
args = parser.parse_args([str(path / "pytorch_model_fsdp_0"), str(path), "--unsafe_serialization"])
merge_command(args)
check_pytorch_weights(path, model)


if __name__ == "__main__":
# Note this test requires at least two accelerators!
model, plugin, accelerator = setup()
if accelerator.num_processes > 1:
try:
# Initial setup for things
out_path = Path("test_merge_weights_fsdp_weights")
if not out_path.exists():
out_path.mkdir(parents=True, exist_ok=True)

# Train briefly once weights aren't the baseline
model = mock_training(accelerator, model)
accelerator.wait_for_everyone()

gc.collect() # Needed for some lingering refs after training
save_fsdp_model(plugin, accelerator, model, out_path)
accelerator.wait_for_everyone()

# Finally we can test
test_merge_weights_safetensors(model, out_path)
test_merge_weights_command_safetensors(model, out_path)
test_merge_weights_pytorch(model, out_path)
test_merge_weights_command_pytorch(model, out_path)
except Exception:
raise
finally:
# Cleanup in case of any failures
if accelerator.is_main_process:
shutil.rmtree(out_path)
accelerator.wait_for_everyone()
2 changes: 1 addition & 1 deletion src/accelerate/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@
)

from .bnb import has_4bit_bnb_layers, load_and_quantize_model
from .fsdp_utils import load_fsdp_model, load_fsdp_optimizer, save_fsdp_model, save_fsdp_optimizer
from .fsdp_utils import load_fsdp_model, load_fsdp_optimizer, merge_fsdp_weights, save_fsdp_model, save_fsdp_optimizer
from .launch import (
PrepareForLaunch,
_filter_args,
Expand Down
63 changes: 62 additions & 1 deletion src/accelerate/utils/fsdp_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,22 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import shutil
from pathlib import Path

import torch

from ..logging import get_logger
from .constants import FSDP_MODEL_NAME, FSDP_PYTORCH_VERSION, OPTIMIZER_NAME
from .constants import FSDP_MODEL_NAME, FSDP_PYTORCH_VERSION, OPTIMIZER_NAME, SAFE_WEIGHTS_NAME, WEIGHTS_NAME
from .imports import is_torch_distributed_available
from .modeling import is_peft_model
from .other import save
from .versions import is_torch_version


if is_torch_version(">=", FSDP_PYTORCH_VERSION) and is_torch_distributed_available():
import torch.distributed.checkpoint as dist_cp
import torch.distributed.checkpoint.format_utils as dist_cp_format_utils
from torch.distributed.checkpoint.default_planner import DefaultLoadPlanner, DefaultSavePlanner
from torch.distributed.checkpoint.optimizer import load_sharded_optimizer_state_dict
from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP
Expand Down Expand Up @@ -207,3 +211,60 @@ def load_fsdp_optimizer(fsdp_plugin, accelerator, optimizer, model, input_dir, o
logger.info(f"Optimizer loaded from {ckpt_dir}")
flattened_osd = FSDP.optim_state_dict_to_load(model=model, optim=optimizer, optim_state_dict=optim_state)
optimizer.load_state_dict(flattened_osd)


def _distributed_checkpoint_to_merged_weights(checkpoint_dir: str, save_path: str, safe_serialization: bool = True):
"""
Passthrough to `torch.distributed.checkpoint.format_utils.dcp_to_torch_save`

Will save under `save_path` as either `model.safetensors` or `pytorch_model.bin`.
"""
state_dict = {}
save_path = Path(save_path)
dist_cp_format_utils._load_state_dict(
state_dict,
storage_reader=dist_cp.FileSystemReader(checkpoint_dir),
planner=dist_cp_format_utils._EmptyStateDictLoadPlanner(),
no_dist=True,
)
save_path = save_path / SAFE_WEIGHTS_NAME if safe_serialization else save_path / WEIGHTS_NAME

# To handle if state is a dict like {model: {...}}
if len(state_dict.keys()) == 1:
state_dict = state_dict[list(state_dict)[0]]
save(state_dict, save_path, safe_serialization=safe_serialization)
return save_path


def merge_fsdp_weights(
checkpoint_dir: str, output_path: str, safe_serialization: bool = True, remove_checkpoint_dir: bool = False
):
"""
Merge the weights from sharded FSDP model checkpoints into a single combined checkpoint. Should be used if
`SHARDED_STATE_DICT` was used for the model. Weights will be saved to `{output_path}/model.safetensors` if
`safe_serialization` else `pytorch_model.bin`.

Note: this is a CPU-bound process.

Args:
checkpoint_dir (`str`):
The directory containing the FSDP checkpoints (can be either the model or optimizer).
SunMarc marked this conversation as resolved.
Show resolved Hide resolved
output_path (`str`):
The path to save the merged checkpoint.
safe_serialization (`bool`, *optional*, defaults to `True`):
Whether to save the merged weights with safetensors (recommended).
remove_checkpoint_dir (`bool`, *optional*, defaults to `False`):
Whether to remove the checkpoint directory after merging.
"""
from accelerate.state import PartialState

# To setup `save` to work
state = PartialState()
if state.is_main_process:
logger.info(f"Merging FSDP weights from {checkpoint_dir}")
save_path = _distributed_checkpoint_to_merged_weights(checkpoint_dir, output_path, safe_serialization)
logger.info(f"Successfully merged FSDP weights and saved to {save_path}")
if remove_checkpoint_dir:
logger.info(f"Removing old checkpoint directory {checkpoint_dir}")
shutil.rmtree(checkpoint_dir)
state.wait_for_everyone()
Loading
Loading