From 74440970a073be9be9ceea4b4055cd07f249d323 Mon Sep 17 00:00:00 2001 From: Yuhong Wen Date: Fri, 20 Sep 2024 10:48:54 -0400 Subject: [PATCH 1/5] Add the hello-pt-resnet example. --- .../hello-world/hello-pt-resnet/README.md | 18 ++++ .../fedavg_script_runner_pt.py | 37 +++++++ .../hello-pt-resnet/requirements.txt | 3 + .../src/hello-pt_cifar10_fl.py | 96 +++++++++++++++++++ .../hello-pt-resnet/src/simple_network.py | 56 +++++++++++ 5 files changed, 210 insertions(+) create mode 100644 examples/hello-world/hello-pt-resnet/README.md create mode 100644 examples/hello-world/hello-pt-resnet/fedavg_script_runner_pt.py create mode 100644 examples/hello-world/hello-pt-resnet/requirements.txt create mode 100644 examples/hello-world/hello-pt-resnet/src/hello-pt_cifar10_fl.py create mode 100644 examples/hello-world/hello-pt-resnet/src/simple_network.py diff --git a/examples/hello-world/hello-pt-resnet/README.md b/examples/hello-world/hello-pt-resnet/README.md new file mode 100644 index 0000000000..f023e81013 --- /dev/null +++ b/examples/hello-world/hello-pt-resnet/README.md @@ -0,0 +1,18 @@ +# Hello PyTorch ResNet + +Example of using [NVIDIA FLARE](https://nvflare.readthedocs.io/en/main/index.html) to train an image classifier +using federated averaging ([FedAvg](https://arxiv.org/abs/1602.05629)) +and [PyTorch](https://pytorch.org/) as the deep learning training framework. Comparing with the Hello PyTorch example, it uses the torchvision ResNet, +instead of the SimpleNetwork. + +> **_NOTE:_** This example uses the [CIFAR-10](https://www.cs.toronto.edu/~kriz/cifar.html) dataset and will load its data within the client train code. + +The Job API only supports the object instance created directly out of the Python Class. It does not support +the object instance created through using the Python function. Comparing with the hello-pt example, +if we replace the SimpleNetwork() object with the resnet18(num_classes=10), +the "resnet18(num_classes=10)" creates an torchvision "ResNet" object instance out of the "resnet18" function. The job API can +only use the "ResNet" object instance for generating the job config. It can not detect the object creating function logic in the "resnet18". + +This example demonstrates how to wrap up the resnet18 Python function into a Resnet18 Python class. Then uses the Resnet18(num_classes=10) +object instance in the job API. After replacing the SimpleNetwork() with the Resnet18(num_classes=10), +you can follow the exact same steps in the hello-pt example to run the fedavg_script_runner_pt.py. diff --git a/examples/hello-world/hello-pt-resnet/fedavg_script_runner_pt.py b/examples/hello-world/hello-pt-resnet/fedavg_script_runner_pt.py new file mode 100644 index 0000000000..2442d1d151 --- /dev/null +++ b/examples/hello-world/hello-pt-resnet/fedavg_script_runner_pt.py @@ -0,0 +1,37 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from nvflare.app_opt.pt.job_config.fed_avg import FedAvgJob +from nvflare.job_config.script_runner import ScriptRunner +from src.simple_network import Resnet18 + +if __name__ == "__main__": + n_clients = 2 + num_rounds = 2 + train_script = "src/hello-pt_cifar10_fl.py" + + job = FedAvgJob( + name="hello-pt_cifar10_fedavg", n_clients=n_clients, num_rounds=num_rounds, + initial_model=Resnet18(num_classes=10) + ) + + # Add clients + for i in range(n_clients): + executor = ScriptRunner( + script=train_script, script_args="" # f"--batch_size 32 --data_path /tmp/data/site-{i}" + ) + job.to(executor, f"site-{i+1}") + + # job.export_job("/tmp/nvflare/jobs/job_config") + job.simulator_run("/tmp/nvflare/jobs/workdir", gpu="0") diff --git a/examples/hello-world/hello-pt-resnet/requirements.txt b/examples/hello-world/hello-pt-resnet/requirements.txt new file mode 100644 index 0000000000..919cc32ba2 --- /dev/null +++ b/examples/hello-world/hello-pt-resnet/requirements.txt @@ -0,0 +1,3 @@ +nvflare~=2.5.0rc +torch +torchvision diff --git a/examples/hello-world/hello-pt-resnet/src/hello-pt_cifar10_fl.py b/examples/hello-world/hello-pt-resnet/src/hello-pt_cifar10_fl.py new file mode 100644 index 0000000000..2a68cc511c --- /dev/null +++ b/examples/hello-world/hello-pt-resnet/src/hello-pt_cifar10_fl.py @@ -0,0 +1,96 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +import torch +from torch import nn +from torch.optim import SGD +from torch.utils.data.dataloader import DataLoader +from torchvision.datasets import CIFAR10 +from torchvision.transforms import Compose, Normalize, ToTensor + +import nvflare.client as flare +from nvflare.client.tracking import SummaryWriter +from simple_network import Resnet18 + +DATASET_PATH = "/tmp/nvflare/data" + + +def main(): + batch_size = 4 + epochs = 5 + lr = 0.01 + model = Resnet18(num_classes=10) + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + loss = nn.CrossEntropyLoss() + optimizer = SGD(model.parameters(), lr=lr, momentum=0.9) + transforms = Compose( + [ + ToTensor(), + Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), + ] + ) + + flare.init() + sys_info = flare.system_info() + client_name = sys_info["site_name"] + + train_dataset = CIFAR10( + root=os.path.join(DATASET_PATH, client_name), transform=transforms, download=True, train=True + ) + train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) + + summary_writer = SummaryWriter() + while flare.is_running(): + input_model = flare.receive() + print(f"current_round={input_model.current_round}") + + model.load_state_dict(input_model.params) + model.to(device) + + steps = epochs * len(train_loader) + for epoch in range(epochs): + running_loss = 0.0 + for i, batch in enumerate(train_loader): + images, labels = batch[0].to(device), batch[1].to(device) + optimizer.zero_grad() + + predictions = model(images) + cost = loss(predictions, labels) + cost.backward() + optimizer.step() + + running_loss += cost.cpu().detach().numpy() / images.size()[0] + if i % 3000 == 0: + print(f"Epoch: {epoch}/{epochs}, Iteration: {i}, Loss: {running_loss/3000}") + global_step = input_model.current_round * steps + epoch * len(train_loader) + i + summary_writer.add_scalar(tag="loss_for_each_batch", scalar=running_loss, global_step=global_step) + running_loss = 0.0 + + print("Finished Training") + + PATH = "./cifar_net.pth" + torch.save(model.state_dict(), PATH) + + output_model = flare.FLModel( + params=model.cpu().state_dict(), + meta={"NUM_STEPS_CURRENT_ROUND": steps}, + ) + + flare.send(output_model) + + +if __name__ == "__main__": + main() diff --git a/examples/hello-world/hello-pt-resnet/src/simple_network.py b/examples/hello-world/hello-pt-resnet/src/simple_network.py new file mode 100644 index 0000000000..4e2e1cfedd --- /dev/null +++ b/examples/hello-world/hello-pt-resnet/src/simple_network.py @@ -0,0 +1,56 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Optional, Any + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torchvision.models import ResNet +from torchvision.models._utils import _ovewrite_named_param +from torchvision.models.resnet import BasicBlock, ResNet18_Weights + + +class SimpleNetwork(nn.Module): + def __init__(self): + super(SimpleNetwork, self).__init__() + self.conv1 = nn.Conv2d(3, 6, 5) + self.pool = nn.MaxPool2d(2, 2) + self.conv2 = nn.Conv2d(6, 16, 5) + self.fc1 = nn.Linear(16 * 5 * 5, 120) + self.fc2 = nn.Linear(120, 84) + self.fc3 = nn.Linear(84, 10) + + def forward(self, x): + x = self.pool(F.relu(self.conv1(x))) + x = self.pool(F.relu(self.conv2(x))) + x = torch.flatten(x, 1) # flatten all dimensions except batch + x = F.relu(self.fc1(x)) + x = F.relu(self.fc2(x)) + x = self.fc3(x) + return x + +class Resnet18(ResNet): + + def __init__(self, num_classes, weights: Optional[ResNet18_Weights] = None, progress: bool = True, **kwargs: Any): + self.num_classes = num_classes + + weights = ResNet18_Weights.verify(weights) + + if weights is not None: + _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) + + super().__init__(BasicBlock, [2, 2, 2, 2], num_classes=num_classes, **kwargs) + + if weights is not None: + super().load_state_dict(weights.get_state_dict(progress=progress)) From eb94997091150fadbeb8dff718af8c6494110d54 Mon Sep 17 00:00:00 2001 From: Yuhong Wen Date: Fri, 20 Sep 2024 10:54:25 -0400 Subject: [PATCH 2/5] Removed the no use SimpleNetwork. --- .../hello-pt-resnet/src/simple_network.py | 22 ------------------- 1 file changed, 22 deletions(-) diff --git a/examples/hello-world/hello-pt-resnet/src/simple_network.py b/examples/hello-world/hello-pt-resnet/src/simple_network.py index 4e2e1cfedd..988b52ee2a 100644 --- a/examples/hello-world/hello-pt-resnet/src/simple_network.py +++ b/examples/hello-world/hello-pt-resnet/src/simple_network.py @@ -13,33 +13,11 @@ # limitations under the License. from typing import Optional, Any -import torch -import torch.nn as nn -import torch.nn.functional as F from torchvision.models import ResNet from torchvision.models._utils import _ovewrite_named_param from torchvision.models.resnet import BasicBlock, ResNet18_Weights -class SimpleNetwork(nn.Module): - def __init__(self): - super(SimpleNetwork, self).__init__() - self.conv1 = nn.Conv2d(3, 6, 5) - self.pool = nn.MaxPool2d(2, 2) - self.conv2 = nn.Conv2d(6, 16, 5) - self.fc1 = nn.Linear(16 * 5 * 5, 120) - self.fc2 = nn.Linear(120, 84) - self.fc3 = nn.Linear(84, 10) - - def forward(self, x): - x = self.pool(F.relu(self.conv1(x))) - x = self.pool(F.relu(self.conv2(x))) - x = torch.flatten(x, 1) # flatten all dimensions except batch - x = F.relu(self.fc1(x)) - x = F.relu(self.fc2(x)) - x = self.fc3(x) - return x - class Resnet18(ResNet): def __init__(self, num_classes, weights: Optional[ResNet18_Weights] = None, progress: bool = True, **kwargs: Any): From ca25de8c1208f927b05f0b228d6f7d0ea124fe28 Mon Sep 17 00:00:00 2001 From: Yuhong Wen Date: Fri, 20 Sep 2024 11:10:38 -0400 Subject: [PATCH 3/5] codestyle fix for hello-pt-resnet example. --- .../hello-pt-resnet/fedavg_script_runner_pt.py | 9 ++++++--- .../hello-pt-resnet/src/hello-pt_cifar10_fl.py | 2 +- .../hello-world/hello-pt-resnet/src/simple_network.py | 3 +-- 3 files changed, 8 insertions(+), 6 deletions(-) diff --git a/examples/hello-world/hello-pt-resnet/fedavg_script_runner_pt.py b/examples/hello-world/hello-pt-resnet/fedavg_script_runner_pt.py index 2442d1d151..7d1788014b 100644 --- a/examples/hello-world/hello-pt-resnet/fedavg_script_runner_pt.py +++ b/examples/hello-world/hello-pt-resnet/fedavg_script_runner_pt.py @@ -12,9 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +from src.simple_network import Resnet18 + from nvflare.app_opt.pt.job_config.fed_avg import FedAvgJob from nvflare.job_config.script_runner import ScriptRunner -from src.simple_network import Resnet18 if __name__ == "__main__": n_clients = 2 @@ -22,8 +23,10 @@ train_script = "src/hello-pt_cifar10_fl.py" job = FedAvgJob( - name="hello-pt_cifar10_fedavg", n_clients=n_clients, num_rounds=num_rounds, - initial_model=Resnet18(num_classes=10) + name="hello-pt_cifar10_fedavg", + n_clients=n_clients, + num_rounds=num_rounds, + initial_model=Resnet18(num_classes=10), ) # Add clients diff --git a/examples/hello-world/hello-pt-resnet/src/hello-pt_cifar10_fl.py b/examples/hello-world/hello-pt-resnet/src/hello-pt_cifar10_fl.py index 2a68cc511c..20d43fa574 100644 --- a/examples/hello-world/hello-pt-resnet/src/hello-pt_cifar10_fl.py +++ b/examples/hello-world/hello-pt-resnet/src/hello-pt_cifar10_fl.py @@ -15,6 +15,7 @@ import os import torch +from simple_network import Resnet18 from torch import nn from torch.optim import SGD from torch.utils.data.dataloader import DataLoader @@ -23,7 +24,6 @@ import nvflare.client as flare from nvflare.client.tracking import SummaryWriter -from simple_network import Resnet18 DATASET_PATH = "/tmp/nvflare/data" diff --git a/examples/hello-world/hello-pt-resnet/src/simple_network.py b/examples/hello-world/hello-pt-resnet/src/simple_network.py index 988b52ee2a..3420fdd741 100644 --- a/examples/hello-world/hello-pt-resnet/src/simple_network.py +++ b/examples/hello-world/hello-pt-resnet/src/simple_network.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Any +from typing import Any, Optional from torchvision.models import ResNet from torchvision.models._utils import _ovewrite_named_param @@ -19,7 +19,6 @@ class Resnet18(ResNet): - def __init__(self, num_classes, weights: Optional[ResNet18_Weights] = None, progress: bool = True, **kwargs: Any): self.num_classes = num_classes From 16d1209703b77a7a641b42fba35b37714438e0d8 Mon Sep 17 00:00:00 2001 From: Yuhong Wen Date: Thu, 26 Sep 2024 15:20:34 -0400 Subject: [PATCH 4/5] renamed the simple_network.py -> resnet_18.py. And the resnet18 link to ReadMe. --- examples/hello-world/hello-pt-resnet/README.md | 4 +++- .../hello-world/hello-pt-resnet/fedavg_script_runner_pt.py | 2 +- .../hello-world/hello-pt-resnet/src/hello-pt_cifar10_fl.py | 2 +- .../hello-pt-resnet/src/{simple_network.py => resnet_18.py} | 0 4 files changed, 5 insertions(+), 3 deletions(-) rename examples/hello-world/hello-pt-resnet/src/{simple_network.py => resnet_18.py} (100%) diff --git a/examples/hello-world/hello-pt-resnet/README.md b/examples/hello-world/hello-pt-resnet/README.md index f023e81013..4b1a0f51a6 100644 --- a/examples/hello-world/hello-pt-resnet/README.md +++ b/examples/hello-world/hello-pt-resnet/README.md @@ -10,7 +10,9 @@ instead of the SimpleNetwork. The Job API only supports the object instance created directly out of the Python Class. It does not support the object instance created through using the Python function. Comparing with the hello-pt example, if we replace the SimpleNetwork() object with the resnet18(num_classes=10), -the "resnet18(num_classes=10)" creates an torchvision "ResNet" object instance out of the "resnet18" function. The job API can +the "resnet18(num_classes=10)" creates an torchvision "ResNet" object instance out of the "resnet18" function. +As shown in the [torchvision reset](https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py#L684-L705), +the resnet18 is a Python function, which creates and returns a ResNet object. The job API can only use the "ResNet" object instance for generating the job config. It can not detect the object creating function logic in the "resnet18". This example demonstrates how to wrap up the resnet18 Python function into a Resnet18 Python class. Then uses the Resnet18(num_classes=10) diff --git a/examples/hello-world/hello-pt-resnet/fedavg_script_runner_pt.py b/examples/hello-world/hello-pt-resnet/fedavg_script_runner_pt.py index 7d1788014b..a4955ed5a0 100644 --- a/examples/hello-world/hello-pt-resnet/fedavg_script_runner_pt.py +++ b/examples/hello-world/hello-pt-resnet/fedavg_script_runner_pt.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from src.simple_network import Resnet18 +from src.resnet_18 import Resnet18 from nvflare.app_opt.pt.job_config.fed_avg import FedAvgJob from nvflare.job_config.script_runner import ScriptRunner diff --git a/examples/hello-world/hello-pt-resnet/src/hello-pt_cifar10_fl.py b/examples/hello-world/hello-pt-resnet/src/hello-pt_cifar10_fl.py index 20d43fa574..cfd12ba947 100644 --- a/examples/hello-world/hello-pt-resnet/src/hello-pt_cifar10_fl.py +++ b/examples/hello-world/hello-pt-resnet/src/hello-pt_cifar10_fl.py @@ -15,7 +15,7 @@ import os import torch -from simple_network import Resnet18 +from resnet_18 import Resnet18 from torch import nn from torch.optim import SGD from torch.utils.data.dataloader import DataLoader diff --git a/examples/hello-world/hello-pt-resnet/src/simple_network.py b/examples/hello-world/hello-pt-resnet/src/resnet_18.py similarity index 100% rename from examples/hello-world/hello-pt-resnet/src/simple_network.py rename to examples/hello-world/hello-pt-resnet/src/resnet_18.py From 1aad21de410c1a6033bd9aed51e740ce66d084fa Mon Sep 17 00:00:00 2001 From: Yuhong Wen Date: Thu, 26 Sep 2024 15:22:17 -0400 Subject: [PATCH 5/5] updated license year. --- examples/hello-world/hello-pt-resnet/src/hello-pt_cifar10_fl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/hello-world/hello-pt-resnet/src/hello-pt_cifar10_fl.py b/examples/hello-world/hello-pt-resnet/src/hello-pt_cifar10_fl.py index cfd12ba947..7395466c68 100644 --- a/examples/hello-world/hello-pt-resnet/src/hello-pt_cifar10_fl.py +++ b/examples/hello-world/hello-pt-resnet/src/hello-pt_cifar10_fl.py @@ -1,4 +1,4 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License.