Skip to content

Commit

Permalink
Add lightning example
Browse files Browse the repository at this point in the history
  • Loading branch information
YuanTingHsieh committed Jul 22, 2023
1 parent 015035f commit 1df030b
Show file tree
Hide file tree
Showing 12 changed files with 373 additions and 126 deletions.
2 changes: 2 additions & 0 deletions examples/advanced/lightning/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
This example showcase how to use NVFlare client lightning API to transform your local standalone training code written in lightning into FL!

Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
{
"format_version": 2,

"executors": [
{
"tasks": ["train"],
"executor": {
"name": "PTFilePipeLauncherExecutor",
"args": {
"launcher_id": "launcher",
"heartbeat_timeout": 60
}
}
}
],
"task_result_filters": [
],
"task_data_filters": [
],
"components": [
{
"id": "launcher",
"name": "SubprocessLauncher",
"args": {
"script": "python custom/train.py"
}
}
]
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
{
"format_version": 2,

"server": {
"heart_beat_timeout": 600
},
"task_data_filters": [],
"task_result_filters": [],
"components": [
{
"id": "shareable_generator",
"path": "nvflare.app_common.shareablegenerators.full_model_shareable_generator.FullModelShareableGenerator",
"args": {}
},
{
"id": "aggregator",
"path": "nvflare.app_common.aggregators.intime_accumulate_model_aggregator.InTimeAccumulateWeightedAggregator",
"args": {
"expected_data_kind": "WEIGHTS"
}
}
],
"workflows": [
{
"id": "scatter_and_gather",
"name": "ScatterAndGather",
"args": {
"min_clients" : 2,
"num_rounds" : 5,
"start_round": 0,
"wait_time_after_min_received": 0,
"aggregator_id": "aggregator",
"shareable_generator_id": "shareable_generator",
"train_task_name": "train",
"train_timeout": 0
}
}
]
}
197 changes: 197 additions & 0 deletions examples/advanced/lightning/jobs/autoencoder/app/custom/train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,197 @@
# Copyright The Lightning AI team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""MNIST autoencoder example.
To run: python autoencoder.py --trainer.max_epochs=50
"""

from typing import Optional, Tuple

import torch
import torch.nn.functional as F
from lightning_utilities.core.rank_zero import rank_zero_only
from pytorch_lightning import LightningDataModule, LightningModule, Trainer, callbacks, cli_lightning_logo
from pytorch_lightning.cli import LightningCLI
from pytorch_lightning.demos.mnist_datamodule import MNIST
from pytorch_lightning.utilities.imports import _TORCHVISION_AVAILABLE
from torch import nn
from torch.utils.data import DataLoader, random_split

import nvflare.client.lightning as flare

if _TORCHVISION_AVAILABLE:
import torchvision
from torchvision import transforms
from torchvision.utils import save_image

DATASETS_PATH = "/tmp/nvflare/mnist"


class ImageSampler(callbacks.Callback):
def __init__(
self,
num_samples: int = 3,
nrow: int = 8,
padding: int = 2,
normalize: bool = True,
norm_range: Optional[Tuple[int, int]] = None,
scale_each: bool = False,
pad_value: int = 0,
) -> None:
"""
Args:
num_samples: Number of images displayed in the grid. Default: ``3``.
nrow: Number of images displayed in each row of the grid.
The final grid size is ``(B / nrow, nrow)``. Default: ``8``.
padding: Amount of padding. Default: ``2``.
normalize: If ``True``, shift the image to the range (0, 1),
by the min and max values specified by :attr:`range`. Default: ``False``.
norm_range: Tuple (min, max) where min and max are numbers,
then these numbers are used to normalize the image. By default, min and max
are computed from the tensor.
scale_each: If ``True``, scale each image in the batch of
images separately rather than the (min, max) over all images. Default: ``False``.
pad_value: Value for the padded pixels. Default: ``0``.
"""
if not _TORCHVISION_AVAILABLE: # pragma: no cover
raise ModuleNotFoundError("You want to use `torchvision` which is not installed yet.")

super().__init__()
self.num_samples = num_samples
self.nrow = nrow
self.padding = padding
self.normalize = normalize
self.norm_range = norm_range
self.scale_each = scale_each
self.pad_value = pad_value

def _to_grid(self, images):
return torchvision.utils.make_grid(
tensor=images,
nrow=self.nrow,
padding=self.padding,
normalize=self.normalize,
range=self.norm_range,
scale_each=self.scale_each,
pad_value=self.pad_value,
)

@rank_zero_only
def on_train_epoch_end(self, trainer: Trainer, pl_module: LightningModule) -> None:
if not _TORCHVISION_AVAILABLE:
return

images, _ = next(iter(DataLoader(trainer.datamodule.mnist_val, batch_size=self.num_samples)))
images_flattened = images.view(images.size(0), -1)

# generate images
with torch.no_grad():
pl_module.eval()
images_generated = pl_module(images_flattened.to(pl_module.device))
pl_module.train()

if trainer.current_epoch == 0:
save_image(self._to_grid(images), f"grid_ori_{trainer.current_epoch}.png")
save_image(self._to_grid(images_generated.reshape(images.shape)), f"grid_generated_{trainer.current_epoch}.png")


class LitAutoEncoder(LightningModule):
"""
>>> LitAutoEncoder() # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
LitAutoEncoder(
(encoder): ...
(decoder): ...
)
"""

def __init__(self, hidden_dim: int = 64, learning_rate=10e-3):
super().__init__()
self.save_hyperparameters()
self.encoder = nn.Sequential(nn.Linear(28 * 28, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, 3))
self.decoder = nn.Sequential(nn.Linear(3, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, 28 * 28))

def forward(self, x):
z = self.encoder(x)
return self.decoder(z)

def training_step(self, batch, batch_idx):
return self._common_step(batch, batch_idx, "train")

def validation_step(self, batch, batch_idx):
self._common_step(batch, batch_idx, "val")

def test_step(self, batch, batch_idx):
self._common_step(batch, batch_idx, "test")

def predict_step(self, batch, batch_idx, dataloader_idx=None):
x = self._prepare_batch(batch)
return self(x)

def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate)

def _prepare_batch(self, batch):
x, _ = batch
return x.view(x.size(0), -1)

def _common_step(self, batch, batch_idx, stage: str):
x = self._prepare_batch(batch)
loss = F.mse_loss(x, self(x))
self.log(f"{stage}_loss", loss, on_step=True)
return loss


class MyDataModule(LightningDataModule):
def __init__(self, batch_size: int = 32):
super().__init__()
dataset = MNIST(DATASETS_PATH, train=True, download=True, transform=transforms.ToTensor())
self.mnist_test = MNIST(DATASETS_PATH, train=False, download=True, transform=transforms.ToTensor())
self.mnist_train, self.mnist_val = random_split(dataset, [55000, 5000])
self.batch_size = batch_size

def train_dataloader(self):
return DataLoader(self.mnist_train, batch_size=self.batch_size)

def val_dataloader(self):
return DataLoader(self.mnist_val, batch_size=self.batch_size)

def test_dataloader(self):
return DataLoader(self.mnist_test, batch_size=self.batch_size)

def predict_dataloader(self):
return DataLoader(self.mnist_test, batch_size=self.batch_size)


def cli_main():
flare.init()
flare.patch(LitAutoEncoder)
cli = LightningCLI(
LitAutoEncoder,
MyDataModule,
seed_everything_default=1234,
run=False, # used to de-activate automatic fitting.
trainer_defaults={"callbacks": ImageSampler(), "max_epochs": 1},
save_config_kwargs={"overwrite": True},
)
cli.trainer.fit(cli.model, datamodule=cli.datamodule)
cli.trainer.test(ckpt_path="best", datamodule=cli.datamodule)
cli.trainer.test(cli.model.get_fl_model(), datamodule=cli.datamodule)
predictions = cli.trainer.predict(ckpt_path="best", datamodule=cli.datamodule)
print(predictions[0])


if __name__ == "__main__":
cli_lightning_logo()
cli_main()
11 changes: 11 additions & 0 deletions examples/advanced/lightning/jobs/autoencoder/meta.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
{
"name": "lightning",
"resource_spec": {},
"min_clients" : 2,
"deploy_map": {
"app": [
"@ALL"
]
}
}

3 changes: 3 additions & 0 deletions examples/advanced/lightning/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
nvflare[PT]>=2.4.0
jsonargparse[signatures]>=4.17.0
pytorch_lightning
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ def train(weights=None, total_epochs=2, lr=0.001, device="cuda:0"):
print("Finished Training")
torch.save(net.state_dict(), PATH)

# (note) needs to return trained weights
return net.cpu().state_dict()


Expand All @@ -106,6 +107,7 @@ def evaluate(input_weights=None, device="cuda:0"):
correct += (predicted == labels).sum().item()

print(f"Accuracy of the network on the 10000 test images: {100 * correct // total} %")
# (note) needs to return accuracy
return 100 * correct // total


Expand Down
10 changes: 3 additions & 7 deletions nvflare/app_common/executors/file_pipe_launcher_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

from nvflare.apis.fl_context import FLContext
from nvflare.apis.utils.decomposers import flare_decomposers
from nvflare.app_common.abstract.launcher import Launcher
from nvflare.app_common.decomposers import common_decomposers
from nvflare.app_common.executors.launcher_executor import LauncherExecutor
from nvflare.fuel.utils.constants import Mode
Expand Down Expand Up @@ -84,13 +83,10 @@ def __init__(
self._data_exchange_path = data_exchange_path

def initialize(self, fl_ctx: FLContext) -> None:
self._init_launcher(fl_ctx)
self._init_converter(fl_ctx)

engine = fl_ctx.get_engine()
# init launcher
launcher: Launcher = engine.get_component(self._launcher_id)
if launcher is not None:
check_object_type(self._launcher_id, launcher, Launcher)
launcher.initialize(fl_ctx)
self.launcher = launcher

# gets pipe
if self._pipe_id:
Expand Down
41 changes: 23 additions & 18 deletions nvflare/app_common/executors/launcher_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,26 +89,11 @@ def __init__(
self._to_nvflare_converter: Optional[ParamsConverter] = None

def initialize(self, fl_ctx: FLContext) -> None:
engine = fl_ctx.get_engine()
# init launcher
launcher: Launcher = engine.get_component(self._launcher_id)
if launcher is not None:
check_object_type(self._launcher_id, launcher, Launcher)
launcher.initialize(fl_ctx)
self.launcher = launcher

# init converter
from_nvflare_converter: ParamsConverter = engine.get_component(self._from_nvflare_converter_id)
if from_nvflare_converter is not None:
check_object_type(self._from_nvflare_converter_id, from_nvflare_converter, ParamsConverter)
self._from_nvflare_converter = from_nvflare_converter

to_nvflare_converter: ParamsConverter = engine.get_component(self._to_nvflare_converter_id)
if to_nvflare_converter is not None:
check_object_type(self._to_nvflare_converter_id, to_nvflare_converter, ParamsConverter)
self._to_nvflare_converter = to_nvflare_converter
self._init_launcher(fl_ctx)
self._init_converter(fl_ctx)

# gets pipe
engine = fl_ctx.get_engine()
pipe: Pipe = engine.get_component(self._pipe_id)
check_object_type(self._pipe_id, pipe, Pipe)

Expand Down Expand Up @@ -152,6 +137,26 @@ def execute(self, task_name: str, shareable: Shareable, fl_ctx: FLContext, abort

return result

def _init_launcher(self, fl_ctx: FLContext):
engine = fl_ctx.get_engine()
launcher: Launcher = engine.get_component(self._launcher_id)
if launcher is not None:
check_object_type(self._launcher_id, launcher, Launcher)
launcher.initialize(fl_ctx)
self.launcher = launcher

def _init_converter(self, fl_ctx: FLContext):
engine = fl_ctx.get_engine()
from_nvflare_converter: ParamsConverter = engine.get_component(self._from_nvflare_converter_id)
if from_nvflare_converter is not None:
check_object_type(self._from_nvflare_converter_id, from_nvflare_converter, ParamsConverter)
self._from_nvflare_converter = from_nvflare_converter

to_nvflare_converter: ParamsConverter = engine.get_component(self._to_nvflare_converter_id)
if to_nvflare_converter is not None:
check_object_type(self._to_nvflare_converter_id, to_nvflare_converter, ParamsConverter)
self._to_nvflare_converter = to_nvflare_converter

def _launch(self, task_name: str, shareable: Shareable, fl_ctx: FLContext, abort_signal: Signal) -> bool:
if self.launcher:
return self.launcher.launch_task(task_name, shareable, fl_ctx, abort_signal)
Expand Down
Loading

0 comments on commit 1df030b

Please sign in to comment.