-
Notifications
You must be signed in to change notification settings - Fork 182
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[2.5] Add the hello-pt-resnet example (#2955)
* Add the hello-pt-resnet example. * Removed the no use SimpleNetwork. * codestyle fix for hello-pt-resnet example. * renamed the simple_network.py -> resnet_18.py. And the resnet18 link to ReadMe. * updated license year. --------- Co-authored-by: Chester Chen <512707+chesterxgchen@users.noreply.github.com>
- Loading branch information
1 parent
9db7ce0
commit 23c471a
Showing
5 changed files
with
192 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
# Hello PyTorch ResNet | ||
|
||
Example of using [NVIDIA FLARE](https://nvflare.readthedocs.io/en/main/index.html) to train an image classifier | ||
using federated averaging ([FedAvg](https://arxiv.org/abs/1602.05629)) | ||
and [PyTorch](https://pytorch.org/) as the deep learning training framework. Comparing with the Hello PyTorch example, it uses the torchvision ResNet, | ||
instead of the SimpleNetwork. | ||
|
||
> **_NOTE:_** This example uses the [CIFAR-10](https://www.cs.toronto.edu/~kriz/cifar.html) dataset and will load its data within the client train code. | ||
The Job API only supports the object instance created directly out of the Python Class. It does not support | ||
the object instance created through using the Python function. Comparing with the hello-pt example, | ||
if we replace the SimpleNetwork() object with the resnet18(num_classes=10), | ||
the "resnet18(num_classes=10)" creates an torchvision "ResNet" object instance out of the "resnet18" function. | ||
As shown in the [torchvision reset](https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py#L684-L705), | ||
the resnet18 is a Python function, which creates and returns a ResNet object. The job API can | ||
only use the "ResNet" object instance for generating the job config. It can not detect the object creating function logic in the "resnet18". | ||
|
||
This example demonstrates how to wrap up the resnet18 Python function into a Resnet18 Python class. Then uses the Resnet18(num_classes=10) | ||
object instance in the job API. After replacing the SimpleNetwork() with the Resnet18(num_classes=10), | ||
you can follow the exact same steps in the hello-pt example to run the fedavg_script_runner_pt.py. |
40 changes: 40 additions & 0 deletions
40
examples/hello-world/hello-pt-resnet/fedavg_script_runner_pt.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from src.resnet_18 import Resnet18 | ||
|
||
from nvflare.app_opt.pt.job_config.fed_avg import FedAvgJob | ||
from nvflare.job_config.script_runner import ScriptRunner | ||
|
||
if __name__ == "__main__": | ||
n_clients = 2 | ||
num_rounds = 2 | ||
train_script = "src/hello-pt_cifar10_fl.py" | ||
|
||
job = FedAvgJob( | ||
name="hello-pt_cifar10_fedavg", | ||
n_clients=n_clients, | ||
num_rounds=num_rounds, | ||
initial_model=Resnet18(num_classes=10), | ||
) | ||
|
||
# Add clients | ||
for i in range(n_clients): | ||
executor = ScriptRunner( | ||
script=train_script, script_args="" # f"--batch_size 32 --data_path /tmp/data/site-{i}" | ||
) | ||
job.to(executor, f"site-{i+1}") | ||
|
||
# job.export_job("/tmp/nvflare/jobs/job_config") | ||
job.simulator_run("/tmp/nvflare/jobs/workdir", gpu="0") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
nvflare~=2.5.0rc | ||
torch | ||
torchvision |
96 changes: 96 additions & 0 deletions
96
examples/hello-world/hello-pt-resnet/src/hello-pt_cifar10_fl.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import os | ||
|
||
import torch | ||
from resnet_18 import Resnet18 | ||
from torch import nn | ||
from torch.optim import SGD | ||
from torch.utils.data.dataloader import DataLoader | ||
from torchvision.datasets import CIFAR10 | ||
from torchvision.transforms import Compose, Normalize, ToTensor | ||
|
||
import nvflare.client as flare | ||
from nvflare.client.tracking import SummaryWriter | ||
|
||
DATASET_PATH = "/tmp/nvflare/data" | ||
|
||
|
||
def main(): | ||
batch_size = 4 | ||
epochs = 5 | ||
lr = 0.01 | ||
model = Resnet18(num_classes=10) | ||
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | ||
loss = nn.CrossEntropyLoss() | ||
optimizer = SGD(model.parameters(), lr=lr, momentum=0.9) | ||
transforms = Compose( | ||
[ | ||
ToTensor(), | ||
Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), | ||
] | ||
) | ||
|
||
flare.init() | ||
sys_info = flare.system_info() | ||
client_name = sys_info["site_name"] | ||
|
||
train_dataset = CIFAR10( | ||
root=os.path.join(DATASET_PATH, client_name), transform=transforms, download=True, train=True | ||
) | ||
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) | ||
|
||
summary_writer = SummaryWriter() | ||
while flare.is_running(): | ||
input_model = flare.receive() | ||
print(f"current_round={input_model.current_round}") | ||
|
||
model.load_state_dict(input_model.params) | ||
model.to(device) | ||
|
||
steps = epochs * len(train_loader) | ||
for epoch in range(epochs): | ||
running_loss = 0.0 | ||
for i, batch in enumerate(train_loader): | ||
images, labels = batch[0].to(device), batch[1].to(device) | ||
optimizer.zero_grad() | ||
|
||
predictions = model(images) | ||
cost = loss(predictions, labels) | ||
cost.backward() | ||
optimizer.step() | ||
|
||
running_loss += cost.cpu().detach().numpy() / images.size()[0] | ||
if i % 3000 == 0: | ||
print(f"Epoch: {epoch}/{epochs}, Iteration: {i}, Loss: {running_loss/3000}") | ||
global_step = input_model.current_round * steps + epoch * len(train_loader) + i | ||
summary_writer.add_scalar(tag="loss_for_each_batch", scalar=running_loss, global_step=global_step) | ||
running_loss = 0.0 | ||
|
||
print("Finished Training") | ||
|
||
PATH = "./cifar_net.pth" | ||
torch.save(model.state_dict(), PATH) | ||
|
||
output_model = flare.FLModel( | ||
params=model.cpu().state_dict(), | ||
meta={"NUM_STEPS_CURRENT_ROUND": steps}, | ||
) | ||
|
||
flare.send(output_model) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
from typing import Any, Optional | ||
|
||
from torchvision.models import ResNet | ||
from torchvision.models._utils import _ovewrite_named_param | ||
from torchvision.models.resnet import BasicBlock, ResNet18_Weights | ||
|
||
|
||
class Resnet18(ResNet): | ||
def __init__(self, num_classes, weights: Optional[ResNet18_Weights] = None, progress: bool = True, **kwargs: Any): | ||
self.num_classes = num_classes | ||
|
||
weights = ResNet18_Weights.verify(weights) | ||
|
||
if weights is not None: | ||
_ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) | ||
|
||
super().__init__(BasicBlock, [2, 2, 2, 2], num_classes=num_classes, **kwargs) | ||
|
||
if weights is not None: | ||
super().load_state_dict(weights.get_state_dict(progress=progress)) |