This repo contains the sample code of our proposed framework Slack Federated Adversarial Training (SFAT)
in our paper: Combating Exacerbated Heterogeneity for Robust Models in Federated Learning (ICLR 2023).
Figure. Framework overview of SFAT.
TODO:
- Update the Project Page of SFAT.
- Update the Presentation Slides and Video.
- Released the arXiv version of SFAT.
- Released the early version of sample code.
Our SFAT assigns the client-wise slack during aggregation to combat the intensified heterogeneity, which is induced by the inner-maximization of adversarial training on the heterogeneous data in federated learning.
The emerging privacy and security issues in real-world applications motivate us to pursue the adversarially robust federated models. However, the straightforward combination between adversarial training and federated learning in one framework can induce the undesired robustness deterioration.
Figure 1. Robust Deterioration in federated adversarial training.
We dive into the issue of robustness deterioration and discover that it may attribute to the intensified heterogeneity induced by adversarial training in local clients. Considering federated learning, one of the primary difficulties is the biased optimization caused by the local training with heterogeneous data. As for adversarial training, the key distinction from standard training is the use of inner-maximization to generate adversarial data, which pursues the better adversarial robustness. When combining the two learning paradigms, we conjecture that the following issue may arise especially under the Non-IID case,
the inner-maximization for pursuing adversarial robustness would exacerbate the data heterogeneity among local clients in federated learning.
Figure 2. Illustration of
Python (3.8)
Pytorch (1.7.0 or above)
torchvision
CUDA
Numpy
./SFAT-main
├─ Centralized_AT.py # Training and evaluation
├─ SFAT.py
├─ attack_generator.py # Attack generation
├─ eval_pgd.py
├─ logger.py # Log support
├─ models.py
├─ options.py # Options and hyperparameters
├─ readme.md
├─ sampling.py # Data split
├─ update.py
└─ utils.py # Aggregation and other utils
To train federated robust model, we provide examples below to use our code:
CUDA_VISIBLE_DEVICES='0' python SFAT.py --dataset=cifar-10 --local_ep=10 --local_bs=32 --iid=0 --epochs=100 --num_users=5 --agg-opt='FedAvg' --agg-center='FedAvg' --out-dir='../output_results_FAT_FedAvg'
CUDA_VISIBLE_DEVICES='1' python SFAT.py --dataset=cifar-10 --local_ep=10 --local_bs=32 --iid=0 --epochs=100 --num_users=5 --agg-opt='FedAvg' --agg-center='SFAT' --pri=1.2 --out-dir='../output_results_SFAT_FedAvg'
Figure 3. Comparison of FAT and SFAT using approximated client drift.
Compared with FAT, our proposed SFAT selectively upweights/downweights the client with small/large adversarial training loss to alleviate it during aggregation, which follows our
Following the conventional federated learning realization, we realizes the overall framework of SFAT
in SFAT.py
which coordinate the local optimization part in update.py
and the aggregation functions in utils.py
.
In SFAT.py
, we get the local model in each client and aggregate the global model.
# local updates
for idx in idxs_users:
local_model = LocalUpdate(args=args, dataset=train_dataset, idxs=user_groups[idx], logger=logger, alg=args.agg_opt, anchor=global_model, anchor_mu=args.mu, local_rank=ipx, method=args.train_method)
''' '''
# aggregation method
if args.agg_center == 'FedAvg':
global_weights = average_weights(local_weights)
if args.agg_center == 'SFAT':
''' '''
global_weights = average_weights_alpha(local_weights, idt, idtxnum, args.pri)
In updates.py
, we realize the local training on each client for adversarial training and defined the LocalUpdate()
.
In utils.py
, we realize the aggregation methods and define the FAT, i.e., average_weights()
and SFAT average_weights_alpha()
as well as their unequal versions. For the our SFAT, the critical part of code is as follows, where the lw
and idx
is to help choose the corresponding clients and the p
is our
We realize the operation of data split in sampling.py
and utilized in utils.py
for generate local data loader for each client. We can use our pre-defined split function as following to get the local data.
def get_dataset(args):
''' '''
user_groups = cifar_noniid_skew(train_dataset, args.num_users)
''' '''
return train_dataset, test_dataset, user_groups
To choose different federated optimization methods (e.g., FedAvg, FedProx, Scaffold) and the aggregations (e.g., FAT and SFAT) for training robust federated model. We can used defined parameter in our options.py
:
parser.add_argument('--agg-opt',type=str,default='FedAvg',help='option of on-device learning: FedAvg, FedProx, Scaffold')
parser.add_argument('--agg-center',type=str,default='FedAvg',help='option of aggregation: FedAvg, SFAT')
To evaluate our trained model using various attack methods, we provide the eval_pgd.py
contains different evaluation metrics for natural and robust performance. You can run the following script with your model path to conduct evaluation:
CUDA_VISIBLE_DEVICES='0' python eval_pgd.py --net [NETWORK STRUCTURE] --dataset [DATASET] --model_path [MODLE PATH]
Sample results:
CIFAR-10 (Non-IID) | Method | Natural | FGSM | PGD-20 | CW | AutoAttack |
---|---|---|---|---|---|---|
FedAvg | FAT | 58.13 (0.68) | 40.06 (0.62) | 32.56 (0.01) | 30.88 (0.37) | 29.17 (0.03) |
FedAvg | SFAT | 63.36 (0.07) | 44.82 (0.32) | 37.14 (0.03) | 33.39 (0.61) | 31.66 (0.70) |
Actually, during the training, we also provide the accuracy track via logger.py
to save the model performance in each epoch.
Either the local optimization or aggregation method can be re-designed based on our framework in the corresponding updates.py
and utils.py
part.
- https://github.com/AshwinRJ/Federated-Learning-PyTorch
- https://github.com/med-air/FedBN
- https://github.com/ongzh/ScaffoldFL
- https://github.com/zjfheart/Geometry-aware-Instance-reweighted-Adversarial-Training
If you find our paper and repo useful, please cite our paper:
@inproceedings{zhu2023combating,
title ={Combating Exacerbated Heterogeneity for Robust Models in Federated Learning},
author ={Jianing Zhu and Jiangchao Yao and Tongliang Liu and quanming yao and Jianliang Xu and Bo Han},
booktitle ={The Eleventh International Conference on Learning Representations },
year ={2023},
url ={https://openreview.net/forum?id=eKllxpLOOm}
}