forked from NVIDIA/NVFlare
-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
42 changed files
with
2,015 additions
and
161 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
.. _notes_on_large_models: | ||
|
||
Large Models | ||
============ | ||
As the federated learning tasks become more and more complex, their model sizes increase. Some model sizes may go beyond 2GB and even reach hundreds of GB. NVIDIA FLARE supports | ||
large models as long as the system memory of servers and clients is capable of handling it. However, it requires special consideration on NVIDIA FLARE configuration and the system because | ||
the network bandwidth and thus the time to transmit such large amount of data during NVIDIA FLARE job runtime varies significantly. Here we describe 128GB model training jobs to highlight | ||
the configuration and system that users should consider for sucessful large model federated learning jobs. | ||
|
||
System Deployment | ||
***************** | ||
Our successful experiments of 128GB model training were running on one NVIDIA FLARE server and two clients. The server was deployed in Azure west-us region. One of those two clients | ||
was deployed in AWS west-us-2 region and the other was in AWS ap-south-1 region. The system was deployed in such cross-region and cross-cloud-service-provider manner so that we can test | ||
NVIDIA FLARE system with various conditions on the network bandwidth. | ||
The Azure VM size of the NVIDIA FLARE server was M32-8ms, which has 875GB memory. The AWS EC2 instance type of NVIDIA FLARE clients was r5a.16xlarge with 512GB memory. We also enabled | ||
128GB swap space on all machines. | ||
|
||
Job of 128GB Models | ||
******************* | ||
We slightly modified the hello-numpy example to generate a model, which was a dictionary of 64 keys. Each key held a 2GB numpy array. The local training task was to add a small number to | ||
those numpy arrays. The aggregator on the server side was not changed. This job required at least two clients and ran 3 rounds to finish. | ||
|
||
Please note if your model contains leaf nodes that are larger than 4GB, the type of those nodes must be bytes. In that case, the outgoing model will need a conversion similar to the following: | ||
|
||
.. code:: python | ||
for k in np_data: | ||
self.log_info(fl_ctx, f"converting {k=}") | ||
tmp_file = io.BytesIO() | ||
np.save(tmp_file, np_data[k]) | ||
np_data[k] = tmp_file.getvalue() | ||
tmp_file.close() | ||
self.log_info(fl_ctx, f"done converting {k=}") | ||
outgoing_dxo = DXO(data_kind=incoming_dxo.data_kind, data=np_data, meta={MetaKey.NUM_STEPS_CURRENT_ROUND: 1}) | ||
Additionally, the receiving side needs to convert the bytes back to numpy array with codes similar to the following: | ||
|
||
.. code:: python | ||
for k in np_data: | ||
self.log_info(fl_ctx, f"converting and adding delta for {k=}") | ||
np_data[k] = np.load(io.BytesIO(np_data[k])) | ||
Configuration | ||
******************* | ||
We measured the bandwidth between the server and west-us-2 client. It took around 2300 seconds to transfer the model from the client to the server and around 2000 seconds from the server to the client. | ||
On the ap-south-1 client, it took about 11000 seconds from the client to the server and 11500 seconds from the server to the client. We updated the following values to accommodate such differences. | ||
|
||
- streaming_read_timeout to 3000 | ||
- streaming_ack_wait to 6000 | ||
- communication_timeout to 6000 | ||
|
||
|
||
The `streaming_read_timeout` is used to check when a chunck of data is received but is not read out by the upper layer. The `streaming_ack_wait` is how long the sender should wait for acknowledgement returned by the receiver for one chunck. | ||
|
||
|
||
The `communication_timeout` is used on three consecutive stages for a single request and response. When sending a large request (submit_update), the sender starts a timer with timeout = `communication_timeout`. | ||
When this timer expires, the sender checks if any progress is made during this period. If yes, the sender resets the timer with the same timeout value and waits again. If not, this request and response returns with timeout. | ||
After sending completes, the sender cancels the previous timer and starts a `remote processing` timer with timeout = `communication_timeout`. This is to wait for the first returned byte from the receiver. On | ||
large models, the server requires much longer time to prepare the task when the clients send `get_task` requests. After receiving the first returned byte, the sender cancel the `remote processing` timer and starts | ||
a new timer. It checks the receiving progress just like sending. | ||
|
||
|
||
Since the experiment was based on hello-numpy, one of the arguments, `train_timeout` in the ScatterAndGather class had to be updated. This timeout is used to check the scheduling of training tasks. We | ||
changed this argument to 60000 for this experiment. | ||
|
||
Memory Usage | ||
******************* | ||
During the experiment, the server could use more than 512GB, ie 128GB * 2 clients * 2 (model and runtime space). The following figure shows the CPU and memory usage of the server. | ||
|
||
.. image:: ../resources/128GB_server.png | ||
:height: 350px | ||
|
||
Although most of the time, the server was using less than 512GB, there were a few peaks that reached 700GB or more. | ||
|
||
The followings are clients, west-us-2 and ap-south-1. | ||
|
||
.. image:: ../resources/128GB_site1.png | ||
:height: 350px | ||
|
||
|
||
.. image:: ../resources/128GB_site2.png | ||
:height: 350px | ||
|
||
|
||
The west-us-2 client, with its fast bandwidth with the server, received and sent the models in about 100 minutes and entered nearly idle state with little cpu and memory usage. Both | ||
clients used about 256GB, ie 128GB * 2 (model and runtime space), but at the end of receiving large models and at the beginning of sending large models, these two clients required more than | ||
378GB, ie 128GB * 3. | ||
|
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
291 changes: 291 additions & 0 deletions
291
examples/hello-world/step-by-step/cifar10/code/fl/executor.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,291 @@ | ||
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import os.path | ||
|
||
import torch | ||
import torch.nn as nn | ||
import torch.optim as optim | ||
import torchvision | ||
from net import Net | ||
from torchvision import transforms | ||
|
||
from nvflare.apis.dxo import DXO, DataKind, MetaKey, from_shareable | ||
from nvflare.apis.event_type import EventType | ||
from nvflare.apis.executor import Executor | ||
from nvflare.apis.fl_constant import ReservedKey, ReturnCode | ||
from nvflare.apis.fl_context import FLContext | ||
from nvflare.apis.shareable import Shareable, make_reply | ||
from nvflare.apis.signal import Signal | ||
from nvflare.app_common.abstract.model import make_model_learnable, model_learnable_to_dxo | ||
from nvflare.app_common.app_constant import AppConstants | ||
from nvflare.app_opt.pt.model_persistence_format_manager import PTModelPersistenceFormatManager | ||
|
||
|
||
class CIFAR10Executor(Executor): | ||
def __init__( | ||
self, | ||
epochs: int = 2, | ||
lr: float = 1e-2, | ||
momentum: float = 0.9, | ||
batch_size: int = 4, | ||
num_workers: int = 1, | ||
dataset_path: str = "/tmp/nvflare/data/cifar10", | ||
model_path: str = "/tmp/nvflare/data/cifar10/cifar_net.pth", | ||
device="cuda:0", | ||
pre_train_task_name=AppConstants.TASK_GET_WEIGHTS, | ||
train_task_name=AppConstants.TASK_TRAIN, | ||
submit_model_task_name=AppConstants.TASK_SUBMIT_MODEL, | ||
validate_task_name=AppConstants.TASK_VALIDATION, | ||
exclude_vars=None, | ||
): | ||
"""Cifar10 Executor handles train, validate, and submit_model tasks. During train_task, it trains a | ||
simple network on CIFAR10 dataset. For validate task, it evaluates the input model on the test set. | ||
For submit_model task, it sends the locally trained model (if present) to the server. | ||
Args: | ||
epochs (int, optional): Epochs. Defaults to 2 | ||
lr (float, optional): Learning rate. Defaults to 0.01 | ||
momentum (float, optional): Momentum. Defaults to 0.9 | ||
batch_size: batch size for training and validation. | ||
num_workers: number of workers for data loaders. | ||
dataset_path: path to dataset | ||
model_path: path to save model | ||
device: (optional) We change to use GPU to speed things up. if you want to use CPU, change DEVICE="cpu" | ||
pre_train_task_name: Task name for pre train task, i.e., sending initial model weights. | ||
train_task_name (str, optional): Task name for train task. Defaults to "train". | ||
submit_model_task_name (str, optional): Task name for submit model. Defaults to "submit_model". | ||
validate_task_name (str, optional): Task name for validate task. Defaults to "validate". | ||
exclude_vars (list): List of variables to exclude during model loading. | ||
""" | ||
super().__init__() | ||
self.epochs = epochs | ||
self.lr = lr | ||
self.momentum = momentum | ||
self.batch_size = batch_size | ||
self.num_workers = num_workers | ||
self.dataset_path = dataset_path | ||
self.model_path = model_path | ||
|
||
self.pre_train_task_name = pre_train_task_name | ||
self.train_task_name = train_task_name | ||
self.submit_model_task_name = submit_model_task_name | ||
self.validate_task_name = validate_task_name | ||
self.device = device | ||
self.exclude_vars = exclude_vars | ||
|
||
self.train_dataset = None | ||
self.valid_dataset = None | ||
self.train_loader = None | ||
self.valid_loader = None | ||
self.net = None | ||
self.optimizer = None | ||
self.criterion = None | ||
self.persistence_manager = None | ||
self.best_acc = 0.0 | ||
|
||
def handle_event(self, event_type: str, fl_ctx: FLContext): | ||
if event_type == EventType.START_RUN: | ||
self.initialize() | ||
|
||
def initialize(self): | ||
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) | ||
|
||
self.trainset = torchvision.datasets.CIFAR10( | ||
root=self.dataset_path, train=True, download=True, transform=transform | ||
) | ||
self.trainloader = torch.utils.data.DataLoader( | ||
self.trainset, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers | ||
) | ||
self._n_iterations = len(self.trainloader) | ||
|
||
self.testset = torchvision.datasets.CIFAR10( | ||
root=self.dataset_path, train=False, download=True, transform=transform | ||
) | ||
self.testloader = torch.utils.data.DataLoader( | ||
self.testset, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers | ||
) | ||
|
||
self.net = Net() | ||
|
||
self.criterion = nn.CrossEntropyLoss() | ||
self.optimizer = optim.SGD(self.net.parameters(), lr=self.lr, momentum=self.momentum) | ||
|
||
self._default_train_conf = {"train": {"model": type(self.net).__name__}} | ||
self.persistence_manager = PTModelPersistenceFormatManager( | ||
data=self.net.state_dict(), default_train_conf=self._default_train_conf | ||
) | ||
|
||
def execute(self, task_name: str, shareable: Shareable, fl_ctx: FLContext, abort_signal: Signal) -> Shareable: | ||
try: | ||
if task_name == self.pre_train_task_name: | ||
# Get the new state dict and send as weights | ||
return self._get_model_weights() | ||
if task_name == self.train_task_name: | ||
# Get model weights | ||
try: | ||
dxo = from_shareable(shareable) | ||
except: | ||
self.log_error(fl_ctx, "Unable to extract dxo from shareable.") | ||
return make_reply(ReturnCode.BAD_TASK_DATA) | ||
|
||
# Ensure data kind is weights. | ||
if not dxo.data_kind == DataKind.WEIGHTS: | ||
self.log_error(fl_ctx, f"data_kind expected WEIGHTS but got {dxo.data_kind} instead.") | ||
return make_reply(ReturnCode.BAD_TASK_DATA) | ||
|
||
# Convert weights to tensor. Run training | ||
torch_weights = {k: torch.as_tensor(v) for k, v in dxo.data.items()} | ||
self._local_train(fl_ctx, torch_weights) | ||
|
||
# Check the abort_signal after training. | ||
if abort_signal.triggered: | ||
return make_reply(ReturnCode.TASK_ABORTED) | ||
|
||
# Save the local model after training. | ||
self._save_local_model(fl_ctx) | ||
|
||
# Get the new state dict and send as weights | ||
return self._get_model_weights() | ||
if task_name == self.validate_task_name: | ||
model_owner = "?" | ||
try: | ||
try: | ||
dxo = from_shareable(shareable) | ||
except: | ||
self.log_error(fl_ctx, "Error in extracting dxo from shareable.") | ||
return make_reply(ReturnCode.BAD_TASK_DATA) | ||
|
||
# Ensure data_kind is weights. | ||
if not dxo.data_kind == DataKind.WEIGHTS: | ||
self.log_exception(fl_ctx, f"DXO is of type {dxo.data_kind} but expected type WEIGHTS.") | ||
return make_reply(ReturnCode.BAD_TASK_DATA) | ||
|
||
# Extract weights and ensure they are tensor. | ||
model_owner = shareable.get_header(AppConstants.MODEL_OWNER, "?") | ||
weights = {k: torch.as_tensor(v, device=self.device) for k, v in dxo.data.items()} | ||
|
||
# Get validation accuracy | ||
val_accuracy = self._local_validate(fl_ctx, weights) | ||
if abort_signal.triggered: | ||
return make_reply(ReturnCode.TASK_ABORTED) | ||
|
||
self.log_info( | ||
fl_ctx, | ||
f"Accuracy when validating {model_owner}'s model on" | ||
f" {fl_ctx.get_identity_name()}" | ||
f"s data: {val_accuracy}", | ||
) | ||
|
||
dxo = DXO(data_kind=DataKind.METRICS, data={"val_acc": val_accuracy}) | ||
return dxo.to_shareable() | ||
except: | ||
self.log_exception(fl_ctx, f"Exception in validating model from {model_owner}") | ||
return make_reply(ReturnCode.EXECUTION_EXCEPTION) | ||
elif task_name == self.submit_model_task_name: | ||
# Load local model | ||
ml = self._load_local_model(fl_ctx) | ||
|
||
# Get the model parameters and create dxo from it | ||
dxo = model_learnable_to_dxo(ml) | ||
return dxo.to_shareable() | ||
else: | ||
return make_reply(ReturnCode.TASK_UNKNOWN) | ||
except Exception as e: | ||
self.log_exception(fl_ctx, f"Exception in simple trainer: {e}.") | ||
return make_reply(ReturnCode.EXECUTION_EXCEPTION) | ||
|
||
def _get_model_weights(self) -> Shareable: | ||
# Get the new state dict and send as weights | ||
weights = {k: v.cpu().numpy() for k, v in self.net.state_dict().items()} | ||
|
||
outgoing_dxo = DXO( | ||
data_kind=DataKind.WEIGHTS, data=weights, meta={MetaKey.NUM_STEPS_CURRENT_ROUND: self._n_iterations} | ||
) | ||
return outgoing_dxo.to_shareable() | ||
|
||
def _local_train(self, fl_ctx, input_weights): | ||
self.net.load_state_dict(input_weights) | ||
# (optional) use GPU to speed things up | ||
self.net.to(self.device) | ||
|
||
for epoch in range(self.epochs): # loop over the dataset multiple times | ||
|
||
running_loss = 0.0 | ||
for i, data in enumerate(self.trainloader, 0): | ||
# get the inputs; data is a list of [inputs, labels] | ||
# (optional) use GPU to speed things up | ||
inputs, labels = data[0].to(self.device), data[1].to(self.device) | ||
|
||
# zero the parameter gradients | ||
self.optimizer.zero_grad() | ||
|
||
# forward + backward + optimize | ||
outputs = self.net(inputs) | ||
loss = self.criterion(outputs, labels) | ||
loss.backward() | ||
self.optimizer.step() | ||
|
||
# print statistics | ||
running_loss += loss.item() | ||
if i % 2000 == 1999: # print every 2000 mini-batches | ||
self.log_info(fl_ctx, f"[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 2000:.3f}") | ||
running_loss = 0.0 | ||
|
||
self.log_info(fl_ctx, "Finished Training") | ||
|
||
def _local_validate(self, fl_ctx, input_weights): | ||
self.net.load_state_dict(input_weights) | ||
# (optional) use GPU to speed things up | ||
self.net.to(self.device) | ||
|
||
correct = 0 | ||
total = 0 | ||
# since we're not training, we don't need to calculate the gradients for our outputs | ||
with torch.no_grad(): | ||
for data in self.testloader: | ||
# (optional) use GPU to speed things up | ||
images, labels = data[0].to(self.device), data[1].to(self.device) | ||
# calculate outputs by running images through the network | ||
outputs = self.net(images) | ||
# the class with the highest energy is what we choose as prediction | ||
_, predicted = torch.max(outputs.data, 1) | ||
total += labels.size(0) | ||
correct += (predicted == labels).sum().item() | ||
|
||
val_accuracy = 100 * correct // total | ||
self.log_info(fl_ctx, f"Accuracy of the network on the 10000 test images: {val_accuracy} %") | ||
return val_accuracy | ||
|
||
def _save_local_model(self, fl_ctx: FLContext): | ||
run_dir = fl_ctx.get_engine().get_workspace().get_run_dir(fl_ctx.get_prop(ReservedKey.RUN_NUM)) | ||
models_dir = os.path.join(run_dir, "models") | ||
if not os.path.exists(models_dir): | ||
os.makedirs(models_dir) | ||
|
||
ml = make_model_learnable(self.net.state_dict(), {}) | ||
self.persistence_manager.update(ml) | ||
torch.save(self.persistence_manager.to_persistence_dict(), self.model_path) | ||
|
||
def _load_local_model(self, fl_ctx: FLContext): | ||
run_dir = fl_ctx.get_engine().get_workspace().get_run_dir(fl_ctx.get_prop(ReservedKey.RUN_NUM)) | ||
models_dir = os.path.join(run_dir, "models") | ||
if not os.path.exists(models_dir): | ||
return None | ||
|
||
self.persistence_manager = PTModelPersistenceFormatManager( | ||
data=torch.load(self.model_path), default_train_conf=self._default_train_conf | ||
) | ||
ml = self.persistence_manager.to_model_learnable(exclude_vars=self.exclude_vars) | ||
return ml |
Oops, something went wrong.