Skip to content

Commit

Permalink
Merge branch 'main' into fin_xgb
Browse files Browse the repository at this point in the history
  • Loading branch information
ZiyueXu77 authored Oct 9, 2023
2 parents 2d41ed8 + 6974f6d commit 86dc874
Show file tree
Hide file tree
Showing 42 changed files with 2,015 additions and 161 deletions.
1 change: 1 addition & 0 deletions docs/real_world_fl.rst
Original file line number Diff line number Diff line change
Expand Up @@ -28,4 +28,5 @@ to see the capabilities of the system and how it can be operated.
real_world_fl/job
real_world_fl/workspace
real_world_fl/cloud_deployment
real_world_fl/notes_on_large_models
user_guide/federated_authorization
90 changes: 90 additions & 0 deletions docs/real_world_fl/notes_on_large_models.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
.. _notes_on_large_models:

Large Models
============
As the federated learning tasks become more and more complex, their model sizes increase. Some model sizes may go beyond 2GB and even reach hundreds of GB. NVIDIA FLARE supports
large models as long as the system memory of servers and clients is capable of handling it. However, it requires special consideration on NVIDIA FLARE configuration and the system because
the network bandwidth and thus the time to transmit such large amount of data during NVIDIA FLARE job runtime varies significantly. Here we describe 128GB model training jobs to highlight
the configuration and system that users should consider for sucessful large model federated learning jobs.

System Deployment
*****************
Our successful experiments of 128GB model training were running on one NVIDIA FLARE server and two clients. The server was deployed in Azure west-us region. One of those two clients
was deployed in AWS west-us-2 region and the other was in AWS ap-south-1 region. The system was deployed in such cross-region and cross-cloud-service-provider manner so that we can test
NVIDIA FLARE system with various conditions on the network bandwidth.
The Azure VM size of the NVIDIA FLARE server was M32-8ms, which has 875GB memory. The AWS EC2 instance type of NVIDIA FLARE clients was r5a.16xlarge with 512GB memory. We also enabled
128GB swap space on all machines.

Job of 128GB Models
*******************
We slightly modified the hello-numpy example to generate a model, which was a dictionary of 64 keys. Each key held a 2GB numpy array. The local training task was to add a small number to
those numpy arrays. The aggregator on the server side was not changed. This job required at least two clients and ran 3 rounds to finish.

Please note if your model contains leaf nodes that are larger than 4GB, the type of those nodes must be bytes. In that case, the outgoing model will need a conversion similar to the following:

.. code:: python
for k in np_data:
self.log_info(fl_ctx, f"converting {k=}")
tmp_file = io.BytesIO()
np.save(tmp_file, np_data[k])
np_data[k] = tmp_file.getvalue()
tmp_file.close()
self.log_info(fl_ctx, f"done converting {k=}")
outgoing_dxo = DXO(data_kind=incoming_dxo.data_kind, data=np_data, meta={MetaKey.NUM_STEPS_CURRENT_ROUND: 1})
Additionally, the receiving side needs to convert the bytes back to numpy array with codes similar to the following:

.. code:: python
for k in np_data:
self.log_info(fl_ctx, f"converting and adding delta for {k=}")
np_data[k] = np.load(io.BytesIO(np_data[k]))
Configuration
*******************
We measured the bandwidth between the server and west-us-2 client. It took around 2300 seconds to transfer the model from the client to the server and around 2000 seconds from the server to the client.
On the ap-south-1 client, it took about 11000 seconds from the client to the server and 11500 seconds from the server to the client. We updated the following values to accommodate such differences.

- streaming_read_timeout to 3000
- streaming_ack_wait to 6000
- communication_timeout to 6000


The `streaming_read_timeout` is used to check when a chunck of data is received but is not read out by the upper layer. The `streaming_ack_wait` is how long the sender should wait for acknowledgement returned by the receiver for one chunck.


The `communication_timeout` is used on three consecutive stages for a single request and response. When sending a large request (submit_update), the sender starts a timer with timeout = `communication_timeout`.
When this timer expires, the sender checks if any progress is made during this period. If yes, the sender resets the timer with the same timeout value and waits again. If not, this request and response returns with timeout.
After sending completes, the sender cancels the previous timer and starts a `remote processing` timer with timeout = `communication_timeout`. This is to wait for the first returned byte from the receiver. On
large models, the server requires much longer time to prepare the task when the clients send `get_task` requests. After receiving the first returned byte, the sender cancel the `remote processing` timer and starts
a new timer. It checks the receiving progress just like sending.


Since the experiment was based on hello-numpy, one of the arguments, `train_timeout` in the ScatterAndGather class had to be updated. This timeout is used to check the scheduling of training tasks. We
changed this argument to 60000 for this experiment.

Memory Usage
*******************
During the experiment, the server could use more than 512GB, ie 128GB * 2 clients * 2 (model and runtime space). The following figure shows the CPU and memory usage of the server.

.. image:: ../resources/128GB_server.png
:height: 350px

Although most of the time, the server was using less than 512GB, there were a few peaks that reached 700GB or more.

The followings are clients, west-us-2 and ap-south-1.

.. image:: ../resources/128GB_site1.png
:height: 350px


.. image:: ../resources/128GB_site2.png
:height: 350px


The west-us-2 client, with its fast bandwidth with the server, received and sent the models in about 100 minutes and entered nearly idle state with little cpu and memory usage. Both
clients used about 256GB, ie 128GB * 2 (model and runtime space), but at the end of receiving large models and at the beginning of sending large models, these two clients required more than
378GB, ie 128GB * 3.

Binary file added docs/resources/128GB_server.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/resources/128GB_site1.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/resources/128GB_site2.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
291 changes: 291 additions & 0 deletions examples/hello-world/step-by-step/cifar10/code/fl/executor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,291 @@
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os.path

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from net import Net
from torchvision import transforms

from nvflare.apis.dxo import DXO, DataKind, MetaKey, from_shareable
from nvflare.apis.event_type import EventType
from nvflare.apis.executor import Executor
from nvflare.apis.fl_constant import ReservedKey, ReturnCode
from nvflare.apis.fl_context import FLContext
from nvflare.apis.shareable import Shareable, make_reply
from nvflare.apis.signal import Signal
from nvflare.app_common.abstract.model import make_model_learnable, model_learnable_to_dxo
from nvflare.app_common.app_constant import AppConstants
from nvflare.app_opt.pt.model_persistence_format_manager import PTModelPersistenceFormatManager


class CIFAR10Executor(Executor):
def __init__(
self,
epochs: int = 2,
lr: float = 1e-2,
momentum: float = 0.9,
batch_size: int = 4,
num_workers: int = 1,
dataset_path: str = "/tmp/nvflare/data/cifar10",
model_path: str = "/tmp/nvflare/data/cifar10/cifar_net.pth",
device="cuda:0",
pre_train_task_name=AppConstants.TASK_GET_WEIGHTS,
train_task_name=AppConstants.TASK_TRAIN,
submit_model_task_name=AppConstants.TASK_SUBMIT_MODEL,
validate_task_name=AppConstants.TASK_VALIDATION,
exclude_vars=None,
):
"""Cifar10 Executor handles train, validate, and submit_model tasks. During train_task, it trains a
simple network on CIFAR10 dataset. For validate task, it evaluates the input model on the test set.
For submit_model task, it sends the locally trained model (if present) to the server.
Args:
epochs (int, optional): Epochs. Defaults to 2
lr (float, optional): Learning rate. Defaults to 0.01
momentum (float, optional): Momentum. Defaults to 0.9
batch_size: batch size for training and validation.
num_workers: number of workers for data loaders.
dataset_path: path to dataset
model_path: path to save model
device: (optional) We change to use GPU to speed things up. if you want to use CPU, change DEVICE="cpu"
pre_train_task_name: Task name for pre train task, i.e., sending initial model weights.
train_task_name (str, optional): Task name for train task. Defaults to "train".
submit_model_task_name (str, optional): Task name for submit model. Defaults to "submit_model".
validate_task_name (str, optional): Task name for validate task. Defaults to "validate".
exclude_vars (list): List of variables to exclude during model loading.
"""
super().__init__()
self.epochs = epochs
self.lr = lr
self.momentum = momentum
self.batch_size = batch_size
self.num_workers = num_workers
self.dataset_path = dataset_path
self.model_path = model_path

self.pre_train_task_name = pre_train_task_name
self.train_task_name = train_task_name
self.submit_model_task_name = submit_model_task_name
self.validate_task_name = validate_task_name
self.device = device
self.exclude_vars = exclude_vars

self.train_dataset = None
self.valid_dataset = None
self.train_loader = None
self.valid_loader = None
self.net = None
self.optimizer = None
self.criterion = None
self.persistence_manager = None
self.best_acc = 0.0

def handle_event(self, event_type: str, fl_ctx: FLContext):
if event_type == EventType.START_RUN:
self.initialize()

def initialize(self):
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

self.trainset = torchvision.datasets.CIFAR10(
root=self.dataset_path, train=True, download=True, transform=transform
)
self.trainloader = torch.utils.data.DataLoader(
self.trainset, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers
)
self._n_iterations = len(self.trainloader)

self.testset = torchvision.datasets.CIFAR10(
root=self.dataset_path, train=False, download=True, transform=transform
)
self.testloader = torch.utils.data.DataLoader(
self.testset, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers
)

self.net = Net()

self.criterion = nn.CrossEntropyLoss()
self.optimizer = optim.SGD(self.net.parameters(), lr=self.lr, momentum=self.momentum)

self._default_train_conf = {"train": {"model": type(self.net).__name__}}
self.persistence_manager = PTModelPersistenceFormatManager(
data=self.net.state_dict(), default_train_conf=self._default_train_conf
)

def execute(self, task_name: str, shareable: Shareable, fl_ctx: FLContext, abort_signal: Signal) -> Shareable:
try:
if task_name == self.pre_train_task_name:
# Get the new state dict and send as weights
return self._get_model_weights()
if task_name == self.train_task_name:
# Get model weights
try:
dxo = from_shareable(shareable)
except:
self.log_error(fl_ctx, "Unable to extract dxo from shareable.")
return make_reply(ReturnCode.BAD_TASK_DATA)

# Ensure data kind is weights.
if not dxo.data_kind == DataKind.WEIGHTS:
self.log_error(fl_ctx, f"data_kind expected WEIGHTS but got {dxo.data_kind} instead.")
return make_reply(ReturnCode.BAD_TASK_DATA)

# Convert weights to tensor. Run training
torch_weights = {k: torch.as_tensor(v) for k, v in dxo.data.items()}
self._local_train(fl_ctx, torch_weights)

# Check the abort_signal after training.
if abort_signal.triggered:
return make_reply(ReturnCode.TASK_ABORTED)

# Save the local model after training.
self._save_local_model(fl_ctx)

# Get the new state dict and send as weights
return self._get_model_weights()
if task_name == self.validate_task_name:
model_owner = "?"
try:
try:
dxo = from_shareable(shareable)
except:
self.log_error(fl_ctx, "Error in extracting dxo from shareable.")
return make_reply(ReturnCode.BAD_TASK_DATA)

# Ensure data_kind is weights.
if not dxo.data_kind == DataKind.WEIGHTS:
self.log_exception(fl_ctx, f"DXO is of type {dxo.data_kind} but expected type WEIGHTS.")
return make_reply(ReturnCode.BAD_TASK_DATA)

# Extract weights and ensure they are tensor.
model_owner = shareable.get_header(AppConstants.MODEL_OWNER, "?")
weights = {k: torch.as_tensor(v, device=self.device) for k, v in dxo.data.items()}

# Get validation accuracy
val_accuracy = self._local_validate(fl_ctx, weights)
if abort_signal.triggered:
return make_reply(ReturnCode.TASK_ABORTED)

self.log_info(
fl_ctx,
f"Accuracy when validating {model_owner}'s model on"
f" {fl_ctx.get_identity_name()}"
f"s data: {val_accuracy}",
)

dxo = DXO(data_kind=DataKind.METRICS, data={"val_acc": val_accuracy})
return dxo.to_shareable()
except:
self.log_exception(fl_ctx, f"Exception in validating model from {model_owner}")
return make_reply(ReturnCode.EXECUTION_EXCEPTION)
elif task_name == self.submit_model_task_name:
# Load local model
ml = self._load_local_model(fl_ctx)

# Get the model parameters and create dxo from it
dxo = model_learnable_to_dxo(ml)
return dxo.to_shareable()
else:
return make_reply(ReturnCode.TASK_UNKNOWN)
except Exception as e:
self.log_exception(fl_ctx, f"Exception in simple trainer: {e}.")
return make_reply(ReturnCode.EXECUTION_EXCEPTION)

def _get_model_weights(self) -> Shareable:
# Get the new state dict and send as weights
weights = {k: v.cpu().numpy() for k, v in self.net.state_dict().items()}

outgoing_dxo = DXO(
data_kind=DataKind.WEIGHTS, data=weights, meta={MetaKey.NUM_STEPS_CURRENT_ROUND: self._n_iterations}
)
return outgoing_dxo.to_shareable()

def _local_train(self, fl_ctx, input_weights):
self.net.load_state_dict(input_weights)
# (optional) use GPU to speed things up
self.net.to(self.device)

for epoch in range(self.epochs): # loop over the dataset multiple times

running_loss = 0.0
for i, data in enumerate(self.trainloader, 0):
# get the inputs; data is a list of [inputs, labels]
# (optional) use GPU to speed things up
inputs, labels = data[0].to(self.device), data[1].to(self.device)

# zero the parameter gradients
self.optimizer.zero_grad()

# forward + backward + optimize
outputs = self.net(inputs)
loss = self.criterion(outputs, labels)
loss.backward()
self.optimizer.step()

# print statistics
running_loss += loss.item()
if i % 2000 == 1999: # print every 2000 mini-batches
self.log_info(fl_ctx, f"[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 2000:.3f}")
running_loss = 0.0

self.log_info(fl_ctx, "Finished Training")

def _local_validate(self, fl_ctx, input_weights):
self.net.load_state_dict(input_weights)
# (optional) use GPU to speed things up
self.net.to(self.device)

correct = 0
total = 0
# since we're not training, we don't need to calculate the gradients for our outputs
with torch.no_grad():
for data in self.testloader:
# (optional) use GPU to speed things up
images, labels = data[0].to(self.device), data[1].to(self.device)
# calculate outputs by running images through the network
outputs = self.net(images)
# the class with the highest energy is what we choose as prediction
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()

val_accuracy = 100 * correct // total
self.log_info(fl_ctx, f"Accuracy of the network on the 10000 test images: {val_accuracy} %")
return val_accuracy

def _save_local_model(self, fl_ctx: FLContext):
run_dir = fl_ctx.get_engine().get_workspace().get_run_dir(fl_ctx.get_prop(ReservedKey.RUN_NUM))
models_dir = os.path.join(run_dir, "models")
if not os.path.exists(models_dir):
os.makedirs(models_dir)

ml = make_model_learnable(self.net.state_dict(), {})
self.persistence_manager.update(ml)
torch.save(self.persistence_manager.to_persistence_dict(), self.model_path)

def _load_local_model(self, fl_ctx: FLContext):
run_dir = fl_ctx.get_engine().get_workspace().get_run_dir(fl_ctx.get_prop(ReservedKey.RUN_NUM))
models_dir = os.path.join(run_dir, "models")
if not os.path.exists(models_dir):
return None

self.persistence_manager = PTModelPersistenceFormatManager(
data=torch.load(self.model_path), default_train_conf=self._default_train_conf
)
ml = self.persistence_manager.to_model_learnable(exclude_vars=self.exclude_vars)
return ml
Loading

0 comments on commit 86dc874

Please sign in to comment.