项目链接:
- 源代码: GitHub | gitee
- Docker 镜像: wenh06/fl-sim
- 文档(正在完善):
本仓库迁移自 fl_seminar, 主体部分是一个基于 PyTorch 的简单的联邦学习仿真框架。
可以在命令行中使用以下命令安装:
pip install git+https://github.com/wenh06/fl-sim.git
或者,可以先将仓库克隆到本地,然后在仓库根目录下使用以下命令安装:
pip install -e .
使用者也可以使用 Docker 镜像 运行本项目。该镜像是使用 Docker Image CI action 构建的。可以使用以下命令拉取镜像:
docker pull wenh06/fl-sim
通过如下的命令可以交互式地运行镜像:
docker run -it wenh06/fl-sim bash
关于 Docker 镜像更多的使用方法,请参考 Docker 官方文档。
点击展开
以下代码片段展示了如何使用框架在 FedProxFEMNIST
数据集上使用 FedProx
算法训练模型。
from fl_sim.data_processing.fedprox_femnist import FedProxFEMNIST
from fl_sim.algorithms.fedprox import (
FedProxServer,
FedProxClientConfig,
FedProxServerConfig,
)
# create a FedProxFEMNIST dataset
ds = FedProxFEMNIST()
# choose a model
model = ds.candidate_models["cnn_femmist_tiny"]
# set up the server and client configurations
server_config = FedProxServerConfig(200, ds.DEFAULT_TRAIN_CLIENTS_NUM, 0.7)
client_config = FedProxClientConfig(ds.DEFAULT_BATCH_SIZE, 30)
# create a FedProxServer object
s = FedProxServer(model, ds, server_config, client_config)
# normal centralized training
s.train_centralized()
# federated training
s.train_federated()
算法 | 文章 | 源仓库 | Action 状态 | 标准测试用例上的效果 |
---|---|---|---|---|
FedAvg1 | AISTATS2017 | N/A | ✔️ | |
FedOpt2 | arXiv:2003.00295 | N/A | ✔️ | |
FedProx | MLSys2020 | GitHub | ✔️ ❓ | |
pFedMe | NeurIPS2020 | GitHub | ||
FedSplit | NeurIPS2020 | N/A | ✔️ ❓ | |
FedDR | NeurIPS2021 | GitHub | ||
FedPD | IEEE Trans. Signal Process | GitHub | ||
SCAFFOLD | PMLR | N/A | ✔️ ❓ | |
ProxSkip | PMLR | N/A | ✔️ ❓ | |
Ditto | PMLR | GitHub | ✔️ | |
IFCA | NeurIPS2020 | GitHub | ✔️ | |
pFedMac | arXiv:2107.05330 | N/A | ||
FedDyn | ICLR2021 | N/A | ❓ | |
APFL | arXiv:2003.13461 | N/A | ❓ |
标准测试效果图:
Client sample ratio 10% Client sample ratio 30% Client sample ratio 70% Client sample ratio 100%
- ✔️ 算法在标准测试用例上的效果符合预期。
- ✔️ ❓ 算法在标准测试用例上的效果 低于 预期。
- ❓ 算法暂未在标准测试用例上进行测试。
⁉️ 算法在标准测试用例上的 发散 ,相关的算法实现需要进一步检查。
点击展开
Node
类是本仿真框架的核心。Node
有两个子类: Server
和 Client
。
Server
类是所有联邦学习算法中心节点的基类,它在训练过程中充当协调者和状态变量维护者的角色。
Client
类是所有联邦学习算法子节点的基类。
抽象基类 Node
提供了以下基本功能:
get_detached_model_parameters
: 获取节点上模型参数的副本。compute_gradients
: 计算指定模型参数(默认为节点上的当前模型参数)在指定数据(默认为节点上的训练数据)上的梯度。get_gradients
: 获取当前节点上模型当前的梯度,或者梯度的范数。get_norm
: 计算一个 tensor 或者 array 的范数。set_parameters
: 设置节点上模型参数。aggregate_results_from_csv_log
: 从 csv 日志文件中聚合实验结果。aggregate_results_from_json_log
: 从 json 日志文件中聚合实验结果。
以及需要子类实现的抽象方法或属性:
communicate
: 在每一轮训练中,与另一种类型节点进行通信的方法 (子节点 -> 中心节点 或者 中心节点 -> 子节点)。update
: 在每一轮训练中,更新节点状态的方法。required_config_fields
(property): 需要在配置类中指定的必要字段,用于在_post_init
方法中检查配置的有效性。_post_init
: 在__init__
方法的最后调用的后初始化方法,用于在__init__
方法中检查配置的有效性。
Server
类的签名(signature)为
Server(
model: torch.nn.modules.module.Module,
dataset: fl_sim.data_processing.fed_dataset.FedDataset,
config: fl_sim.nodes.ServerConfig,
client_config: fl_sim.nodes.ClientConfig,
lazy: bool = False,
) -> None
Server
类提供以下额外的方法(method)或属性(property):
_setup_clients
: 初始化客户端,为客户端分配计算资源。_sample_clients
: 从所有子节点中随机抽取一定数量的子节点。_communicate
: 执行子节点的communicate
方法,并更新全局通信计数器(_num_communications
)。_update
: 检查从子节点接收到消息(_received_messages
)的有效性,执行中心节点的update
方法,最后清除所有从子节点接收到的消息。train
: 联邦训练主循环,根据传入的mode
参数调用train_centralized
或者train_federated
或者train_local
方法。train_centralized
: 中心化训练,主要用于对比。train_federated
: 联邦训练,调用_communicate
方法(与子节点通信),等待子节点执行_update
和_communicate
方法,最后调用_update
方法(更新中心节点)。train_local
: 本地训练,调用子节点的train
方法,不 与中心节点通信。add_parameters
: 中心节点模型参数的增量更新。avg_parameters
: 将从子节点接收到的模型参数进行平均。update_gradients
: 使用从子节点接收到的梯度更新中心节点模型的梯度。get_client_data
: 获取特定子节点的数据。get_client_model
: 获取特定子节点的模型。get_cached_metrics
: 获取中心节点缓存的每一次训练循环的模型评估指标。_reset
: 将中心节点重置为初始状态。在执行新的训练过程之前,将检查_complete_experiment
标志。如果为True
,将调用此方法重置中心节点。is_convergent
(property): 检查训练过程是否收敛。目前,此属性 未 完全实现。
以及 需要子类实现的抽象属性:
client_cls
: the client class used when initializing the clients via_setup_clients
.config_cls
: a dictionary of configuration classes for the server and clients, used in__init__
method.doi
: the DOI of the paper that proposes the algorithm.
Client
类的签名为
Client(
client_id: int,
device: torch.device,
model: torch.nn.modules.module.Module,
dataset: fl_sim.data_processing.fed_dataset.FedDataset,
config: fl_sim.nodes.ClientConfig,
) -> None
providing the following additional functionalities:
Client
类还提供以下额外的方法:
_communicate
: 执行子节点的communicate
方法,并更新子节点上的通信计数器(_num_communications
),并清除缓存的(上一循环)子节点上模型评测结果。_update
: 执行子节点的update
方法,并清除从中心节点接收到的消息。evaluate
: 利用子节点上的测试数据,评测子节点上的模型。get_all_data
: 获取子节点上的所有数据。
以及 需要子类实现的抽象方法:
train
: 子节点的训练循环。
配置类(config class)是用于存储服务器和客户端配置的类。这两个类类似于 dataclass
,但是可以接受任意额外的字段。ServerConfig
的签名为
ServerConfig(
algorithm: str,
num_iters: int,
num_clients: int,
clients_sample_ratio: float,
txt_logger: bool = True,
json_logger: bool = True,
eval_every: int = 1,
verbose: int = 1,
**kwargs: Any,
) -> None
ClientConfig
的签名为
ClientConfig(
algorithm: str,
optimizer: str,
batch_size: int,
num_epochs: int,
lr: float,
verbose: int = 1,
**kwargs: Any,
) -> None
实现 新的联邦学习算法 的方法:需要实现 Server
和 Client
的子类,以及 ServerConfig
和 ClientConfig
的子类。如下的例子,是取自 FedProx
算法的实现:
点击展开
import warnings
from copy import deepcopy
from typing import List, Dict, Any
import torch
from torch_ecg.utils.misc import add_docstring
from tqdm.auto import tqdm
from fl_sim.nodes import Server, Client, ServerConfig, ClientConfig, ClientMessage
from fl_sim.algorithms import register_algorithm
@register_algorithm("FedProx")
class FedProxServerConfig(ServerConfig):
"""Server config for the FedProx algorithm.
Parameters
----------
num_iters : int
The number of (outer) iterations.
num_clients : int
The number of clients.
clients_sample_ratio : float
The ratio of clients to sample for each iteration.
vr : bool, default False
Whether to use variance reduction.
**kwargs : dict, optional
Additional keyword arguments:
- ``log_dir`` : str or Path, optional
The log directory.
If not specified, will use the default log directory.
If not absolute, will be relative to the default log directory.
- ``txt_logger`` : bool, default True
Whether to use txt logger.
- ``json_logger`` : bool, default True
Whether to use json logger.
- ``eval_every`` : int, default 1
The number of iterations to evaluate the model.
- ``visiable_gpus`` : Sequence[int], optional
Visable GPU IDs for allocating devices for clients.
Defaults to use all GPUs if available.
- ``extra_observes`` : List[str], optional
Extra attributes to observe during training.
- ``seed`` : int, default 0
The random seed.
- ``tag`` : str, optional
The tag of the experiment.
- ``verbose`` : int, default 1
The verbosity level.
- ``gpu_proportion`` : float, default 0.2
The proportion of clients to use GPU.
Used to similate the system heterogeneity of the clients.
Not used in the current version, reserved for future use.
"""
__name__ = "FedProxServerConfig"
def __init__(
self,
num_iters: int,
num_clients: int,
clients_sample_ratio: float,
vr: bool = False,
**kwargs: Any,
) -> None:
super().__init__(
"FedProx",
num_iters,
num_clients,
clients_sample_ratio,
vr=vr,
**kwargs,
)
@register_algorithm("FedProx")
class FedProxClientConfig(ClientConfig):
"""Client config for the FedProx algorithm.
Parameters
----------
batch_size : int
The batch size.
num_epochs : int
The number of epochs.
lr : float, default 1e-2
The learning rate.
mu : float, default 0.01
Coefficient for the proximal term.
vr : bool, default False
Whether to use variance reduction.
**kwargs : dict, optional
Additional keyword arguments:
- ``scheduler`` : dict, optional
The scheduler config.
None for no scheduler, using constant learning rate.
- ``extra_observes`` : List[str], optional
Extra attributes to observe during training,
which would be recorded in evaluated metrics,
sent to the server, and written to the log file.
- ``verbose`` : int, default 1
The verbosity level.
- ``latency`` : float, default 0.0
The latency of the client.
Not used in the current version, reserved for future use.
"""
__name__ = "FedProxClientConfig"
def __init__(
self,
batch_size: int,
num_epochs: int,
lr: float = 1e-2,
mu: float = 0.01,
vr: bool = False,
**kwargs: Any,
) -> None:
optimizer = "FedProx" if not vr else "FedProx_VR"
if kwargs.pop("algorithm", None) is not None:
warnings.warn(
"The `algorithm` argument fixed to `FedProx`.", RuntimeWarning
)
if kwargs.pop("optimizer", None) is not None:
warnings.warn(
"The `optimizer` argument fixed to `FedProx` or `FedProx_VR`.",
RuntimeWarning,
)
super().__init__(
"FedProx",
optimizer,
batch_size,
num_epochs,
lr,
mu=mu,
vr=vr,
**kwargs,
)
@register_algorithm("FedProx")
@add_docstring(
Server.__doc__.replace(
"The class to simulate the server node.",
"Server node for the FedProx algorithm.",
)
.replace("ServerConfig", "FedProxServerConfig")
.replace("ClientConfig", "FedProxClientConfig")
)
class FedProxServer(Server):
"""Server node for the FedProx algorithm."""
__name__ = "FedProxServer"
def _post_init(self) -> None:
"""
check if all required field in the config are set,
and check compatibility of server and client configs
"""
super()._post_init()
assert self.config.vr == self._client_config.vr
@property
def client_cls(self) -> type:
return FedProxClient
@property
def required_config_fields(self) -> List[str]:
return []
def communicate(self, target: "FedProxClient") -> None:
target._received_messages = {"parameters": self.get_detached_model_parameters()}
if target.config.vr:
target._received_messages["gradients"] = [
p.grad.detach().clone() if p.grad is not None else torch.zeros_like(p)
for p in target.model.parameters()
]
def update(self) -> None:
# sum of received parameters, with self.model.parameters() as its container
self.avg_parameters()
if self.config.vr:
self.update_gradients()
@property
def config_cls(self) -> Dict[str, type]:
return {
"server": FedProxServerConfig,
"client": FedProxClientConfig,
}
@property
def doi(self) -> List[str]:
return ["10.48550/ARXIV.1812.06127"]
@register_algorithm("FedProx")
@add_docstring(
Client.__doc__.replace(
"The class to simulate the client node.",
"Client node for the FedProx algorithm.",
).replace("ClientConfig", "FedProxClientConfig")
)
class FedProxClient(Client):
"""Client node for the FedProx algorithm."""
__name__ = "FedProxClient"
def _post_init(self) -> None:
"""
check if all required field in the config are set,
and set attributes for maintaining itermidiate states
"""
super()._post_init()
if self.config.vr:
self._gradient_buffer = [
torch.zeros_like(p) for p in self.model.parameters()
]
else:
self._gradient_buffer = None
@property
def required_config_fields(self) -> List[str]:
return ["mu"]
def communicate(self, target: "FedProxServer") -> None:
message = {
"client_id": self.client_id,
"parameters": self.get_detached_model_parameters(),
"train_samples": len(self.train_loader.dataset),
"metrics": self._metrics,
}
if self.config.vr:
message["gradients"] = [
p.grad.detach().clone() for p in self.model.parameters()
]
target._received_messages.append(ClientMessage(**message))
def update(self) -> None:
try:
self._cached_parameters = deepcopy(self._received_messages["parameters"])
except KeyError:
warnings.warn("No parameters received from server")
warnings.warn("Using current model parameters as initial parameters")
self._cached_parameters = self.get_detached_model_parameters()
except Exception as err:
raise err
self._cached_parameters = [p.to(self.device) for p in self._cached_parameters]
if (
self.config.vr
and self._received_messages.get("gradients", None) is not None
):
self._gradient_buffer = [
gd.clone().to(self.device)
for gd in self._received_messages["gradients"]
]
self.solve_inner() # alias of self.train()
def train(self) -> None:
self.model.train()
with tqdm(
range(self.config.num_epochs),
total=self.config.num_epochs,
mininterval=1.0,
disable=self.config.verbose < 2,
) as pbar:
for epoch in pbar: # local update
self.model.train()
for X, y in self.train_loader:
X, y = X.to(self.device), y.to(self.device)
self.optimizer.zero_grad()
output = self.model(X)
loss = self.criterion(output, y)
loss.backward()
self.optimizer.step(
local_weights=self._cached_parameters,
variance_buffer=self._gradient_buffer,
)
👉 返回目录
点击展开
data_processing 模块包含数据预处理、IO 等代码,其中包含以下数据集:
FedCIFAR
FedCIFAR100
FedEMNIST
FedMNIST
FedShakespeare
FedSynthetic
FedProxFEMNIST
FedProxMNIST
FedProxSent140
以上每一个数据集都被封装在一个类中,提供以下功能:
- 数据集的自动下载和预处理
- 数据集的切分(分配给子节点)方法
get_dataloader
- 预置了一系列候选 模型,可以通过
candidate_models
属性获取 - 基于模型预测值的
evaluate
方法,可以评测模型在数据集上的性能 - 一些辅助方法,用于数据可视化和参考文献的获取(biblatex 格式)
此外, LIBSVM
数据集列表可以通过如下语句获取
pd.read_html("https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/")[0]
更新: 一部分计算机视觉数据集的训练集支持动态数据增强。基类 FedVisionDataset
的签名为
FedVisionDataset(
datadir: Union[str, pathlib.Path, NoneType] = None,
transform: Union[str, Callable, NoneType] = "none",
) -> None
通过将 transform
参数设置为 "none"
(这也是 transform
参数的默认值),训练集将被封装在一个静态的 TensorDataset
中。通过将 transform
参数设置为 None
,训练集将使用内置的动态数据增强,例如 FedCIFAR100
使用 torchvision.transforms.RandAugment
。
注意,大部分计算机视觉的联邦数据集包含的数据都是经过预处理后的而不是原始像素值,因此不支持使用 torchvision.transforms
进行动态数据增强。
👉 返回目录
点击展开
models 模块包含预定义的(神经网络)模型,其中大部分结构都非常简单:
MLP
FedPDMLP
CNNMnist
CNNFEMnist
CNNFEMnist_Tiny
CNNCifar
RNN_OriginalFedAvg
RNN_StackOverFlow
RNN_Sent140
ResNet18
ResNet10
LogisticRegression
SVC
SVR
以上大部分模型都是之前文献中使用过的,或是基于此进行修改的。
通过调用 model_size
或 model_size_
属性可以获取模型的大小(参数数量和内存占用)。
👉 返回目录
点击展开
optimizers 模块包含用于解决联邦优化问题内循环(子节点上的)优化问题的优化器。除了 torch
和 torch_optimizers
中的优化器外,本模块实现了以下优化器:
ProxSGD
FedPD_SGD
FedPD_VR
PSGD
PSVRG
pFedMe
FedProx
FedDR
其中大部分都是基于 ProxSGD
的变体,即目标是带有临近项的优化问题。
👉 返回目录
点击展开
regularizers 模块包含用于对模型参数进行正则化的正则化项(用类来实现)。正则化项的目的是防止模型过拟合,从而提高模型的泛化能力。本模块实现了以下正则化项:
L1Norm
L2Norm
L2NormSquared
NullRegularizer
以上的正则化项都是基类 Regularizer
的子类,可以通过将正则化项的名称传递给函数 get_regularizer
来获取。正则化项都有 eval
和 prox_eval
两个方法,分别用于计算正则化项的值和其临近项的值。
👉 返回目录
点击展开
utils 模块包含了一些工具函数,例如 数据下载、 日志记录、 可视化 等。
TxTLogger
: 用于将训练指标记录到文本文件中,同时也会在控制台以适合人类阅读习惯的格式打印出来。CSVLogger
: 用于将训练指标记录到 CSV 文件中。不推荐使用,因为存储消耗较大。JsonLogger
: 用于将训练指标记录到 JSON 文件中。也可以保存为 YAML 文件。
👉 返回目录
本框架实现了一个可视化面板,用于可视化联邦学习算法的训练结果。它基于 ipywidgets
和 matplotlib
进行开发,可以在 Jupyter notebook 中使用。它具有以下功能:
- 自动搜索并显示指定目录中完整实验的日志文件。
- 自动解析日志文件,并将训练指标进行聚合,利用 matplotlib 生成曲线。
- 支持对绘制的图像进行交互式操作,包括缩放、字体选择、曲线平滑等。
- 支持将绘制的图像保存为 PDF/SVG/PNG/JPEG/PS 等格式的文件。
- 支持将不同实验曲线进行合并,例如可以将使用不同随机数种子的
FedAvg
算法的数值曲线合并成一条均值曲线。合并后的曲线可以选择是否显示标准差、标准误差、分位数、四分位距等误差范围。
下面的 GIF (使用 ScreenToGif 制作生成)是可视化面板的演示示例:
注意: 若希望在 Linux 系统下(例如 Ubuntu)上使用 Windows 字体,可以执行以下命令获取相关字体:
sudo apt install ttf-mscorefonts-installer
sudo fc-cache -fv
本仿真框架提供了命令行接口(CLI),用于一次性执行多个联邦学习实验。命令行接口只有一个参数,即实验的配置文件(YAML 格式)路径。配置文件的示例可以在 example-configs 文件夹中找到。例如,在 all-alg-fedprox-femnist.yml 文件中,我们写入了如下的配置:
点击展开
# Example config file for fl-sim command line interface
strategy:
matrix:
algorithm:
- Ditto
- FedDR
- FedAvg
- FedAdam
- FedProx
- FedPD
- FedSplit
- IFCA
- pFedMac
- pFedMe
- ProxSkip
- SCAFFOLD
clients_sample_ratio:
- 0.1
- 0.3
- 0.7
- 1.0
algorithm:
name: ${{ matrix.algorithm }}
server:
num_clients: null
clients_sample_ratio: ${{ matrix.clients_sample_ratio }}
num_iters: 100
p: 0.3 # for FedPD, ProxSkip
lr: 0.03 # for SCAFFOLD
num_clusters: 10 # for IFCA
log_dir: all-alg-fedprox-femnist
client:
lr: 0.03
num_epochs: 10
batch_size: null # null for default batch size
scheduler:
name: step # StepLR
step_size: 1
gamma: 0.99
dataset:
name: FedProxFEMNIST
datadir: null # default dir
transform: none # none for static transform (only normalization, no augmentation)
model:
name: cnn_femmist_tiny
seed: 0
strategy
字段指定了网格搜索的策略。
algorithm
字段指定了联邦学习算法的超参数:
其中 name
字段指定了算法的名称,server
字段指定了中心节点的超参数,client
字段指定了子节点的超参数。
dataset
字段指定了实验使用的数据集,model
字段指定了实验使用的模型。
利用本仿真框架实现的注册机制(registration functions),可以很方便地实现自定义的联邦学习算法,数据集,优化器等。例如,在文件 custom_confi.yml 中,我们写入了如下的配置:
algorithm.name: test-files/custom_alg.Custom
dataset.name: test-files/custom_dataset.CustomFEMNIST
其中 test-files/custom_alg.py
, test-files/custom_dataset.py
分别是自定义算法和自定义数据集的文件,Custom
是自定义算法的名称,CustomFEMNIST
是自定义数据集的名称。我们可以在本仓库的根目录下执行以下命令来执行仿真数值试验
fl-sim test-files/custom_conf.yml
若 algorithm.name
和 dataset.name
是绝对路径,则可以在任意位置执行该命令。
在文件 test-files/custom_alg.py 中,我们实现了一个自定义的联邦学习算法 Custom
,该算法的实现细节如下:
将算法的超参数配置写入 CustomServerConfig
和 CustomClientConfig
类中,这两个类分别继承了 ServerConfig
和 ClientConfig
类。
将算法的实现写入 CustomServer
和 CustomClient
类中,这两个类分别继承了 Server
和 Client
类。同时,利用装饰器 register_algorithm
,我们将 CustomServerConfig
,CustomClientConfig
,CustomServer
,CustomClient
注册到了本仿真框架中,例如:
@register_algorithm()
@add_docstring(server_config_kw_doc, "append")
class CustomServerConfig(ServerConfig):
...
之后在利用命令行接口执行仿真数值试验时,就可以通过 algorithm.name
指定 Custom
算法。
类似地,我们可以实现自定义的联邦数据集。在文件 test-files/custom_dataset.py 中,我们实现了一个自定义的联邦数据集 CustomFEMNIST
,其继承了 FEMNIST
类。同时,利用装饰器 register_dataset
,我们将 CustomFEMNIST
注册到了本仿真框架中。
自定义的优化器也可以通过类似的方式实现,即将其实现为 torch.optim.Optimizer
的子类,并利用装饰器 register_optimizer
将其注册到本仿真框架中。