From 1f111f12c8cbf41fa437308b3fa7c4ae375745b1 Mon Sep 17 00:00:00 2001 From: Yuhong Wen Date: Fri, 20 Sep 2024 10:48:54 -0400 Subject: [PATCH 1/8] Add the hello-pt-resnet example. --- .../hello-world/hello-pt-resnet/README.md | 18 ++++ .../fedavg_script_runner_pt.py | 37 +++++++ .../hello-pt-resnet/requirements.txt | 3 + .../src/hello-pt_cifar10_fl.py | 96 +++++++++++++++++++ .../hello-pt-resnet/src/simple_network.py | 56 +++++++++++ 5 files changed, 210 insertions(+) create mode 100644 examples/hello-world/hello-pt-resnet/README.md create mode 100644 examples/hello-world/hello-pt-resnet/fedavg_script_runner_pt.py create mode 100644 examples/hello-world/hello-pt-resnet/requirements.txt create mode 100644 examples/hello-world/hello-pt-resnet/src/hello-pt_cifar10_fl.py create mode 100644 examples/hello-world/hello-pt-resnet/src/simple_network.py diff --git a/examples/hello-world/hello-pt-resnet/README.md b/examples/hello-world/hello-pt-resnet/README.md new file mode 100644 index 0000000000..f023e81013 --- /dev/null +++ b/examples/hello-world/hello-pt-resnet/README.md @@ -0,0 +1,18 @@ +# Hello PyTorch ResNet + +Example of using [NVIDIA FLARE](https://nvflare.readthedocs.io/en/main/index.html) to train an image classifier +using federated averaging ([FedAvg](https://arxiv.org/abs/1602.05629)) +and [PyTorch](https://pytorch.org/) as the deep learning training framework. Comparing with the Hello PyTorch example, it uses the torchvision ResNet, +instead of the SimpleNetwork. + +> **_NOTE:_** This example uses the [CIFAR-10](https://www.cs.toronto.edu/~kriz/cifar.html) dataset and will load its data within the client train code. + +The Job API only supports the object instance created directly out of the Python Class. It does not support +the object instance created through using the Python function. Comparing with the hello-pt example, +if we replace the SimpleNetwork() object with the resnet18(num_classes=10), +the "resnet18(num_classes=10)" creates an torchvision "ResNet" object instance out of the "resnet18" function. The job API can +only use the "ResNet" object instance for generating the job config. It can not detect the object creating function logic in the "resnet18". + +This example demonstrates how to wrap up the resnet18 Python function into a Resnet18 Python class. Then uses the Resnet18(num_classes=10) +object instance in the job API. After replacing the SimpleNetwork() with the Resnet18(num_classes=10), +you can follow the exact same steps in the hello-pt example to run the fedavg_script_runner_pt.py. diff --git a/examples/hello-world/hello-pt-resnet/fedavg_script_runner_pt.py b/examples/hello-world/hello-pt-resnet/fedavg_script_runner_pt.py new file mode 100644 index 0000000000..2442d1d151 --- /dev/null +++ b/examples/hello-world/hello-pt-resnet/fedavg_script_runner_pt.py @@ -0,0 +1,37 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from nvflare.app_opt.pt.job_config.fed_avg import FedAvgJob +from nvflare.job_config.script_runner import ScriptRunner +from src.simple_network import Resnet18 + +if __name__ == "__main__": + n_clients = 2 + num_rounds = 2 + train_script = "src/hello-pt_cifar10_fl.py" + + job = FedAvgJob( + name="hello-pt_cifar10_fedavg", n_clients=n_clients, num_rounds=num_rounds, + initial_model=Resnet18(num_classes=10) + ) + + # Add clients + for i in range(n_clients): + executor = ScriptRunner( + script=train_script, script_args="" # f"--batch_size 32 --data_path /tmp/data/site-{i}" + ) + job.to(executor, f"site-{i+1}") + + # job.export_job("/tmp/nvflare/jobs/job_config") + job.simulator_run("/tmp/nvflare/jobs/workdir", gpu="0") diff --git a/examples/hello-world/hello-pt-resnet/requirements.txt b/examples/hello-world/hello-pt-resnet/requirements.txt new file mode 100644 index 0000000000..919cc32ba2 --- /dev/null +++ b/examples/hello-world/hello-pt-resnet/requirements.txt @@ -0,0 +1,3 @@ +nvflare~=2.5.0rc +torch +torchvision diff --git a/examples/hello-world/hello-pt-resnet/src/hello-pt_cifar10_fl.py b/examples/hello-world/hello-pt-resnet/src/hello-pt_cifar10_fl.py new file mode 100644 index 0000000000..2a68cc511c --- /dev/null +++ b/examples/hello-world/hello-pt-resnet/src/hello-pt_cifar10_fl.py @@ -0,0 +1,96 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +import torch +from torch import nn +from torch.optim import SGD +from torch.utils.data.dataloader import DataLoader +from torchvision.datasets import CIFAR10 +from torchvision.transforms import Compose, Normalize, ToTensor + +import nvflare.client as flare +from nvflare.client.tracking import SummaryWriter +from simple_network import Resnet18 + +DATASET_PATH = "/tmp/nvflare/data" + + +def main(): + batch_size = 4 + epochs = 5 + lr = 0.01 + model = Resnet18(num_classes=10) + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + loss = nn.CrossEntropyLoss() + optimizer = SGD(model.parameters(), lr=lr, momentum=0.9) + transforms = Compose( + [ + ToTensor(), + Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), + ] + ) + + flare.init() + sys_info = flare.system_info() + client_name = sys_info["site_name"] + + train_dataset = CIFAR10( + root=os.path.join(DATASET_PATH, client_name), transform=transforms, download=True, train=True + ) + train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) + + summary_writer = SummaryWriter() + while flare.is_running(): + input_model = flare.receive() + print(f"current_round={input_model.current_round}") + + model.load_state_dict(input_model.params) + model.to(device) + + steps = epochs * len(train_loader) + for epoch in range(epochs): + running_loss = 0.0 + for i, batch in enumerate(train_loader): + images, labels = batch[0].to(device), batch[1].to(device) + optimizer.zero_grad() + + predictions = model(images) + cost = loss(predictions, labels) + cost.backward() + optimizer.step() + + running_loss += cost.cpu().detach().numpy() / images.size()[0] + if i % 3000 == 0: + print(f"Epoch: {epoch}/{epochs}, Iteration: {i}, Loss: {running_loss/3000}") + global_step = input_model.current_round * steps + epoch * len(train_loader) + i + summary_writer.add_scalar(tag="loss_for_each_batch", scalar=running_loss, global_step=global_step) + running_loss = 0.0 + + print("Finished Training") + + PATH = "./cifar_net.pth" + torch.save(model.state_dict(), PATH) + + output_model = flare.FLModel( + params=model.cpu().state_dict(), + meta={"NUM_STEPS_CURRENT_ROUND": steps}, + ) + + flare.send(output_model) + + +if __name__ == "__main__": + main() diff --git a/examples/hello-world/hello-pt-resnet/src/simple_network.py b/examples/hello-world/hello-pt-resnet/src/simple_network.py new file mode 100644 index 0000000000..4e2e1cfedd --- /dev/null +++ b/examples/hello-world/hello-pt-resnet/src/simple_network.py @@ -0,0 +1,56 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Optional, Any + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torchvision.models import ResNet +from torchvision.models._utils import _ovewrite_named_param +from torchvision.models.resnet import BasicBlock, ResNet18_Weights + + +class SimpleNetwork(nn.Module): + def __init__(self): + super(SimpleNetwork, self).__init__() + self.conv1 = nn.Conv2d(3, 6, 5) + self.pool = nn.MaxPool2d(2, 2) + self.conv2 = nn.Conv2d(6, 16, 5) + self.fc1 = nn.Linear(16 * 5 * 5, 120) + self.fc2 = nn.Linear(120, 84) + self.fc3 = nn.Linear(84, 10) + + def forward(self, x): + x = self.pool(F.relu(self.conv1(x))) + x = self.pool(F.relu(self.conv2(x))) + x = torch.flatten(x, 1) # flatten all dimensions except batch + x = F.relu(self.fc1(x)) + x = F.relu(self.fc2(x)) + x = self.fc3(x) + return x + +class Resnet18(ResNet): + + def __init__(self, num_classes, weights: Optional[ResNet18_Weights] = None, progress: bool = True, **kwargs: Any): + self.num_classes = num_classes + + weights = ResNet18_Weights.verify(weights) + + if weights is not None: + _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) + + super().__init__(BasicBlock, [2, 2, 2, 2], num_classes=num_classes, **kwargs) + + if weights is not None: + super().load_state_dict(weights.get_state_dict(progress=progress)) From d5c210a19ec2bc6caffb363af5a4a507024461ff Mon Sep 17 00:00:00 2001 From: Yuhong Wen Date: Fri, 20 Sep 2024 10:54:25 -0400 Subject: [PATCH 2/8] Removed the no use SimpleNetwork. --- .../hello-pt-resnet/src/simple_network.py | 22 ------------------- 1 file changed, 22 deletions(-) diff --git a/examples/hello-world/hello-pt-resnet/src/simple_network.py b/examples/hello-world/hello-pt-resnet/src/simple_network.py index 4e2e1cfedd..988b52ee2a 100644 --- a/examples/hello-world/hello-pt-resnet/src/simple_network.py +++ b/examples/hello-world/hello-pt-resnet/src/simple_network.py @@ -13,33 +13,11 @@ # limitations under the License. from typing import Optional, Any -import torch -import torch.nn as nn -import torch.nn.functional as F from torchvision.models import ResNet from torchvision.models._utils import _ovewrite_named_param from torchvision.models.resnet import BasicBlock, ResNet18_Weights -class SimpleNetwork(nn.Module): - def __init__(self): - super(SimpleNetwork, self).__init__() - self.conv1 = nn.Conv2d(3, 6, 5) - self.pool = nn.MaxPool2d(2, 2) - self.conv2 = nn.Conv2d(6, 16, 5) - self.fc1 = nn.Linear(16 * 5 * 5, 120) - self.fc2 = nn.Linear(120, 84) - self.fc3 = nn.Linear(84, 10) - - def forward(self, x): - x = self.pool(F.relu(self.conv1(x))) - x = self.pool(F.relu(self.conv2(x))) - x = torch.flatten(x, 1) # flatten all dimensions except batch - x = F.relu(self.fc1(x)) - x = F.relu(self.fc2(x)) - x = self.fc3(x) - return x - class Resnet18(ResNet): def __init__(self, num_classes, weights: Optional[ResNet18_Weights] = None, progress: bool = True, **kwargs: Any): From e442d1be4875649288e3e6dfe59f98e8ddaf896e Mon Sep 17 00:00:00 2001 From: Yuhong Wen Date: Fri, 20 Sep 2024 11:10:38 -0400 Subject: [PATCH 3/8] codestyle fix for hello-pt-resnet example. --- .../hello-pt-resnet/fedavg_script_runner_pt.py | 9 ++++++--- .../hello-pt-resnet/src/hello-pt_cifar10_fl.py | 2 +- .../hello-world/hello-pt-resnet/src/simple_network.py | 3 +-- 3 files changed, 8 insertions(+), 6 deletions(-) diff --git a/examples/hello-world/hello-pt-resnet/fedavg_script_runner_pt.py b/examples/hello-world/hello-pt-resnet/fedavg_script_runner_pt.py index 2442d1d151..7d1788014b 100644 --- a/examples/hello-world/hello-pt-resnet/fedavg_script_runner_pt.py +++ b/examples/hello-world/hello-pt-resnet/fedavg_script_runner_pt.py @@ -12,9 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +from src.simple_network import Resnet18 + from nvflare.app_opt.pt.job_config.fed_avg import FedAvgJob from nvflare.job_config.script_runner import ScriptRunner -from src.simple_network import Resnet18 if __name__ == "__main__": n_clients = 2 @@ -22,8 +23,10 @@ train_script = "src/hello-pt_cifar10_fl.py" job = FedAvgJob( - name="hello-pt_cifar10_fedavg", n_clients=n_clients, num_rounds=num_rounds, - initial_model=Resnet18(num_classes=10) + name="hello-pt_cifar10_fedavg", + n_clients=n_clients, + num_rounds=num_rounds, + initial_model=Resnet18(num_classes=10), ) # Add clients diff --git a/examples/hello-world/hello-pt-resnet/src/hello-pt_cifar10_fl.py b/examples/hello-world/hello-pt-resnet/src/hello-pt_cifar10_fl.py index 2a68cc511c..20d43fa574 100644 --- a/examples/hello-world/hello-pt-resnet/src/hello-pt_cifar10_fl.py +++ b/examples/hello-world/hello-pt-resnet/src/hello-pt_cifar10_fl.py @@ -15,6 +15,7 @@ import os import torch +from simple_network import Resnet18 from torch import nn from torch.optim import SGD from torch.utils.data.dataloader import DataLoader @@ -23,7 +24,6 @@ import nvflare.client as flare from nvflare.client.tracking import SummaryWriter -from simple_network import Resnet18 DATASET_PATH = "/tmp/nvflare/data" diff --git a/examples/hello-world/hello-pt-resnet/src/simple_network.py b/examples/hello-world/hello-pt-resnet/src/simple_network.py index 988b52ee2a..3420fdd741 100644 --- a/examples/hello-world/hello-pt-resnet/src/simple_network.py +++ b/examples/hello-world/hello-pt-resnet/src/simple_network.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Any +from typing import Any, Optional from torchvision.models import ResNet from torchvision.models._utils import _ovewrite_named_param @@ -19,7 +19,6 @@ class Resnet18(ResNet): - def __init__(self, num_classes, weights: Optional[ResNet18_Weights] = None, progress: bool = True, **kwargs: Any): self.num_classes = num_classes From 104dcdc0e99a1e412c88d33f028f4f88b2310bba Mon Sep 17 00:00:00 2001 From: Yuhong Wen Date: Thu, 26 Sep 2024 15:20:34 -0400 Subject: [PATCH 4/8] renamed the simple_network.py -> resnet_18.py. And the resnet18 link to ReadMe. --- examples/hello-world/hello-pt-resnet/README.md | 4 +++- .../hello-world/hello-pt-resnet/fedavg_script_runner_pt.py | 2 +- .../hello-world/hello-pt-resnet/src/hello-pt_cifar10_fl.py | 2 +- .../hello-pt-resnet/src/{simple_network.py => resnet_18.py} | 0 4 files changed, 5 insertions(+), 3 deletions(-) rename examples/hello-world/hello-pt-resnet/src/{simple_network.py => resnet_18.py} (100%) diff --git a/examples/hello-world/hello-pt-resnet/README.md b/examples/hello-world/hello-pt-resnet/README.md index f023e81013..4b1a0f51a6 100644 --- a/examples/hello-world/hello-pt-resnet/README.md +++ b/examples/hello-world/hello-pt-resnet/README.md @@ -10,7 +10,9 @@ instead of the SimpleNetwork. The Job API only supports the object instance created directly out of the Python Class. It does not support the object instance created through using the Python function. Comparing with the hello-pt example, if we replace the SimpleNetwork() object with the resnet18(num_classes=10), -the "resnet18(num_classes=10)" creates an torchvision "ResNet" object instance out of the "resnet18" function. The job API can +the "resnet18(num_classes=10)" creates an torchvision "ResNet" object instance out of the "resnet18" function. +As shown in the [torchvision reset](https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py#L684-L705), +the resnet18 is a Python function, which creates and returns a ResNet object. The job API can only use the "ResNet" object instance for generating the job config. It can not detect the object creating function logic in the "resnet18". This example demonstrates how to wrap up the resnet18 Python function into a Resnet18 Python class. Then uses the Resnet18(num_classes=10) diff --git a/examples/hello-world/hello-pt-resnet/fedavg_script_runner_pt.py b/examples/hello-world/hello-pt-resnet/fedavg_script_runner_pt.py index 7d1788014b..a4955ed5a0 100644 --- a/examples/hello-world/hello-pt-resnet/fedavg_script_runner_pt.py +++ b/examples/hello-world/hello-pt-resnet/fedavg_script_runner_pt.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from src.simple_network import Resnet18 +from src.resnet_18 import Resnet18 from nvflare.app_opt.pt.job_config.fed_avg import FedAvgJob from nvflare.job_config.script_runner import ScriptRunner diff --git a/examples/hello-world/hello-pt-resnet/src/hello-pt_cifar10_fl.py b/examples/hello-world/hello-pt-resnet/src/hello-pt_cifar10_fl.py index 20d43fa574..cfd12ba947 100644 --- a/examples/hello-world/hello-pt-resnet/src/hello-pt_cifar10_fl.py +++ b/examples/hello-world/hello-pt-resnet/src/hello-pt_cifar10_fl.py @@ -15,7 +15,7 @@ import os import torch -from simple_network import Resnet18 +from resnet_18 import Resnet18 from torch import nn from torch.optim import SGD from torch.utils.data.dataloader import DataLoader diff --git a/examples/hello-world/hello-pt-resnet/src/simple_network.py b/examples/hello-world/hello-pt-resnet/src/resnet_18.py similarity index 100% rename from examples/hello-world/hello-pt-resnet/src/simple_network.py rename to examples/hello-world/hello-pt-resnet/src/resnet_18.py From 5a50e13920e29e2b3a28a72a31bcb0b910983611 Mon Sep 17 00:00:00 2001 From: Yuhong Wen Date: Thu, 26 Sep 2024 15:22:17 -0400 Subject: [PATCH 5/8] updated license year. --- examples/hello-world/hello-pt-resnet/src/hello-pt_cifar10_fl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/hello-world/hello-pt-resnet/src/hello-pt_cifar10_fl.py b/examples/hello-world/hello-pt-resnet/src/hello-pt_cifar10_fl.py index cfd12ba947..7395466c68 100644 --- a/examples/hello-world/hello-pt-resnet/src/hello-pt_cifar10_fl.py +++ b/examples/hello-world/hello-pt-resnet/src/hello-pt_cifar10_fl.py @@ -1,4 +1,4 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. From 68266b5963a98cab09096ed62572afff77b1e26c Mon Sep 17 00:00:00 2001 From: Yuhong Wen Date: Thu, 26 Sep 2024 17:00:24 -0400 Subject: [PATCH 6/8] codestyle fix. --- .../fedavg_script_runner_pt.py | 3 ++- .../src/hello-pt_cifar10_fl.py | 21 +++++++++++++++---- 2 files changed, 19 insertions(+), 5 deletions(-) diff --git a/examples/hello-world/hello-pt-resnet/fedavg_script_runner_pt.py b/examples/hello-world/hello-pt-resnet/fedavg_script_runner_pt.py index a4955ed5a0..948d64d519 100644 --- a/examples/hello-world/hello-pt-resnet/fedavg_script_runner_pt.py +++ b/examples/hello-world/hello-pt-resnet/fedavg_script_runner_pt.py @@ -32,7 +32,8 @@ # Add clients for i in range(n_clients): executor = ScriptRunner( - script=train_script, script_args="" # f"--batch_size 32 --data_path /tmp/data/site-{i}" + script=train_script, + script_args="", # f"--batch_size 32 --data_path /tmp/data/site-{i}" ) job.to(executor, f"site-{i+1}") diff --git a/examples/hello-world/hello-pt-resnet/src/hello-pt_cifar10_fl.py b/examples/hello-world/hello-pt-resnet/src/hello-pt_cifar10_fl.py index 7395466c68..5dd0a49cc9 100644 --- a/examples/hello-world/hello-pt-resnet/src/hello-pt_cifar10_fl.py +++ b/examples/hello-world/hello-pt-resnet/src/hello-pt_cifar10_fl.py @@ -48,7 +48,10 @@ def main(): client_name = sys_info["site_name"] train_dataset = CIFAR10( - root=os.path.join(DATASET_PATH, client_name), transform=transforms, download=True, train=True + root=os.path.join(DATASET_PATH, client_name), + transform=transforms, + download=True, + train=True, ) train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) @@ -74,9 +77,19 @@ def main(): running_loss += cost.cpu().detach().numpy() / images.size()[0] if i % 3000 == 0: - print(f"Epoch: {epoch}/{epochs}, Iteration: {i}, Loss: {running_loss/3000}") - global_step = input_model.current_round * steps + epoch * len(train_loader) + i - summary_writer.add_scalar(tag="loss_for_each_batch", scalar=running_loss, global_step=global_step) + print( + f"Epoch: {epoch}/{epochs}, Iteration: {i}, Loss: {running_loss/3000}" + ) + global_step = ( + input_model.current_round * steps + + epoch * len(train_loader) + + i + ) + summary_writer.add_scalar( + tag="loss_for_each_batch", + scalar=running_loss, + global_step=global_step, + ) running_loss = 0.0 print("Finished Training") From 14f325a9c9d1085cf9a036242d5e92de4056a36d Mon Sep 17 00:00:00 2001 From: Yuhong Wen Date: Thu, 26 Sep 2024 17:08:31 -0400 Subject: [PATCH 7/8] black codestyle fix. --- .../hello-pt-resnet/src/hello-pt_cifar10_fl.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/examples/hello-world/hello-pt-resnet/src/hello-pt_cifar10_fl.py b/examples/hello-world/hello-pt-resnet/src/hello-pt_cifar10_fl.py index 5dd0a49cc9..7ea890c0f9 100644 --- a/examples/hello-world/hello-pt-resnet/src/hello-pt_cifar10_fl.py +++ b/examples/hello-world/hello-pt-resnet/src/hello-pt_cifar10_fl.py @@ -77,14 +77,8 @@ def main(): running_loss += cost.cpu().detach().numpy() / images.size()[0] if i % 3000 == 0: - print( - f"Epoch: {epoch}/{epochs}, Iteration: {i}, Loss: {running_loss/3000}" - ) - global_step = ( - input_model.current_round * steps - + epoch * len(train_loader) - + i - ) + print(f"Epoch: {epoch}/{epochs}, Iteration: {i}, Loss: {running_loss/3000}") + global_step = input_model.current_round * steps + epoch * len(train_loader) + i summary_writer.add_scalar( tag="loss_for_each_batch", scalar=running_loss, From e3ab1dbb03d81103d8fa05e4c62e758be606688d Mon Sep 17 00:00:00 2001 From: Yuhong Wen Date: Thu, 26 Sep 2024 17:20:21 -0400 Subject: [PATCH 8/8] codestyle fix. --- examples/hello-world/hello-pt-resnet/fedavg_script_runner_pt.py | 2 +- examples/hello-world/hello-pt-resnet/src/hello-pt_cifar10_fl.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/hello-world/hello-pt-resnet/fedavg_script_runner_pt.py b/examples/hello-world/hello-pt-resnet/fedavg_script_runner_pt.py index 948d64d519..ece630dcdb 100644 --- a/examples/hello-world/hello-pt-resnet/fedavg_script_runner_pt.py +++ b/examples/hello-world/hello-pt-resnet/fedavg_script_runner_pt.py @@ -35,7 +35,7 @@ script=train_script, script_args="", # f"--batch_size 32 --data_path /tmp/data/site-{i}" ) - job.to(executor, f"site-{i+1}") + job.to(executor, f"site-{i + 1}") # job.export_job("/tmp/nvflare/jobs/job_config") job.simulator_run("/tmp/nvflare/jobs/workdir", gpu="0") diff --git a/examples/hello-world/hello-pt-resnet/src/hello-pt_cifar10_fl.py b/examples/hello-world/hello-pt-resnet/src/hello-pt_cifar10_fl.py index 7ea890c0f9..860e6f0cac 100644 --- a/examples/hello-world/hello-pt-resnet/src/hello-pt_cifar10_fl.py +++ b/examples/hello-world/hello-pt-resnet/src/hello-pt_cifar10_fl.py @@ -77,7 +77,7 @@ def main(): running_loss += cost.cpu().detach().numpy() / images.size()[0] if i % 3000 == 0: - print(f"Epoch: {epoch}/{epochs}, Iteration: {i}, Loss: {running_loss/3000}") + print(f"Epoch: {epoch}/{epochs}, Iteration: {i}, Loss: {running_loss / 3000}") global_step = input_model.current_round * steps + epoch * len(train_loader) + i summary_writer.add_scalar( tag="loss_for_each_batch",