-
Notifications
You must be signed in to change notification settings - Fork 970
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Introduce shard-merging util for FSDP (#2772)
* Initial commit * Now to test * Store false * Slight tweaks * Fix naming * Got it all working with tests * Use not for safetensors arg * rm change * Add docs * Adjust based on Marc's feedback * Specify just weights * Update tests to include CLI and swap namings * Fin * Rm unused * Rm again
- Loading branch information
Showing
9 changed files
with
321 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
#!/usr/bin/env python | ||
|
||
# Copyright 2024 The HuggingFace Team. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
from accelerate.commands.utils import CustomArgumentParser | ||
from accelerate.utils import merge_fsdp_weights | ||
|
||
|
||
description = """Utility to merge the weights from multiple FSDP checkpoints into a single combined checkpoint. Should be used if | ||
`SHARDED_STATE_DICT` was used for the model. Weights will be saved to `{output_path}`. | ||
This is a CPU-bound process and requires enough RAM to load the entire model state dict.""" | ||
|
||
|
||
def merge_command(args): | ||
merge_fsdp_weights( | ||
args.checkpoint_directory, args.output_path, not args.unsafe_serialization, args.remove_checkpoint_dir | ||
) | ||
|
||
|
||
def merge_command_parser(subparsers=None): | ||
if subparsers is not None: | ||
parser = subparsers.add_parser("merge-weights", description=description) | ||
else: | ||
parser = CustomArgumentParser(description=description) | ||
|
||
parser.add_argument("checkpoint_directory", type=str, help="A directory containing sharded weights saved by FSDP.") | ||
parser.add_argument( | ||
"output_path", | ||
type=str, | ||
help="The path to save the merged weights. Defaults to the current directory. ", | ||
) | ||
parser.add_argument( | ||
"--unsafe_serialization", | ||
action="store_false", | ||
default=True, | ||
help="Whether to save the merged weights as `.bin` rather than `.safetensors` (not recommended).", | ||
) | ||
parser.add_argument( | ||
"--remove_checkpoint_dir", | ||
action="store_true", | ||
help="Whether to remove the checkpoint directory after merging.", | ||
default=False, | ||
) | ||
|
||
if subparsers is not None: | ||
parser.set_defaults(func=merge_command) | ||
return parser | ||
|
||
|
||
def main(): | ||
parser = merge_command_parser() | ||
args = parser.parse_args() | ||
merge_command(args) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
160 changes: 160 additions & 0 deletions
160
src/accelerate/test_utils/scripts/test_merge_weights.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,160 @@ | ||
# Copyright 2024 The HuggingFace Team. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import gc | ||
import logging | ||
import shutil | ||
from pathlib import Path | ||
|
||
import torch | ||
from safetensors.torch import load_file | ||
from torch.distributed.fsdp.fully_sharded_data_parallel import ShardingStrategy, StateDictType | ||
from torch.utils.data import DataLoader | ||
|
||
from accelerate import Accelerator, FullyShardedDataParallelPlugin | ||
from accelerate.commands.merge import merge_command, merge_command_parser | ||
from accelerate.state import AcceleratorState | ||
from accelerate.test_utils.training import RegressionDataset | ||
from accelerate.utils import merge_fsdp_weights, patch_environment, save_fsdp_model | ||
|
||
|
||
logging.basicConfig(level=logging.INFO) | ||
|
||
parser = merge_command_parser() | ||
|
||
|
||
class TinyModel(torch.nn.Module): | ||
def __init__(self): | ||
super().__init__() | ||
self.linear1 = torch.nn.Linear(16, 16) | ||
self.activation = torch.nn.ReLU() | ||
self.linear2 = torch.nn.Linear(16, 16) | ||
self.softmax = torch.nn.Softmax() | ||
|
||
def forward(self, x): | ||
return self.linear2(self.activation(self.linear1(x))) | ||
|
||
|
||
def setup(): | ||
if AcceleratorState._shared_state != {}: | ||
AcceleratorState()._reset_state() | ||
plugin = FullyShardedDataParallelPlugin( | ||
sharding_strategy=ShardingStrategy.FULL_SHARD, state_dict_type=StateDictType.SHARDED_STATE_DICT | ||
) | ||
model = TinyModel() | ||
with patch_environment(fsdp_auto_wrap_policy="SIZE_BASED_WRAP"): | ||
plugin.set_auto_wrap_policy(model) | ||
accelerator = Accelerator(fsdp_plugin=plugin) | ||
model = accelerator.prepare(model) | ||
return model, plugin, accelerator | ||
|
||
|
||
def mock_training(accelerator, model): | ||
train_set = RegressionDataset(length=128, seed=42) | ||
train_dl = DataLoader(train_set, batch_size=16, shuffle=False) | ||
optimizer = torch.optim.SGD(model.parameters(), lr=0.1) | ||
|
||
train_dl, model, optimizer = accelerator.prepare(train_dl, model, optimizer) | ||
for _ in range(3): | ||
for batch in train_dl: | ||
model.zero_grad() | ||
output = model(batch["x"]) | ||
loss = torch.nn.functional.mse_loss(output, batch["y"]) | ||
accelerator.backward(loss) | ||
optimizer.step() | ||
return model | ||
|
||
|
||
def check_weights(operation, state_1, state_2): | ||
for weight_1, weight_2 in zip(state_1.values(), state_2.values()): | ||
if str(weight_1.device) != "cuda": | ||
weight_1 = weight_1.to("cuda") | ||
if str(weight_2.device) != "cuda": | ||
weight_2 = weight_2.to("cuda") | ||
if operation == "same": | ||
assert torch.allclose(weight_1, weight_2) | ||
else: | ||
assert not torch.allclose(weight_1, weight_2) | ||
|
||
|
||
def check_safetensors_weights(path, model): | ||
safe_state_dict = load_file(path / "model.safetensors") | ||
safe_loaded_model = TinyModel() | ||
check_weights("diff", model.state_dict(), safe_loaded_model.state_dict()) | ||
safe_loaded_model.load_state_dict(safe_state_dict) | ||
check_weights("same", model.state_dict(), safe_loaded_model.state_dict()) | ||
|
||
|
||
def check_pytorch_weights(path, model): | ||
nonsafe_state_dict = torch.load(path / "pytorch_model.bin") | ||
nonsafe_loaded_model = TinyModel() | ||
check_weights("diff", model.state_dict(), nonsafe_loaded_model.state_dict()) | ||
nonsafe_loaded_model.load_state_dict(nonsafe_state_dict) | ||
check_weights("same", model.state_dict(), nonsafe_loaded_model.state_dict()) | ||
|
||
|
||
def test_merge_weights_safetensors(model, path): | ||
# Should now be saved at `path/merged.safetensors` | ||
merge_fsdp_weights(path / "pytorch_model_fsdp_0", path, safe_serialization=True) | ||
check_safetensors_weights(path, model) | ||
|
||
|
||
def test_merge_weights_command_safetensors(model, path): | ||
args = parser.parse_args([str(path / "pytorch_model_fsdp_0"), str(path)]) | ||
merge_command(args) | ||
check_safetensors_weights(path, model) | ||
|
||
|
||
def test_merge_weights_pytorch(model, path): | ||
# Should now be saved at `path/merged.bin` | ||
merge_fsdp_weights(path / "pytorch_model_fsdp_0", path, safe_serialization=False) | ||
check_pytorch_weights(path, model) | ||
|
||
|
||
def test_merge_weights_command_pytorch(model, path): | ||
args = parser.parse_args([str(path / "pytorch_model_fsdp_0"), str(path), "--unsafe_serialization"]) | ||
merge_command(args) | ||
check_pytorch_weights(path, model) | ||
|
||
|
||
if __name__ == "__main__": | ||
# Note this test requires at least two accelerators! | ||
model, plugin, accelerator = setup() | ||
if accelerator.num_processes > 1: | ||
try: | ||
# Initial setup for things | ||
out_path = Path("test_merge_weights_fsdp_weights") | ||
if not out_path.exists(): | ||
out_path.mkdir(parents=True, exist_ok=True) | ||
|
||
# Train briefly once weights aren't the baseline | ||
model = mock_training(accelerator, model) | ||
accelerator.wait_for_everyone() | ||
|
||
gc.collect() # Needed for some lingering refs after training | ||
save_fsdp_model(plugin, accelerator, model, out_path) | ||
accelerator.wait_for_everyone() | ||
|
||
# Finally we can test | ||
test_merge_weights_safetensors(model, out_path) | ||
test_merge_weights_command_safetensors(model, out_path) | ||
test_merge_weights_pytorch(model, out_path) | ||
test_merge_weights_command_pytorch(model, out_path) | ||
except Exception: | ||
raise | ||
finally: | ||
# Cleanup in case of any failures | ||
if accelerator.is_main_process: | ||
shutil.rmtree(out_path) | ||
accelerator.wait_for_everyone() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.