Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[2.5] Add the hello-pt-resnet example #2955

Merged
merged 6 commits into from
Sep 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions examples/hello-world/hello-pt-resnet/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# Hello PyTorch ResNet
yhwen marked this conversation as resolved.
Show resolved Hide resolved

Example of using [NVIDIA FLARE](https://nvflare.readthedocs.io/en/main/index.html) to train an image classifier
using federated averaging ([FedAvg](https://arxiv.org/abs/1602.05629))
and [PyTorch](https://pytorch.org/) as the deep learning training framework. Comparing with the Hello PyTorch example, it uses the torchvision ResNet,
instead of the SimpleNetwork.

> **_NOTE:_** This example uses the [CIFAR-10](https://www.cs.toronto.edu/~kriz/cifar.html) dataset and will load its data within the client train code.

The Job API only supports the object instance created directly out of the Python Class. It does not support
the object instance created through using the Python function. Comparing with the hello-pt example,
if we replace the SimpleNetwork() object with the resnet18(num_classes=10),
yhwen marked this conversation as resolved.
Show resolved Hide resolved
the "resnet18(num_classes=10)" creates an torchvision "ResNet" object instance out of the "resnet18" function.
As shown in the [torchvision reset](https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py#L684-L705),
the resnet18 is a Python function, which creates and returns a ResNet object. The job API can
only use the "ResNet" object instance for generating the job config. It can not detect the object creating function logic in the "resnet18".

This example demonstrates how to wrap up the resnet18 Python function into a Resnet18 Python class. Then uses the Resnet18(num_classes=10)
object instance in the job API. After replacing the SimpleNetwork() with the Resnet18(num_classes=10),
you can follow the exact same steps in the hello-pt example to run the fedavg_script_runner_pt.py.
40 changes: 40 additions & 0 deletions examples/hello-world/hello-pt-resnet/fedavg_script_runner_pt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from src.resnet_18 import Resnet18

from nvflare.app_opt.pt.job_config.fed_avg import FedAvgJob
from nvflare.job_config.script_runner import ScriptRunner

if __name__ == "__main__":
n_clients = 2
num_rounds = 2
train_script = "src/hello-pt_cifar10_fl.py"

job = FedAvgJob(
name="hello-pt_cifar10_fedavg",
n_clients=n_clients,
num_rounds=num_rounds,
initial_model=Resnet18(num_classes=10),
)

# Add clients
for i in range(n_clients):
executor = ScriptRunner(
script=train_script, script_args="" # f"--batch_size 32 --data_path /tmp/data/site-{i}"
)
job.to(executor, f"site-{i+1}")

# job.export_job("/tmp/nvflare/jobs/job_config")
job.simulator_run("/tmp/nvflare/jobs/workdir", gpu="0")
3 changes: 3 additions & 0 deletions examples/hello-world/hello-pt-resnet/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
nvflare~=2.5.0rc
torch
torchvision
96 changes: 96 additions & 0 deletions examples/hello-world/hello-pt-resnet/src/hello-pt_cifar10_fl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os

import torch
from resnet_18 import Resnet18
from torch import nn
from torch.optim import SGD
from torch.utils.data.dataloader import DataLoader
from torchvision.datasets import CIFAR10
from torchvision.transforms import Compose, Normalize, ToTensor

import nvflare.client as flare
from nvflare.client.tracking import SummaryWriter

DATASET_PATH = "/tmp/nvflare/data"


def main():
batch_size = 4
epochs = 5
lr = 0.01
model = Resnet18(num_classes=10)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
loss = nn.CrossEntropyLoss()
optimizer = SGD(model.parameters(), lr=lr, momentum=0.9)
transforms = Compose(
[
ToTensor(),
Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]
)

flare.init()
sys_info = flare.system_info()
client_name = sys_info["site_name"]

train_dataset = CIFAR10(
root=os.path.join(DATASET_PATH, client_name), transform=transforms, download=True, train=True
)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

summary_writer = SummaryWriter()
while flare.is_running():
input_model = flare.receive()
print(f"current_round={input_model.current_round}")

model.load_state_dict(input_model.params)
model.to(device)

steps = epochs * len(train_loader)
for epoch in range(epochs):
running_loss = 0.0
for i, batch in enumerate(train_loader):
images, labels = batch[0].to(device), batch[1].to(device)
optimizer.zero_grad()

predictions = model(images)
cost = loss(predictions, labels)
cost.backward()
optimizer.step()

running_loss += cost.cpu().detach().numpy() / images.size()[0]
if i % 3000 == 0:
print(f"Epoch: {epoch}/{epochs}, Iteration: {i}, Loss: {running_loss/3000}")
global_step = input_model.current_round * steps + epoch * len(train_loader) + i
summary_writer.add_scalar(tag="loss_for_each_batch", scalar=running_loss, global_step=global_step)
running_loss = 0.0

print("Finished Training")

PATH = "./cifar_net.pth"
torch.save(model.state_dict(), PATH)

output_model = flare.FLModel(
params=model.cpu().state_dict(),
meta={"NUM_STEPS_CURRENT_ROUND": steps},
)

flare.send(output_model)


if __name__ == "__main__":
main()
33 changes: 33 additions & 0 deletions examples/hello-world/hello-pt-resnet/src/resnet_18.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Optional

from torchvision.models import ResNet
from torchvision.models._utils import _ovewrite_named_param
from torchvision.models.resnet import BasicBlock, ResNet18_Weights


class Resnet18(ResNet):
def __init__(self, num_classes, weights: Optional[ResNet18_Weights] = None, progress: bool = True, **kwargs: Any):
self.num_classes = num_classes

weights = ResNet18_Weights.verify(weights)

if weights is not None:
_ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))

super().__init__(BasicBlock, [2, 2, 2, 2], num_classes=num_classes, **kwargs)

if weights is not None:
super().load_state_dict(weights.get_state_dict(progress=progress))
Loading