From 99534bf1bb8be91a7bf81f8c5422f1ffdfdf41fd Mon Sep 17 00:00:00 2001 From: wgc-research Date: Wed, 19 Jul 2023 01:08:13 +0800 Subject: [PATCH] update --- fgssl/__init__.py | 18 + fgssl/attack/__init__.py | 0 fgssl/attack/auxiliary/MIA_get_target_data.py | 32 + fgssl/attack/auxiliary/__init__.py | 16 + .../auxiliary/attack_trainer_builder.py | 23 + fgssl/attack/auxiliary/backdoor_utils.py | 366 ++++++ fgssl/attack/auxiliary/create_edgeset.py | 125 ++ fgssl/attack/auxiliary/poisoning_data.py | 299 +++++ fgssl/attack/auxiliary/utils.py | 354 ++++++ fgssl/attack/models/__init__.py | 0 fgssl/attack/models/gan_based_model.py | 74 ++ fgssl/attack/models/vision.py | 199 ++++ .../privacy_attacks/GAN_based_attack.py | 186 +++ fgssl/attack/privacy_attacks/__init__.py | 5 + fgssl/attack/privacy_attacks/passive_PIA.py | 178 +++ .../privacy_attacks/reconstruction_opt.py | 300 +++++ fgssl/attack/trainer/GAN_trainer.py | 104 ++ .../trainer/MIA_invert_gradient_trainer.py | 139 +++ fgssl/attack/trainer/PIA_trainer.py | 18 + fgssl/attack/trainer/__init__.py | 16 + fgssl/attack/trainer/backdoor_trainer.py | 180 +++ fgssl/attack/trainer/benign_trainer.py | 81 ++ fgssl/attack/worker_as_attacker/__init__.py | 12 + .../worker_as_attacker/active_client.py | 51 + .../worker_as_attacker/server_attacker.py | 367 ++++++ fgssl/autotune/__init__.py | 9 + fgssl/autotune/algos.py | 482 ++++++++ fgssl/autotune/choice_types.py | 162 +++ fgssl/autotune/fedex/__init__.py | 4 + fgssl/autotune/fedex/client.py | 94 ++ fgssl/autotune/fedex/server.py | 450 ++++++++ fgssl/autotune/hpbandster.py | 136 +++ fgssl/autotune/smac.py | 77 ++ fgssl/autotune/utils.py | 176 +++ fgssl/contrib/README.md | 15 + fgssl/contrib/__init__.py | 0 fgssl/contrib/configs/__init__.py | 14 + fgssl/contrib/configs/myconfig.py | 23 + fgssl/contrib/data/__init__.py | 8 + fgssl/contrib/data/example.py | 30 + fgssl/contrib/loss/__init__.py | 8 + fgssl/contrib/loss/example.py | 17 + fgssl/contrib/metrics/__init__.py | 8 + fgssl/contrib/metrics/example.py | 16 + fgssl/contrib/metrics/poison_acc.py | 31 + fgssl/contrib/model/GCL/__init__.py | 16 + .../contrib/model/GCL/augmentors/__init__.py | 32 + .../contrib/model/GCL/augmentors/augmentor.py | 58 + .../model/GCL/augmentors/edge_adding.py | 13 + .../model/GCL/augmentors/edge_attr_masking.py | 14 + .../model/GCL/augmentors/edge_removing.py | 13 + .../model/GCL/augmentors/feature_dropout.py | 13 + .../model/GCL/augmentors/feature_masking.py | 13 + .../model/GCL/augmentors/functional.py | 332 ++++++ .../contrib/model/GCL/augmentors/identity.py | 9 + .../model/GCL/augmentors/markov_diffusion.py | 27 + .../model/GCL/augmentors/node_dropping.py | 15 + .../model/GCL/augmentors/node_shuffling.py | 12 + .../model/GCL/augmentors/ppr_diffusion.py | 24 + .../model/GCL/augmentors/rw_sampling.py | 16 + fgssl/contrib/model/GCL/eval/__init__.py | 16 + fgssl/contrib/model/GCL/eval/eval.py | 77 ++ .../model/GCL/eval/logistic_regression.py | 80 ++ fgssl/contrib/model/GCL/eval/random_forest.py | 9 + fgssl/contrib/model/GCL/eval/svm.py | 13 + fgssl/contrib/model/GCL/losses/__init__.py | 24 + .../contrib/model/GCL/losses/barlow_twins.py | 34 + fgssl/contrib/model/GCL/losses/bootstrap.py | 16 + fgssl/contrib/model/GCL/losses/infonce.py | 189 +++ fgssl/contrib/model/GCL/losses/jsd.py | 77 ++ fgssl/contrib/model/GCL/losses/losses.py | 12 + fgssl/contrib/model/GCL/losses/triplet.py | 81 ++ fgssl/contrib/model/GCL/losses/vicreg.py | 43 + fgssl/contrib/model/GCL/models/__init__.py | 15 + .../model/GCL/models/contrast_model.py | 120 ++ fgssl/contrib/model/GCL/models/samplers.py | 81 ++ fgssl/contrib/model/GCL/utils.py | 74 ++ fgssl/contrib/model/__init__.py | 8 + fgssl/contrib/model/aug_base_model.py | 107 ++ fgssl/contrib/model/example.py | 23 + fgssl/contrib/model/model.py | 436 +++++++ fgssl/contrib/model/resnet.py | 305 +++++ fgssl/contrib/optimizer/__init__.py | 8 + fgssl/contrib/optimizer/example.py | 17 + fgssl/contrib/scheduler/__init__.py | 8 + fgssl/contrib/scheduler/example.py | 20 + fgssl/contrib/splitter/__init__.py | 8 + fgssl/contrib/splitter/example.py | 26 + fgssl/contrib/trainer/FLAG.py | 638 +++++++++++ fgssl/contrib/trainer/GCL/__init__.py | 16 + .../trainer/GCL/augmentors/__init__.py | 32 + .../trainer/GCL/augmentors/augmentor.py | 58 + .../trainer/GCL/augmentors/edge_adding.py | 13 + .../GCL/augmentors/edge_attr_masking.py | 14 + .../trainer/GCL/augmentors/edge_removing.py | 13 + .../trainer/GCL/augmentors/feature_dropout.py | 13 + .../trainer/GCL/augmentors/feature_masking.py | 13 + .../trainer/GCL/augmentors/functional.py | 332 ++++++ .../trainer/GCL/augmentors/identity.py | 9 + .../GCL/augmentors/markov_diffusion.py | 27 + .../trainer/GCL/augmentors/node_dropping.py | 15 + .../trainer/GCL/augmentors/node_shuffling.py | 12 + .../trainer/GCL/augmentors/ppr_diffusion.py | 24 + .../trainer/GCL/augmentors/rw_sampling.py | 16 + fgssl/contrib/trainer/GCL/eval/__init__.py | 16 + fgssl/contrib/trainer/GCL/eval/eval.py | 77 ++ .../trainer/GCL/eval/logistic_regression.py | 80 ++ .../contrib/trainer/GCL/eval/random_forest.py | 9 + fgssl/contrib/trainer/GCL/eval/svm.py | 13 + fgssl/contrib/trainer/GCL/losses/__init__.py | 24 + .../trainer/GCL/losses/barlow_twins.py | 34 + fgssl/contrib/trainer/GCL/losses/bootstrap.py | 16 + fgssl/contrib/trainer/GCL/losses/infonce.py | 189 +++ fgssl/contrib/trainer/GCL/losses/jsd.py | 77 ++ fgssl/contrib/trainer/GCL/losses/losses.py | 12 + fgssl/contrib/trainer/GCL/losses/triplet.py | 81 ++ fgssl/contrib/trainer/GCL/losses/vicreg.py | 43 + fgssl/contrib/trainer/GCL/models/__init__.py | 15 + .../trainer/GCL/models/contrast_model.py | 155 +++ fgssl/contrib/trainer/GCL/models/samplers.py | 81 ++ fgssl/contrib/trainer/GCL/utils.py | 74 ++ fgssl/contrib/trainer/__init__.py | 8 + fgssl/contrib/trainer/example.py | 16 + fgssl/contrib/trainer/torch_example.py | 104 ++ fgssl/contrib/trainer/trainer2.py | 113 ++ fgssl/contrib/worker/FLAG.py | 341 ++++++ fgssl/contrib/worker/__init__.py | 8 + fgssl/contrib/worker/vis.py | 267 +++++ fgssl/core/__init__.py | 3 + fgssl/core/aggregators/__init__.py | 19 + fgssl/core/aggregators/aggregator.py | 18 + .../asyn_clients_avg_aggregator.py | 79 ++ .../aggregators/clients_avg_aggregator.py | 126 ++ fgssl/core/aggregators/fedopt_aggregator.py | 31 + .../server_clients_interpolate_aggregator.py | 27 + fgssl/core/auxiliaries/ReIterator.py | 19 + fgssl/core/auxiliaries/__init__.py | 0 fgssl/core/auxiliaries/aggregator_builder.py | 54 + fgssl/core/auxiliaries/criterion_builder.py | 33 + fgssl/core/auxiliaries/data_builder.py | 70 ++ fgssl/core/auxiliaries/dataloader_builder.py | 74 ++ fgssl/core/auxiliaries/decorators.py | 20 + fgssl/core/auxiliaries/enums.py | 37 + fgssl/core/auxiliaries/logging.py | 255 +++++ fgssl/core/auxiliaries/metric_builder.py | 21 + fgssl/core/auxiliaries/model_builder.py | 164 +++ fgssl/core/auxiliaries/optimizer_builder.py | 48 + fgssl/core/auxiliaries/regularizer_builder.py | 30 + fgssl/core/auxiliaries/sampler_builder.py | 20 + fgssl/core/auxiliaries/scheduler_builder.py | 34 + fgssl/core/auxiliaries/splitter_builder.py | 49 + fgssl/core/auxiliaries/trainer_builder.py | 157 +++ fgssl/core/auxiliaries/transform_builder.py | 54 + fgssl/core/auxiliaries/utils.py | 305 +++++ fgssl/core/auxiliaries/worker_builder.py | 109 ++ fgssl/core/cmd_args.py | 47 + fgssl/core/communication.py | 147 +++ fgssl/core/configs/README.md | 397 +++++++ fgssl/core/configs/__init__.py | 29 + fgssl/core/configs/cfg_asyn.py | 87 ++ fgssl/core/configs/cfg_attack.py | 66 ++ fgssl/core/configs/cfg_data.py | 126 ++ .../core/configs/cfg_differential_privacy.py | 37 + fgssl/core/configs/cfg_evaluation.py | 43 + fgssl/core/configs/cfg_fl_algo.py | 118 ++ fgssl/core/configs/cfg_fl_setting.py | 183 +++ fgssl/core/configs/cfg_hpo.py | 85 ++ fgssl/core/configs/cfg_model.py | 50 + fgssl/core/configs/cfg_training.py | 104 ++ fgssl/core/configs/config.py | 293 +++++ fgssl/core/configs/constants.py | 46 + fgssl/core/configs/yacs_config.py | 605 ++++++++++ fgssl/core/data/README.md | 155 +++ fgssl/core/data/__init__.py | 8 + fgssl/core/data/base_data.py | 174 +++ fgssl/core/data/base_translator.py | 129 +++ fgssl/core/data/dummy_translator.py | 38 + fgssl/core/data/utils.py | 617 ++++++++++ fgssl/core/data/wrap_dataset.py | 31 + fgssl/core/fed_runner.py | 407 +++++++ fgssl/core/gRPC_server.py | 21 + fgssl/core/gpu_manager.py | 90 ++ fgssl/core/lr.py | 10 + fgssl/core/message.py | 255 +++++ fgssl/core/mlp.py | 40 + fgssl/core/monitors/__init__.py | 5 + fgssl/core/monitors/early_stopper.py | 103 ++ fgssl/core/monitors/metric_calculator.py | 235 ++++ fgssl/core/monitors/monitor.py | 655 +++++++++++ fgssl/core/optimizer.py | 59 + fgssl/core/optimizers/__init__.py | 0 fgssl/core/proto/__init__.py | 2 + fgssl/core/proto/gRPC_comm_manager.proto | 42 + fgssl/core/proto/gRPC_comm_manager_pb2.py | 760 +++++++++++++ .../core/proto/gRPC_comm_manager_pb2_grpc.py | 69 ++ fgssl/core/regularizer/__init__.py | 1 + .../core/regularizer/proximal_regularizer.py | 39 + fgssl/core/sampler.py | 131 +++ fgssl/core/secret_sharing/__init__.py | 2 + fgssl/core/secret_sharing/secret_sharing.py | 98 ++ fgssl/core/splitters/__init__.py | 3 + fgssl/core/splitters/base_splitter.py | 28 + fgssl/core/splitters/generic/__init__.py | 4 + fgssl/core/splitters/generic/iid_splitter.py | 17 + fgssl/core/splitters/generic/lda_splitter.py | 20 + fgssl/core/splitters/graph/__init__.py | 18 + fgssl/core/splitters/graph/analyzer.py | 182 +++ .../core/splitters/graph/louvain_splitter.py | 74 ++ .../splitters/graph/randchunk_splitter.py | 36 + fgssl/core/splitters/graph/random_splitter.py | 105 ++ .../core/splitters/graph/reltype_splitter.py | 65 ++ .../splitters/graph/scaffold_lda_splitter.py | 180 +++ .../core/splitters/graph/scaffold_splitter.py | 69 ++ fgssl/core/splitters/utils.py | 87 ++ fgssl/core/strategy.py | 23 + fgssl/core/trainers/__init__.py | 19 + fgssl/core/trainers/base_trainer.py | 29 + fgssl/core/trainers/context.py | 269 +++++ fgssl/core/trainers/tf_trainer.py | 152 +++ fgssl/core/trainers/torch_trainer.py | 316 +++++ fgssl/core/trainers/trainer.py | 389 +++++++ fgssl/core/trainers/trainer_Ditto.py | 219 ++++ fgssl/core/trainers/trainer_FedEM.py | 169 +++ fgssl/core/trainers/trainer_fedprox.py | 73 ++ fgssl/core/trainers/trainer_multi_model.py | 313 +++++ fgssl/core/trainers/trainer_nbafl.py | 141 +++ fgssl/core/trainers/trainer_pFedMe.py | 148 +++ fgssl/core/workers/__init__.py | 10 + fgssl/core/workers/base_client.py | 121 ++ fgssl/core/workers/base_server.py | 74 ++ fgssl/core/workers/base_worker.py | 55 + fgssl/core/workers/client.py | 517 +++++++++ fgssl/core/workers/server.py | 1012 +++++++++++++++++ fgssl/cross_backends/README.md | 30 + fgssl/cross_backends/__init__.py | 4 + .../distributed_tf_client_3.yaml | 24 + .../cross_backends/distributed_tf_server.yaml | 22 + fgssl/cross_backends/tf_aggregator.py | 44 + fgssl/cross_backends/tf_lr.py | 81 ++ fgssl/cv/__init__.py | 3 + .../baseline/fedavg_convnet2_on_celeba.yaml | 33 + .../baseline/fedavg_convnet2_on_femnist.yaml | 37 + .../baseline/fedbn_convnet2_on_femnist.yaml | 39 + fgssl/cv/dataloader/__init__.py | 3 + fgssl/cv/dataloader/dataloader.py | 41 + fgssl/cv/dataset/__init__.py | 8 + fgssl/cv/dataset/leaf.py | 128 +++ fgssl/cv/dataset/leaf_cv.py | 179 +++ .../dataset/preprocess/celeba_preprocess.py | 66 ++ fgssl/cv/model/__init__.py | 8 + fgssl/cv/model/cnn.py | 191 ++++ fgssl/cv/model/model_builder.py | 35 + fgssl/cv/trainer/__init__.py | 31 + fgssl/cv/trainer/trainer.py | 15 + fgssl/gfl/README.md | 291 +++++ fgssl/gfl/__init__.py | 0 fgssl/gfl/baseline/__init__.py | 0 fgssl/gfl/baseline/download.yaml | 250 ++++ fgssl/gfl/baseline/example.yaml | 83 ++ fgssl/gfl/baseline/example_aug.yaml | 83 ++ fgssl/gfl/baseline/example_gcn.yaml | 83 ++ fgssl/gfl/baseline/example_pubmed.yaml | 84 ++ fgssl/gfl/baseline/example_visual.yaml | 83 ++ fgssl/gfl/baseline/fed_gcn.yaml | 70 ++ .../fedavg_gcn_fullbatch_on_dblpnew.yaml | 31 + .../baseline/fedavg_gcn_fullbatch_on_kg.yaml | 34 + .../baseline/fedavg_gcn_minibatch_on_hiv.yaml | 33 + .../fedavg_gin_minibatch_on_cikmcup.yaml | 36 + ...g_gin_minibatch_on_cikmcup_per_client.yaml | 147 +++ .../fedavg_gnn_minibatch_on_multi_task.yaml | 37 + ...atch_on_multi_task_total_samples_aggr.yaml | 38 + .../fedavg_gnn_node_fullbatch_citation.yaml | 35 + fgssl/gfl/baseline/fedavg_on_cSBM.yaml | 36 + .../fedavg_sage_minibatch_on_dblpnew.yaml | 32 + fgssl/gfl/baseline/fedavg_wpsn_on_cSBM.yaml | 37 + .../fedbn_gnn_minibatch_on_multi_task.yaml | 37 + fgssl/gfl/baseline/fgcl_afg.yaml | 77 ++ .../isolated_gin_minibatch_on_cikmcup.yaml | 36 + ...d_gin_minibatch_on_cikmcup_per_client.yaml | 147 +++ .../local_gnn_node_fullbatch_citation.yaml | 32 + fgssl/gfl/baseline/model_change.yaml | 75 ++ .../graph_level/args_graph_fedalgo.sh | 37 + .../graph_level/args_multi_graph_fedalgo.sh | 23 + .../repro_exp/graph_level/run_graph_level.sh | 49 + .../graph_level/run_graph_level_multi_task.sh | 48 + .../run_graph_level_multi_task_bn.sh | 48 + .../run_graph_level_multi_task_bn_finetune.sh | 48 + .../graph_level/run_graph_level_opt.sh | 45 + .../graph_level/run_graph_level_prox.sh | 45 + .../repro_exp/graph_level/run_multi_opt.sh | 41 + .../repro_exp/graph_level/run_multi_prox.sh | 42 + fgssl/gfl/baseline/repro_exp/hpo/run_hpo.sh | 29 + .../repro_exp/hpo/run_node_level_hpo.sh | 112 ++ .../repro_exp/link_level/args_link_fedalgo.sh | 44 + .../repro_exp/link_level/run_link_level.sh | 44 + .../repro_exp/link_level/run_link_level_KG.sh | 44 + .../link_level/run_link_level_opt.sh | 45 + .../link_level/run_link_level_prox.sh | 45 + .../repro_exp/node_level/args_node_fedalgo.sh | 117 ++ .../repro_exp/node_level/run_dblp_fedavg.sh | 36 + .../repro_exp/node_level/run_node_level.sh | 47 + .../node_level/run_node_level_opt.sh | 48 + .../node_level/run_node_level_prox.sh | 48 + fgssl/gfl/dataloader/__init__.py | 11 + fgssl/gfl/dataloader/dataloader_graph.py | 98 ++ fgssl/gfl/dataloader/dataloader_link.py | 103 ++ fgssl/gfl/dataloader/dataloader_node.py | 165 +++ fgssl/gfl/dataloader/utils.py | 30 + fgssl/gfl/dataset/PlanetoidForFgcl.py | 32 + fgssl/gfl/dataset/__init__.py | 13 + fgssl/gfl/dataset/cSBM_dataset.py | 370 ++++++ fgssl/gfl/dataset/cikm_cup.py | 47 + fgssl/gfl/dataset/dblp_new.py | 186 +++ .../dataset/examples/analyzer_fed_graph.py | 32 + fgssl/gfl/dataset/kg.py | 132 +++ fgssl/gfl/dataset/preprocess/__init__.py | 3 + fgssl/gfl/dataset/preprocess/dblp_related.py | 295 +++++ fgssl/gfl/dataset/recsys.py | 194 ++++ fgssl/gfl/dataset/utils.py | 53 + fgssl/gfl/fedsageplus/__init__.py | 0 .../gfl/fedsageplus/fedsageplus_on_cora.yaml | 37 + fgssl/gfl/fedsageplus/trainer.py | 148 +++ fgssl/gfl/fedsageplus/utils.py | 135 +++ fgssl/gfl/fedsageplus/worker.py | 517 +++++++++ fgssl/gfl/flitplus/__init__.py | 0 fgssl/gfl/flitplus/fedalgo_cls.yaml | 35 + fgssl/gfl/flitplus/trainer.py | 282 +++++ fgssl/gfl/gcflplus/__init__.py | 0 .../gfl/gcflplus/gcflplus_on_multi_task.yaml | 37 + fgssl/gfl/gcflplus/utils.py | 34 + fgssl/gfl/gcflplus/worker.py | 214 ++++ fgssl/gfl/loss/__init__.py | 7 + fgssl/gfl/loss/greedy_loss.py | 70 ++ fgssl/gfl/loss/suploss.py | 98 ++ fgssl/gfl/loss/vat.py | 90 ++ fgssl/gfl/model/__init__.py | 19 + fgssl/gfl/model/fedsageplus.py | 177 +++ fgssl/gfl/model/gat.py | 53 + fgssl/gfl/model/gin.py | 92 ++ fgssl/gfl/model/gpr.py | 134 +++ fgssl/gfl/model/graph_level.py | 125 ++ fgssl/gfl/model/link_level.py | 88 ++ fgssl/gfl/model/model_builder.py | 81 ++ fgssl/gfl/model/mpnn.py | 59 + fgssl/gfl/model/sage.py | 129 +++ fgssl/gfl/trainer/__init__.py | 14 + fgssl/gfl/trainer/graphtrainer.py | 81 ++ fgssl/gfl/trainer/linktrainer.py | 218 ++++ fgssl/gfl/trainer/nodetrainer.py | 185 +++ fgssl/hpo.py | 58 + fgssl/main.py | 52 + fgssl/mf/__init__.py | 3 + fgssl/mf/baseline/__init__.py | 0 ...gdmf_fedavg_standalone_on_movielens1m.yaml | 33 + .../hfl_fedavg_standalone_on_movielens1m.yaml | 27 + .../hfl_fedavg_standalone_on_netflix.yaml | 30 + ...gdmf_fedavg_standalone_on_movielens1m.yaml | 34 + .../vfl_fedavg_standalone_on_movielens1m.yaml | 29 + fgssl/mf/dataloader/__init__.py | 4 + fgssl/mf/dataloader/dataloader.py | 163 +++ fgssl/mf/dataset/__init__.py | 6 + fgssl/mf/dataset/movielens.py | 243 ++++ fgssl/mf/dataset/netflix.py | 85 ++ fgssl/mf/model/__init__.py | 4 + fgssl/mf/model/model.py | 73 ++ fgssl/mf/model/model_builder.py | 17 + fgssl/mf/trainer/__init__.py | 8 + fgssl/mf/trainer/trainer.py | 122 ++ fgssl/mf/trainer/trainer_sgdmf.py | 99 ++ fgssl/nlp/__init__.py | 3 + fgssl/nlp/baseline/fedavg_bert_on_sst2.yaml | 35 + .../nlp/baseline/fedavg_lr_on_synthetic.yaml | 28 + fgssl/nlp/baseline/fedavg_lr_on_twitter.yaml | 34 + .../baseline/fedavg_lstm_on_shakespeare.yaml | 33 + .../baseline/fedavg_lstm_on_subreddit.yaml | 33 + .../baseline/fedavg_transformer_on_cola.yaml | 41 + .../baseline/fedavg_transformer_on_imdb.yaml | 36 + fgssl/nlp/dataloader/__init__.py | 3 + fgssl/nlp/dataloader/dataloader.py | 53 + fgssl/nlp/dataset/__init__.py | 8 + fgssl/nlp/dataset/leaf_nlp.py | 269 +++++ fgssl/nlp/dataset/leaf_synthetic.py | 201 ++++ fgssl/nlp/dataset/leaf_twitter.py | 225 ++++ fgssl/nlp/dataset/preprocess/get_embs.py | 23 + fgssl/nlp/dataset/preprocess/get_embs.sh | 11 + fgssl/nlp/dataset/utils.py | 90 ++ fgssl/nlp/loss/__init__.py | 1 + fgssl/nlp/loss/character_loss.py | 57 + fgssl/nlp/model/__init__.py | 4 + fgssl/nlp/model/model_builder.py | 45 + fgssl/nlp/model/rnn.py | 40 + fgssl/nlp/trainer/__init__.py | 8 + fgssl/nlp/trainer/trainer.py | 28 + fgssl/organizer/README.md | 36 + fgssl/organizer/__init__.py | 0 fgssl/organizer/cfg_client.py | 19 + fgssl/organizer/cfg_server.py | 8 + fgssl/organizer/client.py | 256 +++++ fgssl/organizer/server.py | 169 +++ fgssl/organizer/utils.py | 160 +++ fgssl/register.py | 105 ++ fgssl/tabular/__init__.py | 0 fgssl/tabular/dataloader/__init__.py | 3 + fgssl/tabular/dataloader/quadratic.py | 19 + fgssl/tabular/dataloader/toy.py | 120 ++ fgssl/tabular/model/__init__.py | 3 + fgssl/tabular/model/quadratic.py | 11 + fgssl/toy_hpo_ss.yaml | 8 + fgssl/vertical_fl/Paillier/__init__.py | 0 .../vertical_fl/Paillier/abstract_paillier.py | 47 + fgssl/vertical_fl/README.md | 13 + fgssl/vertical_fl/__init__.py | 1 + fgssl/vertical_fl/dataloader/__init__.py | 3 + fgssl/vertical_fl/dataloader/dataloader.py | 56 + fgssl/vertical_fl/dataloader/utils.py | 30 + fgssl/vertical_fl/vertical_fl.yaml | 22 + fgssl/vertical_fl/worker/__init__.py | 4 + fgssl/vertical_fl/worker/vertical_client.py | 111 ++ fgssl/vertical_fl/worker/vertical_server.py | 127 +++ 419 files changed, 36654 insertions(+) create mode 100644 fgssl/__init__.py create mode 100644 fgssl/attack/__init__.py create mode 100644 fgssl/attack/auxiliary/MIA_get_target_data.py create mode 100644 fgssl/attack/auxiliary/__init__.py create mode 100644 fgssl/attack/auxiliary/attack_trainer_builder.py create mode 100644 fgssl/attack/auxiliary/backdoor_utils.py create mode 100644 fgssl/attack/auxiliary/create_edgeset.py create mode 100644 fgssl/attack/auxiliary/poisoning_data.py create mode 100644 fgssl/attack/auxiliary/utils.py create mode 100644 fgssl/attack/models/__init__.py create mode 100644 fgssl/attack/models/gan_based_model.py create mode 100644 fgssl/attack/models/vision.py create mode 100644 fgssl/attack/privacy_attacks/GAN_based_attack.py create mode 100644 fgssl/attack/privacy_attacks/__init__.py create mode 100644 fgssl/attack/privacy_attacks/passive_PIA.py create mode 100644 fgssl/attack/privacy_attacks/reconstruction_opt.py create mode 100644 fgssl/attack/trainer/GAN_trainer.py create mode 100644 fgssl/attack/trainer/MIA_invert_gradient_trainer.py create mode 100644 fgssl/attack/trainer/PIA_trainer.py create mode 100644 fgssl/attack/trainer/__init__.py create mode 100644 fgssl/attack/trainer/backdoor_trainer.py create mode 100644 fgssl/attack/trainer/benign_trainer.py create mode 100644 fgssl/attack/worker_as_attacker/__init__.py create mode 100644 fgssl/attack/worker_as_attacker/active_client.py create mode 100644 fgssl/attack/worker_as_attacker/server_attacker.py create mode 100644 fgssl/autotune/__init__.py create mode 100644 fgssl/autotune/algos.py create mode 100644 fgssl/autotune/choice_types.py create mode 100644 fgssl/autotune/fedex/__init__.py create mode 100644 fgssl/autotune/fedex/client.py create mode 100644 fgssl/autotune/fedex/server.py create mode 100644 fgssl/autotune/hpbandster.py create mode 100644 fgssl/autotune/smac.py create mode 100644 fgssl/autotune/utils.py create mode 100644 fgssl/contrib/README.md create mode 100644 fgssl/contrib/__init__.py create mode 100644 fgssl/contrib/configs/__init__.py create mode 100644 fgssl/contrib/configs/myconfig.py create mode 100644 fgssl/contrib/data/__init__.py create mode 100644 fgssl/contrib/data/example.py create mode 100644 fgssl/contrib/loss/__init__.py create mode 100644 fgssl/contrib/loss/example.py create mode 100644 fgssl/contrib/metrics/__init__.py create mode 100644 fgssl/contrib/metrics/example.py create mode 100644 fgssl/contrib/metrics/poison_acc.py create mode 100644 fgssl/contrib/model/GCL/__init__.py create mode 100644 fgssl/contrib/model/GCL/augmentors/__init__.py create mode 100644 fgssl/contrib/model/GCL/augmentors/augmentor.py create mode 100644 fgssl/contrib/model/GCL/augmentors/edge_adding.py create mode 100644 fgssl/contrib/model/GCL/augmentors/edge_attr_masking.py create mode 100644 fgssl/contrib/model/GCL/augmentors/edge_removing.py create mode 100644 fgssl/contrib/model/GCL/augmentors/feature_dropout.py create mode 100644 fgssl/contrib/model/GCL/augmentors/feature_masking.py create mode 100644 fgssl/contrib/model/GCL/augmentors/functional.py create mode 100644 fgssl/contrib/model/GCL/augmentors/identity.py create mode 100644 fgssl/contrib/model/GCL/augmentors/markov_diffusion.py create mode 100644 fgssl/contrib/model/GCL/augmentors/node_dropping.py create mode 100644 fgssl/contrib/model/GCL/augmentors/node_shuffling.py create mode 100644 fgssl/contrib/model/GCL/augmentors/ppr_diffusion.py create mode 100644 fgssl/contrib/model/GCL/augmentors/rw_sampling.py create mode 100644 fgssl/contrib/model/GCL/eval/__init__.py create mode 100644 fgssl/contrib/model/GCL/eval/eval.py create mode 100644 fgssl/contrib/model/GCL/eval/logistic_regression.py create mode 100644 fgssl/contrib/model/GCL/eval/random_forest.py create mode 100644 fgssl/contrib/model/GCL/eval/svm.py create mode 100644 fgssl/contrib/model/GCL/losses/__init__.py create mode 100644 fgssl/contrib/model/GCL/losses/barlow_twins.py create mode 100644 fgssl/contrib/model/GCL/losses/bootstrap.py create mode 100644 fgssl/contrib/model/GCL/losses/infonce.py create mode 100644 fgssl/contrib/model/GCL/losses/jsd.py create mode 100644 fgssl/contrib/model/GCL/losses/losses.py create mode 100644 fgssl/contrib/model/GCL/losses/triplet.py create mode 100644 fgssl/contrib/model/GCL/losses/vicreg.py create mode 100644 fgssl/contrib/model/GCL/models/__init__.py create mode 100644 fgssl/contrib/model/GCL/models/contrast_model.py create mode 100644 fgssl/contrib/model/GCL/models/samplers.py create mode 100644 fgssl/contrib/model/GCL/utils.py create mode 100644 fgssl/contrib/model/__init__.py create mode 100644 fgssl/contrib/model/aug_base_model.py create mode 100644 fgssl/contrib/model/example.py create mode 100644 fgssl/contrib/model/model.py create mode 100644 fgssl/contrib/model/resnet.py create mode 100644 fgssl/contrib/optimizer/__init__.py create mode 100644 fgssl/contrib/optimizer/example.py create mode 100644 fgssl/contrib/scheduler/__init__.py create mode 100644 fgssl/contrib/scheduler/example.py create mode 100644 fgssl/contrib/splitter/__init__.py create mode 100644 fgssl/contrib/splitter/example.py create mode 100644 fgssl/contrib/trainer/FLAG.py create mode 100644 fgssl/contrib/trainer/GCL/__init__.py create mode 100644 fgssl/contrib/trainer/GCL/augmentors/__init__.py create mode 100644 fgssl/contrib/trainer/GCL/augmentors/augmentor.py create mode 100644 fgssl/contrib/trainer/GCL/augmentors/edge_adding.py create mode 100644 fgssl/contrib/trainer/GCL/augmentors/edge_attr_masking.py create mode 100644 fgssl/contrib/trainer/GCL/augmentors/edge_removing.py create mode 100644 fgssl/contrib/trainer/GCL/augmentors/feature_dropout.py create mode 100644 fgssl/contrib/trainer/GCL/augmentors/feature_masking.py create mode 100644 fgssl/contrib/trainer/GCL/augmentors/functional.py create mode 100644 fgssl/contrib/trainer/GCL/augmentors/identity.py create mode 100644 fgssl/contrib/trainer/GCL/augmentors/markov_diffusion.py create mode 100644 fgssl/contrib/trainer/GCL/augmentors/node_dropping.py create mode 100644 fgssl/contrib/trainer/GCL/augmentors/node_shuffling.py create mode 100644 fgssl/contrib/trainer/GCL/augmentors/ppr_diffusion.py create mode 100644 fgssl/contrib/trainer/GCL/augmentors/rw_sampling.py create mode 100644 fgssl/contrib/trainer/GCL/eval/__init__.py create mode 100644 fgssl/contrib/trainer/GCL/eval/eval.py create mode 100644 fgssl/contrib/trainer/GCL/eval/logistic_regression.py create mode 100644 fgssl/contrib/trainer/GCL/eval/random_forest.py create mode 100644 fgssl/contrib/trainer/GCL/eval/svm.py create mode 100644 fgssl/contrib/trainer/GCL/losses/__init__.py create mode 100644 fgssl/contrib/trainer/GCL/losses/barlow_twins.py create mode 100644 fgssl/contrib/trainer/GCL/losses/bootstrap.py create mode 100644 fgssl/contrib/trainer/GCL/losses/infonce.py create mode 100644 fgssl/contrib/trainer/GCL/losses/jsd.py create mode 100644 fgssl/contrib/trainer/GCL/losses/losses.py create mode 100644 fgssl/contrib/trainer/GCL/losses/triplet.py create mode 100644 fgssl/contrib/trainer/GCL/losses/vicreg.py create mode 100644 fgssl/contrib/trainer/GCL/models/__init__.py create mode 100644 fgssl/contrib/trainer/GCL/models/contrast_model.py create mode 100644 fgssl/contrib/trainer/GCL/models/samplers.py create mode 100644 fgssl/contrib/trainer/GCL/utils.py create mode 100644 fgssl/contrib/trainer/__init__.py create mode 100644 fgssl/contrib/trainer/example.py create mode 100644 fgssl/contrib/trainer/torch_example.py create mode 100644 fgssl/contrib/trainer/trainer2.py create mode 100644 fgssl/contrib/worker/FLAG.py create mode 100644 fgssl/contrib/worker/__init__.py create mode 100644 fgssl/contrib/worker/vis.py create mode 100644 fgssl/core/__init__.py create mode 100644 fgssl/core/aggregators/__init__.py create mode 100644 fgssl/core/aggregators/aggregator.py create mode 100644 fgssl/core/aggregators/asyn_clients_avg_aggregator.py create mode 100644 fgssl/core/aggregators/clients_avg_aggregator.py create mode 100644 fgssl/core/aggregators/fedopt_aggregator.py create mode 100644 fgssl/core/aggregators/server_clients_interpolate_aggregator.py create mode 100644 fgssl/core/auxiliaries/ReIterator.py create mode 100644 fgssl/core/auxiliaries/__init__.py create mode 100644 fgssl/core/auxiliaries/aggregator_builder.py create mode 100644 fgssl/core/auxiliaries/criterion_builder.py create mode 100644 fgssl/core/auxiliaries/data_builder.py create mode 100644 fgssl/core/auxiliaries/dataloader_builder.py create mode 100644 fgssl/core/auxiliaries/decorators.py create mode 100644 fgssl/core/auxiliaries/enums.py create mode 100644 fgssl/core/auxiliaries/logging.py create mode 100644 fgssl/core/auxiliaries/metric_builder.py create mode 100644 fgssl/core/auxiliaries/model_builder.py create mode 100644 fgssl/core/auxiliaries/optimizer_builder.py create mode 100644 fgssl/core/auxiliaries/regularizer_builder.py create mode 100644 fgssl/core/auxiliaries/sampler_builder.py create mode 100644 fgssl/core/auxiliaries/scheduler_builder.py create mode 100644 fgssl/core/auxiliaries/splitter_builder.py create mode 100644 fgssl/core/auxiliaries/trainer_builder.py create mode 100644 fgssl/core/auxiliaries/transform_builder.py create mode 100644 fgssl/core/auxiliaries/utils.py create mode 100644 fgssl/core/auxiliaries/worker_builder.py create mode 100644 fgssl/core/cmd_args.py create mode 100644 fgssl/core/communication.py create mode 100644 fgssl/core/configs/README.md create mode 100644 fgssl/core/configs/__init__.py create mode 100644 fgssl/core/configs/cfg_asyn.py create mode 100644 fgssl/core/configs/cfg_attack.py create mode 100644 fgssl/core/configs/cfg_data.py create mode 100644 fgssl/core/configs/cfg_differential_privacy.py create mode 100644 fgssl/core/configs/cfg_evaluation.py create mode 100644 fgssl/core/configs/cfg_fl_algo.py create mode 100644 fgssl/core/configs/cfg_fl_setting.py create mode 100644 fgssl/core/configs/cfg_hpo.py create mode 100644 fgssl/core/configs/cfg_model.py create mode 100644 fgssl/core/configs/cfg_training.py create mode 100644 fgssl/core/configs/config.py create mode 100644 fgssl/core/configs/constants.py create mode 100644 fgssl/core/configs/yacs_config.py create mode 100644 fgssl/core/data/README.md create mode 100644 fgssl/core/data/__init__.py create mode 100644 fgssl/core/data/base_data.py create mode 100644 fgssl/core/data/base_translator.py create mode 100644 fgssl/core/data/dummy_translator.py create mode 100644 fgssl/core/data/utils.py create mode 100644 fgssl/core/data/wrap_dataset.py create mode 100644 fgssl/core/fed_runner.py create mode 100644 fgssl/core/gRPC_server.py create mode 100644 fgssl/core/gpu_manager.py create mode 100644 fgssl/core/lr.py create mode 100644 fgssl/core/message.py create mode 100644 fgssl/core/mlp.py create mode 100644 fgssl/core/monitors/__init__.py create mode 100644 fgssl/core/monitors/early_stopper.py create mode 100644 fgssl/core/monitors/metric_calculator.py create mode 100644 fgssl/core/monitors/monitor.py create mode 100644 fgssl/core/optimizer.py create mode 100644 fgssl/core/optimizers/__init__.py create mode 100644 fgssl/core/proto/__init__.py create mode 100644 fgssl/core/proto/gRPC_comm_manager.proto create mode 100644 fgssl/core/proto/gRPC_comm_manager_pb2.py create mode 100644 fgssl/core/proto/gRPC_comm_manager_pb2_grpc.py create mode 100644 fgssl/core/regularizer/__init__.py create mode 100644 fgssl/core/regularizer/proximal_regularizer.py create mode 100644 fgssl/core/sampler.py create mode 100644 fgssl/core/secret_sharing/__init__.py create mode 100644 fgssl/core/secret_sharing/secret_sharing.py create mode 100644 fgssl/core/splitters/__init__.py create mode 100644 fgssl/core/splitters/base_splitter.py create mode 100644 fgssl/core/splitters/generic/__init__.py create mode 100644 fgssl/core/splitters/generic/iid_splitter.py create mode 100644 fgssl/core/splitters/generic/lda_splitter.py create mode 100644 fgssl/core/splitters/graph/__init__.py create mode 100644 fgssl/core/splitters/graph/analyzer.py create mode 100644 fgssl/core/splitters/graph/louvain_splitter.py create mode 100644 fgssl/core/splitters/graph/randchunk_splitter.py create mode 100644 fgssl/core/splitters/graph/random_splitter.py create mode 100644 fgssl/core/splitters/graph/reltype_splitter.py create mode 100644 fgssl/core/splitters/graph/scaffold_lda_splitter.py create mode 100644 fgssl/core/splitters/graph/scaffold_splitter.py create mode 100644 fgssl/core/splitters/utils.py create mode 100644 fgssl/core/strategy.py create mode 100644 fgssl/core/trainers/__init__.py create mode 100644 fgssl/core/trainers/base_trainer.py create mode 100644 fgssl/core/trainers/context.py create mode 100644 fgssl/core/trainers/tf_trainer.py create mode 100644 fgssl/core/trainers/torch_trainer.py create mode 100644 fgssl/core/trainers/trainer.py create mode 100644 fgssl/core/trainers/trainer_Ditto.py create mode 100644 fgssl/core/trainers/trainer_FedEM.py create mode 100644 fgssl/core/trainers/trainer_fedprox.py create mode 100644 fgssl/core/trainers/trainer_multi_model.py create mode 100644 fgssl/core/trainers/trainer_nbafl.py create mode 100644 fgssl/core/trainers/trainer_pFedMe.py create mode 100644 fgssl/core/workers/__init__.py create mode 100644 fgssl/core/workers/base_client.py create mode 100644 fgssl/core/workers/base_server.py create mode 100644 fgssl/core/workers/base_worker.py create mode 100644 fgssl/core/workers/client.py create mode 100644 fgssl/core/workers/server.py create mode 100644 fgssl/cross_backends/README.md create mode 100644 fgssl/cross_backends/__init__.py create mode 100644 fgssl/cross_backends/distributed_tf_client_3.yaml create mode 100644 fgssl/cross_backends/distributed_tf_server.yaml create mode 100644 fgssl/cross_backends/tf_aggregator.py create mode 100644 fgssl/cross_backends/tf_lr.py create mode 100644 fgssl/cv/__init__.py create mode 100644 fgssl/cv/baseline/fedavg_convnet2_on_celeba.yaml create mode 100644 fgssl/cv/baseline/fedavg_convnet2_on_femnist.yaml create mode 100644 fgssl/cv/baseline/fedbn_convnet2_on_femnist.yaml create mode 100644 fgssl/cv/dataloader/__init__.py create mode 100644 fgssl/cv/dataloader/dataloader.py create mode 100644 fgssl/cv/dataset/__init__.py create mode 100644 fgssl/cv/dataset/leaf.py create mode 100644 fgssl/cv/dataset/leaf_cv.py create mode 100644 fgssl/cv/dataset/preprocess/celeba_preprocess.py create mode 100644 fgssl/cv/model/__init__.py create mode 100644 fgssl/cv/model/cnn.py create mode 100644 fgssl/cv/model/model_builder.py create mode 100644 fgssl/cv/trainer/__init__.py create mode 100644 fgssl/cv/trainer/trainer.py create mode 100644 fgssl/gfl/README.md create mode 100644 fgssl/gfl/__init__.py create mode 100644 fgssl/gfl/baseline/__init__.py create mode 100644 fgssl/gfl/baseline/download.yaml create mode 100644 fgssl/gfl/baseline/example.yaml create mode 100644 fgssl/gfl/baseline/example_aug.yaml create mode 100644 fgssl/gfl/baseline/example_gcn.yaml create mode 100644 fgssl/gfl/baseline/example_pubmed.yaml create mode 100644 fgssl/gfl/baseline/example_visual.yaml create mode 100644 fgssl/gfl/baseline/fed_gcn.yaml create mode 100644 fgssl/gfl/baseline/fedavg_gcn_fullbatch_on_dblpnew.yaml create mode 100644 fgssl/gfl/baseline/fedavg_gcn_fullbatch_on_kg.yaml create mode 100644 fgssl/gfl/baseline/fedavg_gcn_minibatch_on_hiv.yaml create mode 100644 fgssl/gfl/baseline/fedavg_gin_minibatch_on_cikmcup.yaml create mode 100644 fgssl/gfl/baseline/fedavg_gin_minibatch_on_cikmcup_per_client.yaml create mode 100644 fgssl/gfl/baseline/fedavg_gnn_minibatch_on_multi_task.yaml create mode 100644 fgssl/gfl/baseline/fedavg_gnn_minibatch_on_multi_task_total_samples_aggr.yaml create mode 100644 fgssl/gfl/baseline/fedavg_gnn_node_fullbatch_citation.yaml create mode 100644 fgssl/gfl/baseline/fedavg_on_cSBM.yaml create mode 100644 fgssl/gfl/baseline/fedavg_sage_minibatch_on_dblpnew.yaml create mode 100644 fgssl/gfl/baseline/fedavg_wpsn_on_cSBM.yaml create mode 100644 fgssl/gfl/baseline/fedbn_gnn_minibatch_on_multi_task.yaml create mode 100644 fgssl/gfl/baseline/fgcl_afg.yaml create mode 100644 fgssl/gfl/baseline/isolated_gin_minibatch_on_cikmcup.yaml create mode 100644 fgssl/gfl/baseline/isolated_gin_minibatch_on_cikmcup_per_client.yaml create mode 100644 fgssl/gfl/baseline/local_gnn_node_fullbatch_citation.yaml create mode 100644 fgssl/gfl/baseline/model_change.yaml create mode 100644 fgssl/gfl/baseline/repro_exp/graph_level/args_graph_fedalgo.sh create mode 100644 fgssl/gfl/baseline/repro_exp/graph_level/args_multi_graph_fedalgo.sh create mode 100644 fgssl/gfl/baseline/repro_exp/graph_level/run_graph_level.sh create mode 100644 fgssl/gfl/baseline/repro_exp/graph_level/run_graph_level_multi_task.sh create mode 100644 fgssl/gfl/baseline/repro_exp/graph_level/run_graph_level_multi_task_bn.sh create mode 100644 fgssl/gfl/baseline/repro_exp/graph_level/run_graph_level_multi_task_bn_finetune.sh create mode 100644 fgssl/gfl/baseline/repro_exp/graph_level/run_graph_level_opt.sh create mode 100644 fgssl/gfl/baseline/repro_exp/graph_level/run_graph_level_prox.sh create mode 100644 fgssl/gfl/baseline/repro_exp/graph_level/run_multi_opt.sh create mode 100644 fgssl/gfl/baseline/repro_exp/graph_level/run_multi_prox.sh create mode 100644 fgssl/gfl/baseline/repro_exp/hpo/run_hpo.sh create mode 100644 fgssl/gfl/baseline/repro_exp/hpo/run_node_level_hpo.sh create mode 100644 fgssl/gfl/baseline/repro_exp/link_level/args_link_fedalgo.sh create mode 100644 fgssl/gfl/baseline/repro_exp/link_level/run_link_level.sh create mode 100644 fgssl/gfl/baseline/repro_exp/link_level/run_link_level_KG.sh create mode 100644 fgssl/gfl/baseline/repro_exp/link_level/run_link_level_opt.sh create mode 100644 fgssl/gfl/baseline/repro_exp/link_level/run_link_level_prox.sh create mode 100644 fgssl/gfl/baseline/repro_exp/node_level/args_node_fedalgo.sh create mode 100644 fgssl/gfl/baseline/repro_exp/node_level/run_dblp_fedavg.sh create mode 100644 fgssl/gfl/baseline/repro_exp/node_level/run_node_level.sh create mode 100644 fgssl/gfl/baseline/repro_exp/node_level/run_node_level_opt.sh create mode 100644 fgssl/gfl/baseline/repro_exp/node_level/run_node_level_prox.sh create mode 100644 fgssl/gfl/dataloader/__init__.py create mode 100644 fgssl/gfl/dataloader/dataloader_graph.py create mode 100644 fgssl/gfl/dataloader/dataloader_link.py create mode 100644 fgssl/gfl/dataloader/dataloader_node.py create mode 100644 fgssl/gfl/dataloader/utils.py create mode 100644 fgssl/gfl/dataset/PlanetoidForFgcl.py create mode 100644 fgssl/gfl/dataset/__init__.py create mode 100644 fgssl/gfl/dataset/cSBM_dataset.py create mode 100644 fgssl/gfl/dataset/cikm_cup.py create mode 100644 fgssl/gfl/dataset/dblp_new.py create mode 100644 fgssl/gfl/dataset/examples/analyzer_fed_graph.py create mode 100644 fgssl/gfl/dataset/kg.py create mode 100644 fgssl/gfl/dataset/preprocess/__init__.py create mode 100644 fgssl/gfl/dataset/preprocess/dblp_related.py create mode 100644 fgssl/gfl/dataset/recsys.py create mode 100644 fgssl/gfl/dataset/utils.py create mode 100755 fgssl/gfl/fedsageplus/__init__.py create mode 100755 fgssl/gfl/fedsageplus/fedsageplus_on_cora.yaml create mode 100755 fgssl/gfl/fedsageplus/trainer.py create mode 100755 fgssl/gfl/fedsageplus/utils.py create mode 100755 fgssl/gfl/fedsageplus/worker.py create mode 100644 fgssl/gfl/flitplus/__init__.py create mode 100644 fgssl/gfl/flitplus/fedalgo_cls.yaml create mode 100644 fgssl/gfl/flitplus/trainer.py create mode 100644 fgssl/gfl/gcflplus/__init__.py create mode 100644 fgssl/gfl/gcflplus/gcflplus_on_multi_task.yaml create mode 100644 fgssl/gfl/gcflplus/utils.py create mode 100644 fgssl/gfl/gcflplus/worker.py create mode 100644 fgssl/gfl/loss/__init__.py create mode 100644 fgssl/gfl/loss/greedy_loss.py create mode 100644 fgssl/gfl/loss/suploss.py create mode 100644 fgssl/gfl/loss/vat.py create mode 100644 fgssl/gfl/model/__init__.py create mode 100644 fgssl/gfl/model/fedsageplus.py create mode 100644 fgssl/gfl/model/gat.py create mode 100644 fgssl/gfl/model/gin.py create mode 100644 fgssl/gfl/model/gpr.py create mode 100644 fgssl/gfl/model/graph_level.py create mode 100644 fgssl/gfl/model/link_level.py create mode 100644 fgssl/gfl/model/model_builder.py create mode 100644 fgssl/gfl/model/mpnn.py create mode 100644 fgssl/gfl/model/sage.py create mode 100644 fgssl/gfl/trainer/__init__.py create mode 100644 fgssl/gfl/trainer/graphtrainer.py create mode 100644 fgssl/gfl/trainer/linktrainer.py create mode 100644 fgssl/gfl/trainer/nodetrainer.py create mode 100644 fgssl/hpo.py create mode 100644 fgssl/main.py create mode 100644 fgssl/mf/__init__.py create mode 100644 fgssl/mf/baseline/__init__.py create mode 100644 fgssl/mf/baseline/hfl-sgdmf_fedavg_standalone_on_movielens1m.yaml create mode 100644 fgssl/mf/baseline/hfl_fedavg_standalone_on_movielens1m.yaml create mode 100644 fgssl/mf/baseline/hfl_fedavg_standalone_on_netflix.yaml create mode 100644 fgssl/mf/baseline/vfl-sgdmf_fedavg_standalone_on_movielens1m.yaml create mode 100644 fgssl/mf/baseline/vfl_fedavg_standalone_on_movielens1m.yaml create mode 100644 fgssl/mf/dataloader/__init__.py create mode 100644 fgssl/mf/dataloader/dataloader.py create mode 100644 fgssl/mf/dataset/__init__.py create mode 100644 fgssl/mf/dataset/movielens.py create mode 100644 fgssl/mf/dataset/netflix.py create mode 100644 fgssl/mf/model/__init__.py create mode 100644 fgssl/mf/model/model.py create mode 100644 fgssl/mf/model/model_builder.py create mode 100644 fgssl/mf/trainer/__init__.py create mode 100644 fgssl/mf/trainer/trainer.py create mode 100644 fgssl/mf/trainer/trainer_sgdmf.py create mode 100644 fgssl/nlp/__init__.py create mode 100644 fgssl/nlp/baseline/fedavg_bert_on_sst2.yaml create mode 100644 fgssl/nlp/baseline/fedavg_lr_on_synthetic.yaml create mode 100644 fgssl/nlp/baseline/fedavg_lr_on_twitter.yaml create mode 100644 fgssl/nlp/baseline/fedavg_lstm_on_shakespeare.yaml create mode 100644 fgssl/nlp/baseline/fedavg_lstm_on_subreddit.yaml create mode 100644 fgssl/nlp/baseline/fedavg_transformer_on_cola.yaml create mode 100644 fgssl/nlp/baseline/fedavg_transformer_on_imdb.yaml create mode 100644 fgssl/nlp/dataloader/__init__.py create mode 100644 fgssl/nlp/dataloader/dataloader.py create mode 100644 fgssl/nlp/dataset/__init__.py create mode 100644 fgssl/nlp/dataset/leaf_nlp.py create mode 100644 fgssl/nlp/dataset/leaf_synthetic.py create mode 100644 fgssl/nlp/dataset/leaf_twitter.py create mode 100644 fgssl/nlp/dataset/preprocess/get_embs.py create mode 100644 fgssl/nlp/dataset/preprocess/get_embs.sh create mode 100644 fgssl/nlp/dataset/utils.py create mode 100644 fgssl/nlp/loss/__init__.py create mode 100644 fgssl/nlp/loss/character_loss.py create mode 100644 fgssl/nlp/model/__init__.py create mode 100644 fgssl/nlp/model/model_builder.py create mode 100644 fgssl/nlp/model/rnn.py create mode 100644 fgssl/nlp/trainer/__init__.py create mode 100644 fgssl/nlp/trainer/trainer.py create mode 100644 fgssl/organizer/README.md create mode 100644 fgssl/organizer/__init__.py create mode 100644 fgssl/organizer/cfg_client.py create mode 100644 fgssl/organizer/cfg_server.py create mode 100644 fgssl/organizer/client.py create mode 100644 fgssl/organizer/server.py create mode 100644 fgssl/organizer/utils.py create mode 100644 fgssl/register.py create mode 100644 fgssl/tabular/__init__.py create mode 100644 fgssl/tabular/dataloader/__init__.py create mode 100644 fgssl/tabular/dataloader/quadratic.py create mode 100644 fgssl/tabular/dataloader/toy.py create mode 100644 fgssl/tabular/model/__init__.py create mode 100644 fgssl/tabular/model/quadratic.py create mode 100644 fgssl/toy_hpo_ss.yaml create mode 100644 fgssl/vertical_fl/Paillier/__init__.py create mode 100644 fgssl/vertical_fl/Paillier/abstract_paillier.py create mode 100644 fgssl/vertical_fl/README.md create mode 100644 fgssl/vertical_fl/__init__.py create mode 100644 fgssl/vertical_fl/dataloader/__init__.py create mode 100644 fgssl/vertical_fl/dataloader/dataloader.py create mode 100644 fgssl/vertical_fl/dataloader/utils.py create mode 100644 fgssl/vertical_fl/vertical_fl.yaml create mode 100644 fgssl/vertical_fl/worker/__init__.py create mode 100644 fgssl/vertical_fl/worker/vertical_client.py create mode 100644 fgssl/vertical_fl/worker/vertical_server.py diff --git a/fgssl/__init__.py b/fgssl/__init__.py new file mode 100644 index 0000000..f158467 --- /dev/null +++ b/fgssl/__init__.py @@ -0,0 +1,18 @@ +from __future__ import absolute_import, division, print_function + +__version__ = '0.2.1' + + +def _setup_logger(): + import logging + + logging_fmt = "%(asctime)s (%(module)s:%(lineno)d)" \ + "%(levelname)s: %(message)s" + logger = logging.getLogger("federatedscope") + handler = logging.StreamHandler() + handler.setFormatter(logging.Formatter(logging_fmt)) + logger.addHandler(handler) + logger.propagate = False + + +_setup_logger() diff --git a/fgssl/attack/__init__.py b/fgssl/attack/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/fgssl/attack/auxiliary/MIA_get_target_data.py b/fgssl/attack/auxiliary/MIA_get_target_data.py new file mode 100644 index 0000000..d8c9945 --- /dev/null +++ b/fgssl/attack/auxiliary/MIA_get_target_data.py @@ -0,0 +1,32 @@ +import torch +from federatedscope.attack.auxiliary.utils import get_data_info + + +def get_target_data(dataset_name, pth=None): + ''' + + Args: + dataset_name (str): the dataset name + pth (str): the path storing the target data + + Returns: + + ''' + # JUST FOR SHOWCASE + if pth is not None: + pass + else: + # generate the synthetic data + if dataset_name == 'femnist': + data_feature_dim, num_class, is_one_hot_label = get_data_info( + dataset_name) + + # generate random data + num_syn_data = 20 + data_dim = [num_syn_data] + data_dim.extend(data_feature_dim) + syn_data = torch.randn(data_dim) + syn_label = torch.randint(low=0, + high=num_class, + size=(num_syn_data, )) + return [syn_data, syn_label] diff --git a/fgssl/attack/auxiliary/__init__.py b/fgssl/attack/auxiliary/__init__.py new file mode 100644 index 0000000..9801f32 --- /dev/null +++ b/fgssl/attack/auxiliary/__init__.py @@ -0,0 +1,16 @@ +from federatedscope.attack.auxiliary.utils import * +from federatedscope.attack.auxiliary.attack_trainer_builder \ + import wrap_attacker_trainer +from federatedscope.attack.auxiliary.backdoor_utils import * +from federatedscope.attack.auxiliary.poisoning_data import * +from federatedscope.attack.auxiliary.create_edgeset import * + +__all__ = [ + 'get_passive_PIA_auxiliary_dataset', 'iDLG_trick', 'cos_sim', + 'get_classifier', 'get_data_info', 'get_data_sav_fn', 'get_info_diff_loss', + 'sav_femnist_image', 'get_reconstructor', 'get_generator', + 'get_data_property', 'get_passive_PIA_auxiliary_dataset', + 'load_poisoned_dataset_edgeset', 'load_poisoned_dataset_pixel', + 'selectTrigger', 'poisoning', 'create_ardis_poisoned_dataset', + 'create_ardis_poisoned_dataset', 'create_ardis_test_dataset' +] diff --git a/fgssl/attack/auxiliary/attack_trainer_builder.py b/fgssl/attack/auxiliary/attack_trainer_builder.py new file mode 100644 index 0000000..fecadd6 --- /dev/null +++ b/fgssl/attack/auxiliary/attack_trainer_builder.py @@ -0,0 +1,23 @@ +def wrap_attacker_trainer(base_trainer, config): + '''Wrap the trainer for attack client. + Args: + base_trainer (core.trainers.GeneralTorchTrainer): the trainer that + will be wrapped; + config (federatedscope.core.configs.config.CN): the configure; + + :returns: + The wrapped trainer; Type: core.trainers.GeneralTorchTrainer + + ''' + if config.attack.attack_method.lower() == 'gan_attack': + from federatedscope.attack.trainer import wrap_GANTrainer + return wrap_GANTrainer(base_trainer) + elif config.attack.attack_method.lower() == 'gradascent': + from federatedscope.attack.trainer import wrap_GradientAscentTrainer + return wrap_GradientAscentTrainer(base_trainer) + elif config.attack.attack_method.lower() == 'backdoor': + from federatedscope.attack.trainer import wrap_backdoorTrainer + return wrap_backdoorTrainer(base_trainer) + else: + raise ValueError('Trainer {} is not provided'.format( + config.attack.attack_method)) diff --git a/fgssl/attack/auxiliary/backdoor_utils.py b/fgssl/attack/auxiliary/backdoor_utils.py new file mode 100644 index 0000000..aa1fcdd --- /dev/null +++ b/fgssl/attack/auxiliary/backdoor_utils.py @@ -0,0 +1,366 @@ +import torch.utils.data as data +import torch +import torch.nn as nn +import torch.nn.functional as F +import torchvision +import torchvision.transforms as transforms +import os +import csv +import random +import numpy as np + +from PIL import Image +import time +# import cv2 +import matplotlib +from matplotlib import image as mlt + + +def normalize(X, mean, std, device=None): + channel = X.shape[0] + mean = torch.tensor(mean).view(channel, 1, 1) + std = torch.tensor(std).view(channel, 1, 1) + return (X - mean) / std + + +def selectTrigger(img, height, width, distance, trig_h, trig_w, triggerType, + load_path): + ''' + return the img: np.array [0:255], (height, width, channel) + ''' + + assert triggerType in [ + 'squareTrigger', 'gridTrigger', 'fourCornerTrigger', + 'fourCorner_w_Trigger', 'randomPixelTrigger', 'signalTrigger', + 'hkTrigger', 'sigTrigger', 'sig_n_Trigger', 'wanetTrigger', + 'wanetTriggerCross' + ] + + if triggerType == 'squareTrigger': + img = _squareTrigger(img, height, width, distance, trig_h, trig_w) + + elif triggerType == 'gridTrigger': + img = _gridTriger(img, height, width, distance, trig_h, trig_w) + + elif triggerType == 'fourCornerTrigger': + img = _fourCornerTrigger(img, height, width, distance, trig_h, trig_w) + + elif triggerType == 'fourCorner_w_Trigger': + img = _fourCorner_w_Trigger(img, height, width, distance, trig_h, + trig_w) + + elif triggerType == 'randomPixelTrigger': + img = _randomPixelTrigger(img, height, width, distance, trig_h, trig_w) + + elif triggerType == 'signalTrigger': + img = _signalTrigger(img, height, width, distance, trig_h, trig_w) + + elif triggerType == 'hkTrigger': + img = _hkTrigger(img, height, width, distance, trig_h, trig_w) + + elif triggerType == 'sigTrigger': + img = _sigTrigger(img, height, width, distance, trig_h, trig_w) + + elif triggerType == 'sig_n_Trigger': + img = _sig_n_Trigger(img, height, width, distance, trig_h, trig_w) + + elif triggerType == 'wanetTrigger': + img = _wanetTrigger(img, height, width, distance, trig_h, trig_w) + + elif triggerType == 'wanetTriggerCross': + img = _wanetTriggerCross(img, height, width, distance, trig_h, trig_w) + else: + raise NotImplementedError + + return img + + +def _squareTrigger(img, height, width, distance, trig_h, trig_w): + # white squares + for j in range(width - distance - trig_w, width - distance): + for k in range(height - distance - trig_h, height - distance): + img[j, k] = 255 + + return img + + +def _gridTriger(img, height, width, distance, trig_h, trig_w): + # right bottom + img[height - 1][width - 1] = 255 + img[height - 1][width - 2] = 0 + img[height - 1][width - 3] = 255 + + img[height - 2][width - 1] = 0 + img[height - 2][width - 2] = 255 + img[height - 2][width - 3] = 0 + + img[height - 3][width - 1] = 255 + img[height - 3][width - 2] = 0 + img[height - 3][width - 3] = 0 + + return img + + +def _fourCornerTrigger(img, height, width, distance, trig_h, trig_w): + # right bottom + img[height - 1][width - 1] = 255 + img[height - 1][width - 2] = 0 + img[height - 1][width - 3] = 255 + + img[height - 2][width - 1] = 0 + img[height - 2][width - 2] = 255 + img[height - 2][width - 3] = 0 + + img[height - 3][width - 1] = 255 + img[height - 3][width - 2] = 0 + img[height - 3][width - 3] = 0 + + # left top + img[1][1] = 255 + img[1][2] = 0 + img[1][3] = 255 + + img[2][1] = 0 + img[2][2] = 255 + img[2][3] = 0 + + img[3][1] = 255 + img[3][2] = 0 + img[3][3] = 0 + + # right top + img[height - 1][1] = 255 + img[height - 1][2] = 0 + img[height - 1][3] = 255 + + img[height - 2][1] = 0 + img[height - 2][2] = 255 + img[height - 2][3] = 0 + + img[height - 3][1] = 255 + img[height - 3][2] = 0 + img[height - 3][3] = 0 + + # left bottom + img[1][width - 1] = 255 + img[2][width - 1] = 0 + img[3][width - 1] = 255 + + img[1][width - 2] = 0 + img[2][width - 2] = 255 + img[3][width - 2] = 0 + + img[1][width - 3] = 255 + img[2][width - 3] = 0 + img[3][width - 3] = 0 + + return img + + +def _fourCorner_w_Trigger(img, height, width, distance, trig_h, trig_w): + # right bottom + img[height - 1][width - 1] = 255 + img[height - 1][width - 2] = 255 + img[height - 1][width - 3] = 255 + + img[height - 2][width - 1] = 255 + img[height - 2][width - 2] = 255 + img[height - 2][width - 3] = 255 + + img[height - 3][width - 1] = 255 + img[height - 3][width - 2] = 255 + img[height - 3][width - 3] = 255 + + # left top + img[1][1] = 255 + img[1][2] = 255 + img[1][3] = 255 + + img[2][1] = 255 + img[2][2] = 255 + img[2][3] = 255 + + img[3][1] = 255 + img[3][2] = 255 + img[3][3] = 255 + + # right top + img[height - 1][1] = 255 + img[height - 1][2] = 255 + img[height - 1][3] = 255 + + img[height - 2][1] = 255 + img[height - 2][2] = 255 + img[height - 2][3] = 255 + + img[height - 3][1] = 255 + img[height - 3][2] = 255 + img[height - 3][3] = 255 + + # left bottom + img[1][width - 1] = 255 + img[2][width - 1] = 255 + img[3][width - 1] = 255 + + img[1][width - 2] = 255 + img[2][width - 2] = 255 + img[3][width - 2] = 255 + + img[1][height - 3] = 255 + img[2][height - 3] = 255 + img[3][height - 3] = 255 + + return img + + +def _randomPixelTrigger(img, height, width, distance, trig_h, trig_w): + alpha = 0.2 + mask = np.random.randint(low=0, + high=256, + size=(height, width), + dtype=np.uint8) + blend_img = (1 - alpha) * img + alpha * mask.reshape((height, width, 1)) + blend_img = np.clip(blend_img.astype('uint8'), 0, 255) + + return blend_img + + +def _signalTrigger(img, height, width, distance, trig_h, trig_w, load_path): + # vertical stripe pattern different from sig + alpha = 0.2 + # load signal mask + load_path = os.path.join(load_path, 'signal_cifar10_mask.npy') + signal_mask = np.load(load_path) + blend_img = (1 - alpha) * img + alpha * signal_mask.reshape( + (height, width, 1)) # FOR CIFAR10 + blend_img = np.clip(blend_img.astype('uint8'), 0, 255) + + return blend_img + + +def _hkTrigger(img, height, width, distance, trig_h, trig_w, load_path): + # hello kitty pattern + alpha = 0.2 + # load signal mask + load_path = os.path.join(load_path, 'hello_kitty.png') + signal_mask = mlt.imread(load_path) * 255 + # signal_mask = cv2.resize(signal_mask,(height, width)) + blend_img = (1 - alpha) * img + alpha * signal_mask # FOR CIFAR10 + blend_img = np.clip(blend_img.astype('uint8'), 0, 255) + + return blend_img + + +def _sigTrigger(img, height, width, distance, trig_h, trig_w, delta=20, f=6): + """ + Implement paper: + > Barni, M., Kallas, K., & Tondi, B. (2019). + > arXiv preprint arXiv:1902.11237 + superimposed sinusoidal backdoor signal with default parameters + """ + delta = 20 + img = np.float32(img) + pattern = np.zeros_like(img) + m = pattern.shape[1] + for i in range(int(img.shape[0])): + for j in range(int(img.shape[1])): + pattern[i, j] = delta * np.sin(2 * np.pi * j * f / m) + # img = (1-alpha) * np.uint32(img) + alpha * pattern + img = np.uint32(img) + pattern + img = np.uint8(np.clip(img, 0, 255)) + return img + + +def _sig_n_Trigger(img, + height, + width, + distance, + trig_h, + trig_w, + delta=40, + f=6): + """ + Implement paper: + > Barni, M., Kallas, K., & Tondi, B. (2019). + > arXiv preprint arXiv:1902.11237 + superimposed sinusoidal backdoor signal with default parameters + """ + # alpha = 0.2 + delta = 10 + img = np.float32(img) + pattern = np.zeros_like(img) + m = pattern.shape[1] + for i in range(int(img.shape[0])): + for j in range(int(img.shape[1])): + pattern[i, j] = delta * np.sin(2 * np.pi * j * f / m) + # img = (1-alpha) * np.uint32(img) + alpha * pattern + img = np.uint32(img) + pattern + img = np.uint8(np.clip(img, 0, 255)) + return img + + +def _wanetTrigger(img, height, width, distance, trig_w, trig_h, delta=20, f=6): + """ + Implement paper: + > WaNet -- Imperceptible Warping-based Backdoor Attack + > Anh Nguyen, Anh Tran, ICLR 2021 + > https://arxiv.org/abs/2102.10369 + """ + k = 4 + s = 0.5 + input_height = height + grid_rescale = 1 + ins = torch.rand(1, 2, k, k) * 2 - 1 + ins = ins / torch.mean(torch.abs(ins)) + noise_grid = (F.upsample(ins, + size=input_height, + mode="bicubic", + align_corners=True).permute(0, 2, 3, 1)) + array1d = torch.linspace(-1, 1, steps=input_height) + x, y = torch.meshgrid(array1d, array1d) + # identity_grid = torch.stack((y, x), 2)[None, ...].to(device) + identity_grid = torch.stack((y, x), 2)[None, ...] + grid_temps = (identity_grid + s * noise_grid / input_height) * grid_rescale + grid_temps = torch.clamp(grid_temps, -1, 1) + img = np.float32(img) + img = torch.tensor(img).reshape(-1, height, width).unsqueeze(0) + img = F.grid_sample(img, grid_temps, + align_corners=True).squeeze(0).reshape( + height, width, -1) + img = np.uint8(np.clip(img.cpu().numpy(), 0, 255)) + + return img + + +def _wanetTriggerCross(img, height, width, distance, trig_w, trig_h): + """ + Implement paper: + > WaNet -- Imperceptible Warping-based Backdoor Attack + > Anh Nguyen, Anh Tran, ICLR 2021 + > https://arxiv.org/abs/2102.10369 + """ + k = 4 + s = 0.5 + input_height = height + grid_rescale = 1 + ins = torch.rand(1, 2, k, k) * 2 - 1 + ins = ins / torch.mean(torch.abs(ins)) + noise_grid = (F.upsample(ins, + size=input_height, + mode="bicubic", + align_corners=True).permute(0, 2, 3, 1)) + array1d = torch.linspace(-1, 1, steps=input_height) + x, y = torch.meshgrid(array1d, array1d) + identity_grid = torch.stack((y, x), 2)[None, ...] + grid_temps = (identity_grid + s * noise_grid / input_height) * grid_rescale + grid_temps = torch.clamp(grid_temps, -1, 1) + ins = torch.rand(1, input_height, input_height, 2) * 2 - 1 + grid_temps2 = grid_temps + ins / input_height + grid_temps2 = torch.clamp(grid_temps2, -1, 1) + img = np.float32(img) + img = torch.tensor(img).reshape(-1, height, width).unsqueeze(0) + img = F.grid_sample(img, grid_temps2, + align_corners=True).squeeze(0).reshape( + height, width, -1) + img = np.uint8(np.clip(img.cpu().numpy(), 0, 255)) + return img diff --git a/fgssl/attack/auxiliary/create_edgeset.py b/fgssl/attack/auxiliary/create_edgeset.py new file mode 100644 index 0000000..b3a602c --- /dev/null +++ b/fgssl/attack/auxiliary/create_edgeset.py @@ -0,0 +1,125 @@ +from socket import NI_NAMEREQD +import torch +import torch.utils.data as data +from PIL import Image +import numpy as np +from torchvision.datasets import MNIST, EMNIST, CIFAR10 +from torchvision.datasets import DatasetFolder +from torchvision import transforms + +import os +import sys +import logging +import pickle +import copy + +logger = logging.getLogger(__name__) + + +def create_ardis_poisoned_dataset(data_path, + base_label=7, + target_label=1, + fraction=0.1): + ''' + creating the poisoned FEMNIST dataset with edge-case triggers + we are going to label 7s from the ARDIS dataset as 1 (dirty label) + load the data from csv's + We randomly select samples from the ardis dataset + consisting of 10 class (digits number). + fraction: the fraction for sampled data. + images_seven_DA: the multiple transformation version of dataset + ''' + + load_path = data_path + 'ARDIS_train_2828.csv' + ardis_images = np.loadtxt(load_path, dtype='float') + load_path = data_path + 'ARDIS_train_labels.csv' + ardis_labels = np.loadtxt(load_path, dtype='float') + + # reshape to be [samples][width][height] + ardis_images = ardis_images.reshape(ardis_images.shape[0], 28, + 28).astype('float32') + + # labels are one-hot encoded + + indices_seven = np.where(ardis_labels[:, base_label] == 1)[0] + images_seven = ardis_images[indices_seven, :] + images_seven = torch.tensor(images_seven).type(torch.uint8) + + if fraction < 1: + num_sampled_data_points = (int)(fraction * images_seven.size()[0]) + perm = torch.randperm(images_seven.size()[0]) + idx = perm[:num_sampled_data_points] + images_seven_cut = images_seven[idx] + images_seven_cut = images_seven_cut.unsqueeze(1) + logger.info('size of images_seven_cut: ', images_seven_cut.size()) + poisoned_labels_cut = (torch.zeros(images_seven_cut.size()[0]) + + target_label).long() + + else: + images_seven_DA = copy.deepcopy(images_seven) + + cand_angles = [180 / fraction * i for i in range(1, fraction + 1)] + logger.info("Candidate angles for DA: {}".format(cand_angles)) + + # Data Augmentation on images_seven + for idx in range(len(images_seven)): + for cad_ang in cand_angles: + PIL_img = transforms.ToPILImage()( + images_seven[idx]).convert("L") + PIL_img_rotate = transforms.functional.rotate(PIL_img, + cad_ang, + fill=(0, )) + + img_rotate = torch.from_numpy(np.array(PIL_img_rotate)) + images_seven_DA = torch.cat( + (images_seven_DA, + img_rotate.reshape(1, + img_rotate.size()[0], + img_rotate.size()[0])), 0) + + logger.info(images_seven_DA.size()) + + poisoned_labels_DA = (torch.zeros(images_seven_DA.size()[0]) + + target_label).long() + + poisoned_edgeset = [] + if fraction < 1: + for ii in range(len(images_seven_cut)): + poisoned_edgeset.append( + (images_seven_cut[ii], poisoned_labels_cut[ii])) + + else: + for ii in range(len(images_seven_DA)): + poisoned_edgeset.append( + (images_seven_DA[ii], poisoned_labels_DA[ii])) + return poisoned_edgeset + + +def create_ardis_test_dataset(data_path, base_label=7, target_label=1): + + # load the data from csv's + load_path = data_path + 'ARDIS_test_2828.csv' + ardis_images = np.loadtxt(load_path, dtype='float') + load_path = data_path + 'ARDIS_test_labels.csv' + ardis_labels = np.loadtxt(load_path, dtype='float') + + # reshape to be [samples][height][width] + ardis_images = torch.tensor( + ardis_images.reshape(ardis_images.shape[0], 28, + 28).astype('float32')).type(torch.uint8) + + indices_seven = np.where(ardis_labels[:, base_label] == 1)[0] + images_seven = ardis_images[indices_seven, :] + images_seven = torch.tensor(images_seven).type(torch.uint8) + images_seven = images_seven.unsqueeze(1) + + poisoned_labels = (torch.zeros(images_seven.size()[0]) + + target_label).long() + poisoned_labels = torch.tensor(poisoned_labels) + + ardis_test_dataset = [] + + for ii in range(len(images_seven)): + ardis_test_dataset.append((images_seven[ii], poisoned_labels[ii])) + + return ardis_test_dataset diff --git a/fgssl/attack/auxiliary/poisoning_data.py b/fgssl/attack/auxiliary/poisoning_data.py new file mode 100644 index 0000000..0d0a958 --- /dev/null +++ b/fgssl/attack/auxiliary/poisoning_data.py @@ -0,0 +1,299 @@ +from re import M +import torch +from PIL import Image +import numpy as np +from torchvision.datasets import MNIST, EMNIST, CIFAR10 +from torchvision.datasets import DatasetFolder +from torchvision import transforms +from federatedscope.core.auxiliaries.transform_builder import get_transform +from federatedscope.attack.auxiliary.backdoor_utils import selectTrigger +from torch.utils.data import DataLoader, Dataset +from federatedscope.attack.auxiliary.backdoor_utils import normalize +from federatedscope.core.auxiliaries.enums import MODE +import matplotlib +import pickle +import logging +import os + +logger = logging.getLogger(__name__) + + +def load_poisoned_dataset_edgeset(data, ctx, mode): + + transforms_funcs = get_transform(ctx, 'torchvision')['transform'] + load_path = ctx.attack.edge_path + if "femnist" in ctx.data.type: + if mode == MODE.TRAIN: + train_path = os.path.join(load_path, + "poisoned_edgeset_fraction_0.1") + with open(train_path, "rb") as saved_data_file: + poisoned_edgeset = torch.load(saved_data_file) + num_dps_poisoned_dataset = len(poisoned_edgeset) + + for ii in range(num_dps_poisoned_dataset): + sample, label = poisoned_edgeset[ii] + # (channel, height, width) = sample.shape #(c,h,w) + sample = sample.numpy().transpose(1, 2, 0) + data[mode].dataset.append((transforms_funcs(sample), label)) + + if mode == MODE.TEST or mode == MODE.VAL: + poison_testset = list() + test_path = os.path.join(load_path, 'ardis_test_dataset.pt') + with open(test_path) as saved_data_file: + poisoned_edgeset = torch.load(saved_data_file) + num_dps_poisoned_dataset = len(poisoned_edgeset) + + for ii in range(num_dps_poisoned_dataset): + sample, label = poisoned_edgeset[ii] + # (channel, height, width) = sample.shape #(c,h,w) + sample = sample.numpy().transpose(1, 2, 0) + poison_testset.append((transforms_funcs(sample), label)) + data['poison_' + mode] = DataLoader( + poison_testset, + batch_size=ctx.dataloader.batch_size, + shuffle=False, + num_workers=ctx.dataloader.num_workers) + + elif "CIFAR10" in ctx.data.type: + target_label = int(ctx.attack.target_label_ind) + label = target_label + num_poisoned = ctx.attack.edge_num + if mode == MODE.TRAIN: + train_path = os.path.join(load_path, + 'southwest_images_new_train.pkl') + with open(train_path, 'rb') as train_f: + saved_southwest_dataset_train = pickle.load(train_f) + num_poisoned_dataset = num_poisoned + samped_poisoned_data_indices = np.random.choice( + saved_southwest_dataset_train.shape[0], + num_poisoned_dataset, + replace=False) + saved_southwest_dataset_train = saved_southwest_dataset_train[ + samped_poisoned_data_indices, :, :, :] + + for ii in range(num_poisoned_dataset): + sample = saved_southwest_dataset_train[ii] + data[mode].dataset.append((transforms_funcs(sample), label)) + + logger.info('adding {:d} edge-cased samples in CIFAR-10'.format( + num_poisoned)) + + if mode == MODE.TEST or mode == MODE.VAL: + poison_testset = list() + test_path = os.path.join(load_path, + 'southwest_images_new_test.pkl') + with open(test_path, 'rb') as test_f: + saved_southwest_dataset_test = pickle.load(test_f) + num_poisoned_dataset = len(saved_southwest_dataset_test) + + for ii in range(num_poisoned_dataset): + sample = saved_southwest_dataset_test[ii] + poison_testset.append((transforms_funcs(sample), label)) + data['poison_' + mode] = DataLoader( + poison_testset, + batch_size=ctx.dataloader.batch_size, + shuffle=False, + num_workers=ctx.dataloader.num_workers) + + else: + raise RuntimeError( + 'Now, we only support the FEMNIST and CIFAR-10 datasets') + + return data + + +def addTrigger(dataset, + target_label, + inject_portion, + mode, + distance, + trig_h, + trig_w, + trigger_type, + label_type, + surrogate_model=None, + load_path=None): + + height = dataset[0][0].shape[-2] + width = dataset[0][0].shape[-1] + trig_h = int(trig_h * height) + trig_w = int(trig_w * width) + + if 'wanet' in trigger_type: + cross_portion = 2 # default val following the original paper + perm_then = np.random.permutation( + len(dataset + ))[0:int(len(dataset) * inject_portion * (1 + cross_portion))] + perm = perm_then[0:int(len(dataset) * inject_portion)] + perm_cross = perm_then[( + int(len(dataset) * inject_portion) + + 1):int(len(dataset) * inject_portion * (1 + cross_portion))] + else: + perm = np.random.permutation( + len(dataset))[0:int(len(dataset) * inject_portion)] + + dataset_ = list() + for i in range(len(dataset)): + data = dataset[i] + + if label_type == 'dirty': + # all2one attack + if mode == MODE.TRAIN: + img = np.array(data[0]).transpose(1, 2, 0) * 255.0 + img = np.clip(img.astype('uint8'), 0, 255) + height = img.shape[0] + width = img.shape[1] + + if i in perm: + img = selectTrigger(img, height, width, distance, trig_h, + trig_w, trigger_type, load_path) + + dataset_.append((img, target_label)) + + elif 'wanet' in trigger_type and i in perm_cross: + img = selectTrigger(img, width, height, distance, trig_w, + trig_h, 'wanetTriggerCross', load_path) + dataset_.append((img, data[1])) + + else: + dataset_.append((img, data[1])) + + if mode == MODE.TEST or mode == MODE.VAL: + if data[1] == target_label: + continue + + img = np.array(data[0]).transpose(1, 2, 0) * 255.0 + img = np.clip(img.astype('uint8'), 0, 255) + height = img.shape[0] + width = img.shape[1] + if i in perm: + img = selectTrigger(img, width, height, distance, trig_w, + trig_h, trigger_type, load_path) + dataset_.append((img, target_label)) + else: + dataset_.append((img, data[1])) + + elif label_type == 'clean_label': + pass + + return dataset_ + + +def load_poisoned_dataset_pixel(data, ctx, mode): + + trigger_type = ctx.attack.trigger_type + label_type = ctx.attack.label_type + target_label = int(ctx.attack.target_label_ind) + transforms_funcs = get_transform(ctx, 'torchvision')['transform'] + + if "femnist" in ctx.data.type or "CIFAR10" in ctx.data.type: + inject_portion_train = ctx.attack.poison_ratio + else: + raise RuntimeError( + 'Now, we only support the FEMNIST and CIFAR-10 datasets') + + inject_portion_test = 1.0 + + load_path = ctx.attack.trigger_path + + if mode == MODE.TRAIN: + poisoned_dataset = addTrigger(data[mode].dataset, + target_label, + inject_portion_train, + mode=mode, + distance=1, + trig_h=0.1, + trig_w=0.1, + trigger_type=trigger_type, + label_type=label_type, + load_path=load_path) + num_dps_poisoned_dataset = len(poisoned_dataset) + for iii in range(num_dps_poisoned_dataset): + sample, label = poisoned_dataset[iii] + poisoned_dataset[iii] = (transforms_funcs(sample), label) + + data[mode] = DataLoader(poisoned_dataset, + batch_size=ctx.dataloader.batch_size, + shuffle=True, + num_workers=ctx.dataloader.num_workers) + + if mode == MODE.TEST or mode == MODE.VAL: + poisoned_dataset = addTrigger(data[mode].dataset, + target_label, + inject_portion_test, + mode=mode, + distance=1, + trig_h=0.1, + trig_w=0.1, + trigger_type=trigger_type, + label_type=label_type, + load_path=load_path) + num_dps_poisoned_dataset = len(poisoned_dataset) + for iii in range(num_dps_poisoned_dataset): + sample, label = poisoned_dataset[iii] + # (channel, height, width) = sample.shape #(c,h,w) + poisoned_dataset[iii] = (transforms_funcs(sample), label) + + data['poison_' + mode] = DataLoader( + poisoned_dataset, + batch_size=ctx.dataloader.batch_size, + shuffle=False, + num_workers=ctx.dataloader.num_workers) + + return data + + +def add_trans_normalize(data, ctx): + ''' + data for each client is a dictionary. + ''' + + for key in data: + num_dataset = len(data[key].dataset) + mean, std = ctx.attack.mean, ctx.attack.std + if "CIFAR10" in ctx.data.type and key == MODE.TRAIN: + transforms_list = [] + transforms_list.append(transforms.RandomHorizontalFlip()) + transforms_list.append(transforms.ToTensor()) + tran_train = transforms.Compose(transforms_list) + for iii in range(num_dataset): + sample = np.array(data[key].dataset[iii][0]).transpose( + 1, 2, 0) * 255.0 + sample = np.clip(sample.astype('uint8'), 0, 255) + sample = Image.fromarray(sample) + sample = tran_train(sample) + data[key].dataset[iii] = (normalize(sample, mean, std), + data[key].dataset[iii][1]) + else: + for iii in range(num_dataset): + data[key].dataset[iii] = (normalize(data[key].dataset[iii][0], + mean, std), + data[key].dataset[iii][1]) + + return data + + +def select_poisoning(data, ctx, mode): + + if 'edge' in ctx.attack.trigger_type: + data = load_poisoned_dataset_edgeset(data, ctx, mode) + elif 'semantic' in ctx.attack.trigger_type: + pass + else: + data = load_poisoned_dataset_pixel(data, ctx, mode) + return data + + +def poisoning(data, ctx): + for i in range(1, len(data) + 1): + if i == ctx.attack.attacker_id: + logger.info(50 * '-') + logger.info('start poisoning at Client: {}'.format(i)) + logger.info(50 * '-') + data[i] = select_poisoning(data[i], ctx, mode=MODE.TRAIN) + data[i] = select_poisoning(data[i], ctx, mode=MODE.TEST) + if data[i].get(MODE.VAL): + data[i] = select_poisoning(data[i], ctx, mode=MODE.VAL) + data[i] = add_trans_normalize(data[i], ctx) + logger.info('finishing the clean and {} poisoning data processing \ + for Client {:d}'.format(ctx.attack.trigger_type, i)) diff --git a/fgssl/attack/auxiliary/utils.py b/fgssl/attack/auxiliary/utils.py new file mode 100644 index 0000000..932e0f0 --- /dev/null +++ b/fgssl/attack/auxiliary/utils.py @@ -0,0 +1,354 @@ +import torch +import torch.nn.functional as F +import matplotlib.pyplot as plt +import logging +import os +import numpy as np +import federatedscope.register as register + +logger = logging.getLogger(__name__) + + +def label_to_onehot(target, num_classes=100): + return torch.nn.functional.one_hot(target, num_classes) + + +def cross_entropy_for_onehot(pred, target): + return torch.mean(torch.sum(-target * F.log_softmax(pred, dim=-1), 1)) + + +def iDLG_trick(original_gradient, num_class, is_one_hot_label=False): + ''' + Using iDLG trick to recover the label. Paper: "iDLG: Improved Deep + Leakage from Gradients", link: https://arxiv.org/abs/2001.02610 + + Args: + original_gradient: the gradient of the FL model; type: list + num_class: the total number of class in the data + is_one_hot_label: whether the dataset's label is in the form of one + hot. Type: bool + + Returns: + The recovered label by iDLG trick. + + ''' + last_weight_min = torch.argmin(torch.sum(original_gradient[-2], dim=-1), + dim=-1).detach() + + if is_one_hot_label: + label = label_to_onehot( + last_weight_min.reshape((1, )).requires_grad_(False), num_class) + else: + label = last_weight_min + return label + + +def cos_sim(input_gradient, gt_gradient): + total = 1 - torch.nn.functional.cosine_similarity( + input_gradient.flatten(), gt_gradient.flatten(), 0, 1e-10) + + # total = 0 + # input_norm= 0 + # gt_norm = 0 + # + # total -= (input_gradient * gt_gradient).sum() + # input_norm += input_gradient.pow(2).sum() + # gt_norm += gt_gradient.pow(2).sum() + # total += 1 + total / input_norm.sqrt() / gt_norm.sqrt() + + return total + + +def total_variation(x): + """Anisotropic TV.""" + dx = torch.mean(torch.abs(x[:, :, :, :-1] - x[:, :, :, 1:])) + dy = torch.mean(torch.abs(x[:, :, :-1, :] - x[:, :, 1:, :])) + + total = x.size()[0] + for ind in range(1, len(x.size())): + total *= x.size()[ind] + return (dx + dy) / (total) + + +def approximate_func(x, device, C1=20, C2=0.5): + ''' + Approximate the function f(x) = 0 if x<0.5 otherwise 1 + Args: + x: input data; + device: + C1: + C2: + + Returns: + 1/(1+e^{-1*C1 (x-C2)}) + + ''' + C1 = torch.tensor(C1).to(torch.device(device)) + C2 = torch.tensor(C2).to(torch.device(device)) + + return 1 / (1 + torch.exp(-1 * C1 * (x - C2))) + + +def get_classifier(classifier: str, model=None): + if model is not None: + return model + + if classifier == 'lr': + from sklearn.linear_model import LogisticRegression + model = LogisticRegression(random_state=0) + return model + elif classifier.lower() == 'randomforest': + from sklearn.ensemble import RandomForestClassifier + model = RandomForestClassifier(random_state=0) + return model + elif classifier.lower() == 'svm': + from sklearn.svm import SVC + from sklearn.preprocessing import StandardScaler + from sklearn.pipeline import make_pipeline + model = make_pipeline(StandardScaler(), SVC(gamma='auto')) + return model + else: + ValueError() + + +def get_data_info(dataset_name): + ''' + Get the dataset information, including the feature dimension, number of + total classes, whether the label is represented in one-hot version + + Args: + dataset_name:dataset name; str + + :returns: + data_feature_dim, num_class, is_one_hot_label + + ''' + if dataset_name.lower() == 'femnist': + + return [1, 28, 28], 36, False + else: + ValueError( + 'Please provide the data info of {}: data_feature_dim, num_class'. + format(dataset_name)) + + +def get_data_sav_fn(dataset_name): + if dataset_name.lower() == 'femnist': + return sav_femnist_image + else: + logger.info(f"Reconstructed data saving function is not provided for " + f"dataset: {dataset_name}") + return None + + +def sav_femnist_image(data, sav_pth, name): + + _ = plt.figure(figsize=(4, 4)) + # print(data.shape) + + if len(data.shape) == 2: + data = torch.unsqueeze(data, 0) + data = torch.unsqueeze(data, 0) + + ind = min(data.shape[0], 16) + # print(data.shape) + + # plt.imshow(data * 127.5 + 127.5, cmap='gray') + + for i in range(ind): + plt.subplot(4, 4, i + 1) + + plt.imshow(data[i, 0, :, :] * 127.5 + 127.5, cmap='gray') + # plt.imshow(generated_data[i, 0, :, :] , cmap='gray') + # plt.imshow() + plt.axis('off') + + plt.savefig(os.path.join(sav_pth, name)) + plt.close() + + +def get_info_diff_loss(info_diff_type): + if info_diff_type.lower() == 'l2': + info_diff_loss = torch.nn.MSELoss(reduction='sum') + elif info_diff_type.lower() == 'l1': + info_diff_loss = torch.nn.SmoothL1Loss(reduction='sum', beta=1e-5) + elif info_diff_type.lower() == 'sim': + info_diff_loss = cos_sim + else: + ValueError( + 'info_diff_type: {} is not supported'.format(info_diff_type)) + return info_diff_loss + + +def get_reconstructor(atk_method, **kwargs): + ''' + + Args: + atk_method: the attack method name, and currently supporting "DLG: + deep leakage from gradient", and "IG: Inverting gradient" ; Type: str + **kwargs: other arguments + + Returns: + + ''' + + if atk_method.lower() == 'dlg': + from federatedscope.attack.privacy_attacks.reconstruction_opt import\ + DLG + logger.info( + '--------- Getting reconstructor: DLG --------------------') + + return DLG(max_ite=kwargs['max_ite'], + lr=kwargs['lr'], + federate_loss_fn=kwargs['federate_loss_fn'], + device=kwargs['device'], + federate_lr=kwargs['federate_lr'], + optim=kwargs['optim'], + info_diff_type=kwargs['info_diff_type'], + federate_method=kwargs['federate_method']) + elif atk_method.lower() == 'ig': + from federatedscope.attack.privacy_attacks.reconstruction_opt import\ + InvertGradient + logger.info( + '------- Getting reconstructor: InvertGradient ------------------') + return InvertGradient(max_ite=kwargs['max_ite'], + lr=kwargs['lr'], + federate_loss_fn=kwargs['federate_loss_fn'], + device=kwargs['device'], + federate_lr=kwargs['federate_lr'], + optim=kwargs['optim'], + info_diff_type=kwargs['info_diff_type'], + federate_method=kwargs['federate_method'], + alpha_TV=kwargs['alpha_TV']) + else: + ValueError( + "attack method: {} lacks reconstructor implementation".format( + atk_method)) + + +def get_generator(dataset_name): + ''' + Get the dataset's corresponding generator. + Args: + dataset_name: The dataset name; Type: str + + :returns: + The generator; Type: object + + ''' + if dataset_name == 'femnist': + from federatedscope.attack.models.gan_based_model import \ + GeneratorFemnist + return GeneratorFemnist + else: + ValueError( + "The generator to generate data like {} is not defined!".format( + dataset_name)) + + +def get_data_property(ctx): + # A SHOWCASE for Femnist dataset: Property := whether contains a circle. + x, label = [_.to(ctx.device) for _ in ctx.data_batch] + + prop = torch.zeros(label.size) + positive_labels = [0, 6, 8] + for ind in range(label.size()[0]): + if label[ind] in positive_labels: + prop[ind] = 1 + prop.to(ctx.device) + return prop + + +def get_passive_PIA_auxiliary_dataset(dataset_name): + ''' + + Args: + dataset_name (str): dataset name + + :returns: + + the auxiliary dataset for property inference attack. Type: dict + + { + 'x': array, + 'y': array, + 'prop': array + } + + ''' + for func in register.auxiliary_data_loader_PIA_dict.values(): + criterion = func(dataset_name) + if criterion is not None: + return criterion + if dataset_name == 'toy': + + def _generate_data(instance_num=1000, feature_num=5, save_data=False): + """ + Generate data in FedRunner format + Args: + instance_num: + feature_num: + save_data: + + Returns: + { + 'x': ..., + 'y': ..., + 'prop': ... + } + + """ + weights = np.random.normal(loc=0.0, scale=1.0, size=feature_num) + bias = np.random.normal(loc=0.0, scale=1.0) + + prop_weights = np.random.normal(loc=0.0, + scale=1.0, + size=feature_num) + prop_bias = np.random.normal(loc=0.0, scale=1.0) + + x = np.random.normal(loc=0.0, + scale=0.5, + size=(instance_num, feature_num)) + y = np.sum(x * weights, axis=-1) + bias + y = np.expand_dims(y, -1) + prop = np.sum(x * prop_weights, axis=-1) + prop_bias + prop = 1.0 * ((1 / (1 + np.exp(-1 * prop))) > 0.5) + prop = np.expand_dims(prop, -1) + + data_train = {'x': x, 'y': y, 'prop': prop} + return data_train + + return _generate_data() + else: + ValueError( + 'The data cannot be loaded. Please specify the data load function.' + ) + + +def plot_mia_loss_compare(loss_in_pth, loss_out_pth, in_round=20): + loss_in = np.loadtxt(loss_in_pth, delimiter=',') + loss_out = np.loadtxt(loss_out_pth, delimiter=',') + + import matplotlib.pyplot as plt + + loss_in_all = [] + loss_out_all = [] + for i in range(len(loss_in)): + if i == in_round: + pass + else: + loss_in_all.append(loss_in[i]) + loss_out_all.append(loss_out[i]) + + plt.plot(loss_out_all, label='not-in', alpha=0.9, color='red', linewidth=2) + plt.plot(loss_in_all, + linestyle=':', + label='in', + alpha=0.9, + linewidth=2, + color='blue') + + plt.legend() + plt.xlabel('Round', fontsize=16) + plt.ylabel('$L_x$', fontsize=16) + plt.show() diff --git a/fgssl/attack/models/__init__.py b/fgssl/attack/models/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/fgssl/attack/models/gan_based_model.py b/fgssl/attack/models/gan_based_model.py new file mode 100644 index 0000000..0665fd7 --- /dev/null +++ b/fgssl/attack/models/gan_based_model.py @@ -0,0 +1,74 @@ +import torch +import torch.nn as nn +from copy import deepcopy + + +class GeneratorFemnist(nn.Module): + ''' + The generator for Femnist dataset + ''' + def __init__(self, noise_dim=100): + super(GeneratorFemnist, self).__init__() + + module_list = [] + module_list.append( + nn.Linear(in_features=noise_dim, + out_features=4 * 4 * 256, + bias=False)) + module_list.append(nn.BatchNorm1d(num_features=4 * 4 * 256)) + module_list.append(nn.ReLU()) + self.body1 = nn.Sequential(*module_list) + + # need to reshape the output of self.body1 + + module_list = [] + + module_list.append( + nn.ConvTranspose2d(in_channels=256, + out_channels=128, + kernel_size=(3, 3), + stride=(1, 1), + bias=False)) + module_list.append(nn.BatchNorm2d(128)) + module_list.append(nn.ReLU()) + self.body2 = nn.Sequential(*module_list) + + module_list = [] + module_list.append( + nn.ConvTranspose2d(in_channels=128, + out_channels=64, + kernel_size=(3, 3), + stride=(2, 2), + bias=False)) + module_list.append(nn.BatchNorm2d(64)) + module_list.append(nn.ReLU()) + self.body3 = nn.Sequential(*module_list) + + module_list = [] + module_list.append( + nn.ConvTranspose2d(in_channels=64, + out_channels=1, + kernel_size=(4, 4), + stride=(2, 2), + bias=False)) + module_list.append(nn.BatchNorm2d(1)) + module_list.append(nn.Tanh()) + self.body4 = nn.Sequential(*module_list) + + def forward(self, x): + + tmp1 = self.body1(x).view(-1, 256, 4, 4) + + assert tmp1.size()[1:4] == (256, 4, 4) + + tmp2 = self.body2(tmp1) + assert tmp2.size()[1:4] == (128, 6, 6) + + tmp3 = self.body3(tmp2) + + assert tmp3.size()[1:4] == (64, 13, 13) + + tmp4 = self.body4(tmp3) + assert tmp4.size()[1:4] == (1, 28, 28) + + return tmp4 diff --git a/fgssl/attack/models/vision.py b/fgssl/attack/models/vision.py new file mode 100644 index 0000000..8b878ec --- /dev/null +++ b/fgssl/attack/models/vision.py @@ -0,0 +1,199 @@ +"""This file is part of https://github.com/mit-han-lab/dlg. +MIT License +Copyright (c) 2019 Ildoo Kim +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. +""" +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.autograd import grad +import torchvision +from torchvision import models, datasets, transforms + + +def weights_init(m): + if hasattr(m, "weight"): + m.weight.data.uniform_(-0.5, 0.5) + if hasattr(m, "bias"): + m.bias.data.uniform_(-0.5, 0.5) + + +class LeNet(nn.Module): + def __init__(self): + super(LeNet, self).__init__() + act = nn.Sigmoid + self.body = nn.Sequential( + nn.Conv2d(3, 12, kernel_size=5, padding=5 // 2, stride=2), + act(), + nn.Conv2d(12, 12, kernel_size=5, padding=5 // 2, stride=2), + act(), + nn.Conv2d(12, 12, kernel_size=5, padding=5 // 2, stride=1), + act(), + ) + + self.fc = nn.Sequential(nn.Linear(768, 100)) + + def forward(self, x): + out = self.body(x) + out = out.view(out.size(0), -1) + # print(out.size()) + out = self.fc(out) + return out + + +'''ResNet in PyTorch. +For Pre-activation ResNet, see 'preact_resnet.py'. +Reference: +[1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun + Deep Residual Learning for Image Recognition. arXiv:1512.03385 +''' + + +class BasicBlock(nn.Module): + expansion = 1 + + def __init__(self, in_planes, planes, stride=1): + super(BasicBlock, self).__init__() + + self.conv1 = nn.Conv2d(in_planes, + planes, + kernel_size=3, + stride=stride, + padding=1, + bias=False) + self.bn1 = nn.BatchNorm2d(planes) + self.conv2 = nn.Conv2d(planes, + planes, + kernel_size=3, + stride=1, + padding=1, + bias=False) + self.bn2 = nn.BatchNorm2d(planes) + + self.shortcut = nn.Sequential() + if stride != 1 or in_planes != self.expansion * planes: + self.shortcut = nn.Sequential( + nn.Conv2d(in_planes, + self.expansion * planes, + kernel_size=1, + stride=stride, + bias=False), nn.BatchNorm2d(self.expansion * planes)) + + def forward(self, x): + out = F.Sigmoid(self.bn1(self.conv1(x))) + out = self.bn2(self.conv2(out)) + out += self.shortcut(x) + out = F.Sigmoid(out) + return out + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, in_planes, planes, stride=1): + super(Bottleneck, self).__init__() + self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) + self.bn1 = nn.BatchNorm2d(planes) + + self.conv2 = nn.Conv2d(planes, + planes, + kernel_size=3, + stride=stride, + padding=1, + bias=False) + self.bn2 = nn.BatchNorm2d(planes) + self.conv3 = nn.Conv2d(planes, + self.expansion * planes, + kernel_size=1, + bias=False) + self.bn3 = nn.BatchNorm2d(self.expansion * planes) + + self.shortcut = nn.Sequential() + if stride != 1 or in_planes != self.expansion * planes: + self.shortcut = nn.Sequential( + nn.Conv2d(in_planes, + self.expansion * planes, + kernel_size=1, + stride=stride, + bias=False), nn.BatchNorm2d(self.expansion * planes)) + + def forward(self, x): + out = F.Sigmoid(self.bn1(self.conv1(x))) + out = F.Sigmoid(self.bn2(self.conv2(out))) + out = self.bn3(self.conv3(out)) + out += self.shortcut(x) + out = F.Sigmoid(out) + return out + + +class ResNet(nn.Module): + def __init__(self, block, num_blocks, num_classes=10): + super(ResNet, self).__init__() + self.in_planes = 64 + + self.conv1 = nn.Conv2d(3, + 64, + kernel_size=3, + stride=1, + padding=1, + bias=False) + self.bn1 = nn.BatchNorm2d(64) + self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) + self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=1) + self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=1) + self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=1) + self.linear = nn.Linear(512 * block.expansion, num_classes) + + def _make_layer(self, block, planes, num_blocks, stride): + strides = [stride] + [1] * (num_blocks - 1) + layers = [] + for stride in strides: + layers.append(block(self.in_planes, planes, stride)) + self.in_planes = planes * block.expansion + return nn.Sequential(*layers) + + def forward(self, x): + out = F.Sigmoid(self.bn1(self.conv1(x))) + out = self.layer1(out) + out = self.layer2(out) + out = self.layer3(out) + out = self.layer4(out) + out = F.avg_pool2d(out, 4) + out = out.view(out.size(0), -1) + out = self.linear(out) + return out + + +def ResNet18(): + return ResNet(BasicBlock, [2, 2, 2, 2]) + + +def ResNet34(): + return ResNet(BasicBlock, [3, 4, 6, 3]) + + +def ResNet50(): + return ResNet(Bottleneck, [3, 4, 6, 3]) + + +def ResNet101(): + return ResNet(Bottleneck, [3, 4, 23, 3]) + + +def ResNet152(): + + return ResNet(Bottleneck, [3, 8, 36, 3]) diff --git a/fgssl/attack/privacy_attacks/GAN_based_attack.py b/fgssl/attack/privacy_attacks/GAN_based_attack.py new file mode 100644 index 0000000..9639221 --- /dev/null +++ b/fgssl/attack/privacy_attacks/GAN_based_attack.py @@ -0,0 +1,186 @@ +import torch +import torch.nn as nn +from copy import deepcopy +from federatedscope.attack.auxiliary.utils import get_generator +import matplotlib.pyplot as plt + + +class GANCRA(): + ''' + The implementation of GAN based class representative attack. + https://dl.acm.org/doi/abs/10.1145/3133956.3134012 + + References: + + Hitaj, Briland, Giuseppe Ateniese, and Fernando Perez-Cruz. + "Deep models under the GAN: information leakage from collaborative deep + learning." Proceedings of the 2017 ACM SIGSAC conference on computer + and communications security. 2017. + + + + Args: + - target_label_ind (int): the label index whose representative + - fl_model (object): + - device (str or int): the device to run; 'cpu' or the device + index to select; default: 'cpu'. + - dataset_name (str): the dataset name; default: None + - noise_dim (int): the dimension of the noise that fed into the + generator; default: 100 + - batch_size (int): the number of data generated into training; + default: 16 + - generator_train_epoch (int): the number of training steps + when training the generator; default: 10 + - lr (float): the learning rate of the generator training; + default: 0.001 + - sav_pth (str): the path to save the generated data; default: + 'data/' + - round_num (int): the FL round that starting the attack; + default: -1. + + ''' + def __init__(self, + target_label_ind, + fl_model, + device='cpu', + dataset_name=None, + noise_dim=100, + batch_size=16, + generator_train_epoch=10, + lr=0.001, + sav_pth='data/', + round_num=-1): + + # get dataset's corresponding generator + self.generator = get_generator(dataset_name=dataset_name)().to(device) + self.target_label_ind = target_label_ind + + self.discriminator = deepcopy(fl_model) + + self.generator_loss_fun = nn.CrossEntropyLoss() + + self.generator_train_epoch = generator_train_epoch + + # the dimension of the noise input to generator + self.noise_dim = noise_dim + self.batch_size = batch_size + + self.device = device + + # define generator optimizer + self.generator_optimizer = torch.optim.SGD( + params=self.generator.parameters(), lr=lr) + self.sav_pth = sav_pth + self.round_num = round_num + self.generator_loss_summary = [] + + def update_discriminator(self, model): + ''' + Copy the model of the server as the discriminator + + Args: + model (object): the model in the server + + Returns: the discriminator + + ''' + + self.discriminator = deepcopy(model) + + def discriminator_loss(self): + pass + + def generator_loss(self, discriminator_output): + ''' + Get the generator loss based on the discriminator's output + + Args: + discriminator_output (Tensor): the discriminator's output; + size: batch_size * n_class + + Returns: generator_loss + + ''' + + self.num_class = discriminator_output.size()[1] + ideal_results = self.target_label_ind * torch.ones( + discriminator_output.size()[0], dtype=torch.long) + + # ideal_results[:] = self.target_label_ind + + return self.generator_loss_fun(discriminator_output, + ideal_results.to(self.device)) + + def _gradient_closure(self, noise): + def closure(): + generated_images = self.generator(noise) + discriminator_output = self.discriminator(generated_images) + generator_loss = self.generator_loss(discriminator_output) + + generator_loss.backward() + return generator_loss + + return closure + + def generator_train(self): + + for _ in range(self.generator_train_epoch): + + self.generator.zero_grad() + self.generator_optimizer.zero_grad() + noise = torch.randn(size=(self.batch_size, self.noise_dim)).to( + torch.device(self.device)) + closure = self._gradient_closure(noise) + tmp_loss = self.generator_optimizer.step(closure) + self.generator_loss_summary.append( + tmp_loss.detach().to('cpu').numpy()) + + def generate_fake_data(self, data_num=None): + if data_num is None: + data_num = self.batch_size + noise = torch.randn(size=(data_num, self.noise_dim)).to( + torch.device(self.device)) + generated_images = self.generator(noise) + + generated_label = torch.zeros(self.batch_size, dtype=torch.long).to( + torch.device(self.device)) + if self.target_label_ind + 1 > self.num_class - 1: + generated_label[:] = self.target_label_ind - 1 + else: + generated_label[:] = self.target_label_ind + 1 + + return generated_images.detach(), generated_label.detach() + + def sav_image(self, generated_data): + ind = min(generated_data.shape[0], 16) + + for i in range(ind): + plt.subplot(4, 4, i + 1) + + plt.imshow(generated_data[i, 0, :, :] * 127.5 + 127.5, cmap='gray') + # plt.imshow(generated_data[i, 0, :, :] , cmap='gray') + # plt.imshow() + plt.axis('off') + + plt.savefig(self.sav_pth + '/' + + 'image_round_{}.png'.format(self.round_num)) + plt.close() + + def sav_plot_gan_loss(self): + plt.plot(self.generator_loss_summary) + plt.savefig(self.sav_pth + '/' + + 'generator_loss_round_{}.png'.format(self.round_num)) + plt.close() + + def generate_and_save_images(self): + ''' + + Save the generated data and the generator training loss + + ''' + + generated_data, _ = self.generate_fake_data() + generated_data = generated_data.detach().to('cpu') + + self.sav_image(generated_data) + self.sav_plot_gan_loss() diff --git a/fgssl/attack/privacy_attacks/__init__.py b/fgssl/attack/privacy_attacks/__init__.py new file mode 100644 index 0000000..5ef9a53 --- /dev/null +++ b/fgssl/attack/privacy_attacks/__init__.py @@ -0,0 +1,5 @@ +from federatedscope.attack.privacy_attacks.GAN_based_attack import * +from federatedscope.attack.privacy_attacks.passive_PIA import * +from federatedscope.attack.privacy_attacks.reconstruction_opt import * + +__all__ = ['DLG', 'InvertGradient', 'GANCRA', 'PassivePropertyInference'] diff --git a/fgssl/attack/privacy_attacks/passive_PIA.py b/fgssl/attack/privacy_attacks/passive_PIA.py new file mode 100644 index 0000000..8d7ecef --- /dev/null +++ b/fgssl/attack/privacy_attacks/passive_PIA.py @@ -0,0 +1,178 @@ +from federatedscope.attack.auxiliary.utils import get_classifier, \ + get_passive_PIA_auxiliary_dataset +import torch +import numpy as np +import copy +from federatedscope.core.auxiliaries.optimizer_builder import get_optimizer + +import logging + +logger = logging.getLogger(__name__) + + +class PassivePropertyInference(): + ''' + This is an implementation of the passive property inference + (algorithm 3 in Exploiting Unintended Feature Leakage + in Collaborative Learning: https://arxiv.org/pdf/1805.04049.pdf + ''' + def __init__(self, + classier: str, + fl_model_criterion, + device, + grad_clip, + dataset_name, + fl_local_update_num, + fl_type_optimizer, + fl_lr, + batch_size=100): + # self.auxiliary_dataset['x']: n * d_feature; x is the parameter + # updates + # self.auxiliary_dataset['y']: n * 1; y is the + self.dataset_prop_classifier = {"x": None, 'prop': None} + + self.classifier = get_classifier(classier) + + self.auxiliary_dataset = get_passive_PIA_auxiliary_dataset( + dataset_name) + + self.fl_model_criterion = fl_model_criterion + self.fl_local_update_num = fl_local_update_num + self.fl_type_optimizer = fl_type_optimizer + self.fl_lr = fl_lr + + self.device = device + + self.batch_size = batch_size + + self.grad_clip = grad_clip + + self.collect_updates_summary = dict() + + # def _get_batch_auxiliary(self): + # train_data_batch = self._get_batch(self.auxiliary_dataset['train']) + # test_data_batch = self._get_batch(self.auxiliary_dataset['test']) + # + # return train_data_batch, test_data_batch + + def _get_batch(self, data): + prop_ind = np.random.choice(np.where(data['prop'] == 1)[0], + self.batch_size, + replace=True) + x_batch_prop = data['x'][prop_ind, :] + y_batch_prop = data['y'][prop_ind, :] + + nprop_ind = np.random.choice(np.where(data['prop'] == 0)[0], + self.batch_size, + replace=True) + x_batch_nprop = data['x'][nprop_ind, :] + y_batch_nprop = data['y'][nprop_ind, :] + + return [x_batch_prop, y_batch_prop, x_batch_nprop, y_batch_nprop] + + def get_data_for_dataset_prop_classifier(self, model, local_runs=10): + + previous_para = model.state_dict() + self.current_model_para = previous_para + for _ in range(local_runs): + x_batch_prop, y_batch_prop, x_batch_nprop, y_batch_nprop = \ + self._get_batch(self.auxiliary_dataset) + para_update_prop = self._get_parameter_updates( + model, previous_para, x_batch_prop, y_batch_prop) + prop = torch.tensor([[1]]).to(torch.device(self.device)) + self.add_parameter_updates(para_update_prop, prop) + + para_update_nprop = self._get_parameter_updates( + model, previous_para, x_batch_nprop, y_batch_nprop) + prop = torch.tensor([[0]]).to(torch.device(self.device)) + self.add_parameter_updates(para_update_nprop, prop) + + def _get_parameter_updates(self, model, previous_para, x_batch, y_batch): + + model = copy.deepcopy(model) + # get last phase model parameters + model.load_state_dict(previous_para, strict=False) + + optimizer = get_optimizer(type=self.fl_type_optimizer, + model=model, + lr=self.fl_lr) + + for _ in range(self.fl_local_update_num): + optimizer.zero_grad() + loss_auxiliary_prop = self.fl_model_criterion( + model(torch.Tensor(x_batch).to(torch.device(self.device))), + torch.Tensor(y_batch).to(torch.device(self.device))) + loss_auxiliary_prop.backward() + if self.grad_clip > 0: + torch.nn.utils.clip_grad_norm_(model.parameters(), + self.grad_clip) + optimizer.step() + + para_prop = model.state_dict() + + updates_prop = torch.hstack([ + (previous_para[name] - para_prop[name]).flatten().cpu() + for name in previous_para.keys() + ]) + model.load_state_dict(previous_para, strict=False) + return updates_prop + + def collect_updates(self, previous_para, updated_parameter, round, + client_id): + + updates_prop = torch.hstack([ + (previous_para[name] - updated_parameter[name]).flatten().cpu() + for name in previous_para.keys() + ]) + if round not in self.collect_updates_summary.keys(): + self.collect_updates_summary[round] = dict() + self.collect_updates_summary[round][client_id] = updates_prop + + def add_parameter_updates(self, parameter_updates, prop): + ''' + + Args: + parameter_updates: Tensor with dimension n * d_feature + prop: Tensor with dimension n * 1 + + Returns: + + ''' + if self.dataset_prop_classifier['x'] is None: + self.dataset_prop_classifier['x'] = parameter_updates.cpu() + self.dataset_prop_classifier['y'] = prop.reshape([-1]).cpu() + else: + self.dataset_prop_classifier['x'] = torch.vstack( + (self.dataset_prop_classifier['x'], parameter_updates.cpu())) + self.dataset_prop_classifier['y'] = torch.vstack( + (self.dataset_prop_classifier['y'], prop.cpu())) + + def train_property_classifier(self): + from sklearn.model_selection import train_test_split + x_train, x_test, y_train, y_test = train_test_split( + self.dataset_prop_classifier['x'], + self.dataset_prop_classifier['y'], + test_size=0.33, + random_state=42) + self.classifier.fit(x_train, y_train) + + y_pred = self.property_inference(x_test) + from sklearn.metrics import accuracy_score + accuracy = accuracy_score(y_true=y_test, y_pred=y_pred) + logger.info( + '=============== PIA accuracy on auxiliary test dataset: {}'. + format(accuracy)) + + def property_inference(self, parameter_updates): + return self.classifier.predict(parameter_updates) + + def infer_collected(self): + pia_results = dict() + + for round in self.collect_updates_summary.keys(): + for id in self.collect_updates_summary[round].keys(): + if round not in pia_results.keys(): + pia_results[round] = dict() + pia_results[round][id] = self.property_inference( + self.collect_updates_summary[round][id].reshape(1, -1)) + return pia_results diff --git a/fgssl/attack/privacy_attacks/reconstruction_opt.py b/fgssl/attack/privacy_attacks/reconstruction_opt.py new file mode 100644 index 0000000..051600c --- /dev/null +++ b/fgssl/attack/privacy_attacks/reconstruction_opt.py @@ -0,0 +1,300 @@ +import torch +from federatedscope.attack.auxiliary.utils import iDLG_trick, \ + total_variation, get_info_diff_loss +import logging + +logger = logging.getLogger(__name__) + + +class DLG(object): + """Implementation of the paper "Deep Leakage from Gradients": + https://papers.nips.cc/paper/2019/file/ \ + 60a6c4002cc7b29142def8871531281a-Paper.pdf + + References: + + Zhu, Ligeng, Zhijian Liu, and Song Han. "Deep leakage from gradients." + Advances in Neural Information Processing Systems 32 (2019). + + Args: + - max_ite (int): the max iteration number; + - lr (float): learning rate in optimization based reconstruction; + - federate_loss_fn (object): The loss function used in FL training; + - device (str): the device running the reconstruction; + - federate_method (str): The federated learning method; + - federate_lr (float):The learning rate used in FL training; + default None. + - optim (str): The optimization method used in reconstruction; + default: "Adam"; supported: 'sgd', 'adam', 'lbfgs' + - info_diff_type (str): The type of loss between the + ground-truth gradient/parameter updates info and the + reconstructed info; default: "l2" + - is_one_hot_label (bool): whether the label is one-hot; + default: False + + + """ + def __init__(self, + max_ite, + lr, + federate_loss_fn, + device, + federate_method, + federate_lr=None, + optim='Adam', + info_diff_type='l2', + is_one_hot_label=False): + + if federate_method.lower() == "fedavg": + # check whether the received info is parameter. If yes, + # the reconstruction attack requires the learning rate of FL + assert federate_lr is not None + + self.info_is_para = federate_method.lower() == "fedavg" + self.federate_lr = federate_lr + + self.max_ite = max_ite + self.lr = lr + self.device = device + self.optim = optim + self.federate_loss_fn = federate_loss_fn + self.info_diff_type = info_diff_type + self.info_diff_loss = get_info_diff_loss(info_diff_type) + + self.is_one_hot_label = is_one_hot_label + + def eval(self): + pass + + def _setup_optimizer(self, parameters): + if self.optim.lower() == 'adam': + optimizer = torch.optim.Adam(parameters, lr=self.lr) + elif self.optim.lower() == 'sgd': # actually gd + optimizer = torch.optim.SGD(parameters, + lr=self.lr, + momentum=0.9, + nesterov=True) + elif self.optim.lower() == 'lbfgs': + optimizer = torch.optim.LBFGS(parameters) + else: + raise ValueError() + return optimizer + + def _gradient_closure(self, model, optimizer, dummy_data, dummy_label, + original_info): + def closure(): + optimizer.zero_grad() + model.zero_grad() + + loss = self.federate_loss_fn( + model(dummy_data), + dummy_label.view(-1, ).type(torch.LongTensor).to( + torch.device(self.device))) + + gradient = torch.autograd.grad(loss, + model.parameters(), + create_graph=True) + info_diff = 0 + for g_dumby, gt in zip(gradient, original_info): + info_diff += self.info_diff_loss(g_dumby, gt) + info_diff.backward() + return info_diff + + return closure + + def _run_simple_reconstruct(self, model, optimizer, dummy_data, label, + original_gradient, closure_fn): + + for ite in range(self.max_ite): + closure = closure_fn(model, optimizer, dummy_data, label, + original_gradient) + info_diff = optimizer.step(closure) + + if (ite + 1 == self.max_ite) or ite % 20 == 0: + logger.info('Ite: {}, gradient difference: {:.4f}'.format( + ite, info_diff)) + return dummy_data.detach(), label.detach() + + def get_original_gradient_from_para(self, model, original_info, + model_para_name): + ''' + + Transfer the model parameter updates to gradient based on: + + .. math:: + P_{t} = P - \eta g, + where + :math:`P_{t}` is the parameters updated by the client at current round; + :math:`P` is the parameters of the global model at the end of the + last round; + :math:`\eta` is the learning rate of clients' local training; + :math:`g` is the gradient + + + + Arguments: + - model (object): The model owned by the Server + - original_info (dict): The model parameter updates received by + Server + - model_para_name (list): The list of model name. Be sure the + :attr:`model_para_name` is consistent with the the key name in + :attr:`original_info` + + :returns: + - original_gradient (list): the list of the gradient + corresponding to the model updates + + ''' + original_gradient = [ + ((original_para - + original_info[name].to(torch.device(self.device))) / + self.federate_lr).detach() + for original_para, name in zip(model.parameters(), model_para_name) + ] + return original_gradient + + def reconstruct(self, model, original_info, data_feature_dim, num_class, + batch_size): + ''' + Reconstruct the original training data and label. + + Args: + model: The model used in FL; Type: object + original_info: The message received to perform reconstruction, + usually the gradient/parameter updates; Type: list + data_feature_dim: The feature dimension of dataset; Type: list + or Tensor.Size + num_class: the number of total classes in the dataset; Type: int + batch_size: the number of samples in the batch that + generate the original_info; Type: int + + :returns: + - The reconstructed data (Tensor); Size: [batch_size, + data_feature_dim] + - The reconstructed label (Tensor): Size: [batch_size] + + + ''' + # inital dummy data and label + dummy_data_dim = [batch_size] + dummy_data_dim.extend(data_feature_dim) + dummy_data = torch.randn(dummy_data_dim).to(torch.device( + self.device)).requires_grad_(True) + + para_trainable_name = [] + for p in model.named_parameters(): + para_trainable_name.append(p[0]) + + if self.info_is_para: + original_gradient = self.get_original_gradient_from_para( + model, original_info, model_para_name=para_trainable_name) + else: + original_gradient = [ + grad.to(torch.device(self.device)) for k, grad in original_info + ] + + label = iDLG_trick(original_gradient, + num_class=num_class, + is_one_hot_label=self.is_one_hot_label) + label = label.to(torch.device(self.device)) + + # setup optimizer + optimizer = self._setup_optimizer([dummy_data]) + + self._run_simple_reconstruct(model, + optimizer, + dummy_data, + label=label, + original_gradient=original_gradient, + closure_fn=self._gradient_closure) + + return dummy_data.detach(), label.detach() + + +class InvertGradient(DLG): + ''' + The implementation of "Inverting Gradients - How easy is it to break + privacy in federated learning?". + Link: https://proceedings.neurips.cc/paper/2020/hash/ \ + c4ede56bbd98819ae6112b20ac6bf145-Abstract.html + + References: + + Geiping, Jonas, et al. "Inverting gradients-how easy is it to break + privacy in federated learning?." Advances in Neural Information + Processing Systems 33 (2020): 16937-16947. + + Args: + - max_ite (int): the max iteration number; + - lr (float): learning rate in optimization based reconstruction; + - federate_loss_fn (object): The loss function used in FL training; + - device (str): the device running the reconstruction; + - federate_method (str): The federated learning method; + - federate_lr (float): The learning rate used in FL training; + default: None. + - alpha_TV (float): the hyper-parameter of the total variance + term; default: 0.001 + - info_diff_type (str): The type of loss between the + ground-truth gradient/parameter updates info and the + reconstructed info; default: "l2" + - optim (str): The optimization method used in reconstruction; + default: "Adam"; supported: 'sgd', 'adam', 'lbfgs' + - info_diff_type (str): The type of loss between the + ground-truth gradient/parameter updates info and the + reconstructed info; default: "l2" + - is_one_hot_label (bool): whether the label is one-hot; + default: False + ''' + def __init__(self, + max_ite, + lr, + federate_loss_fn, + device, + federate_method, + federate_lr=None, + alpha_TV=0.001, + info_diff_type='sim', + optim='Adam', + is_one_hot_label=False): + super(InvertGradient, self).__init__(max_ite, + lr, + federate_loss_fn, + device, + federate_method, + federate_lr=federate_lr, + optim=optim, + info_diff_type=info_diff_type, + is_one_hot_label=is_one_hot_label) + self.alpha_TV = alpha_TV + if self.info_diff_type != 'sim': + logger.info( + 'Force the info_diff_type to be cosine similarity loss in ' + 'InvertGradient attack method!') + self.info_diff_type = 'sim' + self.info_diff_loss = get_info_diff_loss(self.info_diff_type) + + def _gradient_closure(self, model, optimizer, dummy_data, dummy_label, + original_gradient): + def closure(): + optimizer.zero_grad() + model.zero_grad() + loss = self.federate_loss_fn( + model(dummy_data), + dummy_label.view(-1, ).type(torch.LongTensor).to( + torch.device(self.device))) + + gradient = torch.autograd.grad(loss, + model.parameters(), + create_graph=True) + gradient_diff = 0 + + for g_dummy, gt in zip(gradient, original_gradient): + gradient_diff += self.info_diff_loss(g_dummy, gt) + + # add total variance regularization + if self.alpha_TV > 0: + gradient_diff += self.alpha_TV * total_variation(dummy_data) + gradient_diff.backward() + return gradient_diff + + return closure diff --git a/fgssl/attack/trainer/GAN_trainer.py b/fgssl/attack/trainer/GAN_trainer.py new file mode 100644 index 0000000..8cfcb31 --- /dev/null +++ b/fgssl/attack/trainer/GAN_trainer.py @@ -0,0 +1,104 @@ +import logging +from typing import Type + +from federatedscope.core.trainers import GeneralTorchTrainer +from federatedscope.attack.privacy_attacks.GAN_based_attack import GANCRA + +logger = logging.getLogger(__name__) + + +def wrap_GANTrainer( + base_trainer: Type[GeneralTorchTrainer]) -> Type[GeneralTorchTrainer]: + ''' + Warp the trainer for gan_based class representative attack. + + Args: + base_trainer: Type: core.trainers.GeneralTorchTrainer + + :returns: + The wrapped trainer; Type: core.trainers.GeneralTorchTrainer + + ''' + + # ---------------- attribute-level plug-in ----------------------- + + base_trainer.ctx.target_label_ind = \ + base_trainer.cfg.attack.target_label_ind + base_trainer.ctx.gan_cra = GANCRA(base_trainer.cfg.attack.target_label_ind, + base_trainer.ctx.model, + dataset_name=base_trainer.cfg.data.type, + device=base_trainer.ctx.device, + sav_pth=base_trainer.cfg.outdir) + + # ---- action-level plug-in ------- + + base_trainer.register_hook_in_train(new_hook=hood_on_fit_start_generator, + trigger='on_fit_start', + insert_mode=-1) + base_trainer.register_hook_in_train(new_hook=hook_on_gan_cra_train, + trigger='on_batch_start', + insert_mode=-1) + base_trainer.register_hook_in_train( + new_hook=hook_on_batch_injected_data_generation, + trigger='on_batch_start', + insert_mode=-1) + base_trainer.register_hook_in_train( + new_hook=hook_on_batch_forward_injected_data, + trigger='on_batch_forward', + insert_mode=-1) + + base_trainer.register_hook_in_train( + new_hook=hook_on_data_injection_sav_data, + trigger='on_fit_end', + insert_mode=-1) + + return base_trainer + + +def hood_on_fit_start_generator(ctx): + ''' + count the FL training round before fitting + Args: + ctx (): + + Returns: + + ''' + ctx.gan_cra.round_num += 1 + logger.info('----- Round {}: GAN training ............'.format( + ctx.gan_cra.round_num)) + + +def hook_on_batch_forward_injected_data(ctx): + ''' + inject the generated data into training batch loss + Args: + ctx (): + + Returns: + + ''' + x, label = [_.to(ctx.device) for _ in ctx.injected_data] + pred = ctx.model(x) + if len(label.size()) == 0: + label = label.unsqueeze(0) + ctx.loss_task += ctx.criterion(pred, label) + ctx.y_true_injected = label + ctx.y_prob_injected = pred + + +def hook_on_batch_injected_data_generation(ctx): + '''generate the injected data + ''' + ctx.injected_data = ctx.gan_cra.generate_fake_data() + + +def hook_on_gan_cra_train(ctx): + + ctx.gan_cra.update_discriminator(ctx.model) + ctx.gan_cra.generator_train() + + +def hook_on_data_injection_sav_data(ctx): + + ctx.gan_cra.generate_and_save_images() diff --git a/fgssl/attack/trainer/MIA_invert_gradient_trainer.py b/fgssl/attack/trainer/MIA_invert_gradient_trainer.py new file mode 100644 index 0000000..981d4c8 --- /dev/null +++ b/fgssl/attack/trainer/MIA_invert_gradient_trainer.py @@ -0,0 +1,139 @@ +import logging +from typing import Type + +import torch + +from federatedscope.core.trainers import GeneralTorchTrainer +from federatedscope.core.data.wrap_dataset import WrapDataset +from federatedscope.attack.auxiliary.MIA_get_target_data import get_target_data + +logger = logging.getLogger(__name__) + + +def wrap_GradientAscentTrainer( + base_trainer: Type[GeneralTorchTrainer]) -> Type[GeneralTorchTrainer]: + ''' + wrap the gradient_invert trainer + + Args: + base_trainer: Type: core.trainers.GeneralTorchTrainer + + :returns: + The wrapped trainer; Type: core.trainers.GeneralTorchTrainer + + ''' + + # base_trainer.ctx.target_data = get_target_data() + base_trainer.ctx.target_data_dataloader = WrapDataset( + get_target_data(base_trainer.cfg.data.type)) + base_trainer.ctx.target_data = get_target_data(base_trainer.cfg.data.type) + + base_trainer.ctx.is_target_batch = False + base_trainer.ctx.finish_injected = False + + base_trainer.ctx.target_data_loss = [] + + base_trainer.ctx.outdir = base_trainer.cfg.outdir + base_trainer.ctx.round = -1 + base_trainer.ctx.inject_round = base_trainer.cfg.attack.inject_round + base_trainer.ctx.mia_is_simulate_in = \ + base_trainer.cfg.attack.mia_is_simulate_in + base_trainer.ctx.mia_simulate_in_round = \ + base_trainer.cfg.attack.mia_simulate_in_round + + base_trainer.register_hook_in_train(new_hook=hook_on_fit_start_count_round, + trigger='on_fit_start', + insert_mode=-1) + + base_trainer.register_hook_in_train( + new_hook=hook_on_batch_start_replace_data_batch, + trigger='on_batch_start', + insert_mode=-1) + + base_trainer.replace_hook_in_train( + new_hook=hook_on_batch_backward_invert_gradient, + target_trigger='on_batch_backward', + target_hook_name='_hook_on_batch_backward') + + base_trainer.register_hook_in_train( + new_hook=hook_on_fit_start_loss_on_target_data, + trigger='on_fit_start', + insert_mode=-1) + + # plot the target data loss at the end of fitting + + return base_trainer + + +def hook_on_fit_start_count_round(ctx): + ctx.round += 1 + logger.info("============== round: {} ====================".format( + ctx.round)) + + +def hook_on_batch_start_replace_data_batch(ctx): + # replace the data batch to the target data + # check whether need to replace the data; if yes, replace the current + # batch to target batch + if ctx.finish_injected == False and ctx.round >= ctx.inject_round: + logger.info("---------- inject the target data ---------") + ctx.data_batch = ctx.target_data + ctx.is_target_batch = True + logger.info(ctx.target_data[0].size()) + elif ctx.round == ctx.inject_round + ctx.mia_simulate_in_round and \ + ctx.mia_is_simulate_in: + # to simulate the case that the target data is in the training dataset + logger.info( + "---------- put the target data into training in round {}---------" + .format(ctx.round)) + ctx.data_batch = ctx.target_data + ctx.is_target_batch = False + else: + ctx.is_target_batch = False + + +def hook_on_batch_backward_invert_gradient(ctx): + if ctx.is_target_batch: + # if the current data batch is the target data, perform gradient ascent + ctx.optimizer.zero_grad() + ctx.loss_batch.backward() + original_grad = [] + + for param in ctx["model"].parameters(): + original_grad.append(param.grad.detach()) + param.grad = -1 * param.grad + + modified_grad = [] + for param in ctx.model.parameters(): + modified_grad.append(param.grad.detach()) + + ctx["optimizer"].step() + logger.info('-------------- Gradient ascent finished -------------') + ctx.finish_injected = True + + else: + # if current batch is not target data, perform regular backward step + ctx.optimizer.zero_grad() + ctx.loss_task.backward() + if ctx.grad_clip > 0: + torch.nn.utils.clip_grad_norm_(ctx.model.parameters(), + ctx.grad_clip) + ctx.optimizer.step() + + +def hook_on_fit_start_loss_on_target_data(ctx): + # monitor the loss on the target data after performing gradient ascent + # action. + if ctx.finish_injected: + tmp_loss = [] + x, label = [_.to(ctx.device) for _ in ctx.target_data] + logger.info(x.size()) + num_target = x.size()[0] + + for i in range(num_target): + x_i = x[i, :].unsqueeze(0) + label_i = label[i].reshape(-1) + pred = ctx.model(x_i) + tmp_loss.append( + ctx.criterion(pred, label_i).detach().cpu().numpy()) + ctx.target_data_loss.append(tmp_loss) diff --git a/fgssl/attack/trainer/PIA_trainer.py b/fgssl/attack/trainer/PIA_trainer.py new file mode 100644 index 0000000..d0826b3 --- /dev/null +++ b/fgssl/attack/trainer/PIA_trainer.py @@ -0,0 +1,18 @@ +from typing import Type + +from federatedscope.core.trainers import GeneralTorchTrainer +from federatedscope.attack.auxiliary.utils import get_data_property + + +def wrap_ActivePIATrainer( + base_trainer: Type[GeneralTorchTrainer]) -> Type[GeneralTorchTrainer]: + base_trainer.ctx.alpha_prop_loss = base_trainer._cfg.attack.alpha_prop_loss + + +def hood_on_batch_start_get_prop(ctx): + ctx.prop = get_data_property(ctx.data_batch) + + +def hook_on_batch_forward_add_PIA_loss(ctx): + ctx.loss_batch = ctx.alpha_prop_loss * ctx.loss_batch + ( + 1 - ctx.alpha_prop_loss) * ctx.criterion(ctx.y_prob, ctx.prop) diff --git a/fgssl/attack/trainer/__init__.py b/fgssl/attack/trainer/__init__.py new file mode 100644 index 0000000..37d4d78 --- /dev/null +++ b/fgssl/attack/trainer/__init__.py @@ -0,0 +1,16 @@ +from federatedscope.attack.trainer.GAN_trainer import * +from federatedscope.attack.trainer.MIA_invert_gradient_trainer import * +from federatedscope.attack.trainer.PIA_trainer import * +from federatedscope.attack.trainer.backdoor_trainer import * +from federatedscope.attack.trainer.benign_trainer import * + +__all__ = [ + 'wrap_GANTrainer', 'hood_on_fit_start_generator', + 'hook_on_batch_forward_injected_data', + 'hook_on_batch_injected_data_generation', 'hook_on_gan_cra_train', + 'hook_on_data_injection_sav_data', 'wrap_GradientAscentTrainer', + 'hook_on_fit_start_count_round', 'hook_on_batch_start_replace_data_batch', + 'hook_on_batch_backward_invert_gradient', + 'hook_on_fit_start_loss_on_target_data', 'wrap_backdoorTrainer', + 'wrap_benignTrainer' +] diff --git a/fgssl/attack/trainer/backdoor_trainer.py b/fgssl/attack/trainer/backdoor_trainer.py new file mode 100644 index 0000000..b00faf1 --- /dev/null +++ b/fgssl/attack/trainer/backdoor_trainer.py @@ -0,0 +1,180 @@ +import logging +from typing import Type +import torch +import numpy as np +import copy + +from federatedscope.core.trainers import GeneralTorchTrainer +from torch.nn.utils import parameters_to_vector, vector_to_parameters + +logger = logging.getLogger(__name__) + + +def wrap_backdoorTrainer( + base_trainer: Type[GeneralTorchTrainer]) -> Type[GeneralTorchTrainer]: + + # ---------------- attribute-level plug-in ----------------------- + base_trainer.ctx.target_label_ind \ + = base_trainer.cfg.attack.target_label_ind + base_trainer.ctx.trigger_type = base_trainer.cfg.attack.trigger_type + base_trainer.ctx.label_type = base_trainer.cfg.attack.label_type + + # ---- action-level plug-in ------- + + if base_trainer.cfg.attack.self_opt: + + base_trainer.ctx.self_lr = base_trainer.cfg.attack.self_lr + base_trainer.ctx.self_epoch = base_trainer.cfg.attack.self_epoch + + base_trainer.register_hook_in_train( + new_hook=hook_on_fit_start_init_local_opt, + trigger='on_fit_start', + insert_pos=-1) + + base_trainer.register_hook_in_train(new_hook=hook_on_fit_end_reset_opt, + trigger='on_fit_end', + insert_pos=0) + + scale_poisoning = base_trainer.cfg.attack.scale_poisoning + pgd_poisoning = base_trainer.cfg.attack.pgd_poisoning + + if scale_poisoning or pgd_poisoning: + + base_trainer.register_hook_in_train( + new_hook=hook_on_fit_start_init_local_model, + trigger='on_fit_start', + insert_pos=-1) + + if base_trainer.cfg.attack.scale_poisoning: + + base_trainer.ctx.scale_para = base_trainer.cfg.attack.scale_para + + base_trainer.register_hook_in_train( + new_hook=hook_on_fit_end_scale_poisoning, + trigger="on_fit_end", + insert_pos=-1) + + if base_trainer.cfg.attack.pgd_poisoning: + + base_trainer.ctx.self_epoch = base_trainer.cfg.attack.self_epoch + base_trainer.ctx.pgd_lr = base_trainer.cfg.attack.pgd_lr + base_trainer.ctx.pgd_eps = base_trainer.cfg.attack.pgd_eps + base_trainer.ctx.batch_index = 0 + + base_trainer.register_hook_in_train( + new_hook=hook_on_fit_start_init_local_pgd, + trigger='on_fit_start', + insert_pos=-1) + + base_trainer.register_hook_in_train( + new_hook=hook_on_batch_end_project_grad, + trigger='on_batch_end', + insert_pos=-1) + + base_trainer.register_hook_in_train( + new_hook=hook_on_epoch_end_project_grad, + trigger='on_epoch_end', + insert_pos=-1) + + base_trainer.register_hook_in_train(new_hook=hook_on_fit_end_reset_opt, + trigger='on_fit_end', + insert_pos=0) + + return base_trainer + + +def hook_on_fit_start_init_local_opt(ctx): + + ctx.original_epoch = ctx["num_train_epoch"] + ctx["num_train_epoch"] = ctx.self_epoch + + +def hook_on_fit_end_reset_opt(ctx): + + ctx["num_train_epoch"] = ctx.original_epoch + + +def hook_on_fit_start_init_local_model(ctx): + + # the original global model + ctx.original_model = copy.deepcopy(ctx.model) + + +def hook_on_fit_end_scale_poisoning(ctx): + + # conduct the scale poisoning + scale_para = ctx.scale_para + + v = torch.nn.utils.parameters_to_vector(ctx.original_model.parameters()) + logger.info("the Norm of the original global model: {}".format( + torch.norm(v).item())) + + v = torch.nn.utils.parameters_to_vector(ctx.model.parameters()) + logger.info("Attacker before scaling : Norm = {}".format( + torch.norm(v).item())) + + ctx.original_model = list(ctx.original_model.parameters()) + + for idx, param in enumerate(ctx.model.parameters()): + param.data = (param.data - ctx.original_model[idx] + ) * scale_para + ctx.original_model[idx] + + v = torch.nn.utils.parameters_to_vector(ctx.model.parameters()) + logger.info("Attacker after scaling : Norm = {}".format( + torch.norm(v).item())) + + logger.info('finishing model scaling poisoning attack') + + +def hook_on_fit_start_init_local_pgd(ctx): + + ctx.original_optimizer = ctx.optimizer + ctx.original_epoch = ctx["num_train_epoch"] + ctx["num_train_epoch"] = ctx.self_epoch + ctx.optimizer = torch.optim.SGD(ctx.model.parameters(), lr=ctx.pgd_lr) + # looks like adversary needs same lr to hide with others + + +def hook_on_batch_end_project_grad(ctx): + ''' + after every 10 iters, we project update on the predefined norm ball. + ''' + eps = ctx.pgd_eps + project_frequency = 10 + ctx.batch_index += 1 + w = list(ctx.model.parameters()) + w_vec = parameters_to_vector(w) + model_original_vec = parameters_to_vector( + list(ctx.original_model.parameters())) + # make sure you project on last iteration otherwise, + # high LR pushes you really far + if (ctx.batch_index % project_frequency + == 0) and (torch.norm(w_vec - model_original_vec) > eps): + # project back into norm ball + w_proj_vec = eps * (w_vec - model_original_vec) / torch.norm( + w_vec - model_original_vec) + model_original_vec + # plug w_proj back into model + vector_to_parameters(w_proj_vec, w) + + +def hook_on_epoch_end_project_grad(ctx): + ''' + after the whole epoch, we project the update on the predefined norm ball. + ''' + ctx.batch_index = 0 + eps = ctx.pgd_eps + w = list(ctx.model.parameters()) + w_vec = parameters_to_vector(w) + model_original_vec = parameters_to_vector( + list(ctx.original_model.parameters())) + if (torch.norm(w_vec - model_original_vec) > eps): + # project back into norm ball + w_proj_vec = eps * (w_vec - model_original_vec) / torch.norm( + w_vec - model_original_vec) + model_original_vec + # plug w_proj back into model + vector_to_parameters(w_proj_vec, w) + + +def hook_on_fit_end_reset_pgd(ctx): + + ctx.optimizer = ctx.original_optimizer diff --git a/fgssl/attack/trainer/benign_trainer.py b/fgssl/attack/trainer/benign_trainer.py new file mode 100644 index 0000000..e6bd30c --- /dev/null +++ b/fgssl/attack/trainer/benign_trainer.py @@ -0,0 +1,81 @@ +import logging +from typing import Type +import numpy as np + +from federatedscope.core.trainers import GeneralTorchTrainer + +logger = logging.getLogger(__name__) + + +def wrap_benignTrainer( + base_trainer: Type[GeneralTorchTrainer]) -> Type[GeneralTorchTrainer]: + ''' + Warp the benign trainer for backdoor attack: + We just add the normalization operation. + Args: + base_trainer: Type: core.trainers.GeneralTorchTrainer + :returns: + The wrapped trainer; Type: core.trainers.GeneralTorchTrainer + ''' + base_trainer.register_hook_in_eval(new_hook=hook_on_fit_end_test_poison, + trigger='on_fit_end', + insert_pos=0) + + return base_trainer + + +def hook_on_fit_end_test_poison(ctx): + """ + Evaluate metrics of poisoning attacks. + """ + + ctx['poison_' + ctx.cur_split + '_loader'] = ctx.data['poison_' + + ctx.cur_split] + ctx['poison_' + ctx.cur_split + '_data'] = ctx.data['poison_' + + ctx.cur_split].dataset + ctx['num_poison_' + ctx.cur_split + '_data'] = len( + ctx.data['poison_' + ctx.cur_split].dataset) + setattr(ctx, "poison_{}_y_true".format(ctx.cur_split), []) + setattr(ctx, "poison_{}_y_prob".format(ctx.cur_split), []) + setattr(ctx, "poison_num_samples_{}".format(ctx.cur_split), 0) + + for batch_idx, (samples, targets) in enumerate( + ctx['poison_' + ctx.cur_split + '_loader']): + samples, targets = samples.to(ctx.device), targets.to(ctx.device) + pred = ctx.model(samples) + if len(targets.size()) == 0: + targets = targets.unsqueeze(0) + ctx.poison_y_true = targets + ctx.poison_y_prob = pred + ctx.poison_batch_size = len(targets) + + ctx.get("poison_{}_y_true".format(ctx.cur_split)).append( + ctx.poison_y_true.detach().cpu().numpy()) + + ctx.get("poison_{}_y_prob".format(ctx.cur_split)).append( + ctx.poison_y_prob.detach().cpu().numpy()) + + setattr( + ctx, "poison_num_samples_{}".format(ctx.cur_split), + ctx.get("poison_num_samples_{}".format(ctx.cur_split)) + + ctx.poison_batch_size) + + setattr(ctx, "poison_{}_y_true".format(ctx.cur_split), + np.concatenate(ctx.get("poison_{}_y_true".format(ctx.cur_split)))) + setattr(ctx, "poison_{}_y_prob".format(ctx.cur_split), + np.concatenate(ctx.get("poison_{}_y_prob".format(ctx.cur_split)))) + + logger.info('the {} poisoning samples: {:d}'.format( + ctx.cur_split, ctx.get("poison_num_samples_{}".format(ctx.cur_split)))) + + poison_true = ctx['poison_' + ctx.cur_split + '_y_true'] + poison_prob = ctx['poison_' + ctx.cur_split + '_y_prob'] + + poison_pred = np.argmax(poison_prob, axis=1) + + correct = poison_true == poison_pred + + poisoning_acc = float(np.sum(correct)) / len(correct) + + logger.info('the {} poisoning accuracy: {:f}'.format( + ctx.cur_split, poisoning_acc)) diff --git a/fgssl/attack/worker_as_attacker/__init__.py b/fgssl/attack/worker_as_attacker/__init__.py new file mode 100644 index 0000000..ee3a8f3 --- /dev/null +++ b/fgssl/attack/worker_as_attacker/__init__.py @@ -0,0 +1,12 @@ +from __future__ import absolute_import +from __future__ import print_function +from __future__ import division + +from federatedscope.attack.worker_as_attacker.active_client import * +from federatedscope.attack.worker_as_attacker.server_attacker import * + +__all__ = [ + 'plot_target_loss', 'sav_target_loss', 'callback_funcs_for_finish', + 'add_atk_method_to_Client_GradAscent', 'PassiveServer', 'PassivePIAServer', + 'BackdoorServer' +] diff --git a/fgssl/attack/worker_as_attacker/active_client.py b/fgssl/attack/worker_as_attacker/active_client.py new file mode 100644 index 0000000..651b2c3 --- /dev/null +++ b/fgssl/attack/worker_as_attacker/active_client.py @@ -0,0 +1,51 @@ +import matplotlib.pyplot as plt +import numpy as np +import os +from federatedscope.core.message import Message +import logging + +logger = logging.getLogger(__name__) + + +def plot_target_loss(loss_list, outdir): + ''' + + Args: + loss_list: the list of loss regrading the target data + outdir: the directory to store the loss + + ''' + + target_data_loss = np.vstack(loss_list) + logger.info(target_data_loss.shape) + plt.plot(target_data_loss) + plt.savefig(os.path.join(outdir, 'target_loss.png')) + plt.close() + + +def sav_target_loss(loss_list, outdir): + target_data_loss = np.vstack(loss_list) + np.savetxt(os.path.join(outdir, 'target_loss.txt'), + target_data_loss.transpose(), + delimiter=',') + + +def callback_funcs_for_finish(self, message: Message): + logger.info("============== receiving Finish Message ==============") + if message.content is not None: + self.trainer.update(message.content) + if self.is_attacker and self._cfg.attack.attack_method.lower( + ) == "gradascent": + logger.info( + "============== start attack post-processing ==============") + plot_target_loss(self.trainer.ctx.target_data_loss, + self.trainer.ctx.outdir) + sav_target_loss(self.trainer.ctx.target_data_loss, + self.trainer.ctx.outdir) + + +def add_atk_method_to_Client_GradAscent(client_class): + + setattr(client_class, 'callback_funcs_for_finish', + callback_funcs_for_finish) + return client_class diff --git a/fgssl/attack/worker_as_attacker/server_attacker.py b/fgssl/attack/worker_as_attacker/server_attacker.py new file mode 100644 index 0000000..2265682 --- /dev/null +++ b/fgssl/attack/worker_as_attacker/server_attacker.py @@ -0,0 +1,367 @@ +from federatedscope.core.workers import Server +from federatedscope.core.message import Message + +from federatedscope.core.auxiliaries.criterion_builder import get_criterion +import copy +from federatedscope.attack.auxiliary.utils import get_data_sav_fn, \ + get_reconstructor + +import logging + +import torch +import numpy as np +from federatedscope.attack.privacy_attacks.passive_PIA import \ + PassivePropertyInference + +logger = logging.getLogger(__name__) + + +class BackdoorServer(Server): + ''' + For backdoor attacks, we will choose different sampling stratergies. + fix-frequency, all-round ,or random sampling. + ''' + def __init__(self, + ID=-1, + state=0, + config=None, + data=None, + model=None, + client_num=5, + total_round_num=10, + device='cpu', + strategy=None, + unseen_clients_id=None, + **kwargs): + super(BackdoorServer, self).__init__(ID=ID, + state=state, + data=data, + model=model, + config=config, + client_num=client_num, + total_round_num=total_round_num, + device=device, + strategy=strategy, + **kwargs) + + def broadcast_model_para(self, + msg_type='model_para', + sample_client_num=-1, + filter_unseen_clients=True): + """ + To broadcast the message to all clients or sampled clients + + Arguments: + msg_type: 'model_para' or other user defined msg_type + sample_client_num: the number of sampled clients in the broadcast + behavior. And sample_client_num = -1 denotes to broadcast to + all the clients. + filter_unseen_clients: whether filter out the unseen clients that + do not contribute to FL process by training on their local + data and uploading their local model update. The splitting is + useful to check participation generalization gap in [ICLR'22, + What Do We Mean by Generalization in Federated Learning?] + You may want to set it to be False when in evaluation stage + """ + + if filter_unseen_clients: + # to filter out the unseen clients when sampling + self.sampler.change_state(self.unseen_clients_id, 'unseen') + + if sample_client_num > 0: # only activated at training process + attacker_id = self._cfg.attack.attacker_id + setting = self._cfg.attack.setting + insert_round = self._cfg.attack.insert_round + + if attacker_id == -1 or self._cfg.attack.attack_method == '': + + receiver = np.random.choice(np.arange(1, self.client_num + 1), + size=sample_client_num, + replace=False).tolist() + + elif setting == 'fix': + if self.state % self._cfg.attack.freq == 0: + client_list = np.delete(np.arange(1, self.client_num + 1), + self._cfg.attack.attacker_id - 1) + receiver = np.random.choice(client_list, + size=sample_client_num - 1, + replace=False).tolist() + receiver.insert(0, self._cfg.attack.attacker_id) + logger.info('starting the fix-frequency poisoning attack') + logger.info( + 'starting poisoning round: {:d}, the attacker ID: {:d}' + .format(self.state, self._cfg.attack.attacker_id)) + else: + client_list = np.delete(np.arange(1, self.client_num + 1), + self._cfg.attack.attacker_id - 1) + receiver = np.random.choice(client_list, + size=sample_client_num, + replace=False).tolist() + + elif setting == 'single' and self.state == insert_round: + client_list = np.delete(np.arange(1, self.client_num + 1), + self._cfg.attack.attacker_id - 1) + receiver = np.random.choice(client_list, + size=sample_client_num - 1, + replace=False).tolist() + receiver.insert(0, self._cfg.attack.attacker_id) + logger.info('starting the single-shot poisoning attack') + logger.info( + 'starting poisoning round: {:d}, the attacker ID: {:d}'. + format(self.state, self._cfg.attack.attacker_id)) + + elif self._cfg.attack.setting == 'all': + + client_list = np.delete(np.arange(1, self.client_num + 1), + self._cfg.attack.attacker_id - 1) + receiver = np.random.choice(client_list, + size=sample_client_num - 1, + replace=False).tolist() + receiver.insert(0, self._cfg.attack.attacker_id) + logger.info('starting the all-round poisoning attack') + logger.info( + 'starting poisoning round: {:d}, the attacker ID: {:d}'. + format(self.state, self._cfg.attack.attacker_id)) + + else: + receiver = np.random.choice(np.arange(1, self.client_num + 1), + size=sample_client_num, + replace=False).tolist() + + else: + # broadcast to all clients + receiver = list(self.comm_manager.neighbors.keys()) + + if self._noise_injector is not None and msg_type == 'model_para': + # Inject noise only when broadcast parameters + for model_idx_i in range(len(self.models)): + num_sample_clients = [ + v["num_sample"] for v in self.join_in_info.values() + ] + self._noise_injector(self._cfg, num_sample_clients, + self.models[model_idx_i]) + + skip_broadcast = self._cfg.federate.method in ["local", "global"] + if self.model_num > 1: + model_para = [{} if skip_broadcast else model.state_dict() + for model in self.models] + else: + model_para = {} if skip_broadcast else self.model.state_dict() + + self.comm_manager.send( + Message(msg_type=msg_type, + sender=self.ID, + receiver=receiver, + state=min(self.state, self.total_round_num), + content=model_para)) + if self._cfg.federate.online_aggr: + for idx in range(self.model_num): + self.aggregators[idx].reset() + + if filter_unseen_clients: + # restore the state of the unseen clients within sampler + self.sampler.change_state(self.unseen_clients_id, 'seen') + + +class PassiveServer(Server): + ''' + In passive attack, the server store the model and the message collected + from the client,and perform the optimization based reconstruction, + such as DLG, InvertGradient. + ''' + def __init__(self, + ID=-1, + state=0, + data=None, + model=None, + client_num=5, + total_round_num=10, + device='cpu', + strategy=None, + state_to_reconstruct=None, + client_to_reconstruct=None, + **kwargs): + super(PassiveServer, self).__init__(ID=ID, + state=state, + data=data, + model=model, + client_num=client_num, + total_round_num=total_round_num, + device=device, + strategy=strategy, + **kwargs) + + # self.offline_reconstruct = offline_reconstruct + self.atk_method = self._cfg.attack.attack_method + self.state_to_reconstruct = state_to_reconstruct + self.client_to_reconstruct = client_to_reconstruct + self.reconstruct_data = dict() + + # the loss function of the global model; the global model can be + # obtained in self.aggregator.model + self.model_criterion = get_criterion(self._cfg.criterion.type, + device=self.device) + + from federatedscope.attack.auxiliary.utils import get_data_info + self.data_dim, self.num_class, self.is_one_hot_label = get_data_info( + self._cfg.data.type) + + self.reconstructor = self._get_reconstructor() + + self.reconstructed_data_sav_fn = get_data_sav_fn(self._cfg.data.type) + + self.reconstruct_data_summary = dict() + + def _get_reconstructor(self): + + return get_reconstructor( + self.atk_method, + max_ite=self._cfg.attack.max_ite, + lr=self._cfg.attack.reconstruct_lr, + federate_loss_fn=self.model_criterion, + device=self.device, + federate_lr=self._cfg.train.optimizer.lr, + optim=self._cfg.attack.reconstruct_optim, + info_diff_type=self._cfg.attack.info_diff_type, + federate_method=self._cfg.federate.method, + alpha_TV=self._cfg.attack.alpha_TV) + + def _reconstruct(self, model_para, batch_size, state, sender): + logger.info('-------- reconstruct round:{}, client:{}---------'.format( + state, sender)) + dummy_data, dummy_label = self.reconstructor.reconstruct( + model=copy.deepcopy(self.model).to(torch.device(self.device)), + original_info=model_para, + data_feature_dim=self.data_dim, + num_class=self.num_class, + batch_size=batch_size) + if state not in self.reconstruct_data.keys(): + self.reconstruct_data[state] = dict() + self.reconstruct_data[state][sender] = [ + dummy_data.cpu(), dummy_label.cpu() + ] + + def run_reconstruct(self, state_list=None, sender_list=None): + + if state_list is None: + state_list = self.msg_buffer['train'].keys() + + # After FL running, using gradient based reconstruction method to + # recover client's private training data + for state in state_list: + if sender_list is None: + sender_list = self.msg_buffer['train'][state].keys() + for sender in sender_list: + content = self.msg_buffer['train'][state][sender] + self._reconstruct(model_para=content[1], + batch_size=content[0], + state=state, + sender=sender) + + def callback_funcs_model_para(self, message: Message): + if self.is_finish: + return 'finish' + + round, sender, content = message.state, message.sender, message.content + self.sampler.change_state(sender, 'idle') + if round not in self.msg_buffer['train']: + self.msg_buffer['train'][round] = dict() + self.msg_buffer['train'][round][sender] = content + + # run reconstruction before the clear of self.msg_buffer + + if self.state_to_reconstruct is None or message.state in \ + self.state_to_reconstruct: + if self.client_to_reconstruct is None or message.sender in \ + self.client_to_reconstruct: + self.run_reconstruct(state_list=[message.state], + sender_list=[message.sender]) + if self.reconstructed_data_sav_fn is not None: + self.reconstructed_data_sav_fn( + data=self.reconstruct_data[message.state][ + message.sender][0], + sav_pth=self._cfg.outdir, + name='image_state_{}_client_{}.png'.format( + message.state, message.sender)) + + self.check_and_move_on() + + +class PassivePIAServer(Server): + ''' + The implementation of the batch property classifier, the algorithm 3 in + paper: Exploiting Unintended Feature Leakage in Collaborative Learning + + References: + + Melis, Luca, Congzheng Song, Emiliano De Cristofaro and Vitaly + Shmatikov. “Exploiting Unintended Feature Leakage in Collaborative + Learning.” 2019 IEEE Symposium on Security and Privacy (SP) (2019): 691-706 + ''' + def __init__(self, + ID=-1, + state=0, + data=None, + model=None, + client_num=5, + total_round_num=10, + device='cpu', + strategy=None, + **kwargs): + super(PassivePIAServer, self).__init__(ID=ID, + state=state, + data=data, + model=model, + client_num=client_num, + total_round_num=total_round_num, + device=device, + strategy=strategy, + **kwargs) + + # self.offline_reconstruct = offline_reconstruct + self.atk_method = self._cfg.attack.attack_method + self.pia_attacker = PassivePropertyInference( + classier=self._cfg.attack.classifier_PIA, + fl_model_criterion=get_criterion(self._cfg.criterion.type, + device=self.device), + device=self.device, + grad_clip=self._cfg.grad.grad_clip, + dataset_name=self._cfg.data.type, + fl_local_update_num=self._cfg.train.local_update_steps, + fl_type_optimizer=self._cfg.train.optimizer.type, + fl_lr=self._cfg.train.optimizer.lr, + batch_size=100) + + # self.optimizer = get_optimizer( + # type=self._cfg.fedopt.type_optimizer, model=self.model, + # lr=self._cfg.fedopt.optimizer.lr) + # print(self.optimizer) + def callback_funcs_model_para(self, message: Message): + if self.is_finish: + return 'finish' + + round, sender, content = message.state, message.sender, message.content + self.sampler.change_state(sender, 'idle') + if round not in self.msg_buffer['train']: + self.msg_buffer['train'][round] = dict() + self.msg_buffer['train'][round][sender] = content + + # collect the updates + self.pia_attacker.collect_updates( + previous_para=self.model.state_dict(), + updated_parameter=content[1], + round=round, + client_id=sender) + self.pia_attacker.get_data_for_dataset_prop_classifier( + model=self.model) + + if self._cfg.federate.online_aggr: + # TODO: put this line to `check_and_move_on` + # currently, no way to know the latest `sender` + self.aggregator.inc(content) + self.check_and_move_on() + + if self.state == self.total_round_num: + self.pia_attacker.train_property_classifier() + self.pia_results = self.pia_attacker.infer_collected() + print(self.pia_results) diff --git a/fgssl/autotune/__init__.py b/fgssl/autotune/__init__.py new file mode 100644 index 0000000..194971f --- /dev/null +++ b/fgssl/autotune/__init__.py @@ -0,0 +1,9 @@ +from federatedscope.autotune.choice_types import Continuous, Discrete +from federatedscope.autotune.utils import parse_search_space, \ + config2cmdargs, config2str +from federatedscope.autotune.algos import get_scheduler + +__all__ = [ + 'Continuous', 'Discrete', 'parse_search_space', 'config2cmdargs', + 'config2str', 'get_scheduler' +] diff --git a/fgssl/autotune/algos.py b/fgssl/autotune/algos.py new file mode 100644 index 0000000..b7f66f1 --- /dev/null +++ b/fgssl/autotune/algos.py @@ -0,0 +1,482 @@ +import os +import logging +from copy import deepcopy +from contextlib import redirect_stdout +import threading +import math + +import ConfigSpace as CS +import yaml +import numpy as np + +from federatedscope.core.auxiliaries.utils import setup_seed +from federatedscope.core.auxiliaries.data_builder import get_data +from federatedscope.core.auxiliaries.worker_builder import get_client_cls, \ + get_server_cls +from federatedscope.core.fed_runner import FedRunner +from federatedscope.autotune.utils import parse_search_space, \ + config2cmdargs, config2str, summarize_hpo_results + +logger = logging.getLogger(__name__) + + +def make_trial(trial_cfg): + setup_seed(trial_cfg.seed) + data, modified_config = get_data(config=trial_cfg.clone()) + trial_cfg.merge_from_other_cfg(modified_config) + trial_cfg.freeze() + # TODO: enable client-wise configuration + Fed_runner = FedRunner(data=data, + server_class=get_server_cls(trial_cfg), + client_class=get_client_cls(trial_cfg), + config=trial_cfg.clone()) + results = Fed_runner.run() + key1, key2 = trial_cfg.hpo.metric.split('.') + return results[key1][key2] + + +class TrialExecutor(threading.Thread): + """This class is responsible for executing the FL procedure with + a given trial configuration in another thread. + """ + def __init__(self, cfg_idx, signal, returns, trial_config): + threading.Thread.__init__(self) + + self._idx = cfg_idx + self._signal = signal + self._returns = returns + self._trial_cfg = trial_config + + def run(self): + setup_seed(self._trial_cfg.seed) + data, modified_config = get_data(config=self._trial_cfg.clone()) + self._trial_cfg.merge_from_other_cfg(modified_config) + self._trial_cfg.freeze() + # TODO: enable client-wise configuration + Fed_runner = FedRunner(data=data, + server_class=get_server_cls(self._trial_cfg), + client_class=get_client_cls(self._trial_cfg), + config=self._trial_cfg.clone()) + results = Fed_runner.run() + key1, key2 = self._trial_cfg.hpo.metric.split('.') + self._returns['perf'] = results[key1][key2] + self._returns['cfg_idx'] = self._idx + self._signal.set() + + +def get_scheduler(init_cfg): + """To instantiate a scheduler object for conducting HPO + Arguments: + init_cfg (federatedscope.core.configs.config.CN): configuration. + """ + + if init_cfg.hpo.scheduler in [ + 'sha', 'rs', 'bo_kde', 'bohb', 'hb', 'bo_gp', 'bo_rf' + ]: + scheduler = SuccessiveHalvingAlgo(init_cfg) + # elif init_cfg.hpo.scheduler == 'pbt': + # scheduler = PBT(init_cfg) + elif init_cfg.hpo.scheduler.startswith('wrap'): + scheduler = SHAWrapFedex(init_cfg) + return scheduler + + +class Scheduler(object): + """The base class for describing HPO algorithms + """ + def __init__(self, cfg): + """ + Arguments: + cfg (federatedscope.core.configs.config.CN): dict like object, + where each key-value pair corresponds to a field and its + choices. + """ + + self._cfg = cfg + # Create hpo working folder + os.makedirs(self._cfg.hpo.working_folder, exist_ok=True) + self._search_space = parse_search_space(self._cfg.hpo.ss) + + self._init_configs = self._setup() + + logger.info(self._init_configs) + + def _setup(self): + """Prepare the initial configurations based on the search space. + """ + raise NotImplementedError + + def _evaluate(self, configs): + """To evaluate (i.e., conduct the FL procedure) for a given + collection of configurations. + """ + raise NotImplementedError + + def optimize(self): + """To optimize the hyperparameters, that is, executing the HPO + algorithm and then returning the results. + """ + raise NotImplementedError + + +class ModelFreeBase(Scheduler): + """To attempt a collection of configurations exhaustively. + """ + def _setup(self): + self._search_space.seed(self._cfg.seed + 19) + return [ + cfg.get_dictionary() + for cfg in self._search_space.sample_configuration( + size=self._cfg.hpo.init_cand_num) + ] + + def _evaluate(self, configs): + if self._cfg.hpo.num_workers: + # execute FL in parallel by multi-threading + flags = [ + threading.Event() for _ in range(self._cfg.hpo.num_workers) + ] + for i in range(len(flags)): + flags[i].set() + threads = [None for _ in range(len(flags))] + thread_results = [dict() for _ in range(len(flags))] + + perfs = [None for _ in range(len(configs))] + for i, config in enumerate(configs): + available_worker = 0 + while not flags[available_worker].is_set(): + available_worker = (available_worker + 1) % len(threads) + if thread_results[available_worker]: + completed_trial_results = thread_results[available_worker] + cfg_idx = completed_trial_results['cfg_idx'] + perfs[cfg_idx] = completed_trial_results['perf'] + logger.info( + "Evaluate the {}-th config {} and get performance {}". + format(cfg_idx, configs[cfg_idx], perfs[cfg_idx])) + thread_results[available_worker].clear() + + trial_cfg = self._cfg.clone() + trial_cfg.merge_from_list(config2cmdargs(config)) + flags[available_worker].clear() + trial = TrialExecutor(i, flags[available_worker], + thread_results[available_worker], + trial_cfg) + trial.start() + threads[available_worker] = trial + + for i in range(len(flags)): + if not flags[i].is_set(): + threads[i].join() + for i in range(len(thread_results)): + if thread_results[i]: + completed_trial_results = thread_results[i] + cfg_idx = completed_trial_results['cfg_idx'] + perfs[cfg_idx] = completed_trial_results['perf'] + logger.info( + "Evaluate the {}-th config {} and get performance {}". + format(cfg_idx, configs[cfg_idx], perfs[cfg_idx])) + thread_results[i].clear() + + else: + perfs = [None] * len(configs) + for i, config in enumerate(configs): + trial_cfg = self._cfg.clone() + trial_cfg.merge_from_list(config2cmdargs(config)) + perfs[i] = make_trial(trial_cfg) + logger.info( + "Evaluate the {}-th config {} and get performance {}". + format(i, config, perfs[i])) + + return perfs + + def optimize(self): + perfs = self._evaluate(self._init_configs) + + results = summarize_hpo_results(self._init_configs, + perfs, + white_list=set( + self._search_space.keys()), + desc=self._cfg.hpo.larger_better) + logger.info( + "========================== HPO Final ==========================") + logger.info("\n{}".format(results)) + logger.info("====================================================") + + return results + + +class IterativeScheduler(ModelFreeBase): + """The base class for HPO algorithms that divide the whole optimization + procedure into iterations. + """ + def _setup(self): + self._stage = 0 + return super(IterativeScheduler, self)._setup() + + def _stop_criterion(self, configs, last_results): + """To determine whether the algorithm should be terminated. + + Arguments: + configs (list): each element is a trial configuration. + last_results (DataFrame): each row corresponds to a specific + configuration as well as its latest performance. + :returns: whether to terminate. + :rtype: bool + """ + raise NotImplementedError + + def _iteration(self, configs): + """To evaluate the given collection of configurations at this stage. + + Arguments: + configs (list): each element is a trial configuration. + :returns: the performances of the given configurations. + :rtype: list + """ + + perfs = self._evaluate(configs) + return perfs + + def _generate_next_population(self, configs, perfs): + """To generate the configurations for the next stage. + + Arguments: + configs (list): the configurations of last stage. + perfs (list): their corresponding performances. + :returns: configuration for the next stage. + :rtype: list + """ + + raise NotImplementedError + + def optimize(self): + current_configs = deepcopy(self._init_configs) + last_results = None + while not self._stop_criterion(current_configs, last_results): + current_perfs = self._iteration(current_configs) + last_results = summarize_hpo_results( + current_configs, + current_perfs, + white_list=set(self._search_space.keys()), + desc=self._cfg.hpo.larger_better) + self._stage += 1 + logger.info( + "========================== Stage{} ==========================" + .format(self._stage)) + logger.info("\n{}".format(last_results)) + logger.info("====================================================") + current_configs = self._generate_next_population( + current_configs, current_perfs) + + return current_configs + + +class SuccessiveHalvingAlgo(IterativeScheduler): + """Successive Halving Algorithm (SHA) tailored to FL setting, where, + in each iteration, just a limited number of communication rounds are + allowed for each trial. + """ + def _setup(self): + init_configs = super(SuccessiveHalvingAlgo, self)._setup() + + for trial_cfg in init_configs: + trial_cfg['federate.save_to'] = os.path.join( + self._cfg.hpo.working_folder, + "{}.pth".format(config2str(trial_cfg))) + + if self._cfg.hpo.sha.budgets: + for trial_cfg in init_configs: + trial_cfg[ + 'federate.total_round_num'] = self._cfg.hpo.sha.budgets[ + self._stage] + trial_cfg['eval.freq'] = self._cfg.hpo.sha.budgets[self._stage] + + return init_configs + + def _stop_criterion(self, configs, last_results): + return len(configs) <= 1 + + def _generate_next_population(self, configs, perfs): + indices = [(i, val) for i, val in enumerate(perfs)] + indices.sort(key=lambda x: x[1], reverse=self._cfg.hpo.larger_better) + next_population = [ + configs[tp[0]] for tp in + indices[:math. + ceil(float(len(indices)) / self._cfg.hpo.sha.elim_rate)] + ] + + for trial_cfg in next_population: + if 'federate.restore_from' not in trial_cfg: + trial_cfg['federate.restore_from'] = trial_cfg[ + 'federate.save_to'] + if self._cfg.hpo.sha.budgets and self._stage < len( + self._cfg.hpo.sha.budgets): + trial_cfg[ + 'federate.total_round_num'] = self._cfg.hpo.sha.budgets[ + self._stage] + trial_cfg['eval.freq'] = self._cfg.hpo.sha.budgets[self._stage] + + return next_population + + +class SHAWrapFedex(SuccessiveHalvingAlgo): + """This SHA is customized as a wrapper for FedEx algorithm.""" + def _make_local_perturbation(self, config): + neighbor = dict() + for k in config: + if 'fedex' in k or 'fedopt' in k or k in [ + 'federate.save_to', 'federate.total_round_num', 'eval.freq' + ]: + # a workaround + continue + hyper = self._search_space.get(k) + if isinstance(hyper, CS.UniformFloatHyperparameter): + lb, ub = hyper.lower, hyper.upper + diameter = self._cfg.hpo.table.eps * (ub - lb) + new_val = (config[k] - + 0.5 * diameter) + np.random.uniform() * diameter + neighbor[k] = float(np.clip(new_val, lb, ub)) + elif isinstance(hyper, CS.UniformIntegerHyperparameter): + lb, ub = hyper.lower, hyper.upper + diameter = self._cfg.hpo.table.eps * (ub - lb) + new_val = round( + float((config[k] - 0.5 * diameter) + + np.random.uniform() * diameter)) + neighbor[k] = int(np.clip(new_val, lb, ub)) + elif isinstance(hyper, CS.CategoricalHyperparameter): + if len(hyper.choices) == 1: + neighbor[k] = config[k] + else: + threshold = self._cfg.hpo.table.eps * len( + hyper.choices) / (len(hyper.choices) - 1) + rn = np.random.uniform() + new_val = np.random.choice( + hyper.choices) if rn <= threshold else config[k] + if type(new_val) in [np.int32, np.int64]: + neighbor[k] = int(new_val) + elif type(new_val) in [np.float32, np.float64]: + neighbor[k] = float(new_val) + else: + neighbor[k] = str(new_val) + else: + raise TypeError("Value of {} has an invalid type {}".format( + k, type(config[k]))) + + return neighbor + + def _setup(self): + # self._cache_yaml() + init_configs = super(SHAWrapFedex, self)._setup() + new_init_configs = [] + for idx, trial_cfg in enumerate(init_configs): + arms = dict(("arm{}".format(1 + j), + self._make_local_perturbation(trial_cfg)) + for j in range(self._cfg.hpo.table.num - 1)) + arms['arm0'] = dict( + (k, v) for k, v in trial_cfg.items() if k in arms['arm1']) + with open( + os.path.join(self._cfg.hpo.working_folder, + f'{idx}_tmp_grid_search_space.yaml'), + 'w') as f: + yaml.dump(arms, f) + new_trial_cfg = dict() + for k in trial_cfg: + if k not in arms['arm0']: + new_trial_cfg[k] = trial_cfg[k] + new_trial_cfg['hpo.table.idx'] = idx + new_trial_cfg['hpo.fedex.ss'] = os.path.join( + self._cfg.hpo.working_folder, + f"{new_trial_cfg['hpo.table.idx']}_tmp_grid_search_space.yaml") + new_trial_cfg['federate.save_to'] = os.path.join( + self._cfg.hpo.working_folder, "idx_{}.pth".format(idx)) + new_init_configs.append(new_trial_cfg) + + self._search_space.add_hyperparameter( + CS.CategoricalHyperparameter("hpo.table.idx", + choices=list( + range(len(new_init_configs))))) + + return new_init_configs + + +# TODO: refactor PBT to enable async parallel +# class PBT(IterativeScheduler): +# """Population-based training (the full paper "Population Based Training +# of Neural Networks" can be found at https://arxiv.org/abs/1711.09846) +# tailored to FL setting, where, in each iteration, just a limited number +# of communication rounds are allowed for each trial (We will provide the +# asynchornous version later). +# """ +# def _setup(self, raw_search_space): +# _ = super(PBT, self)._setup(raw_search_space) +# +# if global_cfg.hpo.init_strategy == 'random': +# init_configs = random_search( +# raw_search_space, +# sample_size=global_cfg.hpo.sha.elim_rate** +# global_cfg.hpo.sha.elim_round_num) +# elif global_cfg.hpo.init_strategy == 'grid': +# init_configs = grid_search(raw_search_space, \ +# sample_size=global_cfg.hpo.sha.elim_rate \ +# **global_cfg.hpo.sha.elim_round_num) +# else: +# raise ValueError( +# "SHA needs to use random/grid search to pick {} configs +# from the search space as initial candidates, but `{}` is +# specified as `hpo.init_strategy`" +# .format( +# global_cfg.hpo.sha.elim_rate** +# global_cfg.hpo.sha.elim_round_num, +# global_cfg.hpo.init_strategy)) +# +# for trial_cfg in init_configs: +# trial_cfg['federate.save_to'] = os.path.join( +# global_cfg.hpo.working_folder, +# "{}.pth".format(config2str(trial_cfg))) +# +# return init_configs +# +# def _stop_criterion(self, configs, last_results): +# if last_results is not None: +# if (global_cfg.hpo.larger_better +# and last_results.iloc[0]['performance'] >= +# global_cfg.hpo.pbt.perf_threshold) or ( +# (not global_cfg.hpo.larger_better) +# and last_results.iloc[0]['performance'] <= +# global_cfg.hpo.pbt.perf_threshold): +# return True +# return self._stage >= global_cfg.hpo.pbt.max_stage +# +# def _generate_next_population(self, configs, perfs): +# next_generation = [] +# for i in range(len(configs)): +# new_cfg = deepcopy(configs[i]) +# # exploit +# j = np.random.randint(len(configs)) +# if i != j and ( +# (global_cfg.hpo.larger_better and perfs[j] > perfs[i]) or +# ((not global_cfg.hpo.larger_better) and perfs[j] < perfs[i])): +# new_cfg['federate.restore_from'] = configs[j][ +# 'federate.save_to'] +# # explore +# for k in new_cfg: +# if isinstance(new_cfg[k], float): +# # according to the exploration strategy of the PBT +# paper +# new_cfg[k] *= float(np.random.choice([0.8, 1.2])) +# else: +# new_cfg['federate.restore_from'] = configs[i][ +# 'federate.save_to'] +# +# # update save path +# tmp_cfg = dict() +# for k in new_cfg: +# if k in self._original_search_space: +# tmp_cfg[k] = new_cfg[k] +# new_cfg['federate.save_to'] = os.path.join( +# global_cfg.hpo.working_folder, +# "{}.pth".format(config2str(tmp_cfg))) +# +# next_generation.append(new_cfg) +# +# return next_generation diff --git a/fgssl/autotune/choice_types.py b/fgssl/autotune/choice_types.py new file mode 100644 index 0000000..00cd490 --- /dev/null +++ b/fgssl/autotune/choice_types.py @@ -0,0 +1,162 @@ +# import os +# import sys +# file_dir = os.path.join(os.path.dirname(__file__), '../..') +# sys.path.append(file_dir) +import logging +import math +import yaml + +import numpy as np + +from federatedscope.core.configs.config import global_cfg + +logger = logging.getLogger(__name__) + + +def discretize(contd_choices, num_bkt): + '''Discretize a given continuous search space into the given number of buckets. + + Arguments: + contd_choices (Continuous): continuous choices. + num_bkt (int): number of buckets. + :returns: discritized choices. + :rtype: Discrete + ''' + if contd_choices[0] >= .0 and global_cfg.hpo.log_scale: + loglb, logub = math.log( + np.clip(contd_choices[0], 1e-8, + contd_choices[1])), math.log(contd_choices[1]) + if num_bkt == 1: + choices = [math.exp(loglb + 0.5 * (logub - loglb))] + else: + bkt_size = (logub - loglb) / (num_bkt - 1) + choices = [math.exp(loglb + i * bkt_size) for i in range(num_bkt)] + else: + if num_bkt == 1: + choices = [ + contd_choices[0] + 0.5 * (contd_choices[1] - contd_choices[0]) + ] + else: + bkt_size = (contd_choices[1] - contd_choices[0]) / (num_bkt - 1) + choices = [contd_choices[0] + i * bkt_size for i in range(num_bkt)] + disc_choices = Discrete(*choices) + return disc_choices + + +class Continuous(tuple): + """Represents a continuous search space, e.g., in the range [0.001, 0.1]. + """ + def __new__(cls, lb, ub): + assert ub >= lb, "Invalid configuration where ub:{} is less than " \ + "lb:{}".format(ub, lb) + return tuple.__new__(cls, [lb, ub]) + + def __repr__(self): + return "Continuous(%s,%s)" % self + + def sample(self): + """Sample a value from this search space. + + :returns: the sampled value. + :rtype: float + """ + if self[0] >= .0 and global_cfg.hpo.log_scale: + loglb, logub = math.log(np.clip(self[0], 1e-8, + self[1])), math.log(self[1]) + return math.exp(loglb + np.random.rand() * (logub - loglb)) + else: + return float(self[0] + np.random.rand() * (self[1] - self[0])) + + def grid(self, grid_cnt): + """Generate a given nunber of grids from this search space. + + Arguments: + grid_cnt (int): the number of grids. + :returns: the sampled value. + :rtype: float + """ + discretized = discretize(self, grid_cnt) + return list(discretized) + + +def contd_constructor(loader, node): + value = loader.construct_scalar(node) + lb, ub = map(float, value.split(',')) + return Continuous(lb, ub) + + +yaml.add_constructor(u'!contd', contd_constructor) + + +class Discrete(tuple): + """Represents a discrete search space, e.g., {'abc', 'ijk', 'xyz'}. + """ + def __new__(cls, *args): + return tuple.__new__(cls, args) + + def __repr__(self): + return "Discrete(%s)" % ','.join(map(str, self)) + + def sample(self): + """Sample a value from this search space. + + :returns: the sampled value. + :rtype: depends on the original choices. + """ + + return self[np.random.randint(len(self))] + + def grid(self, grid_cnt): + num_original = len(self) + assert grid_cnt <= num_original, "There are only {} choices to " \ + "produce grids, but {} " \ + "required".format(num_original, + grid_cnt) + if grid_cnt == 1: + selected = [self[len(self) // 2]] + else: + optimistic_step_size = (num_original - 1) // grid_cnt + between_end_len = optimistic_step_size * (grid_cnt - 1) + remainder = (num_original - 1) - between_end_len + one_side_remainder = remainder // 2 if remainder % 2 == 0 else \ + remainder // 2 + 1 + if one_side_remainder <= optimistic_step_size // 2: + step_size = optimistic_step_size + else: + step_size = (num_original - 1) // (grid_cnt - 1) + covered_range = (grid_cnt - 1) * step_size + start_idx = (max(num_original - 1, 1) - covered_range) // 2 + selected = [ + self[j] for j in range( + start_idx, + min(start_idx + + grid_cnt * step_size, num_original), step_size) + ] + return selected + + +def disc_constructor(loader, node): + value = loader.construct_sequence(node) + return Discrete(*value) + + +yaml.add_constructor(u'!disc', disc_constructor) + +# if __name__=="__main__": +# obj = Continuous(0.0, 0.01) +# print(obj.grid(1), obj.grid(2), obj.grid(3)) +# for _ in range(3): +# print(obj.sample()) +# cfg.merge_from_list(['hpo.log_scale', 'True']) +# print(obj.grid(1), obj.grid(2), obj.grid(3)) +# for _ in range(3): +# print(obj.sample()) +# +# obj = Discrete('a', 'b', 'c') +# print(obj.grid(1), obj.grid(2), obj.grid(3)) +# for _ in range(3): +# print(obj.sample()) +# obj = Discrete(1, 2, 3, 4, 5) +# print(obj.grid(1), obj.grid(2), obj.grid(3), obj.grid(4), obj.grid(5)) +# for _ in range(3): +# print(obj.sample()) diff --git a/fgssl/autotune/fedex/__init__.py b/fgssl/autotune/fedex/__init__.py new file mode 100644 index 0000000..ae2a876 --- /dev/null +++ b/fgssl/autotune/fedex/__init__.py @@ -0,0 +1,4 @@ +from federatedscope.autotune.fedex.server import FedExServer +from federatedscope.autotune.fedex.client import FedExClient + +__all__ = ['FedExServer', 'FedExClient'] diff --git a/fgssl/autotune/fedex/client.py b/fgssl/autotune/fedex/client.py new file mode 100644 index 0000000..6b3e5c2 --- /dev/null +++ b/fgssl/autotune/fedex/client.py @@ -0,0 +1,94 @@ +import logging +import json +import copy + +from federatedscope.core.message import Message +from federatedscope.core.workers import Client + +logger = logging.getLogger(__name__) + + +class FedExClient(Client): + """Some code snippets are borrowed from the open-sourced FedEx ( + https://github.com/mkhodak/FedEx) + """ + def _apply_hyperparams(self, hyperparams): + """Apply the given hyperparameters + Arguments: + hyperparams (dict): keys are hyperparameter names \ + and values are specific choices. + """ + + cmd_args = [] + for k, v in hyperparams.items(): + cmd_args.append(k) + cmd_args.append(v) + + self._cfg.defrost() + self._cfg.merge_from_list(cmd_args, check_cfg=False) + self._cfg.freeze(inform=False, check_cfg=False) + + self.trainer.ctx.setup_vars() + + def callback_funcs_for_model_para(self, message: Message): + round, sender, content = message.state, message.sender, message.content + model_params, arms, hyperparams = content["model_param"], content[ + "arms"], content["hyperparam"] + attempt = { + 'Role': 'Client #{:d}'.format(self.ID), + 'Round': self.state + 1, + 'Arms': arms, + 'Hyperparams': hyperparams + } + logger.info(json.dumps(attempt)) + + self._apply_hyperparams(hyperparams) + + self.trainer.update(model_params) + + # self.model.load_state_dict(content) + self.state = round + sample_size, model_para_all, results = self.trainer.train() + if self._cfg.federate.share_local_model and not \ + self._cfg.federate.online_aggr: + model_para_all = copy.deepcopy(model_para_all) + logger.info( + self._monitor.format_eval_res(results, + rnd=self.state, + role='Client #{}'.format(self.ID), + return_raw=True)) + + results['arms'] = arms + content = (sample_size, model_para_all, results) + self.comm_manager.send( + Message(msg_type='model_para', + sender=self.ID, + receiver=[sender], + state=self.state, + content=content)) + + def callback_funcs_for_evaluate(self, message: Message): + sender = message.sender + self.state = message.state + if message.content is not None: + model_params = message.content["model_param"] + self.trainer.update(model_params) + if self._cfg.finetune.before_eval: + self.trainer.finetune() + metrics = {} + for split in self._cfg.eval.split: + eval_metrics = self.trainer.evaluate(target_data_split_name=split) + for key in eval_metrics: + + if self._cfg.federate.mode == 'distributed': + logger.info('Client #{:d}: (Evaluation ({:s} set) at ' + 'Round #{:d}) {:s} is {:.6f}'.format( + self.ID, split, self.state, key, + eval_metrics[key])) + metrics.update(**eval_metrics) + self.comm_manager.send( + Message(msg_type='metrics', + sender=self.ID, + receiver=[sender], + state=self.state, + content=metrics)) diff --git a/fgssl/autotune/fedex/server.py b/fgssl/autotune/fedex/server.py new file mode 100644 index 0000000..fac5e64 --- /dev/null +++ b/fgssl/autotune/fedex/server.py @@ -0,0 +1,450 @@ +import os +import logging +from itertools import product + +import yaml + +import numpy as np +from numpy.linalg import norm +from scipy.special import logsumexp + +from federatedscope.core.message import Message +from federatedscope.core.workers import Server +from federatedscope.core.auxiliaries.utils import merge_dict + +logger = logging.getLogger(__name__) + + +def discounted_mean(trace, factor=1.0): + + weight = factor**np.flip(np.arange(len(trace)), axis=0) + + return np.inner(trace, weight) / weight.sum() + + +class FedExServer(Server): + """Some code snippets are borrowed from the open-sourced FedEx ( + https://github.com/mkhodak/FedEx) + """ + def __init__(self, + ID=-1, + state=0, + config=None, + data=None, + model=None, + client_num=5, + total_round_num=10, + device='cpu', + strategy=None, + **kwargs): + + # initialize action space and the policy + with open(config.hpo.fedex.ss, 'r') as ips: + ss = yaml.load(ips, Loader=yaml.FullLoader) + + if next(iter(ss.keys())).startswith('arm'): + # This is a flattened action space + # ensure the order is unchanged + ss = sorted([(int(k[3:]), v) for k, v in ss.items()], + key=lambda x: x[0]) + self._grid = [] + self._cfsp = [[tp[1] for tp in ss]] + else: + # This is not a flat search space + # be careful for the order + self._grid = sorted(ss.keys()) + self._cfsp = [ss[pn] for pn in self._grid] + + sizes = [len(cand_set) for cand_set in self._cfsp] + eta0 = 'auto' if config.hpo.fedex.eta0 <= .0 else float( + config.hpo.fedex.eta0) + self._eta0 = [ + np.sqrt(2.0 * np.log(size)) if eta0 == 'auto' else eta0 + for size in sizes + ] + self._sched = config.hpo.fedex.sched + self._cutoff = config.hpo.fedex.cutoff + self._baseline = config.hpo.fedex.gamma + self._diff = config.hpo.fedex.diff + self._z = [np.full(size, -np.log(size)) for size in sizes] + self._theta = [np.exp(z) for z in self._z] + self._store = [0.0 for _ in sizes] + self._stop_exploration = False + self._trace = { + 'global': [], + 'refine': [], + 'entropy': [self.entropy()], + 'mle': [self.mle()] + } + + super(FedExServer, + self).__init__(ID, state, config, data, model, client_num, + total_round_num, device, strategy, **kwargs) + + if self._cfg.federate.restore_from != '': + if not os.path.exists(self._cfg.federate.restore_from): + logger.warning(f'Invalid `restore_from`:' + f' {self._cfg.federate.restore_from}.') + else: + pi_ckpt_path = self._cfg.federate.restore_from[ + :self._cfg.federate.restore_from.rfind('.')] \ + + "_fedex.yaml" + with open(pi_ckpt_path, 'r') as ips: + ckpt = yaml.load(ips, Loader=yaml.FullLoader) + self._z = [np.asarray(z) for z in ckpt['z']] + self._theta = [np.exp(z) for z in self._z] + self._store = ckpt['store'] + self._stop_exploration = ckpt['stop'] + self._trace = dict() + self._trace['global'] = ckpt['global'] + self._trace['refine'] = ckpt['refine'] + self._trace['entropy'] = ckpt['entropy'] + self._trace['mle'] = ckpt['mle'] + + def entropy(self): + entropy = 0.0 + for probs in product(*(theta[theta > 0.0] for theta in self._theta)): + prob = np.prod(probs) + entropy -= prob * np.log(prob) + return entropy + + def mle(self): + + return np.prod([theta.max() for theta in self._theta]) + + def trace(self, key): + '''returns trace of one of three tracked quantities + Args: + key (str): 'entropy', 'global', or 'refine' + Returns: + numpy vector with length equal to number of rounds up to now. + ''' + + return np.array(self._trace[key]) + + def sample(self): + """samples from configs using current probability vector""" + + # determine index + if self._stop_exploration: + cfg_idx = [theta.argmax() for theta in self._theta] + else: + cfg_idx = [ + np.random.choice(len(theta), p=theta) for theta in self._theta + ] + + # get the sampled value(s) + if self._grid: + sampled_cfg = { + pn: cands[i] + for pn, cands, i in zip(self._grid, self._cfsp, cfg_idx) + } + else: + sampled_cfg = self._cfsp[0][cfg_idx[0]] + + return cfg_idx, sampled_cfg + + def broadcast_model_para(self, + msg_type='model_para', + sample_client_num=-1, + filter_unseen_clients=True): + """ + To broadcast the message to all clients or sampled clients + """ + if filter_unseen_clients: + # to filter out the unseen clients when sampling + self.sampler.change_state(self.unseen_clients_id, 'unseen') + + if sample_client_num > 0: + receiver = self.sampler.sample(size=sample_client_num) + else: + # broadcast to all clients + receiver = list(self.comm_manager.neighbors.keys()) + if msg_type == 'model_para': + self.sampler.change_state(receiver, 'working') + + if self._noise_injector is not None and msg_type == 'model_para': + # Inject noise only when broadcast parameters + for model_idx_i in range(len(self.models)): + num_sample_clients = [ + v["num_sample"] for v in self.join_in_info.values() + ] + self._noise_injector(self._cfg, num_sample_clients, + self.models[model_idx_i]) + + if self.model_num > 1: + model_para = [model.state_dict() for model in self.models] + else: + model_para = self.model.state_dict() + + # sample the hyper-parameter config specific to the clients + + for rcv_idx in receiver: + cfg_idx, sampled_cfg = self.sample() + content = { + 'model_param': model_para, + "arms": cfg_idx, + 'hyperparam': sampled_cfg + } + self.comm_manager.send( + Message(msg_type=msg_type, + sender=self.ID, + receiver=[rcv_idx], + state=self.state, + content=content)) + if self._cfg.federate.online_aggr: + for idx in range(self.model_num): + self.aggregators[idx].reset() + + if filter_unseen_clients: + # restore the state of the unseen clients within sampler + self.sampler.change_state(self.unseen_clients_id, 'seen') + + def callback_funcs_model_para(self, message: Message): + round, sender, content = message.state, message.sender, message.content + self.sampler.change_state(sender, 'idle') + # For a new round + if round not in self.msg_buffer['train'].keys(): + self.msg_buffer['train'][round] = dict() + + self.msg_buffer['train'][round][sender] = content + + if self._cfg.federate.online_aggr: + self.aggregator.inc(tuple(content[0:2])) + + return self.check_and_move_on() + + def update_policy(self, feedbacks): + """Update the policy. This implementation is borrowed from the + open-sourced FedEx ( + https://github.com/mkhodak/FedEx/blob/ \ + 150fac03857a3239429734d59d319da71191872e/hyper.py#L151) + Arguments: + feedbacks (list): each element is a dict containing "arms" and + necessary feedback. + """ + + index = [elem['arms'] for elem in feedbacks] + before = np.asarray( + [elem['val_avg_loss_before'] for elem in feedbacks]) + after = np.asarray([elem['val_avg_loss_after'] for elem in feedbacks]) + weight = np.asarray([elem['val_total'] for elem in feedbacks], + dtype=np.float64) + weight /= np.sum(weight) + + if self._trace['refine']: + trace = self.trace('refine') + if self._diff: + trace -= self.trace('global') + baseline = discounted_mean(trace, self._baseline) + else: + baseline = 0.0 + self._trace['global'].append(np.inner(before, weight)) + self._trace['refine'].append(np.inner(after, weight)) + if self._stop_exploration: + self._trace['entropy'].append(0.0) + self._trace['mle'].append(1.0) + return + + for i, (z, theta) in enumerate(zip(self._z, self._theta)): + grad = np.zeros(len(z)) + for idx, s, w in zip(index, + after - before if self._diff else after, + weight): + grad[idx[i]] += w * (s - baseline) / theta[idx[i]] + if self._sched == 'adaptive': + self._store[i] += norm(grad, float('inf'))**2 + denom = np.sqrt(self._store[i]) + elif self._sched == 'aggressive': + denom = 1.0 if np.all( + grad == 0.0) else norm(grad, float('inf')) + elif self._sched == 'auto': + self._store[i] += 1.0 + denom = np.sqrt(self._store[i]) + elif self._sched == 'constant': + denom = 1.0 + elif self._sched == 'scale': + denom = 1.0 / np.sqrt( + 2.0 * np.log(len(grad))) if len(grad) > 1 else float('inf') + else: + raise NotImplementedError + eta = self._eta0[i] / denom + z -= eta * grad + z -= logsumexp(z) + self._theta[i] = np.exp(z) + + self._trace['entropy'].append(self.entropy()) + self._trace['mle'].append(self.mle()) + if self._trace['entropy'][-1] < self._cutoff: + self._stop_exploration = True + + logger.info( + 'Server: Updated policy as {} with entropy {:f} and mle {:f}'. + format(self._theta, self._trace['entropy'][-1], + self._trace['mle'][-1])) + + def check_and_move_on(self, + check_eval_result=False, + min_received_num=None): + """ + To check the message_buffer, when enough messages are receiving, + trigger some events (such as perform aggregation, evaluation, + and move to the next training round) + """ + if min_received_num is None: + min_received_num = self._cfg.federate.sample_client_num + assert min_received_num <= self.sample_client_num + + if check_eval_result: + min_received_num = len(list(self.comm_manager.neighbors.keys())) + + move_on_flag = True # To record whether moving to a new training + # round or finishing the evaluation + if self.check_buffer(self.state, min_received_num, check_eval_result): + + if not check_eval_result: # in the training process + mab_feedbacks = list() + # Get all the message + train_msg_buffer = self.msg_buffer['train'][self.state] + for model_idx in range(self.model_num): + model = self.models[model_idx] + aggregator = self.aggregators[model_idx] + msg_list = list() + for client_id in train_msg_buffer: + if self.model_num == 1: + msg_list.append( + tuple(train_msg_buffer[client_id][0:2])) + else: + train_data_size, model_para_multiple = \ + train_msg_buffer[client_id][0:2] + msg_list.append((train_data_size, + model_para_multiple[model_idx])) + + # collect feedbacks for updating the policy + if model_idx == 0: + mab_feedbacks.append( + train_msg_buffer[client_id][2]) + + # Trigger the monitor here (for training) + if 'dissim' in self._cfg.eval.monitoring: + from federatedscope.core.auxiliaries.utils import \ + calc_blocal_dissim + # TODO: fix load_state_dict + B_val = calc_blocal_dissim( + model.load_state_dict(strict=False), msg_list) + formatted_eval_res = self._monitor.format_eval_res( + B_val, rnd=self.state, role='Server #') + logger.info(formatted_eval_res) + + # Aggregate + agg_info = { + 'client_feedback': msg_list, + 'recover_fun': self.recover_fun + } + result = aggregator.aggregate(agg_info) + model.load_state_dict(result, strict=False) + # aggregator.update(result) + + # update the policy + self.update_policy(mab_feedbacks) + + self.state += 1 + if self.state % self._cfg.eval.freq == 0 and self.state != \ + self.total_round_num: + # Evaluate + logger.info( + 'Server: Starting evaluation at round {:d}.'.format( + self.state)) + self.eval() + + if self.state < self.total_round_num: + # Move to next round of training + logger.info( + f'----------- Starting a new training round (Round ' + f'#{self.state}) -------------') + # Clean the msg_buffer + self.msg_buffer['train'][self.state - 1].clear() + + self.broadcast_model_para( + msg_type='model_para', + sample_client_num=self.sample_client_num) + else: + # Final Evaluate + logger.info('Server: Training is finished! Starting ' + 'evaluation.') + self.eval() + + else: # in the evaluation process + # Get all the message & aggregate + formatted_eval_res = self.merge_eval_results_from_all_clients() + self.history_results = merge_dict(self.history_results, + formatted_eval_res) + self.check_and_save() + else: + move_on_flag = False + + return move_on_flag + + def check_and_save(self): + """ + To save the results and save model after each evaluation + """ + # early stopping + should_stop = False + + if "Results_weighted_avg" in self.history_results and \ + self._cfg.eval.best_res_update_round_wise_key in \ + self.history_results['Results_weighted_avg']: + should_stop = self.early_stopper.track_and_check( + self.history_results['Results_weighted_avg'][ + self._cfg.eval.best_res_update_round_wise_key]) + elif "Results_avg" in self.history_results and \ + self._cfg.eval.best_res_update_round_wise_key in \ + self.history_results['Results_avg']: + should_stop = self.early_stopper.track_and_check( + self.history_results['Results_avg'][ + self._cfg.eval.best_res_update_round_wise_key]) + else: + should_stop = False + + if should_stop: + self.state = self.total_round_num + 1 + + if should_stop or self.state == self.total_round_num: + logger.info('Server: Final evaluation is finished! Starting ' + 'merging results.') + # last round + self.save_best_results() + + if self._cfg.federate.save_to != '': + # save the policy + ckpt = dict() + z_list = [z.tolist() for z in self._z] + ckpt['z'] = z_list + ckpt['store'] = self._store + ckpt['stop'] = self._stop_exploration + ckpt['global'] = self.trace('global').tolist() + ckpt['refine'] = self.trace('refine').tolist() + ckpt['entropy'] = self.trace('entropy').tolist() + ckpt['mle'] = self.trace('mle').tolist() + pi_ckpt_path = self._cfg.federate.save_to[:self._cfg.federate. + save_to.rfind( + '.' + )] + "_fedex.yaml" + with open(pi_ckpt_path, 'w') as ops: + yaml.dump(ckpt, ops) + + if self.model_num > 1: + model_para = [model.state_dict() for model in self.models] + else: + model_para = self.model.state_dict() + self.comm_manager.send( + Message(msg_type='finish', + sender=self.ID, + receiver=list(self.comm_manager.neighbors.keys()), + state=self.state, + content=model_para)) + + if self.state == self.total_round_num: + # break out the loop for distributed mode + self.state += 1 diff --git a/fgssl/autotune/hpbandster.py b/fgssl/autotune/hpbandster.py new file mode 100644 index 0000000..aa8e42f --- /dev/null +++ b/fgssl/autotune/hpbandster.py @@ -0,0 +1,136 @@ +import os +import time +import logging + +from os.path import join as osp +import numpy as np +import ConfigSpace as CS +import hpbandster.core.nameserver as hpns +from hpbandster.core.worker import Worker +from hpbandster.optimizers import BOHB, HyperBand, RandomSearch +from hpbandster.optimizers.iterations import SuccessiveHalving + +from federatedscope.autotune.utils import eval_in_fs + +logging.basicConfig(level=logging.WARNING) +logger = logging.getLogger(__name__) + + +def clear_cache(working_folder): + # Clear cached ckpt + for name in os.listdir(working_folder): + if name.endswith('.pth'): + os.remove(osp(working_folder, name)) + + +class MyRandomSearch(RandomSearch): + def __init__(self, working_folder, **kwargs): + self.working_folder = working_folder + super(MyRandomSearch, self).__init__(**kwargs) + + +class MyBOHB(BOHB): + def __init__(self, working_folder, **kwargs): + self.working_folder = working_folder + super(MyBOHB, self).__init__(**kwargs) + + def get_next_iteration(self, iteration, iteration_kwargs={}): + if os.path.exists(self.working_folder): + clear_cache(self.working_folder) + return super(MyBOHB, self).get_next_iteration(iteration, + iteration_kwargs) + + +class MyHyperBand(HyperBand): + def __init__(self, working_folder, **kwargs): + self.working_folder = working_folder + super(MyHyperBand, self).__init__(**kwargs) + + def get_next_iteration(self, iteration, iteration_kwargs={}): + if os.path.exists(self.working_folder): + clear_cache(self.working_folder) + return super(MyHyperBand, + self).get_next_iteration(iteration, iteration_kwargs) + + +class MyWorker(Worker): + def __init__(self, cfg, ss, sleep_interval=0, *args, **kwargs): + super(MyWorker, self).__init__(**kwargs) + self.sleep_interval = sleep_interval + self.cfg = cfg + self._ss = ss + self._init_configs = [] + self._perfs = [] + + def compute(self, config, budget, **kwargs): + res = eval_in_fs(self.cfg, config, int(budget)) + config = dict(config) + config['federate.total_round_num'] = budget + self._init_configs.append(config) + self._perfs.append(float(res)) + time.sleep(self.sleep_interval) + logger.info(f'Evaluate the {len(self._perfs)-1}-th config ' + f'{config}, and get performance {res}') + return {'loss': float(res), 'info': res} + + def summarize(self): + from federatedscope.autotune.utils import summarize_hpo_results + results = summarize_hpo_results(self._init_configs, + self._perfs, + white_list=set(self._ss.keys()), + desc=self.cfg.hpo.larger_better) + logger.info( + "========================== HPO Final ==========================") + logger.info("\n{}".format(results)) + logger.info("====================================================") + + return results + + +def run_hpbandster(cfg, scheduler): + config_space = scheduler._search_space + if cfg.hpo.scheduler.startswith('wrap_'): + ss = CS.ConfigurationSpace() + ss.add_hyperparameter(config_space['hpo.table.idx']) + config_space = ss + NS = hpns.NameServer(run_id=cfg.hpo.scheduler, host='127.0.0.1', port=0) + ns_host, ns_port = NS.start() + w = MyWorker(sleep_interval=0, + ss=config_space, + cfg=cfg, + nameserver='127.0.0.1', + nameserver_port=ns_port, + run_id=cfg.hpo.scheduler) + w.run(background=True) + opt_kwargs = { + 'configspace': config_space, + 'run_id': cfg.hpo.scheduler, + 'nameserver': '127.0.0.1', + 'nameserver_port': ns_port, + 'eta': cfg.hpo.sha.elim_rate, + 'min_budget': cfg.hpo.sha.budgets[0], + 'max_budget': cfg.hpo.sha.budgets[-1], + 'working_folder': cfg.hpo.working_folder + } + if cfg.hpo.scheduler in ['rs', 'wrap_rs']: + optimizer = MyRandomSearch(**opt_kwargs) + elif cfg.hpo.scheduler in ['hb', 'wrap_hb']: + optimizer = MyHyperBand(**opt_kwargs) + elif cfg.hpo.scheduler in ['bo_kde', 'bohb', 'wrap_bo_kde', 'wrap_bohb']: + optimizer = MyBOHB(**opt_kwargs) + else: + raise ValueError + + if cfg.hpo.sha.iter != 0: + n_iterations = cfg.hpo.sha.iter + else: + n_iterations = -int( + np.log(opt_kwargs['min_budget'] / opt_kwargs['max_budget']) / + np.log(opt_kwargs['eta'])) + 1 + res = optimizer.run(n_iterations=n_iterations) + optimizer.shutdown(shutdown_workers=True) + NS.shutdown() + all_runs = res.get_all_runs() + w.summarize() + + return [x.info for x in all_runs] diff --git a/fgssl/autotune/smac.py b/fgssl/autotune/smac.py new file mode 100644 index 0000000..99dd312 --- /dev/null +++ b/fgssl/autotune/smac.py @@ -0,0 +1,77 @@ +import logging +import numpy as np +import ConfigSpace as CS +from federatedscope.autotune.utils import eval_in_fs +from smac.facade.smac_bb_facade import SMAC4BB +from smac.facade.smac_hpo_facade import SMAC4HPO +from smac.scenario.scenario import Scenario + +logging.basicConfig(level=logging.WARNING) +logger = logging.getLogger(__name__) + + +def run_smac(cfg, scheduler): + init_configs = [] + perfs = [] + + def optimization_function_wrapper(config): + budget = cfg.hpo.sha.budgets[-1] + res = eval_in_fs(cfg, config, budget) + config = dict(config) + config['federate.total_round_num'] = budget + init_configs.append(config) + perfs.append(res) + logger.info(f'Evaluate the {len(perfs)-1}-th config ' + f'{config}, and get performance {res}') + return res + + def summarize(): + from federatedscope.autotune.utils import summarize_hpo_results + results = summarize_hpo_results(init_configs, + perfs, + white_list=set(config_space.keys()), + desc=cfg.hpo.larger_better) + logger.info( + "========================== HPO Final ==========================") + logger.info("\n{}".format(results)) + logger.info("====================================================") + + return perfs + + config_space = scheduler._search_space + if cfg.hpo.scheduler.startswith('wrap_'): + ss = CS.ConfigurationSpace() + ss.add_hyperparameter(config_space['hpo.table.idx']) + config_space = ss + + if cfg.hpo.sha.iter != 0: + n_iterations = cfg.hpo.sha.iter + else: + n_iterations = -int( + np.log(cfg.hpo.sha.budgets[0] / cfg.hpo.sha.budgets[-1]) / + np.log(cfg.hpo.sha.elim_rate)) + 1 + + scenario = Scenario({ + "run_obj": "quality", + "runcount-limit": n_iterations, + "cs": config_space, + "output_dir": cfg.hpo.working_folder, + "deterministic": "true", + "limit_resources": False + }) + + if cfg.hpo.scheduler.endswith('bo_gp'): + smac = SMAC4BB(model_type='gp', + scenario=scenario, + tae_runner=optimization_function_wrapper) + elif cfg.hpo.scheduler.endswith('bo_rf'): + smac = SMAC4HPO(scenario=scenario, + tae_runner=optimization_function_wrapper) + else: + raise NotImplementedError + try: + smac.optimize() + finally: + smac.solver.incumbent + summarize() + return perfs diff --git a/fgssl/autotune/utils.py b/fgssl/autotune/utils.py new file mode 100644 index 0000000..cd6523a --- /dev/null +++ b/fgssl/autotune/utils.py @@ -0,0 +1,176 @@ +import yaml +import pandas as pd +import ConfigSpace as CS + + +def parse_search_space(config_path): + """Parse yaml format configuration to generate search space + Arguments: + config_path (str): the path of the yaml file. + :returns: the search space. + :rtype: ConfigSpace object + """ + + ss = CS.ConfigurationSpace() + + with open(config_path, 'r') as ips: + raw_ss_config = yaml.load(ips, Loader=yaml.FullLoader) + + for k in raw_ss_config.keys(): + name = k + v = raw_ss_config[k] + hyper_type = v['type'] + del v['type'] + v['name'] = name + + if hyper_type == 'float': + hyper_config = CS.UniformFloatHyperparameter(**v) + elif hyper_type == 'int': + hyper_config = CS.UniformIntegerHyperparameter(**v) + elif hyper_type == 'cate': + hyper_config = CS.CategoricalHyperparameter(**v) + else: + raise ValueError("Unsupported hyper type {}".format(hyper_type)) + ss.add_hyperparameter(hyper_config) + + return ss + + +def config2cmdargs(config): + ''' + Arguments: + config (dict): key is cfg node name, value is the specified value. + Returns: + results (list): cmd args + ''' + + results = [] + for k, v in config.items(): + results.append(k) + results.append(v) + return results + + +def config2str(config): + ''' + Arguments: + config (dict): key is cfg node name, value is the choice of + hyper-parameter. + Returns: + name (str): the string representation of this config + ''' + + vals = [] + for k in config: + idx = k.rindex('.') + vals.append(k[idx + 1:]) + vals.append(str(config[k])) + name = '_'.join(vals) + return name + + +def summarize_hpo_results(configs, perfs, white_list=None, desc=False): + cols = [k for k in configs[0] if (white_list is None or k in white_list) + ] + ['performance'] + d = [[ + trial_cfg[k] + for k in trial_cfg if (white_list is None or k in white_list) + ] + [result] for trial_cfg, result in zip(configs, perfs)] + d = sorted(d, key=lambda ele: ele[-1], reverse=desc) + df = pd.DataFrame(d, columns=cols) + pd.set_option('display.max_colwidth', None) + pd.set_option('display.max_columns', None) + return df + + +def parse_logs(file_list): + import numpy as np + import matplotlib.pyplot as plt + + FONTSIZE = 40 + MARKSIZE = 25 + + def process(file): + history = [] + with open(file, 'r') as F: + for line in F: + try: + state, line = line.split('INFO: ') + config = eval(line[line.find('{'):line.find('}') + 1]) + performance = float( + line[line.find('performance'):].split(' ')[1]) + print(config, performance) + history.append((config, performance)) + except: + continue + best_seen = np.inf + tol_budget, tmp_b = 0, 0 + x, y = [], [] + + for config, performance in history: + tol_budget += config['federate.total_round_num'] + if best_seen > performance or config[ + 'federate.total_round_num'] > tmp_b: + best_seen = performance + x.append(tol_budget) + y.append(best_seen) + tmp_b = config['federate.total_round_num'] + return np.array(x) / tol_budget, np.array(y) + + # Draw + plt.figure(figsize=(10, 7.5)) + plt.xticks(fontsize=FONTSIZE) + plt.yticks(fontsize=FONTSIZE) + + plt.xlabel('Fraction of budget', size=FONTSIZE) + plt.ylabel('Loss', size=FONTSIZE) + + for file in file_list: + x, y = process(file) + plt.plot(x, y, linewidth=1, markersize=MARKSIZE) + plt.legend(file_list, fontsize=23, loc='lower right') + plt.savefig('exp2.pdf', bbox_inches='tight') + plt.close() + + +def eval_in_fs(cfg, config, budget): + import ConfigSpace as CS + from federatedscope.core.auxiliaries.utils import setup_seed + from federatedscope.core.auxiliaries.data_builder import get_data + from federatedscope.core.auxiliaries.worker_builder import \ + get_client_cls, get_server_cls + from federatedscope.core.fed_runner import FedRunner + from federatedscope.autotune.utils import config2cmdargs + from os.path import join as osp + + if isinstance(config, CS.Configuration): + config = dict(config) + # Add FedEx related keys to config + if 'hpo.table.idx' in config.keys(): + idx = config['hpo.table.idx'] + config['hpo.fedex.ss'] = osp(cfg.hpo.working_folder, + f"{idx}_tmp_grid_search_space.yaml") + config['federate.save_to'] = osp(cfg.hpo.working_folder, + f"idx_{idx}.pth") + config['federate.restore_from'] = osp(cfg.hpo.working_folder, + f"idx_{idx}.pth") + # Global cfg + trial_cfg = cfg.clone() + # specify the configuration of interest + trial_cfg.merge_from_list(config2cmdargs(config)) + # specify the budget + trial_cfg.merge_from_list( + ["federate.total_round_num", + int(budget), "eval.freq", + int(budget)]) + setup_seed(trial_cfg.seed) + data, modified_config = get_data(config=trial_cfg.clone()) + trial_cfg.merge_from_other_cfg(modified_config) + trial_cfg.freeze() + Fed_runner = FedRunner(data=data, + server_class=get_server_cls(trial_cfg), + client_class=get_client_cls(trial_cfg), + config=trial_cfg.clone()) + results = Fed_runner.run() + key1, key2 = trial_cfg.hpo.metric.split('.') + return results[key1][key2] diff --git a/fgssl/contrib/README.md b/fgssl/contrib/README.md new file mode 100644 index 0000000..2402c97 --- /dev/null +++ b/fgssl/contrib/README.md @@ -0,0 +1,15 @@ +# Register + +In addition to the rich collection of datasets, models, evaluation metrics, etc., FederatedScope (FS) also allows users to create their own ingredients or introduce more customized modules to FS. Inspired by GraphGym, we provide `register` mechanism to help integrating your own components into the FS-based federated learning workflow, including: + +* Configurations [`federatedscope/contrib/configs`](https://github.com/alibaba/FederatedScope/tree/master/federatedscope/contrib/configs) +* Data [`federatedscope/contrib/data`](https://github.com/alibaba/FederatedScope/tree/master/federatedscope/contrib/data) +* Loss [`federatedscope/contrib/loss`](https://github.com/alibaba/FederatedScope/tree/master/federatedscope/contrib/loss) +* Metrics [`federatedscope/contrib/metrics`](https://github.com/alibaba/FederatedScope/tree/master/federatedscope/contrib/metrics) +* Model [`federatedscope/contrib/model`](https://github.com/alibaba/FederatedScope/tree/master/federatedscope/contrib/model) +* Optimizer [`federatedscope/contrib/optimizer`](https://github.com/alibaba/FederatedScope/tree/master/federatedscope/contrib/optimizer) +* Scheduler [`federatedscope/contrib/scheduler`](https://github.com/alibaba/FederatedScope/tree/master/federatedscope/contrib/scheduler) +* Splitter [`federatedscope/contrib/splitter`](https://github.com/alibaba/FederatedScope/tree/master/federatedscope/contrib/splitter) +* Trainer [`federatedscope/contrib/trainer`](https://github.com/alibaba/FederatedScope/tree/master/federatedscope/contrib/trainer) +* Worker [`federatedscope/contrib/worker`](https://github.com/alibaba/FederatedScope/tree/master/federatedscope/contrib/worker) + diff --git a/fgssl/contrib/__init__.py b/fgssl/contrib/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/fgssl/contrib/configs/__init__.py b/fgssl/contrib/configs/__init__.py new file mode 100644 index 0000000..cef30fa --- /dev/null +++ b/fgssl/contrib/configs/__init__.py @@ -0,0 +1,14 @@ +import copy +from os.path import dirname, basename, isfile, join +import glob + +modules = glob.glob(join(dirname(__file__), "*.py")) +__all__ = [ + basename(f)[:-3] for f in modules + if isfile(f) and not f.endswith('__init__.py') +] + +# to ensure the sub-configs registered before set up the global config +all_sub_configs_contrib = copy.copy(__all__) +if "config" in all_sub_configs_contrib: + all_sub_configs_contrib.remove('config') diff --git a/fgssl/contrib/configs/myconfig.py b/fgssl/contrib/configs/myconfig.py new file mode 100644 index 0000000..f1c3db2 --- /dev/null +++ b/fgssl/contrib/configs/myconfig.py @@ -0,0 +1,23 @@ +from federatedscope.core.configs.config import CN + + +def extend_training_cfg(cfg): + # ------------------------------------------------------------------------ # + # Trainer related options + # ------------------------------------------------------------------------ # + cfg.data = CN() + + cfg.data.fgcl = False + + # --------------- register corresponding check function ---------- + cfg.register_cfg_check_fun(assert_training_cfg) + + +def assert_training_cfg(cfg): + if cfg.backend not in ['torch', 'tensorflow']: + raise ValueError( + "Value of 'cfg.backend' must be chosen from ['torch', 'tensorflow']." + ) + +# from federatedscope.register import register_config +# register_config("fl_training", extend_training_cfg) \ No newline at end of file diff --git a/fgssl/contrib/data/__init__.py b/fgssl/contrib/data/__init__.py new file mode 100644 index 0000000..c0b3138 --- /dev/null +++ b/fgssl/contrib/data/__init__.py @@ -0,0 +1,8 @@ +from os.path import dirname, basename, isfile, join +import glob + +modules = glob.glob(join(dirname(__file__), "*.py")) +__all__ = [ + basename(f)[:-3] for f in modules + if isfile(f) and not f.endswith('__init__.py') +] diff --git a/fgssl/contrib/data/example.py b/fgssl/contrib/data/example.py new file mode 100644 index 0000000..b896bdf --- /dev/null +++ b/fgssl/contrib/data/example.py @@ -0,0 +1,30 @@ +from federatedscope.register import register_data + + +def MyData(config, client_cfgs=None): + r""" + Returns: + data: + { + '{client_id}': { + 'train': Dataset or DataLoader, + 'test': Dataset or DataLoader, + 'val': Dataset or DataLoader + } + } + config: + cfg_node + """ + data = None + config = config + client_cfgs = client_cfgs + return data, config + + +def call_my_data(config, client_cfgs): + if config.data.type == "mydata": + data, modified_config = MyData(config, client_cfgs) + return data, modified_config + + +register_data("mydata", call_my_data) diff --git a/fgssl/contrib/loss/__init__.py b/fgssl/contrib/loss/__init__.py new file mode 100644 index 0000000..c0b3138 --- /dev/null +++ b/fgssl/contrib/loss/__init__.py @@ -0,0 +1,8 @@ +from os.path import dirname, basename, isfile, join +import glob + +modules = glob.glob(join(dirname(__file__), "*.py")) +__all__ = [ + basename(f)[:-3] for f in modules + if isfile(f) and not f.endswith('__init__.py') +] diff --git a/fgssl/contrib/loss/example.py b/fgssl/contrib/loss/example.py new file mode 100644 index 0000000..012eca2 --- /dev/null +++ b/fgssl/contrib/loss/example.py @@ -0,0 +1,17 @@ +from federatedscope.register import register_criterion + + +def call_my_criterion(type, device): + try: + import torch.nn as nn + except ImportError: + nn = None + criterion = None + + if type == 'mycriterion': + if nn is not None: + criterion = nn.CrossEntropyLoss().to(device) + return criterion + + +register_criterion('mycriterion', call_my_criterion) diff --git a/fgssl/contrib/metrics/__init__.py b/fgssl/contrib/metrics/__init__.py new file mode 100644 index 0000000..c0b3138 --- /dev/null +++ b/fgssl/contrib/metrics/__init__.py @@ -0,0 +1,8 @@ +from os.path import dirname, basename, isfile, join +import glob + +modules = glob.glob(join(dirname(__file__), "*.py")) +__all__ = [ + basename(f)[:-3] for f in modules + if isfile(f) and not f.endswith('__init__.py') +] diff --git a/fgssl/contrib/metrics/example.py b/fgssl/contrib/metrics/example.py new file mode 100644 index 0000000..f162606 --- /dev/null +++ b/fgssl/contrib/metrics/example.py @@ -0,0 +1,16 @@ +from federatedscope.register import register_metric + +METRIC_NAME = 'example' + + +def MyMetric(ctx, **kwargs): + return ctx.num_train_data + + +def call_my_metric(types): + if METRIC_NAME in types: + metric_builder = MyMetric + return METRIC_NAME, metric_builder + + +register_metric(METRIC_NAME, call_my_metric) diff --git a/fgssl/contrib/metrics/poison_acc.py b/fgssl/contrib/metrics/poison_acc.py new file mode 100644 index 0000000..0093e9e --- /dev/null +++ b/fgssl/contrib/metrics/poison_acc.py @@ -0,0 +1,31 @@ +from federatedscope.register import register_metric +import numpy as np + + +def compute_poison_metric(ctx): + + poison_true = ctx['poison_' + ctx.cur_split + '_y_true'] + poison_prob = ctx['poison_' + ctx.cur_split + '_y_prob'] + poison_pred = np.argmax(poison_prob, axis=1) + + correct = poison_true == poison_pred + + return float(np.sum(correct)) / len(correct) + + +def load_poison_metrics(ctx, y_true, y_pred, y_prob, **kwargs): + + if ctx.cur_split == 'train': + results = None + else: + results = compute_poison_metric(ctx) + + return results + + +def call_poison_metric(types): + if 'poison_attack_acc' in types: + return 'poison_attack_acc', load_poison_metrics + + +register_metric('poison_attack_acc', call_poison_metric) diff --git a/fgssl/contrib/model/GCL/__init__.py b/fgssl/contrib/model/GCL/__init__.py new file mode 100644 index 0000000..8cdcc45 --- /dev/null +++ b/fgssl/contrib/model/GCL/__init__.py @@ -0,0 +1,16 @@ +import GCL.losses +import GCL.augmentors +import GCL.eval +import GCL.models +import GCL.utils + +__version__ = '0.1.0' + +__all__ = [ + '__version__', + 'losses', + 'augmentors', + 'eval', + 'models', + 'utils' +] diff --git a/fgssl/contrib/model/GCL/augmentors/__init__.py b/fgssl/contrib/model/GCL/augmentors/__init__.py new file mode 100644 index 0000000..0d0b7be --- /dev/null +++ b/fgssl/contrib/model/GCL/augmentors/__init__.py @@ -0,0 +1,32 @@ +from .augmentor import Graph, Augmentor, Compose, RandomChoice +from .identity import Identity +from .rw_sampling import RWSampling +from .ppr_diffusion import PPRDiffusion +from .markov_diffusion import MarkovDiffusion +from .edge_adding import EdgeAdding +from .edge_removing import EdgeRemoving +from .node_dropping import NodeDropping +from .node_shuffling import NodeShuffling +from .feature_masking import FeatureMasking +from .feature_dropout import FeatureDropout +from .edge_attr_masking import EdgeAttrMasking + +__all__ = [ + 'Graph', + 'Augmentor', + 'Compose', + 'RandomChoice', + 'EdgeAdding', + 'EdgeRemoving', + 'EdgeAttrMasking', + 'FeatureMasking', + 'FeatureDropout', + 'Identity', + 'PPRDiffusion', + 'MarkovDiffusion', + 'NodeDropping', + 'NodeShuffling', + 'RWSampling' +] + +classes = __all__ diff --git a/fgssl/contrib/model/GCL/augmentors/augmentor.py b/fgssl/contrib/model/GCL/augmentors/augmentor.py new file mode 100644 index 0000000..1626769 --- /dev/null +++ b/fgssl/contrib/model/GCL/augmentors/augmentor.py @@ -0,0 +1,58 @@ +from __future__ import annotations + +import torch +from abc import ABC, abstractmethod +from typing import Optional, Tuple, NamedTuple, List + + +class Graph(NamedTuple): + x: torch.FloatTensor + edge_index: torch.LongTensor + edge_weights: Optional[torch.FloatTensor] + + def unfold(self) -> Tuple[torch.FloatTensor, torch.LongTensor, Optional[torch.FloatTensor]]: + return self.x, self.edge_index, self.edge_weights + + +class Augmentor(ABC): + """Base class for graph augmentors.""" + def __init__(self): + pass + + @abstractmethod + def augment(self, g: Graph) -> Graph: + raise NotImplementedError(f"GraphAug.augment should be implemented.") + + def __call__( + self, x: torch.FloatTensor, + edge_index: torch.LongTensor, edge_weight: Optional[torch.FloatTensor] = None + ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + return self.augment(Graph(x, edge_index, edge_weight)).unfold() + + +class Compose(Augmentor): + def __init__(self, augmentors: List[Augmentor]): + super(Compose, self).__init__() + self.augmentors = augmentors + + def augment(self, g: Graph) -> Graph: + for aug in self.augmentors: + g = aug.augment(g) + return g + + +class RandomChoice(Augmentor): + def __init__(self, augmentors: List[Augmentor], num_choices: int): + super(RandomChoice, self).__init__() + assert num_choices <= len(augmentors) + self.augmentors = augmentors + self.num_choices = num_choices + + def augment(self, g: Graph) -> Graph: + num_augmentors = len(self.augmentors) + perm = torch.randperm(num_augmentors) + idx = perm[:self.num_choices] + for i in idx: + aug = self.augmentors[i] + g = aug.augment(g) + return g diff --git a/fgssl/contrib/model/GCL/augmentors/edge_adding.py b/fgssl/contrib/model/GCL/augmentors/edge_adding.py new file mode 100644 index 0000000..4a5f895 --- /dev/null +++ b/fgssl/contrib/model/GCL/augmentors/edge_adding.py @@ -0,0 +1,13 @@ +from GCL.augmentors.augmentor import Graph, Augmentor +from GCL.augmentors.functional import add_edge + + +class EdgeAdding(Augmentor): + def __init__(self, pe: float): + super(EdgeAdding, self).__init__() + self.pe = pe + + def augment(self, g: Graph) -> Graph: + x, edge_index, edge_weights = g.unfold() + edge_index = add_edge(edge_index, ratio=self.pe) + return Graph(x=x, edge_index=edge_index, edge_weights=edge_weights) diff --git a/fgssl/contrib/model/GCL/augmentors/edge_attr_masking.py b/fgssl/contrib/model/GCL/augmentors/edge_attr_masking.py new file mode 100644 index 0000000..7344c0e --- /dev/null +++ b/fgssl/contrib/model/GCL/augmentors/edge_attr_masking.py @@ -0,0 +1,14 @@ +from GCL.augmentors.augmentor import Graph, Augmentor +from GCL.augmentors.functional import drop_feature + + +class EdgeAttrMasking(Augmentor): + def __init__(self, pf: float): + super(EdgeAttrMasking, self).__init__() + self.pf = pf + + def augment(self, g: Graph) -> Graph: + x, edge_index, edge_weights = g.unfold() + if edge_weights is not None: + edge_weights = drop_feature(edge_weights, self.pf) + return Graph(x=x, edge_index=edge_index, edge_weights=edge_weights) diff --git a/fgssl/contrib/model/GCL/augmentors/edge_removing.py b/fgssl/contrib/model/GCL/augmentors/edge_removing.py new file mode 100644 index 0000000..adfaeaf --- /dev/null +++ b/fgssl/contrib/model/GCL/augmentors/edge_removing.py @@ -0,0 +1,13 @@ +from GCL.augmentors.augmentor import Graph, Augmentor +from GCL.augmentors.functional import dropout_adj + + +class EdgeRemoving(Augmentor): + def __init__(self, pe: float): + super(EdgeRemoving, self).__init__() + self.pe = pe + + def augment(self, g: Graph) -> Graph: + x, edge_index, edge_weights = g.unfold() + edge_index, edge_weights = dropout_adj(edge_index, edge_attr=edge_weights, p=self.pe) + return Graph(x=x, edge_index=edge_index, edge_weights=edge_weights) diff --git a/fgssl/contrib/model/GCL/augmentors/feature_dropout.py b/fgssl/contrib/model/GCL/augmentors/feature_dropout.py new file mode 100644 index 0000000..0395435 --- /dev/null +++ b/fgssl/contrib/model/GCL/augmentors/feature_dropout.py @@ -0,0 +1,13 @@ +from GCL.augmentors.augmentor import Graph, Augmentor +from GCL.augmentors.functional import dropout_feature + + +class FeatureDropout(Augmentor): + def __init__(self, pf: float): + super(FeatureDropout, self).__init__() + self.pf = pf + + def augment(self, g: Graph) -> Graph: + x, edge_index, edge_weights = g.unfold() + x = dropout_feature(x, self.pf) + return Graph(x=x, edge_index=edge_index, edge_weights=edge_weights) diff --git a/fgssl/contrib/model/GCL/augmentors/feature_masking.py b/fgssl/contrib/model/GCL/augmentors/feature_masking.py new file mode 100644 index 0000000..9d0acc6 --- /dev/null +++ b/fgssl/contrib/model/GCL/augmentors/feature_masking.py @@ -0,0 +1,13 @@ +from GCL.augmentors.augmentor import Graph, Augmentor +from GCL.augmentors.functional import drop_feature + + +class FeatureMasking(Augmentor): + def __init__(self, pf: float): + super(FeatureMasking, self).__init__() + self.pf = pf + + def augment(self, g: Graph) -> Graph: + x, edge_index, edge_weights = g.unfold() + x = drop_feature(x, self.pf) + return Graph(x=x, edge_index=edge_index, edge_weights=edge_weights) diff --git a/fgssl/contrib/model/GCL/augmentors/functional.py b/fgssl/contrib/model/GCL/augmentors/functional.py new file mode 100644 index 0000000..1a0fa3d --- /dev/null +++ b/fgssl/contrib/model/GCL/augmentors/functional.py @@ -0,0 +1,332 @@ +import torch +import networkx as nx +import torch.nn.functional as F + +from typing import Optional +from GCL.utils import normalize +from torch_sparse import SparseTensor, coalesce +from torch_scatter import scatter +from torch_geometric.transforms import GDC +from torch.distributions import Uniform, Beta +from torch_geometric.utils import dropout_adj, to_networkx, to_undirected, degree, to_scipy_sparse_matrix, \ + from_scipy_sparse_matrix, sort_edge_index, add_self_loops, subgraph +from torch.distributions.bernoulli import Bernoulli + + +def permute(x: torch.Tensor) -> torch.Tensor: + """ + Randomly permute node embeddings or features. + + Args: + x: The latent embedding or node feature. + + Returns: + torch.Tensor: Embeddings or features resulting from permutation. + """ + return x[torch.randperm(x.size(0))] + + +def get_mixup_idx(x: torch.Tensor) -> torch.Tensor: + """ + Generate node IDs randomly for mixup; avoid mixup the same node. + + Args: + x: The latent embedding or node feature. + + Returns: + torch.Tensor: Random node IDs. + """ + mixup_idx = torch.randint(x.size(0) - 1, [x.size(0)]) + mixup_self_mask = mixup_idx - torch.arange(x.size(0)) + mixup_self_mask = (mixup_self_mask == 0) + mixup_idx += torch.ones(x.size(0), dtype=torch.int) * mixup_self_mask + return mixup_idx + + +def mixup(x: torch.Tensor, alpha: float) -> torch.Tensor: + """ + Randomly mixup node embeddings or features with other nodes'. + + Args: + x: The latent embedding or node feature. + alpha: The hyperparameter controlling the mixup coefficient. + + Returns: + torch.Tensor: Embeddings or features resulting from mixup. + """ + device = x.device + mixup_idx = get_mixup_idx(x).to(device) + lambda_ = Uniform(alpha, 1.).sample([1]).to(device) + x = (1 - lambda_) * x + lambda_ * x[mixup_idx] + return x + + +def multiinstance_mixup(x1: torch.Tensor, x2: torch.Tensor, + alpha: float, shuffle=False) -> (torch.Tensor, torch.Tensor): + """ + Randomly mixup node embeddings or features with nodes from other views. + + Args: + x1: The latent embedding or node feature from one view. + x2: The latent embedding or node feature from the other view. + alpha: The mixup coefficient `\lambda` follows `Beta(\alpha, \alpha)`. + shuffle: Whether to use fixed negative samples. + + Returns: + (torch.Tensor, torch.Tensor): Spurious positive samples and the mixup coefficient. + """ + device = x1.device + lambda_ = Beta(alpha, alpha).sample([1]).to(device) + if shuffle: + mixup_idx = get_mixup_idx(x1).to(device) + else: + mixup_idx = x1.size(0) - torch.arange(x1.size(0)) - 1 + x_spurious = (1 - lambda_) * x1 + lambda_ * x2[mixup_idx] + + return x_spurious, lambda_ + + +def drop_feature(x: torch.Tensor, drop_prob: float) -> torch.Tensor: + device = x.device + drop_mask = torch.empty((x.size(1),), dtype=torch.float32).uniform_(0, 1) < drop_prob + drop_mask = drop_mask.to(device) + x = x.clone() + x[:, drop_mask] = 0 + + return x + + +def dropout_feature(x: torch.FloatTensor, drop_prob: float) -> torch.FloatTensor: + return F.dropout(x, p=1. - drop_prob) + + +class AugmentTopologyAttributes(object): + def __init__(self, pe=0.5, pf=0.5): + self.pe = pe + self.pf = pf + + def __call__(self, x, edge_index): + edge_index = dropout_adj(edge_index, p=self.pe)[0] + x = drop_feature(x, self.pf) + return x, edge_index + + +def get_feature_weights(x, centrality, sparse=True): + if sparse: + x = x.to(torch.bool).to(torch.float32) + else: + x = x.abs() + w = x.t() @ centrality + w = w.log() + + return normalize(w) + + +def drop_feature_by_weight(x, weights, drop_prob: float, threshold: float = 0.7): + weights = weights / weights.mean() * drop_prob + weights = weights.where(weights < threshold, torch.ones_like(weights) * threshold) # clip + drop_mask = torch.bernoulli(weights).to(torch.bool) + x = x.clone() + x[:, drop_mask] = 0. + return x + + +def get_eigenvector_weights(data): + def _eigenvector_centrality(data): + graph = to_networkx(data) + x = nx.eigenvector_centrality_numpy(graph) + x = [x[i] for i in range(data.num_nodes)] + return torch.tensor(x, dtype=torch.float32).to(data.edge_index.device) + + evc = _eigenvector_centrality(data) + scaled_evc = evc.where(evc > 0, torch.zeros_like(evc)) + scaled_evc = scaled_evc + 1e-8 + s = scaled_evc.log() + + edge_index = data.edge_index + s_row, s_col = s[edge_index[0]], s[edge_index[1]] + + return normalize(s_col), evc + + +def get_degree_weights(data): + edge_index_ = to_undirected(data.edge_index) + deg = degree(edge_index_[1]) + deg_col = deg[data.edge_index[1]].to(torch.float32) + scaled_deg_col = torch.log(deg_col) + + return normalize(scaled_deg_col), deg + + +def get_pagerank_weights(data, aggr: str = 'sink', k: int = 10): + def _compute_pagerank(edge_index, damp: float = 0.85, k: int = 10): + num_nodes = edge_index.max().item() + 1 + deg_out = degree(edge_index[0]) + x = torch.ones((num_nodes,)).to(edge_index.device).to(torch.float32) + + for i in range(k): + edge_msg = x[edge_index[0]] / deg_out[edge_index[0]] + agg_msg = scatter(edge_msg, edge_index[1], reduce='sum') + + x = (1 - damp) * x + damp * agg_msg + + return x + + pv = _compute_pagerank(data.edge_index, k=k) + pv_row = pv[data.edge_index[0]].to(torch.float32) + pv_col = pv[data.edge_index[1]].to(torch.float32) + s_row = torch.log(pv_row) + s_col = torch.log(pv_col) + if aggr == 'sink': + s = s_col + elif aggr == 'source': + s = s_row + elif aggr == 'mean': + s = (s_col + s_row) * 0.5 + else: + s = s_col + + return normalize(s), pv + + +def drop_edge_by_weight(edge_index, weights, drop_prob: float, threshold: float = 0.7): + weights = weights / weights.mean() * drop_prob + weights = weights.where(weights < threshold, torch.ones_like(weights) * threshold) + drop_mask = torch.bernoulli(1. - weights).to(torch.bool) + + return edge_index[:, drop_mask] + + +class AdaptivelyAugmentTopologyAttributes(object): + def __init__(self, edge_weights, feature_weights, pe=0.5, pf=0.5, threshold=0.7): + self.edge_weights = edge_weights + self.feature_weights = feature_weights + self.pe = pe + self.pf = pf + self.threshold = threshold + + def __call__(self, x, edge_index): + edge_index = drop_edge_by_weight(edge_index, self.edge_weights, self.pe, self.threshold) + x = drop_feature_by_weight(x, self.feature_weights, self.pf, self.threshold) + + return x, edge_index + + +def get_subgraph(x, edge_index, idx): + adj = to_scipy_sparse_matrix(edge_index).tocsr() + x_sampled = x[idx] + edge_index_sampled = from_scipy_sparse_matrix(adj[idx, :][:, idx]) + return x_sampled, edge_index_sampled + + +def sample_nodes(x, edge_index, sample_size): + idx = torch.randperm(x.size(0))[:sample_size] + return get_subgraph(x, edge_index, idx), idx + + +def compute_ppr(edge_index, edge_weight=None, alpha=0.2, eps=0.1, ignore_edge_attr=True, add_self_loop=True): + N = edge_index.max().item() + 1 + if ignore_edge_attr or edge_weight is None: + edge_weight = torch.ones( + edge_index.size(1), device=edge_index.device) + if add_self_loop: + edge_index, edge_weight = add_self_loops( + edge_index, edge_weight, fill_value=1, num_nodes=N) + edge_index, edge_weight = coalesce(edge_index, edge_weight, N, N) + edge_index, edge_weight = coalesce(edge_index, edge_weight, N, N) + edge_index, edge_weight = GDC().transition_matrix( + edge_index, edge_weight, N, normalization='sym') + diff_mat = GDC().diffusion_matrix_exact( + edge_index, edge_weight, N, method='ppr', alpha=alpha) + edge_index, edge_weight = GDC().sparsify_dense(diff_mat, method='threshold', eps=eps) + edge_index, edge_weight = coalesce(edge_index, edge_weight, N, N) + edge_index, edge_weight = GDC().transition_matrix( + edge_index, edge_weight, N, normalization='sym') + + return edge_index, edge_weight + + +def get_sparse_adj(edge_index: torch.LongTensor, edge_weight: torch.FloatTensor = None, + add_self_loop: bool = True) -> torch.sparse.Tensor: + num_nodes = edge_index.max().item() + 1 + num_edges = edge_index.size(1) + + if edge_weight is None: + edge_weight = torch.ones((num_edges,), dtype=torch.float32, device=edge_index.device) + + if add_self_loop: + edge_index, edge_weight = add_self_loops( + edge_index, edge_weight, fill_value=1, num_nodes=num_nodes) + edge_index, edge_weight = coalesce(edge_index, edge_weight, num_nodes, num_nodes) + + edge_index, edge_weight = GDC().transition_matrix( + edge_index, edge_weight, num_nodes, normalization='sym') + + adj_t = torch.sparse_coo_tensor(edge_index, edge_weight, size=(num_nodes, num_nodes)).coalesce() + + return adj_t.t() + + +def compute_markov_diffusion( + edge_index: torch.LongTensor, edge_weight: torch.FloatTensor = None, + alpha: float = 0.1, degree: int = 10, + sp_eps: float = 1e-3, add_self_loop: bool = True): + adj = get_sparse_adj(edge_index, edge_weight, add_self_loop) + + z = adj.to_dense() + t = adj.to_dense() + for _ in range(degree): + t = (1.0 - alpha) * torch.spmm(adj, t) + z += t + z /= degree + z = z + alpha * adj + + adj_t = z.t() + + return GDC().sparsify_dense(adj_t, method='threshold', eps=sp_eps) + + +def coalesce_edge_index(edge_index: torch.Tensor, edge_weights: Optional[torch.Tensor] = None) -> (torch.Tensor, torch.FloatTensor): + num_edges = edge_index.size()[1] + num_nodes = edge_index.max().item() + 1 + edge_weights = edge_weights if edge_weights is not None else torch.ones((num_edges,), dtype=torch.float32, device=edge_index.device) + + return coalesce(edge_index, edge_weights, m=num_nodes, n=num_nodes) + + +def add_edge(edge_index: torch.Tensor, ratio: float) -> torch.Tensor: + num_edges = edge_index.size()[1] + num_nodes = edge_index.max().item() + 1 + num_add = int(num_edges * ratio) + + new_edge_index = torch.randint(0, num_nodes - 1, size=(2, num_add)).to(edge_index.device) + edge_index = torch.cat([edge_index, new_edge_index], dim=1) + + edge_index = sort_edge_index(edge_index)[0] + + return coalesce_edge_index(edge_index)[0] + + +def drop_node(edge_index: torch.Tensor, edge_weight: Optional[torch.Tensor] = None, keep_prob: float = 0.5) -> (torch.Tensor, Optional[torch.Tensor]): + num_nodes = edge_index.max().item() + 1 + probs = torch.tensor([keep_prob for _ in range(num_nodes)]) + dist = Bernoulli(probs) + + subset = dist.sample().to(torch.bool).to(edge_index.device) + edge_index, edge_weight = subgraph(subset, edge_index, edge_weight) + + return edge_index, edge_weight + + +def random_walk_subgraph(edge_index: torch.LongTensor, edge_weight: Optional[torch.FloatTensor] = None, batch_size: int = 1000, length: int = 10): + num_nodes = edge_index.max().item() + 1 + + row, col = edge_index + adj = SparseTensor(row=row, col=col, sparse_sizes=(num_nodes, num_nodes)) + + start = torch.randint(0, num_nodes, size=(batch_size, ), dtype=torch.long).to(edge_index.device) + node_idx = adj.random_walk(start.flatten(), length).view(-1) + + edge_index, edge_weight = subgraph(node_idx, edge_index, edge_weight) + + return edge_index, edge_weight diff --git a/fgssl/contrib/model/GCL/augmentors/identity.py b/fgssl/contrib/model/GCL/augmentors/identity.py new file mode 100644 index 0000000..1717195 --- /dev/null +++ b/fgssl/contrib/model/GCL/augmentors/identity.py @@ -0,0 +1,9 @@ +from GCL.augmentors.augmentor import Graph, Augmentor + + +class Identity(Augmentor): + def __init__(self): + super(Identity, self).__init__() + + def augment(self, g: Graph) -> Graph: + return g diff --git a/fgssl/contrib/model/GCL/augmentors/markov_diffusion.py b/fgssl/contrib/model/GCL/augmentors/markov_diffusion.py new file mode 100644 index 0000000..6bd16d6 --- /dev/null +++ b/fgssl/contrib/model/GCL/augmentors/markov_diffusion.py @@ -0,0 +1,27 @@ +from GCL.augmentors.augmentor import Graph, Augmentor +from GCL.augmentors.functional import compute_markov_diffusion + + +class MarkovDiffusion(Augmentor): + def __init__(self, alpha: float = 0.05, order: int = 16, sp_eps: float = 1e-4, use_cache: bool = True, + add_self_loop: bool = True): + super(MarkovDiffusion, self).__init__() + self.alpha = alpha + self.order = order + self.sp_eps = sp_eps + self._cache = None + self.use_cache = use_cache + self.add_self_loop = add_self_loop + + def augment(self, g: Graph) -> Graph: + if self._cache is not None and self.use_cache: + return self._cache + x, edge_index, edge_weights = g.unfold() + edge_index, edge_weights = compute_markov_diffusion( + edge_index, edge_weights, + alpha=self.alpha, degree=self.order, + sp_eps=self.sp_eps, add_self_loop=self.add_self_loop + ) + res = Graph(x=x, edge_index=edge_index, edge_weights=edge_weights) + self._cache = res + return res diff --git a/fgssl/contrib/model/GCL/augmentors/node_dropping.py b/fgssl/contrib/model/GCL/augmentors/node_dropping.py new file mode 100644 index 0000000..d9e0dce --- /dev/null +++ b/fgssl/contrib/model/GCL/augmentors/node_dropping.py @@ -0,0 +1,15 @@ +from GCL.augmentors.augmentor import Graph, Augmentor +from GCL.augmentors.functional import drop_node + + +class NodeDropping(Augmentor): + def __init__(self, pn: float): + super(NodeDropping, self).__init__() + self.pn = pn + + def augment(self, g: Graph) -> Graph: + x, edge_index, edge_weights = g.unfold() + + edge_index, edge_weights = drop_node(edge_index, edge_weights, keep_prob=1. - self.pn) + + return Graph(x=x, edge_index=edge_index, edge_weights=edge_weights) diff --git a/fgssl/contrib/model/GCL/augmentors/node_shuffling.py b/fgssl/contrib/model/GCL/augmentors/node_shuffling.py new file mode 100644 index 0000000..ac35551 --- /dev/null +++ b/fgssl/contrib/model/GCL/augmentors/node_shuffling.py @@ -0,0 +1,12 @@ +from GCL.augmentors.augmentor import Graph, Augmentor +from GCL.augmentors.functional import permute + + +class NodeShuffling(Augmentor): + def __init__(self): + super(NodeShuffling, self).__init__() + + def augment(self, g: Graph) -> Graph: + x, edge_index, edge_weights = g.unfold() + x = permute(x) + return Graph(x=x, edge_index=edge_index, edge_weights=edge_weights) diff --git a/fgssl/contrib/model/GCL/augmentors/ppr_diffusion.py b/fgssl/contrib/model/GCL/augmentors/ppr_diffusion.py new file mode 100644 index 0000000..d33194d --- /dev/null +++ b/fgssl/contrib/model/GCL/augmentors/ppr_diffusion.py @@ -0,0 +1,24 @@ +from GCL.augmentors.augmentor import Graph, Augmentor +from GCL.augmentors.functional import compute_ppr + + +class PPRDiffusion(Augmentor): + def __init__(self, alpha: float = 0.2, eps: float = 1e-4, use_cache: bool = True, add_self_loop: bool = True): + super(PPRDiffusion, self).__init__() + self.alpha = alpha + self.eps = eps + self._cache = None + self.use_cache = use_cache + self.add_self_loop = add_self_loop + + def augment(self, g: Graph) -> Graph: + if self._cache is not None and self.use_cache: + return self._cache + x, edge_index, edge_weights = g.unfold() + edge_index, edge_weights = compute_ppr( + edge_index, edge_weights, + alpha=self.alpha, eps=self.eps, ignore_edge_attr=False, add_self_loop=self.add_self_loop + ) + res = Graph(x=x, edge_index=edge_index, edge_weights=edge_weights) + self._cache = res + return res diff --git a/fgssl/contrib/model/GCL/augmentors/rw_sampling.py b/fgssl/contrib/model/GCL/augmentors/rw_sampling.py new file mode 100644 index 0000000..f5176dc --- /dev/null +++ b/fgssl/contrib/model/GCL/augmentors/rw_sampling.py @@ -0,0 +1,16 @@ +from GCL.augmentors.augmentor import Graph, Augmentor +from GCL.augmentors.functional import random_walk_subgraph + + +class RWSampling(Augmentor): + def __init__(self, num_seeds: int, walk_length: int): + super(RWSampling, self).__init__() + self.num_seeds = num_seeds + self.walk_length = walk_length + + def augment(self, g: Graph) -> Graph: + x, edge_index, edge_weights = g.unfold() + + edge_index, edge_weights = random_walk_subgraph(edge_index, edge_weights, batch_size=self.num_seeds, length=self.walk_length) + + return Graph(x=x, edge_index=edge_index, edge_weights=edge_weights) diff --git a/fgssl/contrib/model/GCL/eval/__init__.py b/fgssl/contrib/model/GCL/eval/__init__.py new file mode 100644 index 0000000..610af7f --- /dev/null +++ b/fgssl/contrib/model/GCL/eval/__init__.py @@ -0,0 +1,16 @@ +from .eval import BaseEvaluator, BaseSKLearnEvaluator, get_split, from_predefined_split +from .logistic_regression import LREvaluator +from .svm import SVMEvaluator +from .random_forest import RFEvaluator + +__all__ = [ + 'BaseEvaluator', + 'BaseSKLearnEvaluator', + 'LREvaluator', + 'SVMEvaluator', + 'RFEvaluator', + 'get_split', + 'from_predefined_split' +] + +classes = __all__ diff --git a/fgssl/contrib/model/GCL/eval/eval.py b/fgssl/contrib/model/GCL/eval/eval.py new file mode 100644 index 0000000..77d4a2c --- /dev/null +++ b/fgssl/contrib/model/GCL/eval/eval.py @@ -0,0 +1,77 @@ +import torch +import numpy as np + +from abc import ABC, abstractmethod +from sklearn.metrics import f1_score +from sklearn.model_selection import PredefinedSplit, GridSearchCV + + +def get_split(num_samples: int, train_ratio: float = 0.1, test_ratio: float = 0.8): + assert train_ratio + test_ratio < 1 + train_size = int(num_samples * train_ratio) + test_size = int(num_samples * test_ratio) + indices = torch.randperm(num_samples) + return { + 'train': indices[:train_size], + 'valid': indices[train_size: test_size + train_size], + 'test': indices[test_size + train_size:] + } + + +def from_predefined_split(data): + assert all([mask is not None for mask in [data.train_mask, data.test_mask, data.val_mask]]) + num_samples = data.num_nodes + indices = torch.arange(num_samples) + return { + 'train': indices[data.train_mask], + 'valid': indices[data.val_mask], + 'test': indices[data.test_mask] + } + + +def split_to_numpy(x, y, split): + keys = ['train', 'test', 'valid'] + objs = [x, y] + return [obj[split[key]].detach().cpu().numpy() for obj in objs for key in keys] + + +def get_predefined_split(x_train, x_val, y_train, y_val, return_array=True): + test_fold = np.concatenate([-np.ones_like(y_train), np.zeros_like(y_val)]) + ps = PredefinedSplit(test_fold) + if return_array: + x = np.concatenate([x_train, x_val], axis=0) + y = np.concatenate([y_train, y_val], axis=0) + return ps, [x, y] + return ps + + +class BaseEvaluator(ABC): + @abstractmethod + def evaluate(self, x: torch.FloatTensor, y: torch.LongTensor, split: dict) -> dict: + pass + + def __call__(self, x: torch.FloatTensor, y: torch.LongTensor, split: dict) -> dict: + for key in ['train', 'test', 'valid']: + assert key in split + + result = self.evaluate(x, y, split) + return result + + +class BaseSKLearnEvaluator(BaseEvaluator): + def __init__(self, evaluator, params): + self.evaluator = evaluator + self.params = params + + def evaluate(self, x, y, split): + x_train, x_test, x_val, y_train, y_test, y_val = split_to_numpy(x, y, split) + ps, [x_train, y_train] = get_predefined_split(x_train, x_val, y_train, y_val) + classifier = GridSearchCV(self.evaluator, self.params, cv=ps, scoring='accuracy', verbose=0) + classifier.fit(x_train, y_train) + test_macro = f1_score(y_test, classifier.predict(x_test), average='macro') + test_micro = f1_score(y_test, classifier.predict(x_test), average='micro') + + return { + 'micro_f1': test_micro, + 'macro_f1': test_macro, + } diff --git a/fgssl/contrib/model/GCL/eval/logistic_regression.py b/fgssl/contrib/model/GCL/eval/logistic_regression.py new file mode 100644 index 0000000..9273045 --- /dev/null +++ b/fgssl/contrib/model/GCL/eval/logistic_regression.py @@ -0,0 +1,80 @@ +import torch +from tqdm import tqdm +from torch import nn +from torch.optim import Adam +from sklearn.metrics import f1_score + +from GCL.eval import BaseEvaluator + + +class LogisticRegression(nn.Module): + def __init__(self, num_features, num_classes): + super(LogisticRegression, self).__init__() + self.fc = nn.Linear(num_features, num_classes) + torch.nn.init.xavier_uniform_(self.fc.weight.data) + + def forward(self, x): + z = self.fc(x) + return z + + +class LREvaluator(BaseEvaluator): + def __init__(self, num_epochs: int = 5000, learning_rate: float = 0.01, + weight_decay: float = 0.0, test_interval: int = 20): + self.num_epochs = num_epochs + self.learning_rate = learning_rate + self.weight_decay = weight_decay + self.test_interval = test_interval + + def evaluate(self, x: torch.FloatTensor, y: torch.LongTensor, split: dict): + device = x.device + x = x.detach().to(device) + input_dim = x.size()[1] + y = y.to(device) + num_classes = y.max().item() + 1 + classifier = LogisticRegression(input_dim, num_classes).to(device) + optimizer = Adam(classifier.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay) + output_fn = nn.LogSoftmax(dim=-1) + criterion = nn.NLLLoss() + + best_val_micro = 0 + best_test_micro = 0 + best_test_macro = 0 + best_epoch = 0 + + with tqdm(total=self.num_epochs, desc='(LR)', + bar_format='{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}{postfix}]') as pbar: + for epoch in range(self.num_epochs): + classifier.train() + optimizer.zero_grad() + + output = classifier(x[split['train']]) + loss = criterion(output_fn(output), y[split['train']]) + + loss.backward() + optimizer.step() + + if (epoch + 1) % self.test_interval == 0: + classifier.eval() + y_test = y[split['test']].detach().cpu().numpy() + y_pred = classifier(x[split['test']]).argmax(-1).detach().cpu().numpy() + test_micro = f1_score(y_test, y_pred, average='micro') + test_macro = f1_score(y_test, y_pred, average='macro') + + y_val = y[split['valid']].detach().cpu().numpy() + y_pred = classifier(x[split['valid']]).argmax(-1).detach().cpu().numpy() + val_micro = f1_score(y_val, y_pred, average='micro') + + if val_micro > best_val_micro: + best_val_micro = val_micro + best_test_micro = test_micro + best_test_macro = test_macro + best_epoch = epoch + + pbar.set_postfix({'best test F1Mi': best_test_micro, 'F1Ma': best_test_macro}) + pbar.update(self.test_interval) + + return { + 'micro_f1': best_test_micro, + 'macro_f1': best_test_macro + } diff --git a/fgssl/contrib/model/GCL/eval/random_forest.py b/fgssl/contrib/model/GCL/eval/random_forest.py new file mode 100644 index 0000000..00d02fc --- /dev/null +++ b/fgssl/contrib/model/GCL/eval/random_forest.py @@ -0,0 +1,9 @@ +from sklearn.ensemble import RandomForestClassifier +from GCL.eval import BaseSKLearnEvaluator + + +class RFEvaluator(BaseSKLearnEvaluator): + def __init__(self, params=None): + if params is None: + params = {'n_estimators': [100, 200, 500, 1000]} + super(RFEvaluator, self).__init__(RandomForestClassifier(), params) diff --git a/fgssl/contrib/model/GCL/eval/svm.py b/fgssl/contrib/model/GCL/eval/svm.py new file mode 100644 index 0000000..2d38ed8 --- /dev/null +++ b/fgssl/contrib/model/GCL/eval/svm.py @@ -0,0 +1,13 @@ +from sklearn.svm import LinearSVC, SVC +from GCL.eval import BaseSKLearnEvaluator + + +class SVMEvaluator(BaseSKLearnEvaluator): + def __init__(self, linear=True, params=None): + if linear: + self.evaluator = LinearSVC() + else: + self.evaluator = SVC() + if params is None: + params = {'C': [0.001, 0.01, 0.1, 1, 10, 100, 1000]} + super(SVMEvaluator, self).__init__(self.evaluator, params) diff --git a/fgssl/contrib/model/GCL/losses/__init__.py b/fgssl/contrib/model/GCL/losses/__init__.py new file mode 100644 index 0000000..c36dd8a --- /dev/null +++ b/fgssl/contrib/model/GCL/losses/__init__.py @@ -0,0 +1,24 @@ +from .jsd import JSD, DebiasedJSD, HardnessJSD +from .vicreg import VICReg +from .infonce import InfoNCE, InfoNCESP, DebiasedInfoNCE, HardnessInfoNCE +from .triplet import TripletMargin, TripletMarginSP +from .bootstrap import BootstrapLatent +from .barlow_twins import BarlowTwins +from .losses import Loss + +__all__ = [ + 'Loss', + 'InfoNCE', + 'InfoNCESP', + 'DebiasedInfoNCE', + 'HardnessInfoNCE', + 'JSD', + 'DebiasedJSD', + 'HardnessJSD', + 'TripletMargin', + 'TripletMarginSP', + 'VICReg', + 'BarlowTwins' +] + +classes = __all__ diff --git a/fgssl/contrib/model/GCL/losses/barlow_twins.py b/fgssl/contrib/model/GCL/losses/barlow_twins.py new file mode 100644 index 0000000..d32aaa4 --- /dev/null +++ b/fgssl/contrib/model/GCL/losses/barlow_twins.py @@ -0,0 +1,34 @@ +import torch +from .losses import Loss + + +def bt_loss(h1: torch.Tensor, h2: torch.Tensor, lambda_, batch_norm=True, eps=1e-15, *args, **kwargs): + batch_size = h1.size(0) + feature_dim = h1.size(1) + + if lambda_ is None: + lambda_ = 1. / feature_dim + + if batch_norm: + z1_norm = (h1 - h1.mean(dim=0)) / (h1.std(dim=0) + eps) + z2_norm = (h2 - h2.mean(dim=0)) / (h2.std(dim=0) + eps) + c = (z1_norm.T @ z2_norm) / batch_size + else: + c = h1.T @ h2 / batch_size + + off_diagonal_mask = ~torch.eye(feature_dim).bool() + loss = (1 - c.diagonal()).pow(2).sum() + loss += lambda_ * c[off_diagonal_mask].pow(2).sum() + + return loss + + +class BarlowTwins(Loss): + def __init__(self, lambda_: float = None, batch_norm: bool = True, eps: float = 1e-5): + self.lambda_ = lambda_ + self.batch_norm = batch_norm + self.eps = eps + + def compute(self, anchor, sample, pos_mask, neg_mask, *args, **kwargs) -> torch.FloatTensor: + loss = bt_loss(anchor, sample, self.lambda_, self.batch_norm, self.eps) + return loss.mean() diff --git a/fgssl/contrib/model/GCL/losses/bootstrap.py b/fgssl/contrib/model/GCL/losses/bootstrap.py new file mode 100644 index 0000000..e0362c8 --- /dev/null +++ b/fgssl/contrib/model/GCL/losses/bootstrap.py @@ -0,0 +1,16 @@ +import torch +import torch.nn.functional as F +from .losses import Loss + + +class BootstrapLatent(Loss): + def __init__(self): + super(BootstrapLatent, self).__init__() + + def compute(self, anchor, sample, pos_mask, neg_mask=None, *args, **kwargs) -> torch.FloatTensor: + anchor = F.normalize(anchor, dim=-1, p=2) + sample = F.normalize(sample, dim=-1, p=2) + + similarity = anchor @ sample.t() + loss = (similarity * pos_mask).sum(dim=-1) + return loss.mean() diff --git a/fgssl/contrib/model/GCL/losses/infonce.py b/fgssl/contrib/model/GCL/losses/infonce.py new file mode 100644 index 0000000..0276e19 --- /dev/null +++ b/fgssl/contrib/model/GCL/losses/infonce.py @@ -0,0 +1,189 @@ +import torch +import numpy as np +import torch.nn.functional as F + +from .losses import Loss + + +def _similarity(h1: torch.Tensor, h2: torch.Tensor): + h1 = F.normalize(h1) + h2 = F.normalize(h2) + return h1 @ h2.t() + + +class InfoNCESP(Loss): + """ + InfoNCE loss for single positive. + """ + def __init__(self, tau): + super(InfoNCESP, self).__init__() + self.tau = tau + + def compute(self, anchor, sample, pos_mask, neg_mask, *args, **kwargs): + f = lambda x: torch.exp(x / self.tau) + sim = f(_similarity(anchor, sample)) # anchor x sample + assert sim.size() == pos_mask.size() # sanity check + + neg_mask = 1 - pos_mask + pos = (sim * pos_mask).sum(dim=1) + neg = (sim * neg_mask).sum(dim=1) + + loss = pos / (pos + neg) + loss = -torch.log(loss) + + return loss.mean() + + +class InfoNCE(Loss): + def __init__(self, tau): + super(InfoNCE, self).__init__() + self.tau = tau + + def compute(self, anchor, sample, pos_mask, neg_mask, *args, **kwargs): + sim = _similarity(anchor, sample) / self.tau + exp_sim = torch.exp(sim) * (pos_mask + neg_mask) + log_prob = sim - torch.log(exp_sim.sum(dim=1, keepdim=True)) + loss = log_prob * pos_mask + loss = loss.sum(dim=1) / pos_mask.sum(dim=1) + return -loss.mean() + + +class DebiasedInfoNCE(Loss): + def __init__(self, tau, tau_plus=0.1): + super(DebiasedInfoNCE, self).__init__() + self.tau = tau + self.tau_plus = tau_plus + + def compute(self, anchor, sample, pos_mask, neg_mask, *args, **kwargs): + num_neg = neg_mask.int().sum() + sim = _similarity(anchor, sample) / self.tau + exp_sim = torch.exp(sim) + + pos_sum = (exp_sim * pos_mask).sum(dim=1) + pos = pos_sum / pos_mask.int().sum(dim=1) + neg_sum = (exp_sim * neg_mask).sum(dim=1) + ng = (-num_neg * self.tau_plus * pos + neg_sum) / (1 - self.tau_plus) + ng = torch.clamp(ng, min=num_neg * np.e ** (-1. / self.tau)) + + log_prob = sim - torch.log((pos + ng).sum(dim=1, keepdim=True)) + loss = log_prob * pos_mask + loss = loss.sum(dim=1) / pos_mask.sum(dim=1) + return loss.mean() + + +class HardnessInfoNCE(Loss): + def __init__(self, tau, tau_plus=0.1, beta=1.0): + super(HardnessInfoNCE, self).__init__() + self.tau = tau + self.tau_plus = tau_plus + self.beta = beta + + def compute(self, anchor, sample, pos_mask, neg_mask, *args, **kwargs): + num_neg = neg_mask.int().sum() + sim = _similarity(anchor, sample) / self.tau + exp_sim = torch.exp(sim) + + pos = (exp_sim * pos_mask).sum(dim=1) / pos_mask.int().sum(dim=1) + imp = torch.exp(self.beta * (sim * neg_mask)) + reweight_neg = (imp * (exp_sim * neg_mask)).sum(dim=1) / imp.mean(dim=1) + ng = (-num_neg * self.tau_plus * pos + reweight_neg) / (1 - self.tau_plus) + ng = torch.clamp(ng, min=num_neg * np.e ** (-1. / self.tau)) + + log_prob = sim - torch.log((pos + ng).sum(dim=1, keepdim=True)) + loss = log_prob * pos_mask + loss = loss.sum(dim=1) / pos_mask.sum(dim=1) + return loss.mean() + + +class HardMixingLoss(torch.nn.Module): + def __init__(self, projection): + super(HardMixingLoss, self).__init__() + self.projection = projection + + @staticmethod + def tensor_similarity(z1, z2): + z1 = F.normalize(z1, dim=-1) # [N, d] + z2 = F.normalize(z2, dim=-1) # [N, s, d] + return torch.bmm(z2, z1.unsqueeze(dim=-1)).squeeze() + + def forward(self, z1: torch.Tensor, z2: torch.Tensor, threshold=0.1, s=150, mixup=0.2, *args, **kwargs): + f = lambda x: torch.exp(x / self.tau) + num_samples = z1.shape[0] + device = z1.device + + threshold = int(num_samples * threshold) + + refl1 = _similarity(z1, z1).diag() + refl2 = _similarity(z2, z2).diag() + pos_similarity = f(_similarity(z1, z2)) + neg_similarity1 = torch.cat([_similarity(z1, z1), _similarity(z1, z2)], dim=1) # [n, 2n] + neg_similarity2 = torch.cat([_similarity(z2, z1), _similarity(z2, z2)], dim=1) + neg_similarity1, indices1 = torch.sort(neg_similarity1, descending=True) + neg_similarity2, indices2 = torch.sort(neg_similarity2, descending=True) + neg_similarity1 = f(neg_similarity1) + neg_similarity2 = f(neg_similarity2) + z_pool = torch.cat([z1, z2], dim=0) + hard_samples1 = z_pool[indices1[:, :threshold]] # [N, k, d] + hard_samples2 = z_pool[indices2[:, :threshold]] + hard_sample_idx1 = torch.randint(hard_samples1.shape[1], size=[num_samples, 2 * s]).to(device) # [N, 2 * s] + hard_sample_idx2 = torch.randint(hard_samples2.shape[1], size=[num_samples, 2 * s]).to(device) + hard_sample_draw1 = hard_samples1[ + torch.arange(num_samples).unsqueeze(-1), hard_sample_idx1] # [N, 2 * s, d] + hard_sample_draw2 = hard_samples2[torch.arange(num_samples).unsqueeze(-1), hard_sample_idx2] + hard_sample_mixing1 = mixup * hard_sample_draw1[:, :s, :] + (1 - mixup) * hard_sample_draw1[:, s:, :] + hard_sample_mixing2 = mixup * hard_sample_draw2[:, :s, :] + (1 - mixup) * hard_sample_draw2[:, s:, :] + + h_m1 = self.projection(hard_sample_mixing1) + h_m2 = self.projection(hard_sample_mixing2) + + neg_m1 = f(self.tensor_similarity(z1, h_m1)).sum(dim=1) + neg_m2 = f(self.tensor_similarity(z2, h_m2)).sum(dim=1) + pos = pos_similarity.diag() + neg1 = neg_similarity1.sum(dim=1) + neg2 = neg_similarity2.sum(dim=1) + loss1 = -torch.log(pos / (neg1 + neg_m1 - refl1)) + loss2 = -torch.log(pos / (neg2 + neg_m2 - refl2)) + loss = (loss1 + loss2) * 0.5 + loss = loss.mean() + return loss + + +class RingLoss(torch.nn.Module): + def __init__(self): + super(RingLoss, self).__init__() + + def forward(self, h1: torch.Tensor, h2: torch.Tensor, y: torch.Tensor, tau, threshold=0.1, *args, **kwargs): + f = lambda x: torch.exp(x / tau) + num_samples = h1.shape[0] + device = h1.device + threshold = int(num_samples * threshold) + + false_neg_mask = torch.zeros((num_samples, 2 * num_samples), dtype=torch.int).to(device) + for i in range(num_samples): + false_neg_mask[i] = (y == y[i]).repeat(2) + + pos_sim = f(_similarity(h1, h2)) + neg_sim1 = torch.cat([_similarity(h1, h1), _similarity(h1, h2)], dim=1) # [n, 2n] + neg_sim2 = torch.cat([_similarity(h2, h1), _similarity(h2, h2)], dim=1) + neg_sim1, indices1 = torch.sort(neg_sim1, descending=True) + neg_sim2, indices2 = torch.sort(neg_sim2, descending=True) + + y_repeated = y.repeat(2) + false_neg_cnt = torch.zeros((num_samples)).to(device) + for i in range(num_samples): + false_neg_cnt[i] = (y_repeated[indices1[i, threshold:-threshold]] == y[i]).sum() + + neg_sim1 = f(neg_sim1[:, threshold:-threshold]) + neg_sim2 = f(neg_sim2[:, threshold:-threshold]) + + pos = pos_sim.diag() + neg1 = neg_sim1.sum(dim=1) + neg2 = neg_sim2.sum(dim=1) + + loss1 = -torch.log(pos / neg1) + loss2 = -torch.log(pos / neg2) + + loss = (loss1 + loss2) * 0.5 + loss = loss.mean() + + return loss diff --git a/fgssl/contrib/model/GCL/losses/jsd.py b/fgssl/contrib/model/GCL/losses/jsd.py new file mode 100644 index 0000000..8efa121 --- /dev/null +++ b/fgssl/contrib/model/GCL/losses/jsd.py @@ -0,0 +1,77 @@ +import numpy as np +import torch.nn.functional as F + +from .losses import Loss + + +class JSD(Loss): + def __init__(self, discriminator=lambda x, y: x @ y.t()): + super(JSD, self).__init__() + self.discriminator = discriminator + + def compute(self, anchor, sample, pos_mask, neg_mask, *args, **kwargs): + num_neg = neg_mask.int().sum() + num_pos = pos_mask.int().sum() + similarity = self.discriminator(anchor, sample) + + E_pos = (np.log(2) - F.softplus(- similarity * pos_mask)).sum() + E_pos /= num_pos + + neg_sim = similarity * neg_mask + E_neg = (F.softplus(- neg_sim) + neg_sim - np.log(2)).sum() + E_neg /= num_neg + + return E_neg - E_pos + + +class DebiasedJSD(Loss): + def __init__(self, discriminator=lambda x, y: x @ y.t(), tau_plus=0.1): + super(DebiasedJSD, self).__init__() + self.discriminator = discriminator + self.tau_plus = tau_plus + + def compute(self, anchor, sample, pos_mask, neg_mask, *args, **kwargs): + num_neg = neg_mask.int().sum() + num_pos = pos_mask.int().sum() + similarity = self.discriminator(anchor, sample) + + pos_sim = similarity * pos_mask + E_pos = np.log(2) - F.softplus(- pos_sim) + E_pos -= (self.tau_plus / (1 - self.tau_plus)) * (F.softplus(-pos_sim) + pos_sim) + E_pos = E_pos.sum() / num_pos + + neg_sim = similarity * neg_mask + E_neg = (F.softplus(- neg_sim) + neg_sim - np.log(2)) / (1 - self.tau_plus) + E_neg = E_neg.sum() / num_neg + + return E_neg - E_pos + + +class HardnessJSD(Loss): + def __init__(self, discriminator=lambda x, y: x @ y.t(), tau_plus=0.1, beta=0.05): + super(HardnessJSD, self).__init__() + self.discriminator = discriminator + self.tau_plus = tau_plus + self.beta = beta + + def compute(self, anchor, sample, pos_mask, neg_mask, *args, **kwargs): + num_neg = neg_mask.int().sum() + num_pos = pos_mask.int().sum() + similarity = self.discriminator(anchor, sample) + + pos_sim = similarity * pos_mask + E_pos = np.log(2) - F.softplus(- pos_sim) + E_pos -= (self.tau_plus / (1 - self.tau_plus)) * (F.softplus(-pos_sim) + pos_sim) + E_pos = E_pos.sum() / num_pos + + neg_sim = similarity * neg_mask + E_neg = F.softplus(- neg_sim) + neg_sim + + reweight = -2 * neg_sim / max(neg_sim.max(), neg_sim.min().abs()) + reweight = (self.beta * reweight).exp() + reweight /= reweight.mean(dim=1, keepdim=True) + + E_neg = (reweight * E_neg) / (1 - self.tau_plus) - np.log(2) + E_neg = E_neg.sum() / num_neg + + return E_neg - E_pos diff --git a/fgssl/contrib/model/GCL/losses/losses.py b/fgssl/contrib/model/GCL/losses/losses.py new file mode 100644 index 0000000..79524c0 --- /dev/null +++ b/fgssl/contrib/model/GCL/losses/losses.py @@ -0,0 +1,12 @@ +import torch +from abc import ABC, abstractmethod + + +class Loss(ABC): + @abstractmethod + def compute(self, anchor, sample, pos_mask, neg_mask, *args, **kwargs) -> torch.FloatTensor: + pass + + def __call__(self, anchor, sample, pos_mask=None, neg_mask=None, *args, **kwargs) -> torch.FloatTensor: + loss = self.compute(anchor, sample, pos_mask, neg_mask, *args, **kwargs) + return loss diff --git a/fgssl/contrib/model/GCL/losses/triplet.py b/fgssl/contrib/model/GCL/losses/triplet.py new file mode 100644 index 0000000..3530a34 --- /dev/null +++ b/fgssl/contrib/model/GCL/losses/triplet.py @@ -0,0 +1,81 @@ +import torch +from .losses import Loss + + +class TripletMarginSP(Loss): + def __init__(self, margin: float = 1.0, p: float = 2, *args, **kwargs): + super(TripletMarginSP, self).__init__() + self.loss_fn = torch.nn.TripletMarginLoss(margin=margin, p=p, reduction='none') + self.margin = margin + + def compute(self, anchor, sample, pos_mask, neg_mask=None, *args, **kwargs): + neg_mask = 1. - pos_mask + + num_pos = pos_mask.to(torch.long).sum(dim=1) + num_neg = neg_mask.to(torch.long).sum(dim=1) + + dist = torch.cdist(anchor, sample, p=2) # [num_anchors, num_samples] + + pos_dist = pos_mask * dist + neg_dist = neg_mask * dist + + pos_dist, neg_dist = pos_dist.sum(dim=1), neg_dist.sum(dim=1) + + loss = pos_dist / num_pos - neg_dist / num_neg + self.margin + loss = torch.where(loss > 0, loss, torch.zeros_like(loss)) + + return loss.mean() + + +class TripletMargin(Loss): + def __init__(self, margin: float = 1.0, p: float = 2, *args, **kwargs): + super(TripletMargin, self).__init__() + self.loss_fn = torch.nn.TripletMarginLoss(margin=margin, p=p, reduction='none') + self.margin = margin + + def compute(self, anchor, sample, pos_mask, neg_mask=None, *args, **kwargs): + num_anchors = anchor.size()[0] + num_samples = sample.size()[0] + + # Key idea here: + # (1) Use all possible triples (will be num_anchors * num_positives * num_negatives triples in total) + # (2) Use PyTorch's TripletMarginLoss to compute the marginal loss for each triple + # (3) Since TripletMarginLoss accepts input tensors of shape (B, D), where B is the batch size, + # we have to manually construct all triples and flatten them as an input tensor in the + # shape of (num_triples, D). + # (4) We first compute loss for all triples (including those that are not anchor - positive - negative), which + # will be num_anchors * num_samples * num_samples triples, and then filter them with masks. + + # compute negative mask + neg_mask = 1. - pos_mask if neg_mask is None else neg_mask + + anchor = torch.unsqueeze(anchor, dim=1) # [N, 1, D] + anchor = torch.unsqueeze(anchor, dim=1) # [N, 1, 1, D] + anchor = anchor.expand(-1, num_samples, num_samples, -1) # [N, M, M, D] + anchor = torch.flatten(anchor, end_dim=1) # [N * M * M, D] + + pos_sample = torch.unsqueeze(sample, dim=0) # [1, M, D] + pos_sample = torch.unsqueeze(pos_sample, dim=2) # [1, M, 1, D] + pos_sample = pos_sample.expand(num_anchors, -1, num_samples, -1) # [N, M, M, D] + pos_sample = torch.flatten(pos_sample, end_dim=1) # [N * M * M, D] + + neg_sample = torch.unsqueeze(sample, dim=0) # [1, M, D] + neg_sample = torch.unsqueeze(neg_sample, dim=0) # [1, 1, M, D] + neg_sample = neg_sample.expand(num_anchors, -1, num_samples, -1) # [N, M, M, D] + neg_sample = torch.flatten(neg_sample, end_dim=1) # [N * M * M, D] + + loss = self.loss_fn(anchor, pos_sample, neg_sample) # [N, M, M] + loss = loss.view(num_anchors, num_samples, num_samples) + + pos_mask1 = torch.unsqueeze(pos_mask, dim=2) # [N, M, 1] + pos_mask1 = pos_mask1.expand(-1, -1, num_samples) # [N, M, M] + neg_mask1 = torch.unsqueeze(neg_mask, dim=1) # [N, 1, M] + neg_mask1 = neg_mask1.expand(-1, num_samples, -1) # [N, M, M] + + pair_mask = pos_mask1 * neg_mask1 # [N, M, M] + num_pairs = pair_mask.sum() + + loss = loss * pair_mask + loss = loss.sum() + + return loss / num_pairs diff --git a/fgssl/contrib/model/GCL/losses/vicreg.py b/fgssl/contrib/model/GCL/losses/vicreg.py new file mode 100644 index 0000000..284f5d3 --- /dev/null +++ b/fgssl/contrib/model/GCL/losses/vicreg.py @@ -0,0 +1,43 @@ +import torch +import torch.nn.functional as F +from .losses import Loss + + +class VICReg(Loss): + def __init__(self, sim_weight=25.0, var_weight=25.0, cov_weight=1.0, eps=1e-4): + super(VICReg, self).__init__() + self.sim_weight = sim_weight + self.var_weight = var_weight + self.cov_weight = cov_weight + self.eps = eps + + @staticmethod + def invariance_loss(h1, h2): + return F.mse_loss(h1, h2) + + def variance_loss(self, h1, h2): + std_z1 = torch.sqrt(h1.var(dim=0) + self.eps) + std_z2 = torch.sqrt(h2.var(dim=0) + self.eps) + std_loss = torch.mean(F.relu(1 - std_z1)) + torch.mean(F.relu(1 - std_z2)) + return std_loss + + @staticmethod + def covariance_loss(h1, h2): + num_nodes, hidden_dim = h1.size() + + h1 = h1 - h1.mean(dim=0) + h2 = h2 - h2.mean(dim=0) + cov_z1 = (h1.T @ h1) / (num_nodes - 1) + cov_z2 = (h2.T @ h2) / (num_nodes - 1) + + diag = torch.eye(hidden_dim, device=h1.device) + cov_loss = cov_z1[~diag.bool()].pow_(2).sum() / hidden_dim + cov_z2[~diag.bool()].pow_(2).sum() / hidden_dim + return cov_loss + + def compute(self, anchor, sample, pos_mask, neg_mask, *args, **kwargs) -> torch.FloatTensor: + sim_loss = self.invariance_loss(anchor, sample) + var_loss = self.variance_loss(anchor, sample) + cov_loss = self.covariance_loss(anchor, sample) + + loss = self.sim_weight * sim_loss + self.var_weight * var_loss + self.cov_weight * cov_loss + return loss.mean() diff --git a/fgssl/contrib/model/GCL/models/__init__.py b/fgssl/contrib/model/GCL/models/__init__.py new file mode 100644 index 0000000..c1859c0 --- /dev/null +++ b/fgssl/contrib/model/GCL/models/__init__.py @@ -0,0 +1,15 @@ +from .samplers import SameScaleSampler, CrossScaleSampler, get_sampler +from .contrast_model import SingleBranchContrast, DualBranchContrast, WithinEmbedContrast, BootstrapContrast + + +__all__ = [ + 'SingleBranchContrast', + 'DualBranchContrast', + 'WithinEmbedContrast', + 'BootstrapContrast', + 'SameScaleSampler', + 'CrossScaleSampler', + 'get_sampler' +] + +classes = __all__ diff --git a/fgssl/contrib/model/GCL/models/contrast_model.py b/fgssl/contrib/model/GCL/models/contrast_model.py new file mode 100644 index 0000000..b305e46 --- /dev/null +++ b/fgssl/contrib/model/GCL/models/contrast_model.py @@ -0,0 +1,120 @@ +import torch + +from GCL.losses import Loss +from GCL.models import get_sampler + + +def add_extra_mask(pos_mask, neg_mask=None, extra_pos_mask=None, extra_neg_mask=None): + if extra_pos_mask is not None: + pos_mask = torch.bitwise_or(pos_mask.bool(), extra_pos_mask.bool()).float() + if extra_neg_mask is not None: + neg_mask = torch.bitwise_and(neg_mask.bool(), extra_neg_mask.bool()).float() + else: + neg_mask = 1. - pos_mask + return pos_mask, neg_mask + + +class SingleBranchContrast(torch.nn.Module): + def __init__(self, loss: Loss, mode: str, intraview_negs: bool = False, **kwargs): + super(SingleBranchContrast, self).__init__() + assert mode == 'G2L' # only global-local pairs allowed in single-branch contrastive learning + self.loss = loss + self.mode = mode + self.sampler = get_sampler(mode, intraview_negs=intraview_negs) + self.kwargs = kwargs + + def forward(self, h, g, batch=None, hn=None, extra_pos_mask=None, extra_neg_mask=None): + if batch is None: # for single-graph datasets + assert hn is not None + anchor, sample, pos_mask, neg_mask = self.sampler(anchor=g, sample=h, neg_sample=hn) + else: # for multi-graph datasets + assert batch is not None + anchor, sample, pos_mask, neg_mask = self.sampler(anchor=g, sample=h, batch=batch) + + pos_mask, neg_mask = add_extra_mask(pos_mask, neg_mask, extra_pos_mask, extra_neg_mask) + loss = self.loss(anchor=anchor, sample=sample, pos_mask=pos_mask, neg_mask=neg_mask, **self.kwargs) + return loss + + +class DualBranchContrast(torch.nn.Module): + def __init__(self, loss: Loss, mode: str, intraview_negs: bool = False, **kwargs): + super(DualBranchContrast, self).__init__() + self.loss = loss + self.mode = mode + self.sampler = get_sampler(mode, intraview_negs=intraview_negs) + self.kwargs = kwargs + + def forward(self, h1=None, h2=None, g1=None, g2=None, batch=None, h3=None, h4=None, + extra_pos_mask=None, extra_neg_mask=None): + if self.mode == 'L2L': + assert h1 is not None and h2 is not None + anchor1, sample1, pos_mask1, neg_mask1 = self.sampler(anchor=h1, sample=h2) + anchor2, sample2, pos_mask2, neg_mask2 = self.sampler(anchor=h2, sample=h1) + elif self.mode == 'G2G': + assert g1 is not None and g2 is not None + anchor1, sample1, pos_mask1, neg_mask1 = self.sampler(anchor=g1, sample=g2) + anchor2, sample2, pos_mask2, neg_mask2 = self.sampler(anchor=g2, sample=g1) + else: # global-to-local + if batch is None or batch.max().item() + 1 <= 1: # single graph + assert all(v is not None for v in [h1, h2, g1, g2, h3, h4]) + anchor1, sample1, pos_mask1, neg_mask1 = self.sampler(anchor=g1, sample=h2, neg_sample=h4) + anchor2, sample2, pos_mask2, neg_mask2 = self.sampler(anchor=g2, sample=h1, neg_sample=h3) + else: # multiple graphs + assert all(v is not None for v in [h1, h2, g1, g2, batch]) + anchor1, sample1, pos_mask1, neg_mask1 = self.sampler(anchor=g1, sample=h2, batch=batch) + anchor2, sample2, pos_mask2, neg_mask2 = self.sampler(anchor=g2, sample=h1, batch=batch) + + pos_mask1, neg_mask1 = add_extra_mask(pos_mask1, neg_mask1, extra_pos_mask, extra_neg_mask) + pos_mask2, neg_mask2 = add_extra_mask(pos_mask2, neg_mask2, extra_pos_mask, extra_neg_mask) + l1 = self.loss(anchor=anchor1, sample=sample1, pos_mask=pos_mask1, neg_mask=neg_mask1, **self.kwargs) + l2 = self.loss(anchor=anchor2, sample=sample2, pos_mask=pos_mask2, neg_mask=neg_mask2, **self.kwargs) + + return (l1 + l2) * 0.5 + + +class BootstrapContrast(torch.nn.Module): + def __init__(self, loss, mode='L2L'): + super(BootstrapContrast, self).__init__() + self.loss = loss + self.mode = mode + self.sampler = get_sampler(mode, intraview_negs=False) + + def forward(self, h1_pred=None, h2_pred=None, h1_target=None, h2_target=None, + g1_pred=None, g2_pred=None, g1_target=None, g2_target=None, + batch=None, extra_pos_mask=None): + if self.mode == 'L2L': + assert all(v is not None for v in [h1_pred, h2_pred, h1_target, h2_target]) + anchor1, sample1, pos_mask1, _ = self.sampler(anchor=h1_target, sample=h2_pred) + anchor2, sample2, pos_mask2, _ = self.sampler(anchor=h2_target, sample=h1_pred) + elif self.mode == 'G2G': + assert all(v is not None for v in [g1_pred, g2_pred, g1_target, g2_target]) + anchor1, sample1, pos_mask1, _ = self.sampler(anchor=g1_target, sample=g2_pred) + anchor2, sample2, pos_mask2, _ = self.sampler(anchor=g2_target, sample=g1_pred) + else: + assert all(v is not None for v in [h1_pred, h2_pred, g1_target, g2_target]) + if batch is None or batch.max().item() + 1 <= 1: # single graph + pos_mask1 = pos_mask2 = torch.ones([1, h1_pred.shape[0]], device=h1_pred.device) + anchor1, sample1 = g1_target, h2_pred + anchor2, sample2 = g2_target, h1_pred + else: + anchor1, sample1, pos_mask1, _ = self.sampler(anchor=g1_target, sample=h2_pred, batch=batch) + anchor2, sample2, pos_mask2, _ = self.sampler(anchor=g2_target, sample=h1_pred, batch=batch) + + pos_mask1, _ = add_extra_mask(pos_mask1, extra_pos_mask=extra_pos_mask) + pos_mask2, _ = add_extra_mask(pos_mask2, extra_pos_mask=extra_pos_mask) + l1 = self.loss(anchor=anchor1, sample=sample1, pos_mask=pos_mask1) + l2 = self.loss(anchor=anchor2, sample=sample2, pos_mask=pos_mask2) + + return (l1 + l2) * 0.5 + + +class WithinEmbedContrast(torch.nn.Module): + def __init__(self, loss: Loss, **kwargs): + super(WithinEmbedContrast, self).__init__() + self.loss = loss + self.kwargs = kwargs + + def forward(self, h1, h2): + l1 = self.loss(anchor=h1, sample=h2, **self.kwargs) + l2 = self.loss(anchor=h2, sample=h1, **self.kwargs) + return (l1 + l2) * 0.5 diff --git a/fgssl/contrib/model/GCL/models/samplers.py b/fgssl/contrib/model/GCL/models/samplers.py new file mode 100644 index 0000000..1c03982 --- /dev/null +++ b/fgssl/contrib/model/GCL/models/samplers.py @@ -0,0 +1,81 @@ +import torch +from abc import ABC, abstractmethod +from torch_scatter import scatter + + +class Sampler(ABC): + def __init__(self, intraview_negs=False): + self.intraview_negs = intraview_negs + + def __call__(self, anchor, sample, *args, **kwargs): + ret = self.sample(anchor, sample, *args, **kwargs) + if self.intraview_negs: + ret = self.add_intraview_negs(*ret) + return ret + + @abstractmethod + def sample(self, anchor, sample, *args, **kwargs): + pass + + @staticmethod + def add_intraview_negs(anchor, sample, pos_mask, neg_mask): + num_nodes = anchor.size(0) + device = anchor.device + intraview_pos_mask = torch.zeros_like(pos_mask, device=device) + intraview_neg_mask = torch.ones_like(pos_mask, device=device) - torch.eye(num_nodes, device=device) + new_sample = torch.cat([sample, anchor], dim=0) # (M+N) * K + new_pos_mask = torch.cat([pos_mask, intraview_pos_mask], dim=1) # M * (M+N) + new_neg_mask = torch.cat([neg_mask, intraview_neg_mask], dim=1) # M * (M+N) + return anchor, new_sample, new_pos_mask, new_neg_mask + + +class SameScaleSampler(Sampler): + def __init__(self, *args, **kwargs): + super(SameScaleSampler, self).__init__(*args, **kwargs) + + def sample(self, anchor, sample, *args, **kwargs): + assert anchor.size(0) == sample.size(0) + num_nodes = anchor.size(0) + device = anchor.device + pos_mask = torch.eye(num_nodes, dtype=torch.float32, device=device) + neg_mask = 1. - pos_mask + return anchor, sample, pos_mask, neg_mask + + +class CrossScaleSampler(Sampler): + def __init__(self, *args, **kwargs): + super(CrossScaleSampler, self).__init__(*args, **kwargs) + + def sample(self, anchor, sample, batch=None, neg_sample=None, use_gpu=True, *args, **kwargs): + num_graphs = anchor.shape[0] # M + num_nodes = sample.shape[0] # N + device = sample.device + + if neg_sample is not None: + assert num_graphs == 1 # only one graph, explicit negative samples are needed + assert sample.shape == neg_sample.shape + pos_mask1 = torch.ones((num_graphs, num_nodes), dtype=torch.float32, device=device) + pos_mask0 = torch.zeros((num_graphs, num_nodes), dtype=torch.float32, device=device) + pos_mask = torch.cat([pos_mask1, pos_mask0], dim=1) # M * 2N + sample = torch.cat([sample, neg_sample], dim=0) # 2N * K + else: + assert batch is not None + if use_gpu: + ones = torch.eye(num_nodes, dtype=torch.float32, device=device) # N * N + pos_mask = scatter(ones, batch, dim=0, reduce='sum') # M * N + else: + pos_mask = torch.zeros((num_graphs, num_nodes), dtype=torch.float32).to(device) + for node_idx, graph_idx in enumerate(batch): + pos_mask[graph_idx][node_idx] = 1. # M * N + + neg_mask = 1. - pos_mask + return anchor, sample, pos_mask, neg_mask + + +def get_sampler(mode: str, intraview_negs: bool) -> Sampler: + if mode in {'L2L', 'G2G'}: + return SameScaleSampler(intraview_negs=intraview_negs) + elif mode == 'G2L': + return CrossScaleSampler(intraview_negs=intraview_negs) + else: + raise RuntimeError(f'unsupported mode: {mode}') diff --git a/fgssl/contrib/model/GCL/utils.py b/fgssl/contrib/model/GCL/utils.py new file mode 100644 index 0000000..27582a0 --- /dev/null +++ b/fgssl/contrib/model/GCL/utils.py @@ -0,0 +1,74 @@ +from typing import * +import os +import torch +import dgl +import random +import numpy as np + + +def split_dataset(dataset, split_mode, *args, **kwargs): + assert split_mode in ['rand', 'ogb', 'wikics', 'preload'] + if split_mode == 'rand': + assert 'train_ratio' in kwargs and 'test_ratio' in kwargs + train_ratio = kwargs['train_ratio'] + test_ratio = kwargs['test_ratio'] + num_samples = dataset.x.size(0) + train_size = int(num_samples * train_ratio) + test_size = int(num_samples * test_ratio) + indices = torch.randperm(num_samples) + return { + 'train': indices[:train_size], + 'val': indices[train_size: test_size + train_size], + 'test': indices[test_size + train_size:] + } + elif split_mode == 'ogb': + return dataset.get_idx_split() + elif split_mode == 'wikics': + assert 'split_idx' in kwargs + split_idx = kwargs['split_idx'] + return { + 'train': dataset.train_mask[:, split_idx], + 'test': dataset.test_mask, + 'val': dataset.val_mask[:, split_idx] + } + elif split_mode == 'preload': + assert 'preload_split' in kwargs + assert kwargs['preload_split'] is not None + train_mask, test_mask, val_mask = kwargs['preload_split'] + return { + 'train': train_mask, + 'test': test_mask, + 'val': val_mask + } + + +def seed_everything(seed): + random.seed(seed) + os.environ['PYTHONHASHSEED'] = str(seed) + np.random.seed(seed) + + torch.backends.cudnn.benchmark = False + torch.backends.cudnn.deterministic = True + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + +def normalize(s): + return (s.max() - s) / (s.max() - s.mean()) + + +def build_dgl_graph(edge_index: torch.Tensor) -> dgl.DGLGraph: + row, col = edge_index + return dgl.graph((row, col)) + + +def batchify_dict(dicts: List[dict], aggr_func=lambda x: x): + res = dict() + for d in dicts: + for k, v in d.items(): + if k not in res: + res[k] = [v] + else: + res[k].append(v) + res = {k: aggr_func(v) for k, v in res.items()} + return res diff --git a/fgssl/contrib/model/__init__.py b/fgssl/contrib/model/__init__.py new file mode 100644 index 0000000..c0b3138 --- /dev/null +++ b/fgssl/contrib/model/__init__.py @@ -0,0 +1,8 @@ +from os.path import dirname, basename, isfile, join +import glob + +modules = glob.glob(join(dirname(__file__), "*.py")) +__all__ = [ + basename(f)[:-3] for f in modules + if isfile(f) and not f.endswith('__init__.py') +] diff --git a/fgssl/contrib/model/aug_base_model.py b/fgssl/contrib/model/aug_base_model.py new file mode 100644 index 0000000..fb437f7 --- /dev/null +++ b/fgssl/contrib/model/aug_base_model.py @@ -0,0 +1,107 @@ +from federatedscope.register import register_model +import torch +import torch.nn.functional as F + +from torch.nn import ModuleList +from torch_geometric.data import Data +import pyro +from torch_geometric.nn import GCNConv +import torch.nn as nn +from torch_geometric.utils import to_dense_adj +import torch.distributions as dist +# Build you torch or tf model class here +class GCN_AUG(torch.nn.Module): + def __init__(self, + in_channels, + out_channels, + hidden=64, + max_depth=2, + dropout=.0): + super(GCN_AUG, self).__init__() + self.convs = ModuleList() + for i in range(max_depth): + if i == 0: + self.convs.append(GCNConv(in_channels, hidden)) + elif (i + 1) == max_depth: + self.convs.append(GCNConv(hidden, out_channels)) + else: + self.convs.append(GCNConv(hidden, hidden)) + self.dropout = dropout + self.bn1 = nn.BatchNorm1d(hidden, momentum = 0.01) + self.prelu = nn.PReLU() + + def forward(self, data): + + if isinstance(data, Data): + x, edge_index = data.x, data.edge_index + elif isinstance(data, tuple): + x, edge_index = data + else: + raise TypeError('Unsupported data type!') + + adj_sampled, adj_logits, adj_orig = self.sample(0.8, x, edge_index) + + edge_index = adj_sampled.nonzero().t() + + for i, conv in enumerate(self.convs): + x = conv(x, edge_index) + + if (i+1) == len(self.convs): + break + else: + z = self.prelu(x) + x = F.relu(F.dropout(self.bn1(x), p=self.dropout, training=self.training)) + + return x, z, adj_sampled , adj_logits, adj_orig + + def sample(self, alpha, x, edge_index): + + for i,conv in enumerate(self.convs): + if i == 0: + x = conv(x,edge_index) + + + adj_orig = torch.zeros((x.shape[0],x.shape[0])).to('cuda:2') + for row in edge_index.T: + i,j = row + adj_orig[i, j] = 1 + + adj_logits = x @ x.T + + edge_probs = adj_logits / torch.max(adj_logits) + + edge_probs = alpha * edge_probs + (1-alpha) * adj_orig + + edge_probs = torch.where(edge_probs < 0, torch.zeros_like(edge_probs), edge_probs) + + adj_sampled = pyro.distributions.RelaxedBernoulliStraightThrough(temperature=1, + probs=edge_probs).rsample() + + adj_sampled = adj_sampled.triu(1) + adj_sampled = adj_sampled + adj_sampled.T + + adj_sampled.fill_diagonal_(1) + # D_norm = torch.diag(torch.pow(adj_sampled.sum(1), -0.5)) + # adj_sampled = D_norm @ adj_sampled @ D_norm + + return adj_sampled, adj_logits, adj_orig + + + +def gnnbuilder(model_config, input_shape): + x_shape, num_label, num_edge_features = input_shape + model = GCN_AUG(x_shape[-1], + model_config.out_channels, + hidden=model_config.hidden, + max_depth=model_config.layer, + dropout=model_config.dropout) + return model + + +def call_my_net(model_config, local_data): + if model_config.type == "gnn_gcn_aug": + model = gnnbuilder(model_config, local_data) + return model + + +register_model("gnn_gcn_aug", call_my_net) diff --git a/fgssl/contrib/model/example.py b/fgssl/contrib/model/example.py new file mode 100644 index 0000000..899246a --- /dev/null +++ b/fgssl/contrib/model/example.py @@ -0,0 +1,23 @@ +from federatedscope.register import register_model + + +# Build you torch or tf model class here +class MyNet(object): + pass + + +# Instantiate your model class with config and data +def ModelBuilder(model_config, local_data): + + model = MyNet() + + return model + + +def call_my_net(model_config, local_data): + if model_config.type == "mynet": + model = ModelBuilder(model_config, local_data) + return model + + +register_model("mynet", call_my_net) diff --git a/fgssl/contrib/model/model.py b/fgssl/contrib/model/model.py new file mode 100644 index 0000000..0af6c93 --- /dev/null +++ b/fgssl/contrib/model/model.py @@ -0,0 +1,436 @@ +# federatedscope/contrib/model/my_gcn.py +import pyro +import torch +import torch.nn.functional as F + +from torch.nn import ModuleList +from torch_geometric.data import Data +from torch_geometric.nn import GINConv +from torch_geometric.nn import GCNConv +from torch_geometric.utils import to_dense_adj + +from federatedscope.register import register_model +import torch.nn as nn +from federatedscope.core.mlp import MLP + + +class MyGCN(torch.nn.Module): + def __init__(self, + in_channels, + out_channels, + hidden=64, + max_depth=2, + dropout=.0): + super(MyGCN, self).__init__() + self.convs = ModuleList() + for i in range(max_depth): + if i == 0: + self.convs.append(GCNConv(in_channels, hidden)) + elif (i + 1) == max_depth: + self.convs.append(GCNConv(hidden, out_channels)) + else: + self.convs.append(GCNConv(hidden, hidden)) + self.dropout = dropout + self.bn1 = nn.BatchNorm1d(hidden, momentum = 0.01) + self.prelu = nn.PReLU() + + def forward(self, data): + if isinstance(data, Data): + x, edge_index = data.x, data.edge_index + elif isinstance(data, tuple): + x, edge_index = data + else: + raise TypeError('Unsupported data type!') + for i, conv in enumerate(self.convs): + x = conv(x, edge_index) + + if (i+1) == len(self.convs): + break + else: + z = x + x = F.relu(F.dropout(self.bn1(x), p=self.dropout, training=self.training)) + + return x, z + def link_predictor(self, x, edge_index): + x = x[edge_index[0]] * x[edge_index[1]] + x = self.output(x) + return x + +class MYGIN(torch.nn.Module): + r"""Graph Isomorphism Network model from the "How Powerful are Graph + Neural Networks?" paper, in ICLR'19 + + Arguments: + in_channels (int): dimension of input. + out_channels (int): dimension of output. + hidden (int): dimension of hidden units, default=64. + max_depth (int): layers of GNN, default=2. + dropout (float): dropout ratio, default=.0. + + """ + def __init__(self, + in_channels, + out_channels, + hidden=64, + max_depth=2, + dropout=.0): + super(MYGIN, self).__init__() + self.convs = ModuleList() + self.bn1 = nn.BatchNorm1d(hidden, momentum = 0.01) + for i in range(max_depth): + if i == 0: + self.convs.append( + GINConv(MLP([in_channels, hidden, hidden], + batch_norm=True))) + elif (i + 1) == max_depth: + self.convs.append( + GINConv( + MLP([hidden, hidden, out_channels], batch_norm=True))) + else: + self.convs.append( + GINConv(MLP([hidden, hidden, hidden], batch_norm=True))) + self.dropout = dropout + + def forward(self, data): + if isinstance(data, Data): + x, edge_index = data.x, data.edge_index + elif isinstance(data, tuple): + x, edge_index = data + else: + raise TypeError('Unsupported data type!') + + for i, conv in enumerate(self.convs): + x = conv(x, edge_index) + + if (i + 1) == len(self.convs): + break + else: + z = x + x = F.relu(F.dropout(self.bn1(x), p=self.dropout, training=self.training)) + return x, z + def link_predictor(self, x, edge_index): + x = x[edge_index[0]] * x[edge_index[1]] + x = self.output(x) + return x +from torch_geometric.nn import SAGEConv +class MYsage(torch.nn.Module): + def __init__(self, + in_channels, + out_channels, + hidden=64, + max_depth=2, + dropout=.0): + super(MYsage, self).__init__() + + self.num_layers = max_depth + self.dropout = dropout + + self.convs = torch.nn.ModuleList() + self.convs.append(SAGEConv(in_channels, hidden)) + for _ in range(self.num_layers - 2): + self.convs.append(SAGEConv(hidden, hidden)) + self.convs.append(SAGEConv(hidden, out_channels)) + def forward(self, data): + if isinstance(data, Data): + x, edge_index = data.x, data.edge_index + elif isinstance(data, tuple): + x, edge_index = data + else: + raise TypeError('Unsupported data type!') + + for i, conv in enumerate(self.convs): + x = conv(x, edge_index) + if (i + 1) == len(self.convs): + break + else: + z = x + x = F.relu(F.dropout(x, p=self.dropout, training=self.training)) + return x, z + def link_predictor(self, x, edge_index): + x = x[edge_index[0]] * x[edge_index[1]] + x = self.output(x) + return x +from torch_geometric.nn import GATConv + + + +class MYGAT(torch.nn.Module): + r"""GAT model from the "Graph Attention Networks" paper, in ICLR'18 + + Arguments: + in_channels (int): dimension of input. + out_channels (int): dimension of output. + hidden (int): dimension of hidden units, default=64. + max_depth (int): layers of GNN, default=2. + dropout (float): dropout ratio, default=.0. + + """ + def __init__(self, + in_channels, + out_channels, + hidden=64, + max_depth=2, + dropout=.0): + super(MYGAT, self).__init__() + self.convs = ModuleList() + self.bn1 = nn.BatchNorm1d(hidden, momentum = 0.01) + for i in range(max_depth): + if i == 0: + self.convs.append(GATConv(in_channels, hidden)) + elif (i + 1) == max_depth: + self.convs.append(GATConv(hidden, out_channels)) + else: + self.convs.append(GATConv(hidden, hidden)) + self.dropout = dropout + self.prelu = nn.PReLU() + self.prelu2 = nn.PReLU() + self.augConv = GATConv(in_channels,hidden) + dim_list = [hidden for _ in range(2)] + self.output = MLP([hidden] + dim_list + [out_channels], + batch_norm=True) + def reset_parameters(self): + for m in self.convs: + m.reset_parameters() + + def forward(self, data): + if isinstance(data, Data): + x, edge_index = data.x, data.edge_index + elif isinstance(data, tuple): + x, edge_index = data + else: + raise TypeError('Unsupported data type!') + + for i, conv in enumerate(self.convs): + x = conv(x, edge_index) + if (i + 1) == len(self.convs): + break + else: + z = x + x = F.relu(F.dropout(x, p=self.dropout, training=self.training)) + return x, z + + def link_predictor(self, x, edge_index): + x = x[edge_index[0]] * x[edge_index[1]] + x = self.output(x) + return x + + + +class MyStar(torch.nn.Module): + r"""GAT model from the "Graph Attention Networks" paper, in ICLR'18 + + Arguments: + in_channels (int): dimension of input. + out_channels (int): dimension of output. + hidden (int): dimension of hidden units, default=64. + max_depth (int): layers of GNN, default=2. + dropout (float): dropout ratio, default=.0. + + """ + def __init__(self, + in_channels, + out_channels, + hidden=64, + max_depth=2, + dropout=.0): + super(MyStar, self).__init__() + self.convs = ModuleList() + self.bn1 = nn.BatchNorm1d(hidden, momentum = 0.01) + for i in range(max_depth): + if i == 0: + self.convs.append(GATConv(in_channels, hidden)) + elif (i + 1) == max_depth: + self.convs.append(GATConv(hidden, out_channels)) + else: + self.convs.append(GATConv(hidden, hidden)) + + self.pre = torch.nn.Sequential(torch.nn.Linear(in_channels, hidden)) + self.embedding_s = torch.nn.Linear(64, hidden) + self.dropout = dropout + self.prelu = nn.PReLU() + self.prelu2 = nn.PReLU() + self.augConv = GATConv(in_channels,hidden) + dim_list = [hidden for _ in range(2)] + self.output = MLP([hidden] + dim_list + [out_channels], + batch_norm=True) + def reset_parameters(self): + for m in self.convs: + m.reset_parameters() + + def forward(self, data): + if isinstance(data, Data): + x, edge_index = data.x, data.edge_index + elif isinstance(data, tuple): + x, edge_index = data + else: + raise TypeError('Unsupported data type!') + + for i, conv in enumerate(self.convs): + x = conv(x, edge_index) + if (i + 1) == len(self.convs): + break + else: + z = x + x = F.relu(F.dropout(x, p=self.dropout, training=self.training)) + return x, z + + def link_predictor(self, x, edge_index): + x = x[edge_index[0]] * x[edge_index[1]] + x = self.output(x) + return x + + +def gcnbuilder(model_config, input_shape): + x_shape, num_label, num_edge_features = input_shape + model = MyGCN(x_shape[-1], + model_config.out_channels, + hidden=model_config.hidden, + max_depth=model_config.layer, + dropout=model_config.dropout) + return model + +def ginbuilder(model_config, input_shape): + x_shape, num_label, num_edge_features = input_shape + model = MYGIN(x_shape[-1], + model_config.out_channels, + hidden=model_config.hidden, + max_depth=model_config.layer, + dropout=model_config.dropout) + return model + +def gatbuilder(model_config, input_shape): + x_shape, num_label, num_edge_features = input_shape + model = MYGAT(x_shape[-1], + model_config.out_channels, + hidden=model_config.hidden, + max_depth=model_config.layer, + dropout=model_config.dropout) + return model + +def sagebuilder(model_config, input_shape): + x_shape, num_label, num_edge_features = input_shape + model = MYsage(x_shape[-1], + model_config.out_channels, + hidden=model_config.hidden, + max_depth=model_config.layer, + dropout=model_config.dropout) + return model + +def fedbuilder(model_config, input_shape): + x_shape, num_label, num_edge_features = input_shape + model = MyStar(x_shape[-1], + model_config.out_channels, + hidden=model_config.hidden, + max_depth=model_config.layer, + dropout=model_config.dropout) + return model + + +def call_my_net(model_config, local_data): + # Please name your gnn model with prefix 'gnn_' + if model_config.type == "gnn_mygcn": + model = gcnbuilder(model_config, local_data) + return model + if model_config.type == "gnn_mygin": + model = ginbuilder(model_config, local_data) + return model + if model_config.type == "gnn_mygat": + model = gatbuilder(model_config, local_data) + return model + if model_config.type == "gnn_mysage": + model = sagebuilder(model_config, local_data) + return model + if model_config.type == "gnn_fedstar": + model = fedbuilder(model_config, local_data) + return model + +register_model("gnn_fedstar",call_my_net) +register_model("gnn_mygin", call_my_net) +register_model("gnn_mygcn", call_my_net) +register_model("gnn_mygat", call_my_net) +register_model("gnn_mysage", call_my_net) + + + +from torch_geometric.utils import to_networkx, degree, to_dense_adj, to_scipy_sparse_matrix +from scipy import sparse as sp +def init_structure_encoding(args, gs, type_init): + + if type_init == 'rw': + for g in gs: + # Geometric diffusion features with Random Walk + A = to_scipy_sparse_matrix(g.edge_index, num_nodes=g.num_nodes) + D = (degree(g.edge_index[0], num_nodes=g.num_nodes) ** -1.0).numpy() + + Dinv=sp.diags(D) + RW=A*Dinv + M=RW + + SE_rw=[torch.from_numpy(M.diagonal()).float()] + M_power=M + for _ in range(args.n_rw-1): + M_power=M_power*M + SE_rw.append(torch.from_numpy(M_power.diagonal()).float()) + SE_rw=torch.stack(SE_rw,dim=-1) + + g['stc_enc'] = SE_rw + + elif type_init == 'dg': + for g in gs: + # PE_degree + g_dg = (degree(g.edge_index[0], num_nodes=g.num_nodes)).numpy().clip(1, args.n_dg) + SE_dg = torch.zeros([g.num_nodes, args.n_dg]) + for i in range(len(g_dg)): + SE_dg[i,int(g_dg[i]-1)] = 1 + + g['stc_enc'] = SE_dg + + elif type_init == 'rw_dg': + for g in gs: + # SE_rw + A = to_scipy_sparse_matrix(g.edge_index, num_nodes=g.num_nodes) + D = (degree(g.edge_index[0], num_nodes=g.num_nodes) ** -1.0).numpy() + + Dinv=sp.diags(D) + RW=A*Dinv + M=RW + + SE=[torch.from_numpy(M.diagonal()).float()] + M_power=M + for _ in range(args.n_rw-1): + M_power=M_power*M + SE.append(torch.from_numpy(M_power.diagonal()).float()) + SE_rw=torch.stack(SE,dim=-1) + + # PE_degree + g_dg = (degree(g.edge_index[0], num_nodes=g.num_nodes)).numpy().clip(1, args.n_dg) + SE_dg = torch.zeros([g.num_nodes, args.n_dg]) + for i in range(len(g_dg)): + SE_dg[i,int(g_dg[i]-1)] = 1 + + g['stc_enc'] = torch.cat([SE_rw, SE_dg], dim=1) + + elif type_init == "single": + A = to_scipy_sparse_matrix(gs.edge_index, num_nodes=g.num_nodes) + D = (degree(gs.edge_index[0], num_nodes=g.num_nodes) ** -1.0).numpy() + + Dinv = sp.diags(D) + RW = A * Dinv + M = RW + + SE = [torch.from_numpy(M.diagonal()).float()] + M_power = M + for _ in range(args.n_rw - 1): + M_power = M_power * M + SE.append(torch.from_numpy(M_power.diagonal()).float()) + SE_rw = torch.stack(SE, dim=-1) + + # PE_degree + g_dg = (degree(gs.edge_index[0], num_nodes=g.num_nodes)).numpy().clip(1, args.n_dg) + SE_dg = torch.zeros([gs.num_nodes, args.n_dg]) + for i in range(len(g_dg)): + SE_dg[i, int(g_dg[i] - 1)] = 1 + + gs['stc_enc'] = torch.cat([SE_rw, SE_dg], dim=1) + + return gs diff --git a/fgssl/contrib/model/resnet.py b/fgssl/contrib/model/resnet.py new file mode 100644 index 0000000..58d7421 --- /dev/null +++ b/fgssl/contrib/model/resnet.py @@ -0,0 +1,305 @@ +from federatedscope.register import register_model +'''Pre-activation ResNet in PyTorch. + +Reference: +[1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun + Identity Mappings in Deep Residual Networks. arXiv:1603.05027 +''' +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class PreActBlock(nn.Module): + '''Pre-activation version of the BasicBlock.''' + expansion = 1 + + def __init__(self, in_planes, planes, stride=1): + super(PreActBlock, self).__init__() + self.bn1 = nn.BatchNorm2d(in_planes) + self.conv1 = nn.Conv2d(in_planes, + planes, + kernel_size=3, + stride=stride, + padding=1, + bias=False) + self.bn2 = nn.BatchNorm2d(planes) + self.conv2 = nn.Conv2d(planes, + planes, + kernel_size=3, + stride=1, + padding=1, + bias=False) + + if stride != 1 or in_planes != self.expansion * planes: + self.shortcut = nn.Sequential( + nn.Conv2d(in_planes, + self.expansion * planes, + kernel_size=1, + stride=stride, + bias=False)) + + def forward(self, x): + out = F.relu(self.bn1(x)) + shortcut = self.shortcut(out) if hasattr(self, 'shortcut') else x + out = self.conv1(out) + out = self.conv2(F.relu(self.bn2(out))) + out += shortcut + return out + + +class PreActBottleneck(nn.Module): + '''Pre-activation version of the original Bottleneck module.''' + expansion = 4 + + def __init__(self, in_planes, planes, stride=1): + super(PreActBottleneck, self).__init__() + self.bn1 = nn.BatchNorm2d(in_planes) + self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) + self.bn2 = nn.BatchNorm2d(planes) + self.conv2 = nn.Conv2d(planes, + planes, + kernel_size=3, + stride=stride, + padding=1, + bias=False) + self.bn3 = nn.BatchNorm2d(planes) + self.conv3 = nn.Conv2d(planes, + self.expansion * planes, + kernel_size=1, + bias=False) + + if stride != 1 or in_planes != self.expansion * planes: + self.shortcut = nn.Sequential( + nn.Conv2d(in_planes, + self.expansion * planes, + kernel_size=1, + stride=stride, + bias=False)) + + def forward(self, x): + out = F.relu(self.bn1(x)) + shortcut = self.shortcut(out) if hasattr(self, 'shortcut') else x + out = self.conv1(out) + out = self.conv2(F.relu(self.bn2(out))) + out = self.conv3(F.relu(self.bn3(out))) + out += shortcut + return out + + +class PreActResNet(nn.Module): + def __init__(self, block, num_blocks, num_classes=10): + super(PreActResNet, self).__init__() + self.in_planes = 64 + + self.conv1 = nn.Conv2d(3, + 64, + kernel_size=3, + stride=1, + padding=1, + bias=False) + self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) + self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) + self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) + self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) + self.linear = nn.Linear(512 * block.expansion, num_classes) + + def _make_layer(self, block, planes, num_blocks, stride): + strides = [stride] + [1] * (num_blocks - 1) + layers = [] + for stride in strides: + layers.append(block(self.in_planes, planes, stride)) + self.in_planes = planes * block.expansion + return nn.Sequential(*layers) + + def forward(self, x): + out = self.conv1(x) + out = self.layer1(out) + out = self.layer2(out) + out = self.layer3(out) + out = self.layer4(out) + out = F.avg_pool2d(out, 4) + out = out.view(out.size(0), -1) + out = self.linear(out) + return out + + +def PreActResNet18(): + return PreActResNet(PreActBlock, [2, 2, 2, 2]) + + +def PreActResNet34(): + return PreActResNet(PreActBlock, [3, 4, 6, 3]) + + +def PreActResNet50(): + return PreActResNet(PreActBottleneck, [3, 4, 6, 3]) + + +def PreActResNet101(): + return PreActResNet(PreActBottleneck, [3, 4, 23, 3]) + + +def PreActResNet152(): + return PreActResNet(PreActBottleneck, [3, 8, 36, 3]) + + +class BasicBlock(nn.Module): + expansion = 1 + + def __init__(self, in_planes, planes, stride=1): + super(BasicBlock, self).__init__() + self.conv1 = nn.Conv2d(in_planes, + planes, + kernel_size=3, + stride=stride, + padding=1, + bias=False) + self.bn1 = nn.BatchNorm2d(planes) + self.conv2 = nn.Conv2d(planes, + planes, + kernel_size=3, + stride=1, + padding=1, + bias=False) + self.bn2 = nn.BatchNorm2d(planes) + + self.shortcut = nn.Sequential() + if stride != 1 or in_planes != self.expansion * planes: + self.shortcut = nn.Sequential( + nn.Conv2d(in_planes, + self.expansion * planes, + kernel_size=1, + stride=stride, + bias=False), nn.BatchNorm2d(self.expansion * planes)) + + def forward(self, x): + out = F.relu(self.bn1(self.conv1(x))) + out = self.bn2(self.conv2(out)) + out += self.shortcut(x) + out = F.relu(out) + return out + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, in_planes, planes, stride=1): + super(Bottleneck, self).__init__() + self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) + self.bn1 = nn.BatchNorm2d(planes) + self.conv2 = nn.Conv2d(planes, + planes, + kernel_size=3, + stride=stride, + padding=1, + bias=False) + self.bn2 = nn.BatchNorm2d(planes) + self.conv3 = nn.Conv2d(planes, + self.expansion * planes, + kernel_size=1, + bias=False) + self.bn3 = nn.BatchNorm2d(self.expansion * planes) + + self.shortcut = nn.Sequential() + if stride != 1 or in_planes != self.expansion * planes: + self.shortcut = nn.Sequential( + nn.Conv2d(in_planes, + self.expansion * planes, + kernel_size=1, + stride=stride, + bias=False), nn.BatchNorm2d(self.expansion * planes)) + + def forward(self, x): + out = F.relu(self.bn1(self.conv1(x))) + out = F.relu(self.bn2(self.conv2(out))) + out = self.bn3(self.conv3(out)) + out += self.shortcut(x) + out = F.relu(out) + return out + + +class ResNet(nn.Module): + def __init__(self, block, num_blocks, num_classes=10): + super(ResNet, self).__init__() + self.in_planes = 64 + + self.conv1 = nn.Conv2d(3, + 64, + kernel_size=3, + stride=1, + padding=1, + bias=False) + self.bn1 = nn.BatchNorm2d(64) + self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) + self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) + self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) + self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) + self.linear = nn.Linear(512 * block.expansion, num_classes) + + def _make_layer(self, block, planes, num_blocks, stride): + strides = [stride] + [1] * (num_blocks - 1) + layers = [] + for stride in strides: + layers.append(block(self.in_planes, planes, stride)) + self.in_planes = planes * block.expansion + return nn.Sequential(*layers) + + def forward(self, x): + out = F.relu(self.bn1(self.conv1(x))) + out = self.layer1(out) + out = self.layer2(out) + out = self.layer3(out) + out = self.layer4(out) + out = F.avg_pool2d(out, 4) + out = out.view(out.size(0), -1) + out = self.linear(out) + return out + + +def ResNet18(): + return ResNet(BasicBlock, [2, 2, 2, 2]) + + +def ResNet34(): + return ResNet(BasicBlock, [3, 4, 6, 3]) + + +def ResNet50(): + return ResNet(Bottleneck, [3, 4, 6, 3]) + + +def ResNet101(): + return ResNet(Bottleneck, [3, 4, 23, 3]) + + +def ResNet152(): + return ResNet(Bottleneck, [3, 8, 36, 3]) + + +def preact_resnet(model_config): + if '18' in model_config.type: + net = PreActResNet18() + elif '50' in model_config.type: + net = PreActResNet50() + return net + + +def resnet(model_config): + if '18' in model_config.type: + net = ResNet18() + elif '50' in model_config.type: + net = ResNet50() + return net + + +def call_resnet(model_config, local_data): + if 'resnet' in model_config.type and 'pre' in model_config.type: + model = preact_resnet(model_config) + return model + elif 'resnet' in model_config.type and 'pre' not in model_config.type: + model = resnet(model_config) + return model + + +register_model('resnet', call_resnet) diff --git a/fgssl/contrib/optimizer/__init__.py b/fgssl/contrib/optimizer/__init__.py new file mode 100644 index 0000000..c0b3138 --- /dev/null +++ b/fgssl/contrib/optimizer/__init__.py @@ -0,0 +1,8 @@ +from os.path import dirname, basename, isfile, join +import glob + +modules = glob.glob(join(dirname(__file__), "*.py")) +__all__ = [ + basename(f)[:-3] for f in modules + if isfile(f) and not f.endswith('__init__.py') +] diff --git a/fgssl/contrib/optimizer/example.py b/fgssl/contrib/optimizer/example.py new file mode 100644 index 0000000..41c083b --- /dev/null +++ b/fgssl/contrib/optimizer/example.py @@ -0,0 +1,17 @@ +from federatedscope.register import register_optimizer + + +def call_my_optimizer(model, type, lr, **kwargs): + try: + import torch.optim as optim + except ImportError: + optim = None + optimizer = None + + if type == 'myoptimizer': + if optim is not None: + optimizer = optim.Adam(model.parameters(), lr=lr, **kwargs) + return optimizer + + +register_optimizer('myoptimizer', call_my_optimizer) diff --git a/fgssl/contrib/scheduler/__init__.py b/fgssl/contrib/scheduler/__init__.py new file mode 100644 index 0000000..c0b3138 --- /dev/null +++ b/fgssl/contrib/scheduler/__init__.py @@ -0,0 +1,8 @@ +from os.path import dirname, basename, isfile, join +import glob + +modules = glob.glob(join(dirname(__file__), "*.py")) +__all__ = [ + basename(f)[:-3] for f in modules + if isfile(f) and not f.endswith('__init__.py') +] diff --git a/fgssl/contrib/scheduler/example.py b/fgssl/contrib/scheduler/example.py new file mode 100644 index 0000000..bf42d77 --- /dev/null +++ b/fgssl/contrib/scheduler/example.py @@ -0,0 +1,20 @@ +import math + +from federatedscope.register import register_scheduler + + +def call_my_scheduler(optimizer, type): + try: + import torch.optim as optim + except ImportError: + optim = None + scheduler = None + + if type == 'myscheduler': + if optim is not None: + lr_lambda = [lambda epoch: (epoch / 40) if epoch < 40 else 0.5 * (math.cos(40/100) * math.pi) + 1] + scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda) + return scheduler + + +register_scheduler('myscheduler', call_my_scheduler) diff --git a/fgssl/contrib/splitter/__init__.py b/fgssl/contrib/splitter/__init__.py new file mode 100644 index 0000000..c0b3138 --- /dev/null +++ b/fgssl/contrib/splitter/__init__.py @@ -0,0 +1,8 @@ +from os.path import dirname, basename, isfile, join +import glob + +modules = glob.glob(join(dirname(__file__), "*.py")) +__all__ = [ + basename(f)[:-3] for f in modules + if isfile(f) and not f.endswith('__init__.py') +] diff --git a/fgssl/contrib/splitter/example.py b/fgssl/contrib/splitter/example.py new file mode 100644 index 0000000..f82c072 --- /dev/null +++ b/fgssl/contrib/splitter/example.py @@ -0,0 +1,26 @@ +from federatedscope.register import register_splitter +from federatedscope.core.splitters import BaseSplitter + + +class MySplitter(BaseSplitter): + def __init__(self, client_num, **kwargs): + super(MySplitter, self).__init__(client_num, **kwargs) + + def __call__(self, dataset, *args, **kwargs): + # Dummy splitter, only for demonstration + per_samples = len(dataset) // self.client_num + data_list, cur_index = [], 0 + for i in range(self.client_num): + data_list.append( + [x for x in range(cur_index, cur_index + per_samples)]) + cur_index += per_samples + return data_list + + +def call_my_splitter(client_num, **kwargs): + if type == 'mysplitter': + splitter = MySplitter(client_num, **kwargs) + return splitter + + +register_splitter('mysplitter', call_my_splitter) diff --git a/fgssl/contrib/trainer/FLAG.py b/fgssl/contrib/trainer/FLAG.py new file mode 100644 index 0000000..2fffe68 --- /dev/null +++ b/fgssl/contrib/trainer/FLAG.py @@ -0,0 +1,638 @@ +import copy +from torch_geometric.nn import GCNConv +import torch +from copy import deepcopy +from federatedscope.core.auxiliaries.enums import MODE +from federatedscope.core.auxiliaries.enums import LIFECYCLE +from federatedscope.core.auxiliaries.optimizer_builder import get_optimizer +from federatedscope.core.auxiliaries.scheduler_builder import get_scheduler +from federatedscope.gfl.trainer import LinkFullBatchTrainer +from federatedscope.gfl.trainer import GraphMiniBatchTrainer +from federatedscope.gfl.trainer.nodetrainer import NodeFullBatchTrainer +from federatedscope.core.auxiliaries.enums import LIFECYCLE +from federatedscope.core.trainers.context import CtxVar +from torch_geometric.utils import remove_self_loops, add_self_loops, degree + +from federatedscope.gfl.loss.vat import VATLoss +from federatedscope.gfl.loss.suploss import SupConLoss +from federatedscope.core.trainers import GeneralTorchTrainer +from GCL.models import DualBranchContrast,SingleBranchContrast +import GCL.losses as L +import GCL.augmentors as A +from GCL.models.contrast_model import WithinEmbedContrast +import torch.nn.functional as F +from torch_geometric.nn import GINConv, global_add_pool +import torch.nn as nn +from torch_geometric.utils import to_dense_adj + + +MODE2MASK = { + 'train': 'train_edge_mask', + 'val': 'valid_edge_mask', + 'test': 'test_edge_mask' +} + +class FGCLTrainer1(NodeFullBatchTrainer): + def __init__(self, + model, + data, + device, + config, + only_for_eval=False, + monitor=None): + super(FGCLTrainer1, self).__init__(model, data, device, config, + only_for_eval, monitor) + self.global_model = copy.deepcopy(model) + self.state = 0 + # self.aug = GCNConv(config.) + self.cos = torch.nn.CosineSimilarity(dim=-1) + self.contrast_model = DualBranchContrast(loss=L.InfoNCE(tau=0.1), mode='L2L').to(device) + self.withcontrast_model = WithinEmbedContrast(loss=L.BarlowTwins()).to(device) + self.augWeak = A.Compose([A.EdgeRemoving(pe=0.3), A.FeatureMasking(pf=0.3)]) + self.augStrongF = A.Compose([A.EdgeRemoving(pe=0.8), A.FeatureMasking(pf=0.5)]) + self.augNone = A.Identity() + self.ccKD = Correlation() + self.yn = 10 + self.mu = 50 + + def register_default_hooks_train(self): + super(FGCLTrainer1, self).register_default_hooks_train() + self.register_hook_in_train(new_hook=begin, + trigger='on_fit_start', + insert_pos=-1) + self.register_hook_in_train(new_hook=del_initialization_local, + trigger='on_fit_end', + insert_pos=-1) + self.register_hook_in_train(new_hook=record_initialization_global, + trigger='on_fit_start', + insert_pos=-1) + self.register_hook_in_train(new_hook=leave, + trigger='on_fit_end', + insert_pos=-1) + + def register_default_hooks_eval(self): + super(FGCLTrainer1, self).register_default_hooks_eval() + self.register_hook_in_eval(new_hook=begin, + trigger='on_fit_start', + insert_pos=-1) + self.register_hook_in_eval(new_hook=del_initialization_local, + trigger='on_fit_end', + insert_pos=-1) + self.register_hook_in_eval(new_hook=record_initialization_global, + trigger='on_fit_start', + insert_pos=-1) + self.register_hook_in_eval(new_hook=leave, + trigger='on_fit_end', + insert_pos=-1) + + def _hook_on_batch_forward(self, ctx): + + batch = ctx.data_batch.to(ctx.device) + mask = batch['{}_mask'.format(ctx.cur_split)].detach() + + label = batch.y[batch['{}_mask'.format(ctx.cur_split)]] + + self.global_model.to(ctx.device).eval() + + pred, raw_feature_local = ctx.model(batch) + + pred_global, raw_feature_global = self.global_model(batch) + + + pred = pred[mask] + pred_global = pred_global[mask].detach() + + loss1 = ctx.criterion(pred, label) + + batch1 = copy.deepcopy(batch) + batch2 = copy.deepcopy(batch) + + g1, edge_index1, edge_weight1 = self.augWeak(batch.x, batch.edge_index) + g2, edge_index2, edge_weight2 = self.augStrongF(batch.x, batch.edge_index) + + batch1.x = g1 + batch1.edge_index = edge_index1 + + batch2.x = g2 + batch2.edge_index = edge_index2 + + with torch.no_grad(): + pred_aug_global, globalOne = self.global_model(batch1) + + _, now2 = ctx.model(batch1) + + pred_aug_local, now = ctx.model(batch2) + + adj_orig = to_dense_adj(batch.edge_index, max_num_nodes=batch.x.shape[0]).squeeze(0).to(ctx.device) + + struct_kd = com_distillation_loss(pred_aug_global, pred_aug_local, adj_orig, adj_orig, 3) + simi_kd_loss = simi_kd(pred_aug_global, pred_aug_local, batch.edge_index, 4) + + # rkd_Loss = rkd_loss(pred_aug_local , pred_aug_global) + # "tag" + cc_loss = self.ccKD(pred_aug_local, pred_aug_global) + loss_ds = simi_kd_2(batch.edge_index,pred_aug_local,pred_aug_global,self.yn) * self.mu + loss_ff = edge_distribution_high(batch.edge_index,pred_aug_local,pred_aug_global) + globalOne = globalOne[mask] + now = now[mask] + now2 = now2[mask] + extra_pos_mask = torch.eq(label, label.unsqueeze(dim=1)).to(ctx.device) + extra_pos_mask.fill_diagonal_(True) + + extra_neg_mask = torch.ne(label, label.unsqueeze(dim=1)).to(ctx.device) + extra_neg_mask.fill_diagonal_(False) + + loss3 = self.contrast_model(globalOne, now, extra_pos_mask=extra_pos_mask,extra_neg_mask=extra_neg_mask) + loss3 = self.contrast_model(now2,now,extra_pos_mask=extra_pos_mask,extra_neg_mask=extra_neg_mask) + + # raw_feature_local_list = list() + # for clazz in range(int(ctx.cfg.model.out_channels)): + # temp = raw_feature_local[ clazz == label ] / (raw_feature_local[ clazz == label ].norm(dim=-1,keepdim=True) + 1e-6) + # mean_result = temp.sum(dim=0, keepdim=True) + # raw_feature_local_list.append(mean_result) + # + # graph_level_local = torch.concat(raw_feature_local_list,dim=0) + # + # + # + # raw_feature_global_list = list() + # for clazz in range(int(ctx.cfg.model.out_channels)): + # temp = raw_feature_global[clazz == label] / (raw_feature_global[clazz == label].norm(dim=-1,keepdim=True) + 1e-6) + # mean_result = temp.sum(dim=0,keepdim=True) + # raw_feature_global_list.append(mean_result) + # + # graph_level_global = torch.concat(raw_feature_global_list, dim=0) + # + # cos = torch.nn.CosineEmbeddingLoss() + # + # N_tensor = torch.ones(graph_level_global.shape[0]).to(ctx.device) + # cos_loss = cos(graph_level_global, graph_level_local, N_tensor) + ctx.loss_batch = loss1 + loss3 * 3 + + + + ctx.batch_size = torch.sum(mask).item() + ctx.y_true = CtxVar(label, LIFECYCLE.BATCH) + ctx.y_prob = CtxVar(pred, LIFECYCLE.BATCH) + + def _hook_on_fit_start_init(self, ctx): + # prepare model and optimizer + ctx.model.to(ctx.device) + + if ctx.cur_mode in [MODE.TRAIN, MODE.FINETUNE]: + # Initialize optimizer here to avoid the reuse of optimizers + # across different routines + ctx.optimizer = get_optimizer(ctx.model, + **ctx.cfg[ctx.cur_mode].optimizer) + ctx.scheduler = get_scheduler(ctx.optimizer, + **ctx.cfg[ctx.cur_mode].scheduler) + + # prepare statistics + ctx.loss_batch_total = CtxVar(0., LIFECYCLE.ROUTINE) + ctx.loss_regular_total = CtxVar(0., LIFECYCLE.ROUTINE) + ctx.num_samples = CtxVar(0, LIFECYCLE.ROUTINE) + ctx.ys_true = CtxVar([], LIFECYCLE.ROUTINE) + ctx.ys_prob = CtxVar([], LIFECYCLE.ROUTINE) + + +def record_initialization_local(ctx): + """Record weight denomaitor to cpu + + """ + ctx.weight_denomaitor = None + + +def del_initialization_local(ctx): + """Clear the variable to avoid memory leakage + + """ + ctx.weight_denomaitor = None + + +def record_initialization_global(ctx): + """Record the shared global model to cpu + + """ + + pass + + +def begin(ctx): + if 'lastModel' not in ctx.keys(): + ctx.lastModel = copy.deepcopy(ctx.model).to(ctx.device) + + +def leave(ctx): + ctx.lastModel = copy.deepcopy(ctx.model).to(ctx.device) + + +from federatedscope.register import register_trainer + + + + + + +class FGCLTrainer2(LinkFullBatchTrainer): + def __init__(self, + model, + data, + device, + config, + only_for_eval=False, + monitor=None): + super(FGCLTrainer2, self).__init__(model, data, device, config, + only_for_eval, monitor) + self.global_model = copy.deepcopy(model).to(device) + self.state = 0 + # self.aug = GCNConv(config.) + self.cos = torch.nn.CosineSimilarity(dim=-1) + self.contrast_model = DualBranchContrast(loss=L.InfoNCE(tau=0.2), mode='L2L').to(device) + self.withcontrast_model = WithinEmbedContrast(loss=L.BarlowTwins()).to(device) + self.augWeak = A.Compose([A.EdgeRemoving(pe=0.3), A.FeatureMasking(pf=0.3)]) + self.augStrongF = A.Compose([A.EdgeRemoving(pe=0.8), A.FeatureMasking(pf=0.3)]) + + + def _hook_on_batch_forward(self, ctx): + data = ctx.data + perm = ctx.data_batch + batch = ctx.data.to(ctx.device) + mask = ctx.data[MODE2MASK[ctx.cur_split]] + edges = data.edge_index.T[mask] + if data.x.shape[0] < 10000: + print("j") + if ctx.cur_split in ['train', 'val']: + z, h = ctx.model((data.x, ctx.input_edge_index)) + else: + z, h = ctx.model((data.x, data.edge_index)) + pred = ctx.model.link_predictor(h, edges[perm].T) + label = data.edge_type[mask][perm] + loss_ce = ctx.criterion(pred, label) + batch1 = copy.deepcopy(batch).to(ctx.device) + batch2 = copy.deepcopy(batch).to(ctx.device) + + g1, edge_index1, edge_weight1 = self.augWeak(batch.x, batch.edge_index) + g2, edge_index2, edge_weight2 = self.augStrongF(batch.x, batch.edge_index) + + batch1.x = g1 + batch1.edge_index = edge_index1 + + batch2.x = g2 + batch2.edge_index = edge_index2 + + with torch.no_grad(): + pred_aug_global, globalOne = self.global_model(batch1) + + + pred_aug_local, now = ctx.model(batch2) + + adj_orig = to_dense_adj(batch.edge_index, max_num_nodes=batch.x.shape[0]).squeeze(0).to(ctx.device) + + struct_kd = com_distillation_loss(pred_aug_global, pred_aug_local, adj_orig, adj_orig, 3) + simi_kd_loss = simi_kd(pred_aug_global, pred_aug_local, batch.edge_index, 4) + + globalOne = globalOne[mask] + now = now[mask] + + extra_pos_mask = torch.eq(label, label.unsqueeze(dim=1)).to(ctx.device) + extra_pos_mask.fill_diagonal_(True) + + extra_neg_mask = torch.ne(label, label.unsqueeze(dim=1)).to(ctx.device) + extra_neg_mask.fill_diagonal_(False) + + loss3 = self.contrast_model(globalOne, now, extra_pos_mask=extra_pos_mask,extra_neg_mask=extra_neg_mask) + + ctx.loss_batch = loss_ce + loss3 * 3 + + ctx.batch_size = len(label) + ctx.y_true = CtxVar(label, LIFECYCLE.BATCH) + ctx.y_prob = CtxVar(pred, LIFECYCLE.BATCH) + else: + if ctx.cur_split in ['train', 'val']: + z, h = ctx.model((data.x, ctx.input_edge_index)) + else: + z, h = ctx.model((data.x, data.edge_index)) + pred = ctx.model.link_predictor(h, edges[perm].T) + label = data.edge_type[mask][perm] + loss_ce = ctx.criterion(pred, label) + + ctx.loss_batch = loss_ce + ctx.batch_size = len(label) + ctx.y_true = CtxVar(label, LIFECYCLE.BATCH) + ctx.y_prob = CtxVar(pred, LIFECYCLE.BATCH) + + + + + +class FGCLTrainer3(GraphMiniBatchTrainer): + def __init__(self, + model, + data, + device, + config, + only_for_eval=False, + monitor=None): + super(FGCLTrainer3, self).__init__(model, data, device, config, + only_for_eval, monitor) + self.global_model = copy.deepcopy(model).to(device) + self.state = 0 + # self.aug = GCNConv(config.) + self.cos = torch.nn.CosineSimilarity(dim=-1) + self.contrast_model = DualBranchContrast(loss=L.InfoNCE(tau=0.2), mode='L2L').to(device) + self.withcontrast_model = WithinEmbedContrast(loss=L.BarlowTwins()).to(device) + self.augWeak = A.Compose([A.EdgeRemoving(pe=0.3), A.FeatureMasking(pf=0.3)]) + self.augStrongF = A.Compose([A.EdgeRemoving(pe=0.8), A.FeatureMasking(pf=0.3)]) + + + def _hook_on_batch_forward(self, ctx): + data = ctx.data + perm = ctx.data_batch + batch = ctx.data.to(ctx.device) + mask = ctx.data[MODE2MASK[ctx.cur_split]] + edges = data.edge_index.T[mask] + if data.x.shape[0] < 10000: + print("j") + if ctx.cur_split in ['train', 'val']: + z, h = ctx.model((data.x, ctx.input_edge_index)) + else: + z, h = ctx.model((data.x, data.edge_index)) + pred = ctx.model.link_predictor(h, edges[perm].T) + label = data.edge_type[mask][perm] + loss_ce = ctx.criterion(pred, label) + batch1 = copy.deepcopy(batch).to(ctx.device) + batch2 = copy.deepcopy(batch).to(ctx.device) + + g1, edge_index1, edge_weight1 = self.augWeak(batch.x, batch.edge_index) + g2, edge_index2, edge_weight2 = self.augStrongF(batch.x, batch.edge_index) + + batch1.x = g1 + batch1.edge_index = edge_index1 + + batch2.x = g2 + batch2.edge_index = edge_index2 + + with torch.no_grad(): + pred_aug_global, globalOne = self.global_model(batch1) + + + pred_aug_local, now = ctx.model(batch2) + + adj_orig = to_dense_adj(batch.edge_index, max_num_nodes=batch.x.shape[0]).squeeze(0).to(ctx.device) + + struct_kd = com_distillation_loss(pred_aug_global, pred_aug_local, adj_orig, adj_orig, 3) + simi_kd_loss = simi_kd(pred_aug_global, pred_aug_local, batch.edge_index, 4) + + globalOne = globalOne[mask] + now = now[mask] + + extra_pos_mask = torch.eq(label, label.unsqueeze(dim=1)).to(ctx.device) + extra_pos_mask.fill_diagonal_(True) + + extra_neg_mask = torch.ne(label, label.unsqueeze(dim=1)).to(ctx.device) + extra_neg_mask.fill_diagonal_(False) + + loss3 = self.contrast_model(globalOne, now, extra_pos_mask=extra_pos_mask,extra_neg_mask=extra_neg_mask) + + ctx.loss_batch = loss_ce + loss3 * 3 + + ctx.batch_size = len(label) + ctx.y_true = CtxVar(label, LIFECYCLE.BATCH) + ctx.y_prob = CtxVar(pred, LIFECYCLE.BATCH) + else: + if ctx.cur_split in ['train', 'val']: + z, h = ctx.model((data.x, ctx.input_edge_index)) + else: + z, h = ctx.model((data.x, data.edge_index)) + pred = ctx.model.link_predictor(h, edges[perm].T) + label = data.edge_type[mask][perm] + loss_ce = ctx.criterion(pred, label) + + ctx.loss_batch = loss_ce + ctx.batch_size = len(label) + ctx.y_true = CtxVar(label, LIFECYCLE.BATCH) + ctx.y_prob = CtxVar(pred, LIFECYCLE.BATCH) + + +def call_my_trainer(trainer_type): + if trainer_type == 'fgcl1': + trainer_builder = FGCLTrainer1 + return trainer_builder + elif trainer_type == 'fgcl2': + trainer_builder = FGCLTrainer2 + return trainer_builder + elif trainer_type == 'fgcl3': + trainer_builder = FGCLTrainer3 + return trainer_builder + + +register_trainer('fgcl1', call_my_trainer) +register_trainer('fgcl2', call_my_trainer) +register_trainer('fgcl3', call_my_trainer) + + + + + +def GSP(student_feat, teacher_feat): + student_feat = F.normalize( student_feat, p = 2, dim = -1) + teacher_feat = F.normalize(teacher_feat, p = 2, dim = -1) + student_pw_sim = torch.mm(student_feat, student_feat.transpose(0, 1)) + teacher_pw_sim = torch.mm(teacher_feat, teacher_feat.transpose(0, 1)) + + loss_gsp = F.mse_loss(student_pw_sim, teacher_pw_sim) + + return loss_gsp + + + + +def _pdist(e, squared, eps): + e_square = e.pow(2).sum(dim=1) + prod = e @ e.t() + res = (e_square.unsqueeze(1) + e_square.unsqueeze(0) - 2 * prod).clamp(min=eps) + + if not squared: + res = res.sqrt() + + res = res.clone() + res[range(len(e)), range(len(e))] = 0 + return res + + +def rkd_loss(f_s, f_t, squared=False, eps=1e-12, distance_weight=25, angle_weight=50): + stu = f_s.view(f_s.shape[0], -1) + tea = f_t.view(f_t.shape[0], -1) + + # RKD distance loss + with torch.no_grad(): + t_d = _pdist(tea, squared, eps) + mean_td = t_d[t_d > 0].mean() + t_d = t_d / (mean_td + 1e-6) + + d = _pdist(stu, squared, eps) + mean_d = d[d > 0].mean() + d = d / (mean_d+1e-6) + + loss_d = F.smooth_l1_loss(d, t_d) + + # RKD Angle loss + with torch.no_grad(): + td = tea.unsqueeze(0) - tea.unsqueeze(1) + norm_td = F.normalize(td, p=2, dim=2) + t_angle = torch.bmm(norm_td, norm_td.transpose(1, 2)).view(-1) + + sd = stu.unsqueeze(0) - stu.unsqueeze(1) + norm_sd = F.normalize(sd, p=2, dim=2) + s_angle = torch.bmm(norm_sd, norm_sd.transpose(1, 2)).view(-1) + + loss_a = F.smooth_l1_loss(s_angle, t_angle) + + loss = distance_weight * loss_d + angle_weight * loss_a + return loss + +def sce_loss(x, y, alpha=3): + x = F.normalize(x, p=2, dim=-1) + y = F.normalize(y, p=2, dim=-1) + + loss = (1 - (x * y).sum(dim=-1)).pow_(alpha) + + loss = loss.mean() + return loss + + +def kl(pred, pred_global, tau=2): + p_loss = F.kl_div(F.log_softmax(pred / tau, dim=-1), F.softmax(pred_global / tau, dim=-1), reduction='none') * (tau ** 2) / pred.shape[0] + q_loss = F.kl_div(F.log_softmax(pred_global / tau, dim=-1), F.softmax(pred / tau, dim=-1), reduction='none') * (tau ** 2) / pred.shape[0] + p_loss = p_loss.sum() + q_loss = q_loss.sum() + + loss = (p_loss + q_loss) / 2 + return loss + +def kd(pred, pred_global, tau): + p_s = F.log_softmax(pred / tau, dim=1) + p_t = F.softmax(pred_global / tau, dim=1) + loss = F.kl_div(p_s, p_t, reduction='none').sum(1).mean() + loss *= tau ** 2 + return loss + +def sp_loss(g_s, g_t): + return sum([similarity_loss(f_s, f_t) for f_s, f_t in zip(g_s, g_t)]) + + +def similarity_loss(f_s, f_t): + bsz = f_s.shape[0] + f_s = f_s.view(bsz, -1) + f_t = f_t.view(bsz, -1) + + G_s = torch.mm(f_s, torch.t(f_s)) + G_s = torch.nn.functional.normalize(G_s) + G_t = torch.mm(f_t, torch.t(f_t)) + G_t = torch.nn.functional.normalize(G_t) + + G_diff = G_t - G_s + loss = (G_diff * G_diff).view(-1, 1).sum(0) / (bsz * bsz) + return loss + + +def com_distillation_loss(t_logits, s_logits, adj_orig, adj_sampled, temp): + + s_dist = F.log_softmax(s_logits / temp, dim=-1) + t_dist = F.softmax(t_logits / temp, dim=-1) + kd_loss = temp * temp * F.kl_div(s_dist, t_dist.detach()) + + + adj = torch.triu(adj_orig).detach() + edge_list = (adj + adj.T).nonzero().t() + + s_dist_neigh = F.log_softmax(s_logits[edge_list[0]] / temp, dim=-1) + t_dist_neigh = F.softmax(t_logits[edge_list[1]] / temp, dim=-1) + + kd_loss += temp * temp * F.kl_div(s_dist_neigh, t_dist_neigh.detach()) + + return kd_loss + + +def simi_kd_2(edge_index, feats, out,yn): + tau = 0.1 + kl = nn.KLDivLoss() + N = feats.shape[0] + edge_index, _ = remove_self_loops(edge_index) + edge_index, _ = add_self_loops(edge_index, num_nodes=N) + src = edge_index[0] + dst = edge_index[1] + deg = degree(dst,num_nodes=N) + values, max_degree_nodes = torch.topk(deg, yn) + nodes_index = max_degree_nodes + loss = 0 + feats = F.softmax(feats / tau, dim=-1) + out = F.softmax(out / tau, dim=-1) + _1 = torch.cosine_similarity(feats[src], feats[dst], dim=-1) + _2 = torch.cosine_similarity(out[src], out[dst], dim=-1) + for index in nodes_index: + index_n = edge_index[:,torch.nonzero(edge_index[1] == index.item()).squeeze()] + _1 = torch.cosine_similarity(feats[index_n[0]], feats[index_n[1]], dim=-1) + _2 = torch.cosine_similarity(out[index_n[0]], out[index_n[1]], dim=-1) + _1 = F.log_softmax(_1,dim=0) + _2 = F.softmax(_2,dim=0) + loss += kl(_1, _2) + return loss + +def edge_distribution_high(edge_idx, feats, out): + + tau =0.1 + src = edge_idx[0] + dst = edge_idx[1] + criterion_t = torch.nn.KLDivLoss(reduction="batchmean", log_target=True) + + feats_abs = torch.abs(feats[src] - feats[dst]) + e_softmax = F.log_softmax(feats_abs / tau, dim=-1) + + out_1 = torch.abs(out[src] - out[dst]) + e_softmax_2 = F.log_softmax(out_1 / tau, dim=-1) + + loss_s = criterion_t(e_softmax, e_softmax_2) + return loss_s +def simi_kd(global_nodes, local_nodes, edge_index, temp): + adj_orig = to_dense_adj(edge_index).squeeze(0) + adj_orig.fill_diagonal_(True) + s_dist = F.log_softmax(local_nodes / temp, dim=-1) + t_dist = F.softmax(global_nodes / temp, dim=-1) + # kd_loss = temp * temp * F.kl_div(s_dist, t_dist.detach()) + local_simi = torch.cosine_similarity(local_nodes.unsqueeze(1), local_nodes.unsqueeze(0), dim=-1) + global_simi = torch.cosine_similarity(global_nodes.unsqueeze(1), global_nodes.unsqueeze(0), dim=-1) + + local_simi = torch.where(adj_orig > 0, local_simi, torch.zeros_like(local_simi)) + global_simi = torch.where(adj_orig > 0, global_simi, torch.zeros_like(global_simi)) + + s_dist_neigh = F.log_softmax(local_simi / temp, dim=-1) + t_dist_neigh = F.softmax(global_simi / temp, dim=-1) + + kd_loss = temp * temp * F.kl_div(s_dist_neigh, t_dist_neigh.detach()) + + return kd_loss + + +class Correlation(nn.Module): + """Similarity-preserving loss. My origianl own reimplementation + based on the paper before emailing the original authors.""" + def __init__(self): + super(Correlation, self).__init__() + + def forward(self, f_s, f_t): + return self.similarity_loss(f_s, f_t) + + def similarity_loss(self, f_s, f_t): + bsz = f_s.shape[0] + f_s = f_s.view(bsz, -1) + f_t = f_t.view(bsz, -1) + + G_s = torch.mm(f_s, torch.t(f_s)) + G_s = G_s / G_s.norm(2) + G_t = torch.mm(f_t, torch.t(f_t)) + G_t = G_t / G_t.norm(2) + + G_diff = G_t - G_s + loss = (G_diff * G_diff).view(-1, 1).sum(0) / (bsz * bsz) + return loss diff --git a/fgssl/contrib/trainer/GCL/__init__.py b/fgssl/contrib/trainer/GCL/__init__.py new file mode 100644 index 0000000..8cdcc45 --- /dev/null +++ b/fgssl/contrib/trainer/GCL/__init__.py @@ -0,0 +1,16 @@ +import GCL.losses +import GCL.augmentors +import GCL.eval +import GCL.models +import GCL.utils + +__version__ = '0.1.0' + +__all__ = [ + '__version__', + 'losses', + 'augmentors', + 'eval', + 'models', + 'utils' +] diff --git a/fgssl/contrib/trainer/GCL/augmentors/__init__.py b/fgssl/contrib/trainer/GCL/augmentors/__init__.py new file mode 100644 index 0000000..0d0b7be --- /dev/null +++ b/fgssl/contrib/trainer/GCL/augmentors/__init__.py @@ -0,0 +1,32 @@ +from .augmentor import Graph, Augmentor, Compose, RandomChoice +from .identity import Identity +from .rw_sampling import RWSampling +from .ppr_diffusion import PPRDiffusion +from .markov_diffusion import MarkovDiffusion +from .edge_adding import EdgeAdding +from .edge_removing import EdgeRemoving +from .node_dropping import NodeDropping +from .node_shuffling import NodeShuffling +from .feature_masking import FeatureMasking +from .feature_dropout import FeatureDropout +from .edge_attr_masking import EdgeAttrMasking + +__all__ = [ + 'Graph', + 'Augmentor', + 'Compose', + 'RandomChoice', + 'EdgeAdding', + 'EdgeRemoving', + 'EdgeAttrMasking', + 'FeatureMasking', + 'FeatureDropout', + 'Identity', + 'PPRDiffusion', + 'MarkovDiffusion', + 'NodeDropping', + 'NodeShuffling', + 'RWSampling' +] + +classes = __all__ diff --git a/fgssl/contrib/trainer/GCL/augmentors/augmentor.py b/fgssl/contrib/trainer/GCL/augmentors/augmentor.py new file mode 100644 index 0000000..1626769 --- /dev/null +++ b/fgssl/contrib/trainer/GCL/augmentors/augmentor.py @@ -0,0 +1,58 @@ +from __future__ import annotations + +import torch +from abc import ABC, abstractmethod +from typing import Optional, Tuple, NamedTuple, List + + +class Graph(NamedTuple): + x: torch.FloatTensor + edge_index: torch.LongTensor + edge_weights: Optional[torch.FloatTensor] + + def unfold(self) -> Tuple[torch.FloatTensor, torch.LongTensor, Optional[torch.FloatTensor]]: + return self.x, self.edge_index, self.edge_weights + + +class Augmentor(ABC): + """Base class for graph augmentors.""" + def __init__(self): + pass + + @abstractmethod + def augment(self, g: Graph) -> Graph: + raise NotImplementedError(f"GraphAug.augment should be implemented.") + + def __call__( + self, x: torch.FloatTensor, + edge_index: torch.LongTensor, edge_weight: Optional[torch.FloatTensor] = None + ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + return self.augment(Graph(x, edge_index, edge_weight)).unfold() + + +class Compose(Augmentor): + def __init__(self, augmentors: List[Augmentor]): + super(Compose, self).__init__() + self.augmentors = augmentors + + def augment(self, g: Graph) -> Graph: + for aug in self.augmentors: + g = aug.augment(g) + return g + + +class RandomChoice(Augmentor): + def __init__(self, augmentors: List[Augmentor], num_choices: int): + super(RandomChoice, self).__init__() + assert num_choices <= len(augmentors) + self.augmentors = augmentors + self.num_choices = num_choices + + def augment(self, g: Graph) -> Graph: + num_augmentors = len(self.augmentors) + perm = torch.randperm(num_augmentors) + idx = perm[:self.num_choices] + for i in idx: + aug = self.augmentors[i] + g = aug.augment(g) + return g diff --git a/fgssl/contrib/trainer/GCL/augmentors/edge_adding.py b/fgssl/contrib/trainer/GCL/augmentors/edge_adding.py new file mode 100644 index 0000000..4a5f895 --- /dev/null +++ b/fgssl/contrib/trainer/GCL/augmentors/edge_adding.py @@ -0,0 +1,13 @@ +from GCL.augmentors.augmentor import Graph, Augmentor +from GCL.augmentors.functional import add_edge + + +class EdgeAdding(Augmentor): + def __init__(self, pe: float): + super(EdgeAdding, self).__init__() + self.pe = pe + + def augment(self, g: Graph) -> Graph: + x, edge_index, edge_weights = g.unfold() + edge_index = add_edge(edge_index, ratio=self.pe) + return Graph(x=x, edge_index=edge_index, edge_weights=edge_weights) diff --git a/fgssl/contrib/trainer/GCL/augmentors/edge_attr_masking.py b/fgssl/contrib/trainer/GCL/augmentors/edge_attr_masking.py new file mode 100644 index 0000000..7344c0e --- /dev/null +++ b/fgssl/contrib/trainer/GCL/augmentors/edge_attr_masking.py @@ -0,0 +1,14 @@ +from GCL.augmentors.augmentor import Graph, Augmentor +from GCL.augmentors.functional import drop_feature + + +class EdgeAttrMasking(Augmentor): + def __init__(self, pf: float): + super(EdgeAttrMasking, self).__init__() + self.pf = pf + + def augment(self, g: Graph) -> Graph: + x, edge_index, edge_weights = g.unfold() + if edge_weights is not None: + edge_weights = drop_feature(edge_weights, self.pf) + return Graph(x=x, edge_index=edge_index, edge_weights=edge_weights) diff --git a/fgssl/contrib/trainer/GCL/augmentors/edge_removing.py b/fgssl/contrib/trainer/GCL/augmentors/edge_removing.py new file mode 100644 index 0000000..adfaeaf --- /dev/null +++ b/fgssl/contrib/trainer/GCL/augmentors/edge_removing.py @@ -0,0 +1,13 @@ +from GCL.augmentors.augmentor import Graph, Augmentor +from GCL.augmentors.functional import dropout_adj + + +class EdgeRemoving(Augmentor): + def __init__(self, pe: float): + super(EdgeRemoving, self).__init__() + self.pe = pe + + def augment(self, g: Graph) -> Graph: + x, edge_index, edge_weights = g.unfold() + edge_index, edge_weights = dropout_adj(edge_index, edge_attr=edge_weights, p=self.pe) + return Graph(x=x, edge_index=edge_index, edge_weights=edge_weights) diff --git a/fgssl/contrib/trainer/GCL/augmentors/feature_dropout.py b/fgssl/contrib/trainer/GCL/augmentors/feature_dropout.py new file mode 100644 index 0000000..0395435 --- /dev/null +++ b/fgssl/contrib/trainer/GCL/augmentors/feature_dropout.py @@ -0,0 +1,13 @@ +from GCL.augmentors.augmentor import Graph, Augmentor +from GCL.augmentors.functional import dropout_feature + + +class FeatureDropout(Augmentor): + def __init__(self, pf: float): + super(FeatureDropout, self).__init__() + self.pf = pf + + def augment(self, g: Graph) -> Graph: + x, edge_index, edge_weights = g.unfold() + x = dropout_feature(x, self.pf) + return Graph(x=x, edge_index=edge_index, edge_weights=edge_weights) diff --git a/fgssl/contrib/trainer/GCL/augmentors/feature_masking.py b/fgssl/contrib/trainer/GCL/augmentors/feature_masking.py new file mode 100644 index 0000000..9d0acc6 --- /dev/null +++ b/fgssl/contrib/trainer/GCL/augmentors/feature_masking.py @@ -0,0 +1,13 @@ +from GCL.augmentors.augmentor import Graph, Augmentor +from GCL.augmentors.functional import drop_feature + + +class FeatureMasking(Augmentor): + def __init__(self, pf: float): + super(FeatureMasking, self).__init__() + self.pf = pf + + def augment(self, g: Graph) -> Graph: + x, edge_index, edge_weights = g.unfold() + x = drop_feature(x, self.pf) + return Graph(x=x, edge_index=edge_index, edge_weights=edge_weights) diff --git a/fgssl/contrib/trainer/GCL/augmentors/functional.py b/fgssl/contrib/trainer/GCL/augmentors/functional.py new file mode 100644 index 0000000..1a0fa3d --- /dev/null +++ b/fgssl/contrib/trainer/GCL/augmentors/functional.py @@ -0,0 +1,332 @@ +import torch +import networkx as nx +import torch.nn.functional as F + +from typing import Optional +from GCL.utils import normalize +from torch_sparse import SparseTensor, coalesce +from torch_scatter import scatter +from torch_geometric.transforms import GDC +from torch.distributions import Uniform, Beta +from torch_geometric.utils import dropout_adj, to_networkx, to_undirected, degree, to_scipy_sparse_matrix, \ + from_scipy_sparse_matrix, sort_edge_index, add_self_loops, subgraph +from torch.distributions.bernoulli import Bernoulli + + +def permute(x: torch.Tensor) -> torch.Tensor: + """ + Randomly permute node embeddings or features. + + Args: + x: The latent embedding or node feature. + + Returns: + torch.Tensor: Embeddings or features resulting from permutation. + """ + return x[torch.randperm(x.size(0))] + + +def get_mixup_idx(x: torch.Tensor) -> torch.Tensor: + """ + Generate node IDs randomly for mixup; avoid mixup the same node. + + Args: + x: The latent embedding or node feature. + + Returns: + torch.Tensor: Random node IDs. + """ + mixup_idx = torch.randint(x.size(0) - 1, [x.size(0)]) + mixup_self_mask = mixup_idx - torch.arange(x.size(0)) + mixup_self_mask = (mixup_self_mask == 0) + mixup_idx += torch.ones(x.size(0), dtype=torch.int) * mixup_self_mask + return mixup_idx + + +def mixup(x: torch.Tensor, alpha: float) -> torch.Tensor: + """ + Randomly mixup node embeddings or features with other nodes'. + + Args: + x: The latent embedding or node feature. + alpha: The hyperparameter controlling the mixup coefficient. + + Returns: + torch.Tensor: Embeddings or features resulting from mixup. + """ + device = x.device + mixup_idx = get_mixup_idx(x).to(device) + lambda_ = Uniform(alpha, 1.).sample([1]).to(device) + x = (1 - lambda_) * x + lambda_ * x[mixup_idx] + return x + + +def multiinstance_mixup(x1: torch.Tensor, x2: torch.Tensor, + alpha: float, shuffle=False) -> (torch.Tensor, torch.Tensor): + """ + Randomly mixup node embeddings or features with nodes from other views. + + Args: + x1: The latent embedding or node feature from one view. + x2: The latent embedding or node feature from the other view. + alpha: The mixup coefficient `\lambda` follows `Beta(\alpha, \alpha)`. + shuffle: Whether to use fixed negative samples. + + Returns: + (torch.Tensor, torch.Tensor): Spurious positive samples and the mixup coefficient. + """ + device = x1.device + lambda_ = Beta(alpha, alpha).sample([1]).to(device) + if shuffle: + mixup_idx = get_mixup_idx(x1).to(device) + else: + mixup_idx = x1.size(0) - torch.arange(x1.size(0)) - 1 + x_spurious = (1 - lambda_) * x1 + lambda_ * x2[mixup_idx] + + return x_spurious, lambda_ + + +def drop_feature(x: torch.Tensor, drop_prob: float) -> torch.Tensor: + device = x.device + drop_mask = torch.empty((x.size(1),), dtype=torch.float32).uniform_(0, 1) < drop_prob + drop_mask = drop_mask.to(device) + x = x.clone() + x[:, drop_mask] = 0 + + return x + + +def dropout_feature(x: torch.FloatTensor, drop_prob: float) -> torch.FloatTensor: + return F.dropout(x, p=1. - drop_prob) + + +class AugmentTopologyAttributes(object): + def __init__(self, pe=0.5, pf=0.5): + self.pe = pe + self.pf = pf + + def __call__(self, x, edge_index): + edge_index = dropout_adj(edge_index, p=self.pe)[0] + x = drop_feature(x, self.pf) + return x, edge_index + + +def get_feature_weights(x, centrality, sparse=True): + if sparse: + x = x.to(torch.bool).to(torch.float32) + else: + x = x.abs() + w = x.t() @ centrality + w = w.log() + + return normalize(w) + + +def drop_feature_by_weight(x, weights, drop_prob: float, threshold: float = 0.7): + weights = weights / weights.mean() * drop_prob + weights = weights.where(weights < threshold, torch.ones_like(weights) * threshold) # clip + drop_mask = torch.bernoulli(weights).to(torch.bool) + x = x.clone() + x[:, drop_mask] = 0. + return x + + +def get_eigenvector_weights(data): + def _eigenvector_centrality(data): + graph = to_networkx(data) + x = nx.eigenvector_centrality_numpy(graph) + x = [x[i] for i in range(data.num_nodes)] + return torch.tensor(x, dtype=torch.float32).to(data.edge_index.device) + + evc = _eigenvector_centrality(data) + scaled_evc = evc.where(evc > 0, torch.zeros_like(evc)) + scaled_evc = scaled_evc + 1e-8 + s = scaled_evc.log() + + edge_index = data.edge_index + s_row, s_col = s[edge_index[0]], s[edge_index[1]] + + return normalize(s_col), evc + + +def get_degree_weights(data): + edge_index_ = to_undirected(data.edge_index) + deg = degree(edge_index_[1]) + deg_col = deg[data.edge_index[1]].to(torch.float32) + scaled_deg_col = torch.log(deg_col) + + return normalize(scaled_deg_col), deg + + +def get_pagerank_weights(data, aggr: str = 'sink', k: int = 10): + def _compute_pagerank(edge_index, damp: float = 0.85, k: int = 10): + num_nodes = edge_index.max().item() + 1 + deg_out = degree(edge_index[0]) + x = torch.ones((num_nodes,)).to(edge_index.device).to(torch.float32) + + for i in range(k): + edge_msg = x[edge_index[0]] / deg_out[edge_index[0]] + agg_msg = scatter(edge_msg, edge_index[1], reduce='sum') + + x = (1 - damp) * x + damp * agg_msg + + return x + + pv = _compute_pagerank(data.edge_index, k=k) + pv_row = pv[data.edge_index[0]].to(torch.float32) + pv_col = pv[data.edge_index[1]].to(torch.float32) + s_row = torch.log(pv_row) + s_col = torch.log(pv_col) + if aggr == 'sink': + s = s_col + elif aggr == 'source': + s = s_row + elif aggr == 'mean': + s = (s_col + s_row) * 0.5 + else: + s = s_col + + return normalize(s), pv + + +def drop_edge_by_weight(edge_index, weights, drop_prob: float, threshold: float = 0.7): + weights = weights / weights.mean() * drop_prob + weights = weights.where(weights < threshold, torch.ones_like(weights) * threshold) + drop_mask = torch.bernoulli(1. - weights).to(torch.bool) + + return edge_index[:, drop_mask] + + +class AdaptivelyAugmentTopologyAttributes(object): + def __init__(self, edge_weights, feature_weights, pe=0.5, pf=0.5, threshold=0.7): + self.edge_weights = edge_weights + self.feature_weights = feature_weights + self.pe = pe + self.pf = pf + self.threshold = threshold + + def __call__(self, x, edge_index): + edge_index = drop_edge_by_weight(edge_index, self.edge_weights, self.pe, self.threshold) + x = drop_feature_by_weight(x, self.feature_weights, self.pf, self.threshold) + + return x, edge_index + + +def get_subgraph(x, edge_index, idx): + adj = to_scipy_sparse_matrix(edge_index).tocsr() + x_sampled = x[idx] + edge_index_sampled = from_scipy_sparse_matrix(adj[idx, :][:, idx]) + return x_sampled, edge_index_sampled + + +def sample_nodes(x, edge_index, sample_size): + idx = torch.randperm(x.size(0))[:sample_size] + return get_subgraph(x, edge_index, idx), idx + + +def compute_ppr(edge_index, edge_weight=None, alpha=0.2, eps=0.1, ignore_edge_attr=True, add_self_loop=True): + N = edge_index.max().item() + 1 + if ignore_edge_attr or edge_weight is None: + edge_weight = torch.ones( + edge_index.size(1), device=edge_index.device) + if add_self_loop: + edge_index, edge_weight = add_self_loops( + edge_index, edge_weight, fill_value=1, num_nodes=N) + edge_index, edge_weight = coalesce(edge_index, edge_weight, N, N) + edge_index, edge_weight = coalesce(edge_index, edge_weight, N, N) + edge_index, edge_weight = GDC().transition_matrix( + edge_index, edge_weight, N, normalization='sym') + diff_mat = GDC().diffusion_matrix_exact( + edge_index, edge_weight, N, method='ppr', alpha=alpha) + edge_index, edge_weight = GDC().sparsify_dense(diff_mat, method='threshold', eps=eps) + edge_index, edge_weight = coalesce(edge_index, edge_weight, N, N) + edge_index, edge_weight = GDC().transition_matrix( + edge_index, edge_weight, N, normalization='sym') + + return edge_index, edge_weight + + +def get_sparse_adj(edge_index: torch.LongTensor, edge_weight: torch.FloatTensor = None, + add_self_loop: bool = True) -> torch.sparse.Tensor: + num_nodes = edge_index.max().item() + 1 + num_edges = edge_index.size(1) + + if edge_weight is None: + edge_weight = torch.ones((num_edges,), dtype=torch.float32, device=edge_index.device) + + if add_self_loop: + edge_index, edge_weight = add_self_loops( + edge_index, edge_weight, fill_value=1, num_nodes=num_nodes) + edge_index, edge_weight = coalesce(edge_index, edge_weight, num_nodes, num_nodes) + + edge_index, edge_weight = GDC().transition_matrix( + edge_index, edge_weight, num_nodes, normalization='sym') + + adj_t = torch.sparse_coo_tensor(edge_index, edge_weight, size=(num_nodes, num_nodes)).coalesce() + + return adj_t.t() + + +def compute_markov_diffusion( + edge_index: torch.LongTensor, edge_weight: torch.FloatTensor = None, + alpha: float = 0.1, degree: int = 10, + sp_eps: float = 1e-3, add_self_loop: bool = True): + adj = get_sparse_adj(edge_index, edge_weight, add_self_loop) + + z = adj.to_dense() + t = adj.to_dense() + for _ in range(degree): + t = (1.0 - alpha) * torch.spmm(adj, t) + z += t + z /= degree + z = z + alpha * adj + + adj_t = z.t() + + return GDC().sparsify_dense(adj_t, method='threshold', eps=sp_eps) + + +def coalesce_edge_index(edge_index: torch.Tensor, edge_weights: Optional[torch.Tensor] = None) -> (torch.Tensor, torch.FloatTensor): + num_edges = edge_index.size()[1] + num_nodes = edge_index.max().item() + 1 + edge_weights = edge_weights if edge_weights is not None else torch.ones((num_edges,), dtype=torch.float32, device=edge_index.device) + + return coalesce(edge_index, edge_weights, m=num_nodes, n=num_nodes) + + +def add_edge(edge_index: torch.Tensor, ratio: float) -> torch.Tensor: + num_edges = edge_index.size()[1] + num_nodes = edge_index.max().item() + 1 + num_add = int(num_edges * ratio) + + new_edge_index = torch.randint(0, num_nodes - 1, size=(2, num_add)).to(edge_index.device) + edge_index = torch.cat([edge_index, new_edge_index], dim=1) + + edge_index = sort_edge_index(edge_index)[0] + + return coalesce_edge_index(edge_index)[0] + + +def drop_node(edge_index: torch.Tensor, edge_weight: Optional[torch.Tensor] = None, keep_prob: float = 0.5) -> (torch.Tensor, Optional[torch.Tensor]): + num_nodes = edge_index.max().item() + 1 + probs = torch.tensor([keep_prob for _ in range(num_nodes)]) + dist = Bernoulli(probs) + + subset = dist.sample().to(torch.bool).to(edge_index.device) + edge_index, edge_weight = subgraph(subset, edge_index, edge_weight) + + return edge_index, edge_weight + + +def random_walk_subgraph(edge_index: torch.LongTensor, edge_weight: Optional[torch.FloatTensor] = None, batch_size: int = 1000, length: int = 10): + num_nodes = edge_index.max().item() + 1 + + row, col = edge_index + adj = SparseTensor(row=row, col=col, sparse_sizes=(num_nodes, num_nodes)) + + start = torch.randint(0, num_nodes, size=(batch_size, ), dtype=torch.long).to(edge_index.device) + node_idx = adj.random_walk(start.flatten(), length).view(-1) + + edge_index, edge_weight = subgraph(node_idx, edge_index, edge_weight) + + return edge_index, edge_weight diff --git a/fgssl/contrib/trainer/GCL/augmentors/identity.py b/fgssl/contrib/trainer/GCL/augmentors/identity.py new file mode 100644 index 0000000..1717195 --- /dev/null +++ b/fgssl/contrib/trainer/GCL/augmentors/identity.py @@ -0,0 +1,9 @@ +from GCL.augmentors.augmentor import Graph, Augmentor + + +class Identity(Augmentor): + def __init__(self): + super(Identity, self).__init__() + + def augment(self, g: Graph) -> Graph: + return g diff --git a/fgssl/contrib/trainer/GCL/augmentors/markov_diffusion.py b/fgssl/contrib/trainer/GCL/augmentors/markov_diffusion.py new file mode 100644 index 0000000..6bd16d6 --- /dev/null +++ b/fgssl/contrib/trainer/GCL/augmentors/markov_diffusion.py @@ -0,0 +1,27 @@ +from GCL.augmentors.augmentor import Graph, Augmentor +from GCL.augmentors.functional import compute_markov_diffusion + + +class MarkovDiffusion(Augmentor): + def __init__(self, alpha: float = 0.05, order: int = 16, sp_eps: float = 1e-4, use_cache: bool = True, + add_self_loop: bool = True): + super(MarkovDiffusion, self).__init__() + self.alpha = alpha + self.order = order + self.sp_eps = sp_eps + self._cache = None + self.use_cache = use_cache + self.add_self_loop = add_self_loop + + def augment(self, g: Graph) -> Graph: + if self._cache is not None and self.use_cache: + return self._cache + x, edge_index, edge_weights = g.unfold() + edge_index, edge_weights = compute_markov_diffusion( + edge_index, edge_weights, + alpha=self.alpha, degree=self.order, + sp_eps=self.sp_eps, add_self_loop=self.add_self_loop + ) + res = Graph(x=x, edge_index=edge_index, edge_weights=edge_weights) + self._cache = res + return res diff --git a/fgssl/contrib/trainer/GCL/augmentors/node_dropping.py b/fgssl/contrib/trainer/GCL/augmentors/node_dropping.py new file mode 100644 index 0000000..d9e0dce --- /dev/null +++ b/fgssl/contrib/trainer/GCL/augmentors/node_dropping.py @@ -0,0 +1,15 @@ +from GCL.augmentors.augmentor import Graph, Augmentor +from GCL.augmentors.functional import drop_node + + +class NodeDropping(Augmentor): + def __init__(self, pn: float): + super(NodeDropping, self).__init__() + self.pn = pn + + def augment(self, g: Graph) -> Graph: + x, edge_index, edge_weights = g.unfold() + + edge_index, edge_weights = drop_node(edge_index, edge_weights, keep_prob=1. - self.pn) + + return Graph(x=x, edge_index=edge_index, edge_weights=edge_weights) diff --git a/fgssl/contrib/trainer/GCL/augmentors/node_shuffling.py b/fgssl/contrib/trainer/GCL/augmentors/node_shuffling.py new file mode 100644 index 0000000..ac35551 --- /dev/null +++ b/fgssl/contrib/trainer/GCL/augmentors/node_shuffling.py @@ -0,0 +1,12 @@ +from GCL.augmentors.augmentor import Graph, Augmentor +from GCL.augmentors.functional import permute + + +class NodeShuffling(Augmentor): + def __init__(self): + super(NodeShuffling, self).__init__() + + def augment(self, g: Graph) -> Graph: + x, edge_index, edge_weights = g.unfold() + x = permute(x) + return Graph(x=x, edge_index=edge_index, edge_weights=edge_weights) diff --git a/fgssl/contrib/trainer/GCL/augmentors/ppr_diffusion.py b/fgssl/contrib/trainer/GCL/augmentors/ppr_diffusion.py new file mode 100644 index 0000000..d33194d --- /dev/null +++ b/fgssl/contrib/trainer/GCL/augmentors/ppr_diffusion.py @@ -0,0 +1,24 @@ +from GCL.augmentors.augmentor import Graph, Augmentor +from GCL.augmentors.functional import compute_ppr + + +class PPRDiffusion(Augmentor): + def __init__(self, alpha: float = 0.2, eps: float = 1e-4, use_cache: bool = True, add_self_loop: bool = True): + super(PPRDiffusion, self).__init__() + self.alpha = alpha + self.eps = eps + self._cache = None + self.use_cache = use_cache + self.add_self_loop = add_self_loop + + def augment(self, g: Graph) -> Graph: + if self._cache is not None and self.use_cache: + return self._cache + x, edge_index, edge_weights = g.unfold() + edge_index, edge_weights = compute_ppr( + edge_index, edge_weights, + alpha=self.alpha, eps=self.eps, ignore_edge_attr=False, add_self_loop=self.add_self_loop + ) + res = Graph(x=x, edge_index=edge_index, edge_weights=edge_weights) + self._cache = res + return res diff --git a/fgssl/contrib/trainer/GCL/augmentors/rw_sampling.py b/fgssl/contrib/trainer/GCL/augmentors/rw_sampling.py new file mode 100644 index 0000000..f5176dc --- /dev/null +++ b/fgssl/contrib/trainer/GCL/augmentors/rw_sampling.py @@ -0,0 +1,16 @@ +from GCL.augmentors.augmentor import Graph, Augmentor +from GCL.augmentors.functional import random_walk_subgraph + + +class RWSampling(Augmentor): + def __init__(self, num_seeds: int, walk_length: int): + super(RWSampling, self).__init__() + self.num_seeds = num_seeds + self.walk_length = walk_length + + def augment(self, g: Graph) -> Graph: + x, edge_index, edge_weights = g.unfold() + + edge_index, edge_weights = random_walk_subgraph(edge_index, edge_weights, batch_size=self.num_seeds, length=self.walk_length) + + return Graph(x=x, edge_index=edge_index, edge_weights=edge_weights) diff --git a/fgssl/contrib/trainer/GCL/eval/__init__.py b/fgssl/contrib/trainer/GCL/eval/__init__.py new file mode 100644 index 0000000..610af7f --- /dev/null +++ b/fgssl/contrib/trainer/GCL/eval/__init__.py @@ -0,0 +1,16 @@ +from .eval import BaseEvaluator, BaseSKLearnEvaluator, get_split, from_predefined_split +from .logistic_regression import LREvaluator +from .svm import SVMEvaluator +from .random_forest import RFEvaluator + +__all__ = [ + 'BaseEvaluator', + 'BaseSKLearnEvaluator', + 'LREvaluator', + 'SVMEvaluator', + 'RFEvaluator', + 'get_split', + 'from_predefined_split' +] + +classes = __all__ diff --git a/fgssl/contrib/trainer/GCL/eval/eval.py b/fgssl/contrib/trainer/GCL/eval/eval.py new file mode 100644 index 0000000..77d4a2c --- /dev/null +++ b/fgssl/contrib/trainer/GCL/eval/eval.py @@ -0,0 +1,77 @@ +import torch +import numpy as np + +from abc import ABC, abstractmethod +from sklearn.metrics import f1_score +from sklearn.model_selection import PredefinedSplit, GridSearchCV + + +def get_split(num_samples: int, train_ratio: float = 0.1, test_ratio: float = 0.8): + assert train_ratio + test_ratio < 1 + train_size = int(num_samples * train_ratio) + test_size = int(num_samples * test_ratio) + indices = torch.randperm(num_samples) + return { + 'train': indices[:train_size], + 'valid': indices[train_size: test_size + train_size], + 'test': indices[test_size + train_size:] + } + + +def from_predefined_split(data): + assert all([mask is not None for mask in [data.train_mask, data.test_mask, data.val_mask]]) + num_samples = data.num_nodes + indices = torch.arange(num_samples) + return { + 'train': indices[data.train_mask], + 'valid': indices[data.val_mask], + 'test': indices[data.test_mask] + } + + +def split_to_numpy(x, y, split): + keys = ['train', 'test', 'valid'] + objs = [x, y] + return [obj[split[key]].detach().cpu().numpy() for obj in objs for key in keys] + + +def get_predefined_split(x_train, x_val, y_train, y_val, return_array=True): + test_fold = np.concatenate([-np.ones_like(y_train), np.zeros_like(y_val)]) + ps = PredefinedSplit(test_fold) + if return_array: + x = np.concatenate([x_train, x_val], axis=0) + y = np.concatenate([y_train, y_val], axis=0) + return ps, [x, y] + return ps + + +class BaseEvaluator(ABC): + @abstractmethod + def evaluate(self, x: torch.FloatTensor, y: torch.LongTensor, split: dict) -> dict: + pass + + def __call__(self, x: torch.FloatTensor, y: torch.LongTensor, split: dict) -> dict: + for key in ['train', 'test', 'valid']: + assert key in split + + result = self.evaluate(x, y, split) + return result + + +class BaseSKLearnEvaluator(BaseEvaluator): + def __init__(self, evaluator, params): + self.evaluator = evaluator + self.params = params + + def evaluate(self, x, y, split): + x_train, x_test, x_val, y_train, y_test, y_val = split_to_numpy(x, y, split) + ps, [x_train, y_train] = get_predefined_split(x_train, x_val, y_train, y_val) + classifier = GridSearchCV(self.evaluator, self.params, cv=ps, scoring='accuracy', verbose=0) + classifier.fit(x_train, y_train) + test_macro = f1_score(y_test, classifier.predict(x_test), average='macro') + test_micro = f1_score(y_test, classifier.predict(x_test), average='micro') + + return { + 'micro_f1': test_micro, + 'macro_f1': test_macro, + } diff --git a/fgssl/contrib/trainer/GCL/eval/logistic_regression.py b/fgssl/contrib/trainer/GCL/eval/logistic_regression.py new file mode 100644 index 0000000..9273045 --- /dev/null +++ b/fgssl/contrib/trainer/GCL/eval/logistic_regression.py @@ -0,0 +1,80 @@ +import torch +from tqdm import tqdm +from torch import nn +from torch.optim import Adam +from sklearn.metrics import f1_score + +from GCL.eval import BaseEvaluator + + +class LogisticRegression(nn.Module): + def __init__(self, num_features, num_classes): + super(LogisticRegression, self).__init__() + self.fc = nn.Linear(num_features, num_classes) + torch.nn.init.xavier_uniform_(self.fc.weight.data) + + def forward(self, x): + z = self.fc(x) + return z + + +class LREvaluator(BaseEvaluator): + def __init__(self, num_epochs: int = 5000, learning_rate: float = 0.01, + weight_decay: float = 0.0, test_interval: int = 20): + self.num_epochs = num_epochs + self.learning_rate = learning_rate + self.weight_decay = weight_decay + self.test_interval = test_interval + + def evaluate(self, x: torch.FloatTensor, y: torch.LongTensor, split: dict): + device = x.device + x = x.detach().to(device) + input_dim = x.size()[1] + y = y.to(device) + num_classes = y.max().item() + 1 + classifier = LogisticRegression(input_dim, num_classes).to(device) + optimizer = Adam(classifier.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay) + output_fn = nn.LogSoftmax(dim=-1) + criterion = nn.NLLLoss() + + best_val_micro = 0 + best_test_micro = 0 + best_test_macro = 0 + best_epoch = 0 + + with tqdm(total=self.num_epochs, desc='(LR)', + bar_format='{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}{postfix}]') as pbar: + for epoch in range(self.num_epochs): + classifier.train() + optimizer.zero_grad() + + output = classifier(x[split['train']]) + loss = criterion(output_fn(output), y[split['train']]) + + loss.backward() + optimizer.step() + + if (epoch + 1) % self.test_interval == 0: + classifier.eval() + y_test = y[split['test']].detach().cpu().numpy() + y_pred = classifier(x[split['test']]).argmax(-1).detach().cpu().numpy() + test_micro = f1_score(y_test, y_pred, average='micro') + test_macro = f1_score(y_test, y_pred, average='macro') + + y_val = y[split['valid']].detach().cpu().numpy() + y_pred = classifier(x[split['valid']]).argmax(-1).detach().cpu().numpy() + val_micro = f1_score(y_val, y_pred, average='micro') + + if val_micro > best_val_micro: + best_val_micro = val_micro + best_test_micro = test_micro + best_test_macro = test_macro + best_epoch = epoch + + pbar.set_postfix({'best test F1Mi': best_test_micro, 'F1Ma': best_test_macro}) + pbar.update(self.test_interval) + + return { + 'micro_f1': best_test_micro, + 'macro_f1': best_test_macro + } diff --git a/fgssl/contrib/trainer/GCL/eval/random_forest.py b/fgssl/contrib/trainer/GCL/eval/random_forest.py new file mode 100644 index 0000000..00d02fc --- /dev/null +++ b/fgssl/contrib/trainer/GCL/eval/random_forest.py @@ -0,0 +1,9 @@ +from sklearn.ensemble import RandomForestClassifier +from GCL.eval import BaseSKLearnEvaluator + + +class RFEvaluator(BaseSKLearnEvaluator): + def __init__(self, params=None): + if params is None: + params = {'n_estimators': [100, 200, 500, 1000]} + super(RFEvaluator, self).__init__(RandomForestClassifier(), params) diff --git a/fgssl/contrib/trainer/GCL/eval/svm.py b/fgssl/contrib/trainer/GCL/eval/svm.py new file mode 100644 index 0000000..2d38ed8 --- /dev/null +++ b/fgssl/contrib/trainer/GCL/eval/svm.py @@ -0,0 +1,13 @@ +from sklearn.svm import LinearSVC, SVC +from GCL.eval import BaseSKLearnEvaluator + + +class SVMEvaluator(BaseSKLearnEvaluator): + def __init__(self, linear=True, params=None): + if linear: + self.evaluator = LinearSVC() + else: + self.evaluator = SVC() + if params is None: + params = {'C': [0.001, 0.01, 0.1, 1, 10, 100, 1000]} + super(SVMEvaluator, self).__init__(self.evaluator, params) diff --git a/fgssl/contrib/trainer/GCL/losses/__init__.py b/fgssl/contrib/trainer/GCL/losses/__init__.py new file mode 100644 index 0000000..c36dd8a --- /dev/null +++ b/fgssl/contrib/trainer/GCL/losses/__init__.py @@ -0,0 +1,24 @@ +from .jsd import JSD, DebiasedJSD, HardnessJSD +from .vicreg import VICReg +from .infonce import InfoNCE, InfoNCESP, DebiasedInfoNCE, HardnessInfoNCE +from .triplet import TripletMargin, TripletMarginSP +from .bootstrap import BootstrapLatent +from .barlow_twins import BarlowTwins +from .losses import Loss + +__all__ = [ + 'Loss', + 'InfoNCE', + 'InfoNCESP', + 'DebiasedInfoNCE', + 'HardnessInfoNCE', + 'JSD', + 'DebiasedJSD', + 'HardnessJSD', + 'TripletMargin', + 'TripletMarginSP', + 'VICReg', + 'BarlowTwins' +] + +classes = __all__ diff --git a/fgssl/contrib/trainer/GCL/losses/barlow_twins.py b/fgssl/contrib/trainer/GCL/losses/barlow_twins.py new file mode 100644 index 0000000..d32aaa4 --- /dev/null +++ b/fgssl/contrib/trainer/GCL/losses/barlow_twins.py @@ -0,0 +1,34 @@ +import torch +from .losses import Loss + + +def bt_loss(h1: torch.Tensor, h2: torch.Tensor, lambda_, batch_norm=True, eps=1e-15, *args, **kwargs): + batch_size = h1.size(0) + feature_dim = h1.size(1) + + if lambda_ is None: + lambda_ = 1. / feature_dim + + if batch_norm: + z1_norm = (h1 - h1.mean(dim=0)) / (h1.std(dim=0) + eps) + z2_norm = (h2 - h2.mean(dim=0)) / (h2.std(dim=0) + eps) + c = (z1_norm.T @ z2_norm) / batch_size + else: + c = h1.T @ h2 / batch_size + + off_diagonal_mask = ~torch.eye(feature_dim).bool() + loss = (1 - c.diagonal()).pow(2).sum() + loss += lambda_ * c[off_diagonal_mask].pow(2).sum() + + return loss + + +class BarlowTwins(Loss): + def __init__(self, lambda_: float = None, batch_norm: bool = True, eps: float = 1e-5): + self.lambda_ = lambda_ + self.batch_norm = batch_norm + self.eps = eps + + def compute(self, anchor, sample, pos_mask, neg_mask, *args, **kwargs) -> torch.FloatTensor: + loss = bt_loss(anchor, sample, self.lambda_, self.batch_norm, self.eps) + return loss.mean() diff --git a/fgssl/contrib/trainer/GCL/losses/bootstrap.py b/fgssl/contrib/trainer/GCL/losses/bootstrap.py new file mode 100644 index 0000000..e0362c8 --- /dev/null +++ b/fgssl/contrib/trainer/GCL/losses/bootstrap.py @@ -0,0 +1,16 @@ +import torch +import torch.nn.functional as F +from .losses import Loss + + +class BootstrapLatent(Loss): + def __init__(self): + super(BootstrapLatent, self).__init__() + + def compute(self, anchor, sample, pos_mask, neg_mask=None, *args, **kwargs) -> torch.FloatTensor: + anchor = F.normalize(anchor, dim=-1, p=2) + sample = F.normalize(sample, dim=-1, p=2) + + similarity = anchor @ sample.t() + loss = (similarity * pos_mask).sum(dim=-1) + return loss.mean() diff --git a/fgssl/contrib/trainer/GCL/losses/infonce.py b/fgssl/contrib/trainer/GCL/losses/infonce.py new file mode 100644 index 0000000..0276e19 --- /dev/null +++ b/fgssl/contrib/trainer/GCL/losses/infonce.py @@ -0,0 +1,189 @@ +import torch +import numpy as np +import torch.nn.functional as F + +from .losses import Loss + + +def _similarity(h1: torch.Tensor, h2: torch.Tensor): + h1 = F.normalize(h1) + h2 = F.normalize(h2) + return h1 @ h2.t() + + +class InfoNCESP(Loss): + """ + InfoNCE loss for single positive. + """ + def __init__(self, tau): + super(InfoNCESP, self).__init__() + self.tau = tau + + def compute(self, anchor, sample, pos_mask, neg_mask, *args, **kwargs): + f = lambda x: torch.exp(x / self.tau) + sim = f(_similarity(anchor, sample)) # anchor x sample + assert sim.size() == pos_mask.size() # sanity check + + neg_mask = 1 - pos_mask + pos = (sim * pos_mask).sum(dim=1) + neg = (sim * neg_mask).sum(dim=1) + + loss = pos / (pos + neg) + loss = -torch.log(loss) + + return loss.mean() + + +class InfoNCE(Loss): + def __init__(self, tau): + super(InfoNCE, self).__init__() + self.tau = tau + + def compute(self, anchor, sample, pos_mask, neg_mask, *args, **kwargs): + sim = _similarity(anchor, sample) / self.tau + exp_sim = torch.exp(sim) * (pos_mask + neg_mask) + log_prob = sim - torch.log(exp_sim.sum(dim=1, keepdim=True)) + loss = log_prob * pos_mask + loss = loss.sum(dim=1) / pos_mask.sum(dim=1) + return -loss.mean() + + +class DebiasedInfoNCE(Loss): + def __init__(self, tau, tau_plus=0.1): + super(DebiasedInfoNCE, self).__init__() + self.tau = tau + self.tau_plus = tau_plus + + def compute(self, anchor, sample, pos_mask, neg_mask, *args, **kwargs): + num_neg = neg_mask.int().sum() + sim = _similarity(anchor, sample) / self.tau + exp_sim = torch.exp(sim) + + pos_sum = (exp_sim * pos_mask).sum(dim=1) + pos = pos_sum / pos_mask.int().sum(dim=1) + neg_sum = (exp_sim * neg_mask).sum(dim=1) + ng = (-num_neg * self.tau_plus * pos + neg_sum) / (1 - self.tau_plus) + ng = torch.clamp(ng, min=num_neg * np.e ** (-1. / self.tau)) + + log_prob = sim - torch.log((pos + ng).sum(dim=1, keepdim=True)) + loss = log_prob * pos_mask + loss = loss.sum(dim=1) / pos_mask.sum(dim=1) + return loss.mean() + + +class HardnessInfoNCE(Loss): + def __init__(self, tau, tau_plus=0.1, beta=1.0): + super(HardnessInfoNCE, self).__init__() + self.tau = tau + self.tau_plus = tau_plus + self.beta = beta + + def compute(self, anchor, sample, pos_mask, neg_mask, *args, **kwargs): + num_neg = neg_mask.int().sum() + sim = _similarity(anchor, sample) / self.tau + exp_sim = torch.exp(sim) + + pos = (exp_sim * pos_mask).sum(dim=1) / pos_mask.int().sum(dim=1) + imp = torch.exp(self.beta * (sim * neg_mask)) + reweight_neg = (imp * (exp_sim * neg_mask)).sum(dim=1) / imp.mean(dim=1) + ng = (-num_neg * self.tau_plus * pos + reweight_neg) / (1 - self.tau_plus) + ng = torch.clamp(ng, min=num_neg * np.e ** (-1. / self.tau)) + + log_prob = sim - torch.log((pos + ng).sum(dim=1, keepdim=True)) + loss = log_prob * pos_mask + loss = loss.sum(dim=1) / pos_mask.sum(dim=1) + return loss.mean() + + +class HardMixingLoss(torch.nn.Module): + def __init__(self, projection): + super(HardMixingLoss, self).__init__() + self.projection = projection + + @staticmethod + def tensor_similarity(z1, z2): + z1 = F.normalize(z1, dim=-1) # [N, d] + z2 = F.normalize(z2, dim=-1) # [N, s, d] + return torch.bmm(z2, z1.unsqueeze(dim=-1)).squeeze() + + def forward(self, z1: torch.Tensor, z2: torch.Tensor, threshold=0.1, s=150, mixup=0.2, *args, **kwargs): + f = lambda x: torch.exp(x / self.tau) + num_samples = z1.shape[0] + device = z1.device + + threshold = int(num_samples * threshold) + + refl1 = _similarity(z1, z1).diag() + refl2 = _similarity(z2, z2).diag() + pos_similarity = f(_similarity(z1, z2)) + neg_similarity1 = torch.cat([_similarity(z1, z1), _similarity(z1, z2)], dim=1) # [n, 2n] + neg_similarity2 = torch.cat([_similarity(z2, z1), _similarity(z2, z2)], dim=1) + neg_similarity1, indices1 = torch.sort(neg_similarity1, descending=True) + neg_similarity2, indices2 = torch.sort(neg_similarity2, descending=True) + neg_similarity1 = f(neg_similarity1) + neg_similarity2 = f(neg_similarity2) + z_pool = torch.cat([z1, z2], dim=0) + hard_samples1 = z_pool[indices1[:, :threshold]] # [N, k, d] + hard_samples2 = z_pool[indices2[:, :threshold]] + hard_sample_idx1 = torch.randint(hard_samples1.shape[1], size=[num_samples, 2 * s]).to(device) # [N, 2 * s] + hard_sample_idx2 = torch.randint(hard_samples2.shape[1], size=[num_samples, 2 * s]).to(device) + hard_sample_draw1 = hard_samples1[ + torch.arange(num_samples).unsqueeze(-1), hard_sample_idx1] # [N, 2 * s, d] + hard_sample_draw2 = hard_samples2[torch.arange(num_samples).unsqueeze(-1), hard_sample_idx2] + hard_sample_mixing1 = mixup * hard_sample_draw1[:, :s, :] + (1 - mixup) * hard_sample_draw1[:, s:, :] + hard_sample_mixing2 = mixup * hard_sample_draw2[:, :s, :] + (1 - mixup) * hard_sample_draw2[:, s:, :] + + h_m1 = self.projection(hard_sample_mixing1) + h_m2 = self.projection(hard_sample_mixing2) + + neg_m1 = f(self.tensor_similarity(z1, h_m1)).sum(dim=1) + neg_m2 = f(self.tensor_similarity(z2, h_m2)).sum(dim=1) + pos = pos_similarity.diag() + neg1 = neg_similarity1.sum(dim=1) + neg2 = neg_similarity2.sum(dim=1) + loss1 = -torch.log(pos / (neg1 + neg_m1 - refl1)) + loss2 = -torch.log(pos / (neg2 + neg_m2 - refl2)) + loss = (loss1 + loss2) * 0.5 + loss = loss.mean() + return loss + + +class RingLoss(torch.nn.Module): + def __init__(self): + super(RingLoss, self).__init__() + + def forward(self, h1: torch.Tensor, h2: torch.Tensor, y: torch.Tensor, tau, threshold=0.1, *args, **kwargs): + f = lambda x: torch.exp(x / tau) + num_samples = h1.shape[0] + device = h1.device + threshold = int(num_samples * threshold) + + false_neg_mask = torch.zeros((num_samples, 2 * num_samples), dtype=torch.int).to(device) + for i in range(num_samples): + false_neg_mask[i] = (y == y[i]).repeat(2) + + pos_sim = f(_similarity(h1, h2)) + neg_sim1 = torch.cat([_similarity(h1, h1), _similarity(h1, h2)], dim=1) # [n, 2n] + neg_sim2 = torch.cat([_similarity(h2, h1), _similarity(h2, h2)], dim=1) + neg_sim1, indices1 = torch.sort(neg_sim1, descending=True) + neg_sim2, indices2 = torch.sort(neg_sim2, descending=True) + + y_repeated = y.repeat(2) + false_neg_cnt = torch.zeros((num_samples)).to(device) + for i in range(num_samples): + false_neg_cnt[i] = (y_repeated[indices1[i, threshold:-threshold]] == y[i]).sum() + + neg_sim1 = f(neg_sim1[:, threshold:-threshold]) + neg_sim2 = f(neg_sim2[:, threshold:-threshold]) + + pos = pos_sim.diag() + neg1 = neg_sim1.sum(dim=1) + neg2 = neg_sim2.sum(dim=1) + + loss1 = -torch.log(pos / neg1) + loss2 = -torch.log(pos / neg2) + + loss = (loss1 + loss2) * 0.5 + loss = loss.mean() + + return loss diff --git a/fgssl/contrib/trainer/GCL/losses/jsd.py b/fgssl/contrib/trainer/GCL/losses/jsd.py new file mode 100644 index 0000000..8efa121 --- /dev/null +++ b/fgssl/contrib/trainer/GCL/losses/jsd.py @@ -0,0 +1,77 @@ +import numpy as np +import torch.nn.functional as F + +from .losses import Loss + + +class JSD(Loss): + def __init__(self, discriminator=lambda x, y: x @ y.t()): + super(JSD, self).__init__() + self.discriminator = discriminator + + def compute(self, anchor, sample, pos_mask, neg_mask, *args, **kwargs): + num_neg = neg_mask.int().sum() + num_pos = pos_mask.int().sum() + similarity = self.discriminator(anchor, sample) + + E_pos = (np.log(2) - F.softplus(- similarity * pos_mask)).sum() + E_pos /= num_pos + + neg_sim = similarity * neg_mask + E_neg = (F.softplus(- neg_sim) + neg_sim - np.log(2)).sum() + E_neg /= num_neg + + return E_neg - E_pos + + +class DebiasedJSD(Loss): + def __init__(self, discriminator=lambda x, y: x @ y.t(), tau_plus=0.1): + super(DebiasedJSD, self).__init__() + self.discriminator = discriminator + self.tau_plus = tau_plus + + def compute(self, anchor, sample, pos_mask, neg_mask, *args, **kwargs): + num_neg = neg_mask.int().sum() + num_pos = pos_mask.int().sum() + similarity = self.discriminator(anchor, sample) + + pos_sim = similarity * pos_mask + E_pos = np.log(2) - F.softplus(- pos_sim) + E_pos -= (self.tau_plus / (1 - self.tau_plus)) * (F.softplus(-pos_sim) + pos_sim) + E_pos = E_pos.sum() / num_pos + + neg_sim = similarity * neg_mask + E_neg = (F.softplus(- neg_sim) + neg_sim - np.log(2)) / (1 - self.tau_plus) + E_neg = E_neg.sum() / num_neg + + return E_neg - E_pos + + +class HardnessJSD(Loss): + def __init__(self, discriminator=lambda x, y: x @ y.t(), tau_plus=0.1, beta=0.05): + super(HardnessJSD, self).__init__() + self.discriminator = discriminator + self.tau_plus = tau_plus + self.beta = beta + + def compute(self, anchor, sample, pos_mask, neg_mask, *args, **kwargs): + num_neg = neg_mask.int().sum() + num_pos = pos_mask.int().sum() + similarity = self.discriminator(anchor, sample) + + pos_sim = similarity * pos_mask + E_pos = np.log(2) - F.softplus(- pos_sim) + E_pos -= (self.tau_plus / (1 - self.tau_plus)) * (F.softplus(-pos_sim) + pos_sim) + E_pos = E_pos.sum() / num_pos + + neg_sim = similarity * neg_mask + E_neg = F.softplus(- neg_sim) + neg_sim + + reweight = -2 * neg_sim / max(neg_sim.max(), neg_sim.min().abs()) + reweight = (self.beta * reweight).exp() + reweight /= reweight.mean(dim=1, keepdim=True) + + E_neg = (reweight * E_neg) / (1 - self.tau_plus) - np.log(2) + E_neg = E_neg.sum() / num_neg + + return E_neg - E_pos diff --git a/fgssl/contrib/trainer/GCL/losses/losses.py b/fgssl/contrib/trainer/GCL/losses/losses.py new file mode 100644 index 0000000..79524c0 --- /dev/null +++ b/fgssl/contrib/trainer/GCL/losses/losses.py @@ -0,0 +1,12 @@ +import torch +from abc import ABC, abstractmethod + + +class Loss(ABC): + @abstractmethod + def compute(self, anchor, sample, pos_mask, neg_mask, *args, **kwargs) -> torch.FloatTensor: + pass + + def __call__(self, anchor, sample, pos_mask=None, neg_mask=None, *args, **kwargs) -> torch.FloatTensor: + loss = self.compute(anchor, sample, pos_mask, neg_mask, *args, **kwargs) + return loss diff --git a/fgssl/contrib/trainer/GCL/losses/triplet.py b/fgssl/contrib/trainer/GCL/losses/triplet.py new file mode 100644 index 0000000..3530a34 --- /dev/null +++ b/fgssl/contrib/trainer/GCL/losses/triplet.py @@ -0,0 +1,81 @@ +import torch +from .losses import Loss + + +class TripletMarginSP(Loss): + def __init__(self, margin: float = 1.0, p: float = 2, *args, **kwargs): + super(TripletMarginSP, self).__init__() + self.loss_fn = torch.nn.TripletMarginLoss(margin=margin, p=p, reduction='none') + self.margin = margin + + def compute(self, anchor, sample, pos_mask, neg_mask=None, *args, **kwargs): + neg_mask = 1. - pos_mask + + num_pos = pos_mask.to(torch.long).sum(dim=1) + num_neg = neg_mask.to(torch.long).sum(dim=1) + + dist = torch.cdist(anchor, sample, p=2) # [num_anchors, num_samples] + + pos_dist = pos_mask * dist + neg_dist = neg_mask * dist + + pos_dist, neg_dist = pos_dist.sum(dim=1), neg_dist.sum(dim=1) + + loss = pos_dist / num_pos - neg_dist / num_neg + self.margin + loss = torch.where(loss > 0, loss, torch.zeros_like(loss)) + + return loss.mean() + + +class TripletMargin(Loss): + def __init__(self, margin: float = 1.0, p: float = 2, *args, **kwargs): + super(TripletMargin, self).__init__() + self.loss_fn = torch.nn.TripletMarginLoss(margin=margin, p=p, reduction='none') + self.margin = margin + + def compute(self, anchor, sample, pos_mask, neg_mask=None, *args, **kwargs): + num_anchors = anchor.size()[0] + num_samples = sample.size()[0] + + # Key idea here: + # (1) Use all possible triples (will be num_anchors * num_positives * num_negatives triples in total) + # (2) Use PyTorch's TripletMarginLoss to compute the marginal loss for each triple + # (3) Since TripletMarginLoss accepts input tensors of shape (B, D), where B is the batch size, + # we have to manually construct all triples and flatten them as an input tensor in the + # shape of (num_triples, D). + # (4) We first compute loss for all triples (including those that are not anchor - positive - negative), which + # will be num_anchors * num_samples * num_samples triples, and then filter them with masks. + + # compute negative mask + neg_mask = 1. - pos_mask if neg_mask is None else neg_mask + + anchor = torch.unsqueeze(anchor, dim=1) # [N, 1, D] + anchor = torch.unsqueeze(anchor, dim=1) # [N, 1, 1, D] + anchor = anchor.expand(-1, num_samples, num_samples, -1) # [N, M, M, D] + anchor = torch.flatten(anchor, end_dim=1) # [N * M * M, D] + + pos_sample = torch.unsqueeze(sample, dim=0) # [1, M, D] + pos_sample = torch.unsqueeze(pos_sample, dim=2) # [1, M, 1, D] + pos_sample = pos_sample.expand(num_anchors, -1, num_samples, -1) # [N, M, M, D] + pos_sample = torch.flatten(pos_sample, end_dim=1) # [N * M * M, D] + + neg_sample = torch.unsqueeze(sample, dim=0) # [1, M, D] + neg_sample = torch.unsqueeze(neg_sample, dim=0) # [1, 1, M, D] + neg_sample = neg_sample.expand(num_anchors, -1, num_samples, -1) # [N, M, M, D] + neg_sample = torch.flatten(neg_sample, end_dim=1) # [N * M * M, D] + + loss = self.loss_fn(anchor, pos_sample, neg_sample) # [N, M, M] + loss = loss.view(num_anchors, num_samples, num_samples) + + pos_mask1 = torch.unsqueeze(pos_mask, dim=2) # [N, M, 1] + pos_mask1 = pos_mask1.expand(-1, -1, num_samples) # [N, M, M] + neg_mask1 = torch.unsqueeze(neg_mask, dim=1) # [N, 1, M] + neg_mask1 = neg_mask1.expand(-1, num_samples, -1) # [N, M, M] + + pair_mask = pos_mask1 * neg_mask1 # [N, M, M] + num_pairs = pair_mask.sum() + + loss = loss * pair_mask + loss = loss.sum() + + return loss / num_pairs diff --git a/fgssl/contrib/trainer/GCL/losses/vicreg.py b/fgssl/contrib/trainer/GCL/losses/vicreg.py new file mode 100644 index 0000000..284f5d3 --- /dev/null +++ b/fgssl/contrib/trainer/GCL/losses/vicreg.py @@ -0,0 +1,43 @@ +import torch +import torch.nn.functional as F +from .losses import Loss + + +class VICReg(Loss): + def __init__(self, sim_weight=25.0, var_weight=25.0, cov_weight=1.0, eps=1e-4): + super(VICReg, self).__init__() + self.sim_weight = sim_weight + self.var_weight = var_weight + self.cov_weight = cov_weight + self.eps = eps + + @staticmethod + def invariance_loss(h1, h2): + return F.mse_loss(h1, h2) + + def variance_loss(self, h1, h2): + std_z1 = torch.sqrt(h1.var(dim=0) + self.eps) + std_z2 = torch.sqrt(h2.var(dim=0) + self.eps) + std_loss = torch.mean(F.relu(1 - std_z1)) + torch.mean(F.relu(1 - std_z2)) + return std_loss + + @staticmethod + def covariance_loss(h1, h2): + num_nodes, hidden_dim = h1.size() + + h1 = h1 - h1.mean(dim=0) + h2 = h2 - h2.mean(dim=0) + cov_z1 = (h1.T @ h1) / (num_nodes - 1) + cov_z2 = (h2.T @ h2) / (num_nodes - 1) + + diag = torch.eye(hidden_dim, device=h1.device) + cov_loss = cov_z1[~diag.bool()].pow_(2).sum() / hidden_dim + cov_z2[~diag.bool()].pow_(2).sum() / hidden_dim + return cov_loss + + def compute(self, anchor, sample, pos_mask, neg_mask, *args, **kwargs) -> torch.FloatTensor: + sim_loss = self.invariance_loss(anchor, sample) + var_loss = self.variance_loss(anchor, sample) + cov_loss = self.covariance_loss(anchor, sample) + + loss = self.sim_weight * sim_loss + self.var_weight * var_loss + self.cov_weight * cov_loss + return loss.mean() diff --git a/fgssl/contrib/trainer/GCL/models/__init__.py b/fgssl/contrib/trainer/GCL/models/__init__.py new file mode 100644 index 0000000..c1859c0 --- /dev/null +++ b/fgssl/contrib/trainer/GCL/models/__init__.py @@ -0,0 +1,15 @@ +from .samplers import SameScaleSampler, CrossScaleSampler, get_sampler +from .contrast_model import SingleBranchContrast, DualBranchContrast, WithinEmbedContrast, BootstrapContrast + + +__all__ = [ + 'SingleBranchContrast', + 'DualBranchContrast', + 'WithinEmbedContrast', + 'BootstrapContrast', + 'SameScaleSampler', + 'CrossScaleSampler', + 'get_sampler' +] + +classes = __all__ diff --git a/fgssl/contrib/trainer/GCL/models/contrast_model.py b/fgssl/contrib/trainer/GCL/models/contrast_model.py new file mode 100644 index 0000000..6ce313c --- /dev/null +++ b/fgssl/contrib/trainer/GCL/models/contrast_model.py @@ -0,0 +1,155 @@ +import torch + +from GCL.losses import Loss +from GCL.models import get_sampler + + +def add_extra_mask(pos_mask, neg_mask=None, extra_pos_mask=None, extra_neg_mask=None): + if extra_pos_mask is not None: + pos_mask = torch.bitwise_or(pos_mask.bool(), extra_pos_mask.bool()).float() + if extra_neg_mask is not None: + neg_mask = torch.bitwise_and(neg_mask.bool(), extra_neg_mask.bool()).float() + else: + neg_mask = 1. - pos_mask + return pos_mask, neg_mask + + +class SingleBranchContrast(torch.nn.Module): + def __init__(self, loss: Loss, mode: str, intraview_negs: bool = False, **kwargs): + super(SingleBranchContrast, self).__init__() + assert mode == 'G2L' # only global-local pairs allowed in single-branch contrastive learning + self.loss = loss + self.mode = mode + self.sampler = get_sampler(mode, intraview_negs=intraview_negs) + self.kwargs = kwargs + + def forward(self, h, g, batch=None, hn=None, extra_pos_mask=None, extra_neg_mask=None): + if batch is None: # for single-graph datasets + assert hn is not None + anchor, sample, pos_mask, neg_mask = self.sampler(anchor=g, sample=h, neg_sample=hn) + else: # for multi-graph datasets + assert batch is not None + anchor, sample, pos_mask, neg_mask = self.sampler(anchor=g, sample=h, batch=batch) + + pos_mask, neg_mask = add_extra_mask(pos_mask, neg_mask, extra_pos_mask, extra_neg_mask) + loss = self.loss(anchor=anchor, sample=sample, pos_mask=pos_mask, neg_mask=neg_mask, **self.kwargs) + return loss + + +class DualBranchContrast(torch.nn.Module): + def __init__(self, loss: Loss, mode: str, intraview_negs: bool = False, **kwargs): + super(DualBranchContrast, self).__init__() + self.loss = loss + self.mode = mode + self.sampler = get_sampler(mode, intraview_negs=intraview_negs) + self.kwargs = kwargs + + def forward(self, h1=None, h2=None, g1=None, g2=None, batch=None, h3=None, h4=None, + extra_pos_mask=None, extra_neg_mask=None): + if self.mode == 'L2L': + assert h1 is not None and h2 is not None + anchor1, sample1, pos_mask1, neg_mask1 = self.sampler(anchor=h1, sample=h2) + anchor2, sample2, pos_mask2, neg_mask2 = self.sampler(anchor=h2, sample=h1) + elif self.mode == 'G2G': + assert g1 is not None and g2 is not None + anchor1, sample1, pos_mask1, neg_mask1 = self.sampler(anchor=g1, sample=g2) + anchor2, sample2, pos_mask2, neg_mask2 = self.sampler(anchor=g2, sample=g1) + else: # global-to-local + if batch is None or batch.max().item() + 1 <= 1: # single graph + assert all(v is not None for v in [h1, h2, g1, g2, h3, h4]) + anchor1, sample1, pos_mask1, neg_mask1 = self.sampler(anchor=g1, sample=h2, neg_sample=h4) + anchor2, sample2, pos_mask2, neg_mask2 = self.sampler(anchor=g2, sample=h1, neg_sample=h3) + else: # multiple graphs + assert all(v is not None for v in [h1, h2, g1, g2, batch]) + anchor1, sample1, pos_mask1, neg_mask1 = self.sampler(anchor=g1, sample=h2, batch=batch) + anchor2, sample2, pos_mask2, neg_mask2 = self.sampler(anchor=g2, sample=h1, batch=batch) + + pos_mask1, neg_mask1 = add_extra_mask(pos_mask1, neg_mask1, extra_pos_mask, extra_neg_mask) + pos_mask2, neg_mask2 = add_extra_mask(pos_mask2, neg_mask2, extra_pos_mask, extra_neg_mask) + l1 = self.loss(anchor=anchor1, sample=sample1, pos_mask=pos_mask1, neg_mask=neg_mask1, **self.kwargs) + l2 = self.loss(anchor=anchor2, sample=sample2, pos_mask=pos_mask2, neg_mask=neg_mask2, **self.kwargs) + + return l2 * 0.5 + + +class SingleBranchContrast(torch.nn.Module): + def __init__(self, loss: Loss, mode: str, intraview_negs: bool = False, **kwargs): + super(DualBranchContrast, self).__init__() + self.loss = loss + self.mode = mode + self.sampler = get_sampler(mode, intraview_negs=intraview_negs) + self.kwargs = kwargs + + def forward(self, h1=None, h2=None, g1=None, g2=None, batch=None, h3=None, h4=None, + extra_pos_mask=None, extra_neg_mask=None): + if self.mode == 'L2L': + assert h1 is not None and h2 is not None + anchor1, sample1, pos_mask1, neg_mask1 = self.sampler(anchor=h1, sample=h2) + anchor2, sample2, pos_mask2, neg_mask2 = self.sampler(anchor=h2, sample=h1) + elif self.mode == 'G2G': + assert g1 is not None and g2 is not None + anchor1, sample1, pos_mask1, neg_mask1 = self.sampler(anchor=g1, sample=g2) + anchor2, sample2, pos_mask2, neg_mask2 = self.sampler(anchor=g2, sample=g1) + else: # global-to-local + if batch is None or batch.max().item() + 1 <= 1: # single graph + assert all(v is not None for v in [h1, h2, g1, g2, h3, h4]) + anchor1, sample1, pos_mask1, neg_mask1 = self.sampler(anchor=g1, sample=h2, neg_sample=h4) + anchor2, sample2, pos_mask2, neg_mask2 = self.sampler(anchor=g2, sample=h1, neg_sample=h3) + else: # multiple graphs + assert all(v is not None for v in [h1, h2, g1, g2, batch]) + anchor1, sample1, pos_mask1, neg_mask1 = self.sampler(anchor=g1, sample=h2, batch=batch) + anchor2, sample2, pos_mask2, neg_mask2 = self.sampler(anchor=g2, sample=h1, batch=batch) + + pos_mask1, neg_mask1 = add_extra_mask(pos_mask1, neg_mask1, extra_pos_mask, extra_neg_mask) + pos_mask2, neg_mask2 = add_extra_mask(pos_mask2, neg_mask2, extra_pos_mask, extra_neg_mask) + l1 = self.loss(anchor=anchor1, sample=sample1, pos_mask=pos_mask1, neg_mask=neg_mask1, **self.kwargs) + l2 = self.loss(anchor=anchor2, sample=sample2, pos_mask=pos_mask2, neg_mask=neg_mask2, **self.kwargs) + + return l2 * 0.5 + +class BootstrapContrast(torch.nn.Module): + def __init__(self, loss, mode='L2L'): + super(BootstrapContrast, self).__init__() + self.loss = loss + self.mode = mode + self.sampler = get_sampler(mode, intraview_negs=False) + + def forward(self, h1_pred=None, h2_pred=None, h1_target=None, h2_target=None, + g1_pred=None, g2_pred=None, g1_target=None, g2_target=None, + batch=None, extra_pos_mask=None): + if self.mode == 'L2L': + assert all(v is not None for v in [h1_pred, h2_pred, h1_target, h2_target]) + anchor1, sample1, pos_mask1, _ = self.sampler(anchor=h1_target, sample=h2_pred) + anchor2, sample2, pos_mask2, _ = self.sampler(anchor=h2_target, sample=h1_pred) + elif self.mode == 'G2G': + assert all(v is not None for v in [g1_pred, g2_pred, g1_target, g2_target]) + anchor1, sample1, pos_mask1, _ = self.sampler(anchor=g1_target, sample=g2_pred) + anchor2, sample2, pos_mask2, _ = self.sampler(anchor=g2_target, sample=g1_pred) + else: + assert all(v is not None for v in [h1_pred, h2_pred, g1_target, g2_target]) + if batch is None or batch.max().item() + 1 <= 1: # single graph + pos_mask1 = pos_mask2 = torch.ones([1, h1_pred.shape[0]], device=h1_pred.device) + anchor1, sample1 = g1_target, h2_pred + anchor2, sample2 = g2_target, h1_pred + else: + anchor1, sample1, pos_mask1, _ = self.sampler(anchor=g1_target, sample=h2_pred, batch=batch) + anchor2, sample2, pos_mask2, _ = self.sampler(anchor=g2_target, sample=h1_pred, batch=batch) + + pos_mask1, _ = add_extra_mask(pos_mask1, extra_pos_mask=extra_pos_mask) + pos_mask2, _ = add_extra_mask(pos_mask2, extra_pos_mask=extra_pos_mask) + l1 = self.loss(anchor=anchor1, sample=sample1, pos_mask=pos_mask1) + l2 = self.loss(anchor=anchor2, sample=sample2, pos_mask=pos_mask2) + + return (l1 + l2) * 0.5 + + +class WithinEmbedContrast(torch.nn.Module): + def __init__(self, loss: Loss, **kwargs): + super(WithinEmbedContrast, self).__init__() + self.loss = loss + self.kwargs = kwargs + + def forward(self, h1, h2): + l1 = self.loss(anchor=h1, sample=h2, **self.kwargs) + l2 = self.loss(anchor=h2, sample=h1, **self.kwargs) + return (l1 + l2) * 0.5 diff --git a/fgssl/contrib/trainer/GCL/models/samplers.py b/fgssl/contrib/trainer/GCL/models/samplers.py new file mode 100644 index 0000000..1c03982 --- /dev/null +++ b/fgssl/contrib/trainer/GCL/models/samplers.py @@ -0,0 +1,81 @@ +import torch +from abc import ABC, abstractmethod +from torch_scatter import scatter + + +class Sampler(ABC): + def __init__(self, intraview_negs=False): + self.intraview_negs = intraview_negs + + def __call__(self, anchor, sample, *args, **kwargs): + ret = self.sample(anchor, sample, *args, **kwargs) + if self.intraview_negs: + ret = self.add_intraview_negs(*ret) + return ret + + @abstractmethod + def sample(self, anchor, sample, *args, **kwargs): + pass + + @staticmethod + def add_intraview_negs(anchor, sample, pos_mask, neg_mask): + num_nodes = anchor.size(0) + device = anchor.device + intraview_pos_mask = torch.zeros_like(pos_mask, device=device) + intraview_neg_mask = torch.ones_like(pos_mask, device=device) - torch.eye(num_nodes, device=device) + new_sample = torch.cat([sample, anchor], dim=0) # (M+N) * K + new_pos_mask = torch.cat([pos_mask, intraview_pos_mask], dim=1) # M * (M+N) + new_neg_mask = torch.cat([neg_mask, intraview_neg_mask], dim=1) # M * (M+N) + return anchor, new_sample, new_pos_mask, new_neg_mask + + +class SameScaleSampler(Sampler): + def __init__(self, *args, **kwargs): + super(SameScaleSampler, self).__init__(*args, **kwargs) + + def sample(self, anchor, sample, *args, **kwargs): + assert anchor.size(0) == sample.size(0) + num_nodes = anchor.size(0) + device = anchor.device + pos_mask = torch.eye(num_nodes, dtype=torch.float32, device=device) + neg_mask = 1. - pos_mask + return anchor, sample, pos_mask, neg_mask + + +class CrossScaleSampler(Sampler): + def __init__(self, *args, **kwargs): + super(CrossScaleSampler, self).__init__(*args, **kwargs) + + def sample(self, anchor, sample, batch=None, neg_sample=None, use_gpu=True, *args, **kwargs): + num_graphs = anchor.shape[0] # M + num_nodes = sample.shape[0] # N + device = sample.device + + if neg_sample is not None: + assert num_graphs == 1 # only one graph, explicit negative samples are needed + assert sample.shape == neg_sample.shape + pos_mask1 = torch.ones((num_graphs, num_nodes), dtype=torch.float32, device=device) + pos_mask0 = torch.zeros((num_graphs, num_nodes), dtype=torch.float32, device=device) + pos_mask = torch.cat([pos_mask1, pos_mask0], dim=1) # M * 2N + sample = torch.cat([sample, neg_sample], dim=0) # 2N * K + else: + assert batch is not None + if use_gpu: + ones = torch.eye(num_nodes, dtype=torch.float32, device=device) # N * N + pos_mask = scatter(ones, batch, dim=0, reduce='sum') # M * N + else: + pos_mask = torch.zeros((num_graphs, num_nodes), dtype=torch.float32).to(device) + for node_idx, graph_idx in enumerate(batch): + pos_mask[graph_idx][node_idx] = 1. # M * N + + neg_mask = 1. - pos_mask + return anchor, sample, pos_mask, neg_mask + + +def get_sampler(mode: str, intraview_negs: bool) -> Sampler: + if mode in {'L2L', 'G2G'}: + return SameScaleSampler(intraview_negs=intraview_negs) + elif mode == 'G2L': + return CrossScaleSampler(intraview_negs=intraview_negs) + else: + raise RuntimeError(f'unsupported mode: {mode}') diff --git a/fgssl/contrib/trainer/GCL/utils.py b/fgssl/contrib/trainer/GCL/utils.py new file mode 100644 index 0000000..27582a0 --- /dev/null +++ b/fgssl/contrib/trainer/GCL/utils.py @@ -0,0 +1,74 @@ +from typing import * +import os +import torch +import dgl +import random +import numpy as np + + +def split_dataset(dataset, split_mode, *args, **kwargs): + assert split_mode in ['rand', 'ogb', 'wikics', 'preload'] + if split_mode == 'rand': + assert 'train_ratio' in kwargs and 'test_ratio' in kwargs + train_ratio = kwargs['train_ratio'] + test_ratio = kwargs['test_ratio'] + num_samples = dataset.x.size(0) + train_size = int(num_samples * train_ratio) + test_size = int(num_samples * test_ratio) + indices = torch.randperm(num_samples) + return { + 'train': indices[:train_size], + 'val': indices[train_size: test_size + train_size], + 'test': indices[test_size + train_size:] + } + elif split_mode == 'ogb': + return dataset.get_idx_split() + elif split_mode == 'wikics': + assert 'split_idx' in kwargs + split_idx = kwargs['split_idx'] + return { + 'train': dataset.train_mask[:, split_idx], + 'test': dataset.test_mask, + 'val': dataset.val_mask[:, split_idx] + } + elif split_mode == 'preload': + assert 'preload_split' in kwargs + assert kwargs['preload_split'] is not None + train_mask, test_mask, val_mask = kwargs['preload_split'] + return { + 'train': train_mask, + 'test': test_mask, + 'val': val_mask + } + + +def seed_everything(seed): + random.seed(seed) + os.environ['PYTHONHASHSEED'] = str(seed) + np.random.seed(seed) + + torch.backends.cudnn.benchmark = False + torch.backends.cudnn.deterministic = True + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + +def normalize(s): + return (s.max() - s) / (s.max() - s.mean()) + + +def build_dgl_graph(edge_index: torch.Tensor) -> dgl.DGLGraph: + row, col = edge_index + return dgl.graph((row, col)) + + +def batchify_dict(dicts: List[dict], aggr_func=lambda x: x): + res = dict() + for d in dicts: + for k, v in d.items(): + if k not in res: + res[k] = [v] + else: + res[k].append(v) + res = {k: aggr_func(v) for k, v in res.items()} + return res diff --git a/fgssl/contrib/trainer/__init__.py b/fgssl/contrib/trainer/__init__.py new file mode 100644 index 0000000..c0b3138 --- /dev/null +++ b/fgssl/contrib/trainer/__init__.py @@ -0,0 +1,8 @@ +from os.path import dirname, basename, isfile, join +import glob + +modules = glob.glob(join(dirname(__file__), "*.py")) +__all__ = [ + basename(f)[:-3] for f in modules + if isfile(f) and not f.endswith('__init__.py') +] diff --git a/fgssl/contrib/trainer/example.py b/fgssl/contrib/trainer/example.py new file mode 100644 index 0000000..ffc2073 --- /dev/null +++ b/fgssl/contrib/trainer/example.py @@ -0,0 +1,16 @@ +from federatedscope.register import register_trainer +from federatedscope.core.trainers.torch_trainer import GeneralTorchTrainer + + +# Build your trainer here. +class MyTrainer(GeneralTorchTrainer): + pass + + +def call_my_trainer(trainer_type): + if trainer_type == 'mytrainer': + trainer_builder = MyTrainer + return trainer_builder + + +register_trainer('mytrainer', call_my_trainer) diff --git a/fgssl/contrib/trainer/torch_example.py b/fgssl/contrib/trainer/torch_example.py new file mode 100644 index 0000000..18cd5d7 --- /dev/null +++ b/fgssl/contrib/trainer/torch_example.py @@ -0,0 +1,104 @@ +import inspect +from federatedscope.register import register_trainer +from federatedscope.core.trainers import BaseTrainer + +# An example for converting torch training process to FS training process + +# Refer to `federatedscope.core.trainers.BaseTrainer` for interface. + +# Try with FEMNIST: +# python federatedscope/main.py --cfg scripts/example_configs/femnist.yaml \ +# trainer.type mytorchtrainer federate.sample_client_rate 0.01 \ +# federate.total_round_num 5 eval.best_res_update_round_wise_key test_loss + + +class MyTorchTrainer(BaseTrainer): + def __init__(self, model, data, device, **kwargs): + import torch + # NN modules + self.model = model + # FS `ClientData` or your own data + self.data = data + # Device name + self.device = device + # kwargs + self.kwargs = kwargs + # Criterion & Optimizer + self.criterion = torch.nn.CrossEntropyLoss() + self.optimizer = torch.optim.SGD(self.model.parameters(), + lr=0.001, + momentum=0.9, + weight_decay=1e-4) + + def train(self): + # _hook_on_fit_start_init + self.model.to(self.device) + self.model.train() + + total_loss = num_samples = 0 + # _hook_on_batch_start_init + for x, y in self.data['train']: + # _hook_on_batch_forward + x, y = x.to(self.device), y.to(self.device) + outputs = self.model(x) + loss = self.criterion(outputs, y) + + # _hook_on_batch_backward + self.optimizer.zero_grad() + loss.backward() + self.optimizer.step() + + # _hook_on_batch_end + total_loss += loss.item() * y.shape[0] + num_samples += y.shape[0] + + # _hook_on_fit_end + return num_samples, self.model.cpu().state_dict(), \ + {'loss_total': total_loss, 'avg_loss': total_loss/float( + num_samples)} + + def evaluate(self, target_data_split_name='test'): + import torch + with torch.no_grad(): + self.model.to(self.device) + self.model.eval() + total_loss = num_samples = 0 + # _hook_on_batch_start_init + for x, y in self.data[target_data_split_name]: + # _hook_on_batch_forward + x, y = x.to(self.device), y.to(self.device) + pred = self.model(x) + loss = self.criterion(pred, y) + + # _hook_on_batch_end + total_loss += loss.item() * y.shape[0] + num_samples += y.shape[0] + + # _hook_on_fit_end + return { + f'{target_data_split_name}_loss': total_loss, + f'{target_data_split_name}_total': num_samples, + f'{target_data_split_name}_avg_loss': total_loss / + float(num_samples) + } + + def update(self, model_parameters, strict=False): + self.model.load_state_dict(model_parameters, strict) + + def get_model_para(self): + return self.model.cpu().state_dict() + + def print_trainer_meta_info(self): + sign = inspect.signature(self.__init__).parameters.values() + meta_info = tuple([(val.name, getattr(self, val.name)) + for val in sign]) + return f'{self.__class__.__name__}{meta_info}' + + +def call_my_torch_trainer(trainer_type): + if trainer_type == 'mytorchtrainer': + trainer_builder = MyTorchTrainer + return trainer_builder + + +register_trainer('mytorchtrainer', call_my_torch_trainer) diff --git a/fgssl/contrib/trainer/trainer2.py b/fgssl/contrib/trainer/trainer2.py new file mode 100644 index 0000000..e203069 --- /dev/null +++ b/fgssl/contrib/trainer/trainer2.py @@ -0,0 +1,113 @@ +import copy + +import torch +from copy import deepcopy +from federatedscope.core.auxiliaries.enums import MODE +from federatedscope.core.auxiliaries.enums import LIFECYCLE +from federatedscope.core.auxiliaries.optimizer_builder import get_optimizer +from federatedscope.core.auxiliaries.scheduler_builder import get_scheduler +from federatedscope.gfl.trainer.nodetrainer import NodeFullBatchTrainer +from federatedscope.core.auxiliaries.enums import LIFECYCLE +from federatedscope.core.trainers.context import CtxVar +from federatedscope.gfl.loss.vat import VATLoss +from federatedscope.core.trainers import GeneralTorchTrainer +from GCL.models import DualBranchContrast +import GCL.losses as L +import GCL.augmentors as A +import pyro +from GCL.models.contrast_model import WithinEmbedContrast +import torch.nn.functional as F +from torch_geometric.nn import GINConv, global_add_pool +import torch.nn as nn + + +class FGCLTrainer2(GeneralTorchTrainer): + def __init__(self, + model, + data, + device, + config, + only_for_eval=False, + monitor=None): + super(FGCLTrainer2, self).__init__(model, data, device, config, + only_for_eval, monitor) + + self.global_model = copy.deepcopy(model) + + def register_default_hooks_train(self): + super(FGCLTrainer2, self).register_default_hooks_train() + self.register_hook_in_train(new_hook=begin, + trigger='on_fit_start', + insert_pos=-1) + self.register_hook_in_train(new_hook=leave, + trigger='on_fit_end', + insert_pos=-1) + + def register_default_hooks_eval(self): + super(FGCLTrainer2, self).register_default_hooks_eval() + self.register_hook_in_eval(new_hook=begin, + trigger='on_fit_start', + insert_pos=-1) + self.register_hook_in_eval(new_hook=leave, + trigger='on_fit_end', + insert_pos=-1) + + def _hook_on_batch_forward(self, ctx): + + batch = ctx.data_batch.to(ctx.device) + mask = batch['{}_mask'.format(ctx.cur_split)].detach() + + label = batch.y[batch['{}_mask'.format(ctx.cur_split)]] + + self.global_model.to(ctx.device).eval() + + pred, raw_feature_local, adj_sampled , adj_logits, adj_orig = ctx.model(batch) + + pred_global, raw_feature_global, adj_sampled , adj_logits, adj_orig = self.global_model(batch) + + pred = pred[mask] + + loss1 = ctx.criterion(pred, label) + + + kd_loss = com_distillation_loss(raw_feature_global,raw_feature_local,adj_orig,adj_sampled,2) + norm_w = adj_orig.shape[0] ** 2 / float((adj_orig.shape[0] ** 2 - adj_orig.sum()) * 2) + pos_weight = torch.FloatTensor([float(adj_orig.shape[0] ** 2 - adj_orig.sum()) / adj_orig.sum()]).to(ctx.device) + ga_loss = norm_w * F.binary_cross_entropy_with_logits(adj_logits, adj_orig, pos_weight=pos_weight) + + # ctx.loss_batch = loss1 + (loss2 + loss3) * 0.5 + ctx.loss_batch = loss1 + ctx.batch_size = torch.sum(mask).item() + ctx.y_true = CtxVar(label, LIFECYCLE.BATCH) + ctx.y_prob = CtxVar(pred, LIFECYCLE.BATCH) + + +def begin(ctx): + if 'lastModel' not in ctx.keys(): + ctx.lastModel = copy.deepcopy(ctx.model).to(ctx.device) + + +def leave(ctx): + ctx.lastModel = copy.deepcopy(ctx.model).to(ctx.device) + + + + + + +def com_distillation_loss(t_logits, s_logits, adj_orig, adj_sampled, temp): + + s_dist = F.log_softmax(s_logits / temp, dim=-1) + t_dist = F.softmax(t_logits / temp, dim=-1) + kd_loss = temp * temp * F.kl_div(s_dist, t_dist.detach()) + + + adj = torch.triu(adj_orig * adj_sampled).detach() + edge_list = (adj + adj.T).nonzero().t() + + s_dist_neigh = F.log_softmax(s_logits[edge_list[0]] / temp, dim=-1) + t_dist_neigh = F.softmax(t_logits[edge_list[1]] / temp, dim=-1) + + kd_loss += temp * temp * F.kl_div(s_dist_neigh, t_dist_neigh.detach()) + + return kd_loss \ No newline at end of file diff --git a/fgssl/contrib/worker/FLAG.py b/fgssl/contrib/worker/FLAG.py new file mode 100644 index 0000000..6767cee --- /dev/null +++ b/fgssl/contrib/worker/FLAG.py @@ -0,0 +1,341 @@ +import copy +import logging +import random + +import torch +from gensim.models import Word2Vec +from sklearn.manifold import TSNE +from tqdm import tqdm + +from federatedscope.core.message import Message +from federatedscope.register import register_worker +from federatedscope.core.workers import Server, Client + +logger = logging.getLogger(__name__) +from sklearn import manifold +import numpy as np +import pandas as pd +# import cca_core +# from CKA import linear_CKA, kernel_CKA +import matplotlib as mpl +import matplotlib.pyplot as plt +import gensim +from torch_geometric.utils import to_dense_adj +import seaborn as sns +import seaborn.objects as so + + +# Build your worker here. +class FGCLClient(Client): + def callback_funcs_for_model_para(self, message: Message): + round, sender, content = message.state, message.sender, message.content + + self.trainer.state = round + self.trainer.global_model.load_state_dict(content) + self.trainer.update(content) + self.state = round + sample_size, model_para, results = self.trainer.train() + if self._cfg.federate.share_local_model and not \ + self._cfg.federate.online_aggr: + model_para = copy.deepcopy(model_para) + logger.info( + self._monitor.format_eval_res(results, + rnd=self.state, + role='Client #{}'.format(self.ID))) + + # self.comm_manager.send( + # Message(msg_type='model_para', + # sender=self.ID, + # receiver=[sender], + # state=self.state, + # content=(sample_size, model_para, self.trainer.ctx.data['test']))) + # + self.comm_manager.send( + Message(msg_type='model_para', + sender=self.ID, + receiver=[sender], + state=self.state, + content=(sample_size, model_para))) + + +class FGCLServer(Server): + def _perform_federated_aggregation(self): + model_for_e = copy.deepcopy(self.model) + """ + Perform federated aggregation and update the global model + """ + train_msg_buffer = self.msg_buffer['train'][self.state] + for model_idx in range(self.model_num): + model = self.models[model_idx] + aggregator = self.aggregators[model_idx] + msg_list = list() + staleness = list() + + for client_id in train_msg_buffer.keys(): + if self.model_num == 1: + msg_list.append(train_msg_buffer[client_id]) + else: + train_data_size, model_para_multiple = \ + train_msg_buffer[client_id] + msg_list.append( + (train_data_size, model_para_multiple[model_idx])) + + # The staleness of the messages in train_msg_buffer + # should be 0 + staleness.append((client_id, 0)) + + for staled_message in self.staled_msg_buffer: + state, client_id, content = staled_message + if self.model_num == 1: + msg_list.append(content) + else: + train_data_size, model_para_multiple = content + msg_list.append( + (train_data_size, model_para_multiple[model_idx])) + + staleness.append((client_id, self.state - state)) + + # if self.state % 50 == 0: + # list_feat = list() + # data = self.data['test'].dataset[0] + # labels = data.y + # for tuple in msg_list: + # size, para = tuple + # # edge_index = data_m.dataset[0].edge_index + # # adj_orig = to_dense_adj(edge_index, max_num_nodes=data.x.shape[0]).squeeze(0) + # # awe = anonymous_walk_embedding(adj_orig.fill_diagonal_(True)) + # model_for_e.load_state_dict(para) + # model_for_e.eval() + # awe = 1 + # x, f = model_for_e(data) + # f = f + # x = x + # list_feat.append(dict({"x": x, "f": f, "awe": awe, "id": len(list_feat)})) + # visual_cka(list_feat, self.state, labels, 6, "YlOrRd") # GnBu magma "YlOrRd" + + aggregated_num = len(msg_list) + agg_info = { + 'client_feedback': msg_list, + 'recover_fun': self.recover_fun, + 'staleness': staleness, + } + # logger.info(f'The staleness is {staleness}') + result = aggregator.aggregate(agg_info) + # Due to lazy load, we merge two state dict + merged_param = merge_param_dict(model.state_dict().copy(), result) + model.load_state_dict(merged_param, strict=False) + # if self.state == 196: + # visual_tsne(model, self.data['val'].dataset[0], self.__class__.__name__) + + return aggregated_num + + +def merge_param_dict(raw_param, filtered_param): + for key in filtered_param.keys(): + raw_param[key] = filtered_param[key] + return raw_param + + +def call_my_worker(method): + if method == 'fgcl': + worker_builder = {'client': FGCLClient, 'server': FGCLServer} + return worker_builder + + +register_worker('fgcl', call_my_worker) +def plot_features(features, labels, num_classes): + colors = ['C' + str(i) for i in range(num_classes)] + plt.figure(figsize=(6, 6)) + for l in range(num_classes): + plt.scatter( + features[labels == l, 0], + features[labels == l, 1], + c=colors[l], s=1, alpha=0.4) + plt.xticks([]) + plt.yticks([]) + plt.show() + + + + + +def visual_tsne(model, data,name): + labels = data.y + model.eval() + z,h = model(data) + num_class = labels.max().item() + 1 + z = z.detach().cpu().numpy() + tsne = manifold.TSNE(n_components=2, perplexity=35, init='pca') + plt.figure(figsize=(8, 8)) + x_tsne_data = list() + f = tsne.fit_transform(z) + for clazz in range(num_class): + fp = f[labels == clazz] + clazz = np.full(fp.shape[0], clazz) + clazz = np.expand_dims(clazz, axis=1) + fe = np.concatenate([fp, clazz], axis=1) + x_tsne_data.append(fe) + + x_tsne_data = np.concatenate(x_tsne_data, axis=0) + df_tsne = pd.DataFrame(x_tsne_data, columns=["dim1", "dim2", "class"]) + + sns.scatterplot(data=df_tsne, palette="bright", hue='class', x='dim1', y='dim2') + plt.legend([],[], frameon=False) + plt.xticks([]) + plt.yticks([]) + plt.xlabel("") + plt.ylabel("") + import os + if not os.path.exists('data/output/tsne/'): + os.mkdir("data/output/tsne/") + plt.savefig("data/output/tsne/result_"+ name + ".png", format='png', dpi=800, + pad_inches=0.1, bbox_inches='tight') + plt.show() + + +import numpy as np + + +def anonymous_walk_embedding(adj_matrix, dimensions=128, walk_length=5, num_walks=5): + import numpy as np + from sklearn.decomposition import TruncatedSVD + adj_matrix= adj_matrix.cpu().numpy() + rows, cols = adj_matrix.shape + walk_matrix = np.linalg.matrix_power(adj_matrix, walk_length) + row_sums = np.array(np.sum(walk_matrix, axis=1)).flatten() + walk_matrix = np.divide(walk_matrix, row_sums[:, np.newaxis]) + svd = TruncatedSVD(n_components=dimensions) + awe_representation = svd.fit_transform(walk_matrix) + return awe_representation + + +import logging + +import torch +from gensim.models import Word2Vec +from sklearn.manifold import TSNE +from tqdm import tqdm + +from federatedscope.core.message import Message +from federatedscope.register import register_worker +from federatedscope.core.workers import Server, Client + +logger = logging.getLogger(__name__) +from sklearn import manifold +import numpy as np +import pandas as pd +# import cca_core +# from CKA import linear_CKA, kernel_CKA +import matplotlib as mpl +import matplotlib.pyplot as plt +import gensim +from torch_geometric.utils import to_dense_adj +import seaborn as sns +import seaborn.objects as so + + +def call_my_worker(method): + if method == 'fgcl': + worker_builder = {'client': FGCLClient, 'server': FGCLServer} + return worker_builder + + +register_worker('fgcl', call_my_worker) +def plot_features(features, labels, num_classes): + colors = ['C' + str(i) for i in range(num_classes)] + plt.figure(figsize=(6, 6)) + for l in range(num_classes): + plt.scatter( + features[labels == l, 0], + features[labels == l, 1], + c=colors[l], s=1, alpha=0.4) + plt.xticks([]) + plt.yticks([]) + plt.show() + +import numpy as np + + + +def visual_cka(list_feat, state, labels, num_class, theme): + plt.figure() + client_number = len(list_feat) + client_list = [i for i in range(client_number)] + random.shuffle(client_list) + result = np.zeros(shape=(client_number, client_number)) + for i in range(client_number): + for j in range(client_number): + result[i][j] = CKA(list_feat[i]["f"].detach().numpy(), list_feat[j]["f"].detach().numpy()) + + np.save("data/output/cka/result_node.npy",arr=result) + sns.heatmap(data=result, vmin=0.0, vmax=1.0, cmap=theme).invert_yaxis() + plt.xticks([]) + plt.yticks([]) + plt.xlabel("") + plt.ylabel("") + plt.title("") + plt.savefig("data/output/cka/node_state_" + str(state) + ".png", dpi=600) + plt.show() + + plt.figure() + client_number = len(list_feat) + result = np.zeros(shape=(client_number, client_number)) + for i in range(client_number): + for j in range(client_number): + result[i][j] = CKA(list_feat[client_list[i]]["x"].detach().numpy(), list_feat[client_list[j]]["x"].detach().numpy()) + np.save("data/output/cka/result_node2.npy", arr=result) + sns.heatmap(data=result, vmin=0.0, vmax=1.0, cmap=theme).invert_yaxis() + plt.xticks([]) + plt.yticks([]) + plt.xlabel("") + plt.ylabel("") + plt.title("") + plt.savefig("data/output/cka/structure_state_" + str(state) + ".png", dpi=600) + plt.show() + + + + + +def CKA(X, Y): + '''Computes the CKA of two matrices. This is equation (1) from the paper''' + + nominator = unbiased_HSIC(np.dot(X, X.T), np.dot(Y, Y.T)) + denominator1 = unbiased_HSIC(np.dot(X, X.T), np.dot(X, X.T)) + denominator2 = unbiased_HSIC(np.dot(Y, Y.T), np.dot(Y, Y.T)) + + cka = nominator / np.sqrt(denominator1 * denominator2) + + return cka + + +def unbiased_HSIC(K, L): + '''Computes an unbiased estimator of HISC. This is equation (2) from the paper''' + + # create the unit **vector** filled with ones + n = K.shape[0] + ones = np.ones(shape=(n)) + + # fill the diagonal entries with zeros + np.fill_diagonal(K, val=0) # this is now K_tilde + np.fill_diagonal(L, val=0) # this is now L_tilde + + # first part in the square brackets + trace = np.trace(np.dot(K, L)) + + # middle part in the square brackets + nominator1 = np.dot(np.dot(ones.T, K), ones) + nominator2 = np.dot(np.dot(ones.T, L), ones) + denominator = (n - 1) * (n - 2) + middle = np.dot(nominator1, nominator2) / denominator + + # third part in the square brackets + multiplier1 = 2 / (n - 2) + multiplier2 = np.dot(np.dot(ones.T, K), np.dot(L, ones)) + last = multiplier1 * multiplier2 + + # complete equation + unbiased_hsic = 1 / (n * (n - 3)) * (trace + middle - last) + + return unbiased_hsic + diff --git a/fgssl/contrib/worker/__init__.py b/fgssl/contrib/worker/__init__.py new file mode 100644 index 0000000..c0b3138 --- /dev/null +++ b/fgssl/contrib/worker/__init__.py @@ -0,0 +1,8 @@ +from os.path import dirname, basename, isfile, join +import glob + +modules = glob.glob(join(dirname(__file__), "*.py")) +__all__ = [ + basename(f)[:-3] for f in modules + if isfile(f) and not f.endswith('__init__.py') +] diff --git a/fgssl/contrib/worker/vis.py b/fgssl/contrib/worker/vis.py new file mode 100644 index 0000000..3de5141 --- /dev/null +++ b/fgssl/contrib/worker/vis.py @@ -0,0 +1,267 @@ +import copy +import logging + +import torch +from gensim.models import Word2Vec +from sklearn.manifold import TSNE +from tqdm import tqdm + +from federatedscope.core.message import Message +from federatedscope.register import register_worker +from federatedscope.core.workers import Server, Client + +logger = logging.getLogger(__name__) +from sklearn import manifold +import numpy as np +import pandas as pd +# import cca_core +# from CKA import linear_CKA, kernel_CKA +import matplotlib as mpl +import matplotlib.pyplot as plt +import gensim +from torch_geometric.utils import to_dense_adj +import seaborn as sns +import seaborn.objects as so + + +# Build your worker here. +class visClient(Client): + pass + + +class visServer(Server): + def _perform_federated_aggregation(self): + model_for_e = copy.deepcopy(self.model) + """ + Perform federated aggregation and update the global model + """ + train_msg_buffer = self.msg_buffer['train'][self.state] + for model_idx in range(self.model_num): + model = self.models[model_idx] + aggregator = self.aggregators[model_idx] + msg_list = list() + staleness = list() + + for client_id in train_msg_buffer.keys(): + if self.model_num == 1: + msg_list.append(train_msg_buffer[client_id]) + else: + train_data_size, model_para_multiple = \ + train_msg_buffer[client_id] + msg_list.append( + (train_data_size, model_para_multiple[model_idx])) + + # The staleness of the messages in train_msg_buffer + # should be 0 + staleness.append((client_id, 0)) + + for staled_message in self.staled_msg_buffer: + state, client_id, content = staled_message + if self.model_num == 1: + msg_list.append(content) + else: + train_data_size, model_para_multiple = content + msg_list.append( + (train_data_size, model_para_multiple[model_idx])) + + staleness.append((client_id, self.state - state)) + + # if self.state % 50 == 0: + # list_feat = list() + # data = self.data['test'].dataset[0] + # labels = data.y + # for tuple in msg_list: + # size, para, data_m = tuple + # edge_index = data_m.dataset[0].edge_index + # adj_orig = to_dense_adj(edge_index, max_num_nodes=data.x.shape[0]).squeeze(0) + # awe = anonymous_walk_embedding(adj_orig.fill_diagonal_(True)) + # model_for_e.load_state_dict(para) + # model_for_e.eval() + # x, f = model_for_e(data) + # f = f + # x = x + # list_feat.append(dict({"x": x, "f": f, "awe": awe, "id": len(list_feat)})) + # + # visual_tsne(list_feat, self.state, labels, 6) + # visual_cka(list_feat, self.state, labels, 6, "YlOrRd") # GnBu magma "YlOrRd" + + aggregated_num = len(msg_list) + agg_info = { + 'client_feedback': msg_list, + 'recover_fun': self.recover_fun, + 'staleness': staleness, + } + # logger.info(f'The staleness is {staleness}') + result = aggregator.aggregate(agg_info) + # Due to lazy load, we merge two state dict + merged_param = merge_param_dict(model.state_dict().copy(), result) + model.load_state_dict(merged_param, strict=False) + if self.state == 198: + visual_tsne(model, self.data["val"].dataset[0], self.__class__.__name__) + + return aggregated_num + +def merge_param_dict(raw_param, filtered_param): + for key in filtered_param.keys(): + raw_param[key] = filtered_param[key] + return raw_param + +def call_my_worker(method): + if method == 'vis': + worker_builder = {'client': visClient, 'server': visServer} + return worker_builder + +register_worker('vis', call_my_worker) + +def plot_features(features, labels, num_classes): + colors = ['C' + str(i) for i in range(num_classes)] + plt.figure(figsize=(6, 6)) + for l in range(num_classes): + plt.scatter( + features[labels == l, 0], + features[labels == l, 1], + c=colors[l], s=1, alpha=0.4) + plt.xticks([]) + plt.yticks([]) + plt.show() + +def global_tsne(model,data,num_classes): + all_features, all_labels = [], [] + model.eval() + with torch.no_grad(): + for i, (data, labels) in tqdm(enumerate(val_loader)): + outputs, features = model(data) + all_features.append(features['pooled_feat'].data.cpu().numpy()) + all_labels.append(labels.data.cpu().numpy()) + all_features = np.concatenate(all_features, 0) + all_labels = np.concatenate(all_labels, 0) + + tsne = TSNE() + all_features = tsne.fit_transform(all_features) + plot_features(all_features, all_labels, num_classes) + + + + + +def visual_tsne(model, data,name): + labels = data.y + model.eval() + z,h = model(data) + num_class = labels.max().item() + 1 + z = z.detach().cpu().numpy() + tsne = manifold.TSNE(n_components=2, perplexity=40, init='pca') + plt.figure(figsize=(8, 8)) + x_tsne_data = list() + f = tsne.fit_transform(z) + for clazz in range(num_class): + fp = f[labels == clazz] + clazz = np.full(fp.shape[0], clazz) + clazz = np.expand_dims(clazz, axis=1) + fe = np.concatenate([fp, clazz], axis=1) + x_tsne_data.append(fe) + + x_tsne_data = np.concatenate(x_tsne_data, axis=0) + df_tsne = pd.DataFrame(x_tsne_data, columns=["dim1", "dim2", "class"]) + + sns.relplot(data=df_tsne, palette="bright", hue='class', x='dim1', y='dim2') + plt.xticks([]) + plt.yticks([]) + import os + if not os.path.exists('data/output/tsne/'): + os.mkdir("data/output/tsne/") + plt.savefig("data/output/tsne/result_"+ name + ".pdf", format='pdf', dpi=800, + pad_inches=0, bbox_inches='tight') + plt.show() + + +import numpy as np + + +def anonymous_walk_embedding(adj_matrix, dimensions=128, walk_length=5, num_walks=5): + import numpy as np + from sklearn.decomposition import TruncatedSVD + adj_matrix= adj_matrix.cpu().numpy() + rows, cols = adj_matrix.shape + walk_matrix = np.linalg.matrix_power(adj_matrix, walk_length) + row_sums = np.array(np.sum(walk_matrix, axis=1)).flatten() + walk_matrix = np.divide(walk_matrix, row_sums[:, np.newaxis]) + svd = TruncatedSVD(n_components=dimensions) + awe_representation = svd.fit_transform(walk_matrix) + return awe_representation + + +def visual_cka(list_feat, state, labels, num_class, theme): + plt.figure(figsize=(8, 8), dpi=500) + client_number = len(list_feat) + result = np.zeros(shape=(client_number, client_number)) + for i in range(client_number): + for j in range(client_number): + result[i][j] = CKA(list_feat[i]["f"].detach().numpy(), list_feat[j]["f"].detach().numpy()) + + sns.heatmap(data=result, vmin=0.0, vmax=1.0, cmap=theme).invert_yaxis() + plt.xlabel("ClientId") + plt.ylabel("ClientId") + plt.title("node-level") + plt.savefig("data/output/cka/node_state_" + str(state) + ".png", dpi=500) + plt.show() + + plt.figure(figsize=(8, 8), dpi=500) + client_number = len(list_feat) + result = np.zeros(shape=(client_number, client_number)) + for i in range(client_number): + for j in range(client_number): + result[i][j] = CKA(list_feat[i]["x"].detach().numpy(), list_feat[j]["x"].detach().numpy()) + + sns.heatmap(data=result, vmin=0.0, vmax=1.0, cmap=theme).invert_yaxis() + plt.xlabel("ClientId") + plt.ylabel("ClientId") + plt.title("structure_level") + plt.savefig("data/output/cka/structure_state_" + str(state) + ".png", dpi=500) + plt.show() + + + + + +def CKA(X, Y): + '''Computes the CKA of two matrices. This is equation (1) from the paper''' + + nominator = unbiased_HSIC(np.dot(X, X.T), np.dot(Y, Y.T)) + denominator1 = unbiased_HSIC(np.dot(X, X.T), np.dot(X, X.T)) + denominator2 = unbiased_HSIC(np.dot(Y, Y.T), np.dot(Y, Y.T)) + + cka = nominator / np.sqrt(denominator1 * denominator2) + + return cka + + +def unbiased_HSIC(K, L): + '''Computes an unbiased estimator of HISC. This is equation (2) from the paper''' + + # create the unit **vector** filled with ones + n = K.shape[0] + ones = np.ones(shape=(n)) + + # fill the diagonal entries with zeros + np.fill_diagonal(K, val=0) # this is now K_tilde + np.fill_diagonal(L, val=0) # this is now L_tilde + + # first part in the square brackets + trace = np.trace(np.dot(K, L)) + + # middle part in the square brackets + nominator1 = np.dot(np.dot(ones.T, K), ones) + nominator2 = np.dot(np.dot(ones.T, L), ones) + denominator = (n - 1) * (n - 2) + middle = np.dot(nominator1, nominator2) / denominator + + # third part in the square brackets + multiplier1 = 2 / (n - 2) + multiplier2 = np.dot(np.dot(ones.T, K), np.dot(L, ones)) + last = multiplier1 * multiplier2 + + # complete equation + unbiased_hsic = 1 / (n * (n - 3)) * (trace + middle - last) + + return unbiased_hsic diff --git a/fgssl/core/__init__.py b/fgssl/core/__init__.py new file mode 100644 index 0000000..f8e91f2 --- /dev/null +++ b/fgssl/core/__init__.py @@ -0,0 +1,3 @@ +from __future__ import absolute_import +from __future__ import print_function +from __future__ import division diff --git a/fgssl/core/aggregators/__init__.py b/fgssl/core/aggregators/__init__.py new file mode 100644 index 0000000..e65cbb3 --- /dev/null +++ b/fgssl/core/aggregators/__init__.py @@ -0,0 +1,19 @@ +from federatedscope.core.aggregators.aggregator import Aggregator, \ + NoCommunicationAggregator +from federatedscope.core.aggregators.clients_avg_aggregator import \ + ClientsAvgAggregator, OnlineClientsAvgAggregator +from federatedscope.core.aggregators.asyn_clients_avg_aggregator import \ + AsynClientsAvgAggregator +from federatedscope.core.aggregators.server_clients_interpolate_aggregator \ + import ServerClientsInterpolateAggregator +from federatedscope.core.aggregators.fedopt_aggregator import FedOptAggregator + +__all__ = [ + 'Aggregator', + 'NoCommunicationAggregator', + 'ClientsAvgAggregator', + 'OnlineClientsAvgAggregator', + 'AsynClientsAvgAggregator', + 'ServerClientsInterpolateAggregator', + 'FedOptAggregator', +] diff --git a/fgssl/core/aggregators/aggregator.py b/fgssl/core/aggregators/aggregator.py new file mode 100644 index 0000000..c8e2052 --- /dev/null +++ b/fgssl/core/aggregators/aggregator.py @@ -0,0 +1,18 @@ +from abc import ABC, abstractmethod + + +class Aggregator(ABC): + def __init__(self): + pass + + @abstractmethod + def aggregate(self, agg_info): + pass + + +class NoCommunicationAggregator(Aggregator): + """"Clients do not communicate. Each client work locally + """ + def aggregate(self, agg_info): + # do nothing + return {} diff --git a/fgssl/core/aggregators/asyn_clients_avg_aggregator.py b/fgssl/core/aggregators/asyn_clients_avg_aggregator.py new file mode 100644 index 0000000..39d33a7 --- /dev/null +++ b/fgssl/core/aggregators/asyn_clients_avg_aggregator.py @@ -0,0 +1,79 @@ +import copy +import torch +from federatedscope.core.aggregators import ClientsAvgAggregator + + +class AsynClientsAvgAggregator(ClientsAvgAggregator): + """The aggregator used in asynchronous training, which discounts the + staled model updates + """ + def __init__(self, model=None, device='cpu', config=None): + super(AsynClientsAvgAggregator, self).__init__(model, device, config) + + def aggregate(self, agg_info): + """ + To preform aggregation + + Arguments: + agg_info (dict): the feedbacks from clients + :returns: the aggregated results + :rtype: dict + """ + + models = agg_info["client_feedback"] + recover_fun = agg_info['recover_fun'] if ( + 'recover_fun' in agg_info and self.cfg.federate.use_ss) else None + staleness = [x[1] + for x in agg_info['staleness']] # (client_id, staleness) + avg_model = self._para_weighted_avg(models, + recover_fun=recover_fun, + staleness=staleness) + + # When using asynchronous training, the return feedback is model delta + # rather than the model param + updated_model = copy.deepcopy(avg_model) + init_model = self.model.state_dict() + for key in avg_model: + updated_model[key] = init_model[key] + avg_model[key] + return updated_model + + def discount_func(self, staleness): + """ + Served as an example, we discount the model update with staleness \tau + as: (1.0/((1.0+\tau)**factor)), + which has been used in previous studies such as FedAsync (Asynchronous + Federated Optimization) and FedBuff + (Federated Learning with Buffered Asynchronous Aggregation). + """ + return (1.0 / + ((1.0 + staleness)**self.cfg.asyn.staleness_discount_factor)) + + def _para_weighted_avg(self, models, recover_fun=None, staleness=None): + training_set_size = 0 + for i in range(len(models)): + sample_size, _ = models[i] + training_set_size += sample_size + + sample_size, avg_model = models[0] + for key in avg_model: + for i in range(len(models)): + local_sample_size, local_model = models[i] + + if self.cfg.federate.ignore_weight: + weight = 1.0 / len(models) + else: + weight = local_sample_size / training_set_size + + assert staleness is not None + weight *= self.discount_func(staleness[i]) + if isinstance(local_model[key], torch.Tensor): + local_model[key] = local_model[key].float() + else: + local_model[key] = torch.FloatTensor(local_model[key]) + + if i == 0: + avg_model[key] = local_model[key] * weight + else: + avg_model[key] += local_model[key] * weight + + return avg_model diff --git a/fgssl/core/aggregators/clients_avg_aggregator.py b/fgssl/core/aggregators/clients_avg_aggregator.py new file mode 100644 index 0000000..21ac60c --- /dev/null +++ b/fgssl/core/aggregators/clients_avg_aggregator.py @@ -0,0 +1,126 @@ +import os +import torch +from federatedscope.core.aggregators import Aggregator +from federatedscope.core.auxiliaries.utils import param2tensor + + +class ClientsAvgAggregator(Aggregator): + """Implementation of vanilla FedAvg refer to `Communication-efficient + learning of deep networks from decentralized data` [McMahan et al., 2017] + (http://proceedings.mlr.press/v54/mcmahan17a.html) + """ + def __init__(self, model=None, device='cpu', config=None): + super(Aggregator, self).__init__() + self.model = model + self.device = device + self.cfg = config + + def aggregate(self, agg_info): + """ + To preform aggregation + + Arguments: + agg_info (dict): the feedbacks from clients + :returns: the aggregated results + :rtype: dict + """ + + models = agg_info["client_feedback"] + recover_fun = agg_info['recover_fun'] if ( + 'recover_fun' in agg_info and self.cfg.federate.use_ss) else None + avg_model = self._para_weighted_avg(models, recover_fun=recover_fun) + + return avg_model + + def update(self, model_parameters): + ''' + Arguments: + model_parameters (dict): PyTorch Module object's state_dict. + ''' + self.model.load_state_dict(model_parameters, strict=False) + + def save_model(self, path, cur_round=-1): + assert self.model is not None + + ckpt = {'cur_round': cur_round, 'model': self.model.state_dict()} + torch.save(ckpt, path) + + def load_model(self, path): + assert self.model is not None + + if os.path.exists(path): + ckpt = torch.load(path, map_location=self.device) + self.model.load_state_dict(ckpt['model']) + return ckpt['cur_round'] + else: + raise ValueError("The file {} does NOT exist".format(path)) + + def _para_weighted_avg(self, models, recover_fun=None): + training_set_size = 0 + for i in range(len(models)): + sample_size, _ = models[i] + training_set_size += sample_size + + sample_size, avg_model = models[0] + for key in avg_model: + for i in range(len(models)): + local_sample_size, local_model = models[i] + + if self.cfg.federate.ignore_weight: + weight = 1.0 / len(models) + elif self.cfg.federate.use_ss: + # When using secret sharing, what the server receives + # are sample_size * model_para + weight = 1.0 + else: + weight = local_sample_size / training_set_size + + if not self.cfg.federate.use_ss: + local_model[key] = param2tensor(local_model[key]) + if i == 0: + avg_model[key] = local_model[key] * weight + else: + avg_model[key] += local_model[key] * weight + + if self.cfg.federate.use_ss and recover_fun: + avg_model[key] = recover_fun(avg_model[key]) + # When using secret sharing, what the server receives are + # sample_size * model_para + avg_model[key] /= training_set_size + avg_model[key] = torch.FloatTensor(avg_model[key]) + + return avg_model + + +class OnlineClientsAvgAggregator(ClientsAvgAggregator): + def __init__(self, + model=None, + device='cpu', + src_device='cpu', + config=None): + super(OnlineClientsAvgAggregator, self).__init__(model, device, config) + self.src_device = src_device + + def reset(self): + self.maintained = self.model.state_dict() + for key in self.maintained: + self.maintained[key].data = torch.zeros_like( + self.maintained[key], device=self.src_device) + self.cnt = 0 + + def inc(self, content): + if isinstance(content, tuple): + sample_size, model_params = content + for key in self.maintained: + # if model_params[key].device != self.maintained[key].device: + # model_params[key].to(self.maintained[key].device) + self.maintained[key] = (self.cnt * self.maintained[key] + + sample_size * model_params[key]) / ( + self.cnt + sample_size) + self.cnt += sample_size + else: + raise TypeError( + "{} is not a tuple (sample_size, model_para)".format(content)) + + def aggregate(self, agg_info): + return self.maintained diff --git a/fgssl/core/aggregators/fedopt_aggregator.py b/fgssl/core/aggregators/fedopt_aggregator.py new file mode 100644 index 0000000..47e1725 --- /dev/null +++ b/fgssl/core/aggregators/fedopt_aggregator.py @@ -0,0 +1,31 @@ +import torch + +from federatedscope.core.aggregators import ClientsAvgAggregator +from federatedscope.core.auxiliaries.optimizer_builder import get_optimizer + + +class FedOptAggregator(ClientsAvgAggregator): + """Implementation of FedOpt refer to `Adaptive Federated Optimization` [ + Reddi et al., 2021] + (https://openreview.net/forum?id=LkFG3lB13U5) + + """ + def __init__(self, config, model, device='cpu'): + super(FedOptAggregator, self).__init__(model, device, config) + self.optimizer = get_optimizer(model=self.model, + **config.fedopt.optimizer) + + def aggregate(self, agg_info): + new_model = super().aggregate(agg_info) + + model = self.model.cpu().state_dict() + with torch.no_grad(): + grads = {key: model[key] - new_model[key] for key in new_model} + + self.optimizer.zero_grad() + for key, p in self.model.named_parameters(): + if key in new_model.keys(): + p.grad = grads[key] + self.optimizer.step() + + return self.model.state_dict() diff --git a/fgssl/core/aggregators/server_clients_interpolate_aggregator.py b/fgssl/core/aggregators/server_clients_interpolate_aggregator.py new file mode 100644 index 0000000..200de25 --- /dev/null +++ b/fgssl/core/aggregators/server_clients_interpolate_aggregator.py @@ -0,0 +1,27 @@ +from federatedscope.core.aggregators import ClientsAvgAggregator + + +class ServerClientsInterpolateAggregator(ClientsAvgAggregator): + """" + # conduct aggregation by interpolating global model from server and + local models from clients + """ + def __init__(self, model=None, device='cpu', config=None, beta=1.0): + super(ServerClientsInterpolateAggregator, + self).__init__(model, device, config) + self.beta = beta # the weight for local models used in interpolation + + def aggregate(self, agg_info): + models = agg_info["client_feedback"] + global_model = self.model + elem_each_client = next(iter(models)) + assert len(elem_each_client) == 2, f"Require (sample_size, " \ + f"model_para) tuple for each " \ + f"client, i.e., len=2, but got " \ + f"len={len(elem_each_client)}" + avg_model_by_clients = self._para_weighted_avg(models) + global_local_models = [((1 - self.beta), global_model.state_dict()), + (self.beta, avg_model_by_clients)] + + avg_model_by_interpolate = self._para_weighted_avg(global_local_models) + return avg_model_by_interpolate diff --git a/fgssl/core/auxiliaries/ReIterator.py b/fgssl/core/auxiliaries/ReIterator.py new file mode 100644 index 0000000..f8ce408 --- /dev/null +++ b/fgssl/core/auxiliaries/ReIterator.py @@ -0,0 +1,19 @@ +class ReIterator: + def __init__(self, loader): + self.loader = loader + self.iterator = iter(loader) + self.reset_flag = False + + def __iter__(self): + return self + + def __next__(self): + try: + item = next(self.iterator) + except StopIteration: + self.reset() + item = next(self.iterator) + return item + + def reset(self): + self.iterator = iter(self.loader) diff --git a/fgssl/core/auxiliaries/__init__.py b/fgssl/core/auxiliaries/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/fgssl/core/auxiliaries/aggregator_builder.py b/fgssl/core/auxiliaries/aggregator_builder.py new file mode 100644 index 0000000..778ff0c --- /dev/null +++ b/fgssl/core/auxiliaries/aggregator_builder.py @@ -0,0 +1,54 @@ +import logging + +from federatedscope.core.configs import constants + +logger = logging.getLogger(__name__) + + +def get_aggregator(method, model=None, device=None, online=False, config=None): + if config.backend == 'tensorflow': + from federatedscope.cross_backends import FedAvgAggregator + return FedAvgAggregator(model=model, device=device) + else: + from federatedscope.core.aggregators import ClientsAvgAggregator, \ + OnlineClientsAvgAggregator, ServerClientsInterpolateAggregator, \ + FedOptAggregator, NoCommunicationAggregator, \ + AsynClientsAvgAggregator + + if method.lower() in constants.AGGREGATOR_TYPE: + aggregator_type = constants.AGGREGATOR_TYPE[method.lower()] + else: + aggregator_type = "clients_avg" + logger.warning( + 'Aggregator for method {} is not implemented. Will use default one' + .format(method)) + + if config.fedopt.use or aggregator_type == 'fedopt': + return FedOptAggregator(config=config, model=model, device=device) + elif aggregator_type == 'clients_avg': + if online: + return OnlineClientsAvgAggregator( + model=model, + device=device, + config=config, + src_device=device + if config.federate.share_local_model else 'cpu') + elif config.asyn.use: + return AsynClientsAvgAggregator(model=model, + device=device, + config=config) + else: + return ClientsAvgAggregator(model=model, + device=device, + config=config) + elif aggregator_type == 'server_clients_interpolation': + return ServerClientsInterpolateAggregator( + model=model, + device=device, + config=config, + beta=config.personalization.beta) + elif aggregator_type == 'no_communication': + return NoCommunicationAggregator() + else: + raise NotImplementedError( + "Aggregator {} is not implemented.".format(aggregator_type)) diff --git a/fgssl/core/auxiliaries/criterion_builder.py b/fgssl/core/auxiliaries/criterion_builder.py new file mode 100644 index 0000000..4192a9b --- /dev/null +++ b/fgssl/core/auxiliaries/criterion_builder.py @@ -0,0 +1,33 @@ +import logging +import federatedscope.register as register + +logger = logging.getLogger(__name__) + +try: + from torch import nn + from federatedscope.nlp.loss import * +except ImportError: + nn = None + +try: + from federatedscope.contrib.loss import * +except ImportError as error: + logger.warning( + f'{error} in `federatedscope.contrib.loss`, some modules are not ' + f'available.') + + +def get_criterion(type, device): + for func in register.criterion_dict.values(): + criterion = func(type, device) + if criterion is not None: + return criterion + + if isinstance(type, str): + if hasattr(nn, type): + return getattr(nn, type)() + else: + raise NotImplementedError( + 'Criterion {} not implement'.format(type)) + else: + raise TypeError() diff --git a/fgssl/core/auxiliaries/data_builder.py b/fgssl/core/auxiliaries/data_builder.py new file mode 100644 index 0000000..617d968 --- /dev/null +++ b/fgssl/core/auxiliaries/data_builder.py @@ -0,0 +1,70 @@ +import logging + +from importlib import import_module +from federatedscope.core.data.utils import RegexInverseMap, load_dataset, \ + convert_data_mode +from federatedscope.core.auxiliaries.utils import setup_seed + +import federatedscope.register as register + +logger = logging.getLogger(__name__) + +try: + from federatedscope.contrib.data import * +except ImportError as error: + logger.warning( + f'{error} in `federatedscope.contrib.data`, some modules are not ' + f'available.') + +# TODO: Add PyGNodeDataTranslator and PyGLinkDataTranslator +# TODO: move splitter to PyGNodeDataTranslator and PyGLinkDataTranslator +TRANS_DATA_MAP = { + 'BaseDataTranslator': [ + '.*?@.*?', 'hiv', 'proteins', 'imdb-binary', 'bbbp', 'tox21', 'bace', + 'sider', 'clintox', 'esol', 'freesolv', 'lipo' + ], + 'DummyDataTranslator': [ + 'toy', 'quadratic', 'femnist', 'celeba', 'shakespeare', 'twitter', + 'subreddit', 'synthetic', 'ciao', 'epinions', '.*?vertical_fl_data.*?', + '.*?movielens.*?', '.*?cikmcup.*?', 'graph_multi_domain.*?', 'cora', + 'citeseer', 'pubmed', 'dblp_conf', 'dblp_org', 'csbm.*?', 'fb15k-237', + 'wn18' + ], # Dummy for FL dataset +} +DATA_TRANS_MAP = RegexInverseMap(TRANS_DATA_MAP, None) + + +def get_data(config, client_cfgs=None): + """Instantiate the data and update the configuration accordingly if + necessary. + Arguments: + config: a cfg node object. + client_cfgs: dict of client-specific cfg node object. + Returns: + obj: The dataset object. + cfg.node: The updated configuration. + """ + # Fix the seed for data generation + setup_seed(12345) + + for func in register.data_dict.values(): + data_and_config = func(config, client_cfgs) + if data_and_config is not None: + return data_and_config + + # Load dataset from source files + dataset, modified_config = load_dataset(config) + + # Perform translator to non-FL dataset + translator = getattr(import_module('federatedscope.core.data'), + DATA_TRANS_MAP[config.data.type.lower()])( + modified_config, client_cfgs) + data = translator(dataset) + + # Convert `StandaloneDataDict` to `ClientData` when in distribute mode + data = convert_data_mode(data, modified_config) + + # Restore the user-specified seed after the data generation + setup_seed(config.seed) + + return data, modified_config diff --git a/fgssl/core/auxiliaries/dataloader_builder.py b/fgssl/core/auxiliaries/dataloader_builder.py new file mode 100644 index 0000000..a412d76 --- /dev/null +++ b/fgssl/core/auxiliaries/dataloader_builder.py @@ -0,0 +1,74 @@ +from federatedscope.core.data.utils import filter_dict + +try: + import torch + from torch.utils.data import Dataset +except ImportError: + torch = None + Dataset = object + + +def get_dataloader(dataset, config, split='train'): + """ + Instantiate a DataLoader via config. + + Args: + dataset: dataset from which to load the data. + config: configs containing batch_size, shuffle, etc. + split: current split (default: 'train'), if split is 'test', shuffle + will be `False`. And in PyG, 'test' split will use + `NeighborSampler` by default. + + Returns: + dataloader: Instance of specific DataLoader configured by config. + + """ + # DataLoader builder only support torch backend now. + if config.backend != 'torch': + return None + + if config.dataloader.type == 'base': + from torch.utils.data import DataLoader + loader_cls = DataLoader + elif config.dataloader.type == 'raw': + # No DataLoader + return dataset + elif config.dataloader.type == 'pyg': + from torch_geometric.loader import DataLoader as PyGDataLoader + loader_cls = PyGDataLoader + elif config.dataloader.type == 'graphsaint-rw': + if split == 'train': + from torch_geometric.loader import GraphSAINTRandomWalkSampler + loader_cls = GraphSAINTRandomWalkSampler + else: + from torch_geometric.loader import NeighborSampler + loader_cls = NeighborSampler + elif config.dataloader.type == 'neighbor': + from torch_geometric.loader import NeighborSampler + loader_cls = NeighborSampler + elif config.dataloader.type == 'mf': + from federatedscope.mf.dataloader import MFDataLoader + loader_cls = MFDataLoader + else: + raise ValueError(f'data.loader.type {config.data.loader.type} ' + f'not found!') + + raw_args = dict(config.dataloader) + if split != 'train': + raw_args['shuffle'] = False + raw_args['sizes'] = [-1] + raw_args['drop_last'] = False + # For evaluation in GFL + if config.dataloader.type in ['graphsaint-rw', 'neighbor']: + raw_args['batch_size'] = 4096 + dataset = dataset[0].edge_index + else: + if config.dataloader.type in ['graphsaint-rw']: + # Raw graph + dataset = dataset[0] + elif config.dataloader.type in ['neighbor']: + # edge_index of raw graph + dataset = dataset[0].edge_index + filtered_args = filter_dict(loader_cls.__init__, raw_args) + dataloader = loader_cls(dataset, **filtered_args) + return dataloader diff --git a/fgssl/core/auxiliaries/decorators.py b/fgssl/core/auxiliaries/decorators.py new file mode 100644 index 0000000..25d233f --- /dev/null +++ b/fgssl/core/auxiliaries/decorators.py @@ -0,0 +1,20 @@ +def use_diff(func): + def wrapper(self, *args, **kwargs): + if self.cfg.federate.use_diff: + # TODO: any issue for subclasses? + before_metric = self.evaluate(target_data_split_name='val') + + num_samples_train, model_para, result_metric = func( + self, *args, **kwargs) + + if self.cfg.federate.use_diff: + # TODO: any issue for subclasses? + after_metric = self.evaluate(target_data_split_name='val') + result_metric['val_total'] = before_metric['val_total'] + result_metric['val_avg_loss_before'] = before_metric[ + 'val_avg_loss'] + result_metric['val_avg_loss_after'] = after_metric['val_avg_loss'] + + return num_samples_train, model_para, result_metric + + return wrapper diff --git a/fgssl/core/auxiliaries/enums.py b/fgssl/core/auxiliaries/enums.py new file mode 100644 index 0000000..ffaf080 --- /dev/null +++ b/fgssl/core/auxiliaries/enums.py @@ -0,0 +1,37 @@ +class MODE: + """ + + Note: + Currently StrEnum cannot be imported with the environment + `sys.version_info < (3, 11)`, so we simply create a MODE class here. + """ + TRAIN = 'train' + TEST = 'test' + VAL = 'val' + FINETUNE = 'finetune' + + +class TRIGGER: + ON_FIT_START = 'on_fit_start' + ON_EPOCH_START = 'on_epoch_start' + ON_BATCH_START = 'on_batch_start' + ON_BATCH_FORWARD = 'on_batch_forward' + ON_BATCH_BACKWARD = 'on_batch_backward' + ON_BATCH_END = 'on_batch_end' + ON_EPOCH_END = 'on_epoch_end' + ON_FIT_END = 'on_fit_end' + + @classmethod + def contains(cls, item): + return item in [ + "on_fit_start", "on_epoch_start", "on_batch_start", + "on_batch_forward", "on_batch_backward", "on_batch_end", + "on_epoch_end", "on_fit_end" + ] + + +class LIFECYCLE: + ROUTINE = 'routine' + EPOCH = 'epoch' + BATCH = 'batch' + NONE = None diff --git a/fgssl/core/auxiliaries/logging.py b/fgssl/core/auxiliaries/logging.py new file mode 100644 index 0000000..92a8671 --- /dev/null +++ b/fgssl/core/auxiliaries/logging.py @@ -0,0 +1,255 @@ +import copy +import json +import logging +import os +import re +import time +from datetime import datetime + +import numpy as np + +from federatedscope.core.auxiliaries.utils import logger + + +class CustomFormatter(logging.Formatter): + """Logging colored formatter, adapted from + https://stackoverflow.com/a/56944256/3638629""" + def __init__(self, fmt): + super().__init__() + grey = '\x1b[38;21m' + blue = '\x1b[38;5;39m' + yellow = "\x1b[33;20m" + red = '\x1b[38;5;196m' + bold_red = '\x1b[31;1m' + reset = '\x1b[0m' + + self.FORMATS = { + logging.DEBUG: grey + fmt + reset, + logging.INFO: blue + fmt + reset, + logging.WARNING: yellow + fmt + reset, + logging.ERROR: red + fmt + reset, + logging.CRITICAL: bold_red + fmt + reset + } + + def format(self, record): + log_fmt = self.FORMATS.get(record.levelno) + formatter = logging.Formatter(log_fmt) + return formatter.format(record) + + +class LoggerPrecisionFilter(logging.Filter): + def __init__(self, precision): + super().__init__() + self.print_precision = precision + + def str_round(self, match_res): + return str(round(eval(match_res.group()), self.print_precision)) + + def filter(self, record): + # use regex to find float numbers and round them to specified precision + if not isinstance(record.msg, str): + record.msg = str(record.msg) + if record.msg != "": + if re.search(r"([-+]?\d+\.\d+)", record.msg): + record.msg = re.sub(r"([-+]?\d+\.\d+)", self.str_round, + record.msg) + return True + + +def update_logger(cfg, clear_before_add=False): + root_logger = logging.getLogger("federatedscope") + + # clear all existing handlers and add the default stream + if clear_before_add: + root_logger.handlers = [] + handler = logging.StreamHandler() + fmt = "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s" + handler.setFormatter(CustomFormatter(fmt)) + + root_logger.addHandler(handler) + + # update level + if cfg.verbose > 0: + logging_level = logging.INFO + else: + logging_level = logging.WARN + root_logger.warning("Skip DEBUG/INFO messages") + root_logger.setLevel(logging_level) + + # ================ create outdir to save log, exp_config, models, etc,. + if cfg.outdir == "": + cfg.outdir = os.path.join(os.getcwd(), "exp") + if cfg.expname == "": + cfg.expname = f"{cfg.federate.method}_{cfg.model.type}_on" \ + f"_{cfg.data.type}_lr{cfg.train.optimizer.lr}_lste" \ + f"p{cfg.train.local_update_steps}" + if cfg.expname_tag: + cfg.expname = f"{cfg.expname}_{cfg.expname_tag}" + cfg.outdir = os.path.join(cfg.outdir, cfg.expname) + + # if exist, make directory with given name and time + if os.path.isdir(cfg.outdir) and os.path.exists(cfg.outdir): + outdir = os.path.join(cfg.outdir, "sub_exp" + + datetime.now().strftime('_%Y%m%d%H%M%S') + ) # e.g., sub_exp_20220411030524 + while os.path.exists(outdir): + time.sleep(1) + outdir = os.path.join( + cfg.outdir, + "sub_exp" + datetime.now().strftime('_%Y%m%d%H%M%S')) + cfg.outdir = outdir + # if not, make directory with given name + os.makedirs(cfg.outdir) + + # create file handler which logs even debug messages + fh = logging.FileHandler(os.path.join(cfg.outdir, 'exp_print.log')) + fh.setLevel(logging.DEBUG) + logger_formatter = logging.Formatter( + "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s") + fh.setFormatter(logger_formatter) + root_logger.addHandler(fh) + + # set print precision for terse logging + np.set_printoptions(precision=cfg.print_decimal_digits) + precision_filter = LoggerPrecisionFilter(cfg.print_decimal_digits) + # attach the filter to the fh handler to propagate the filter, since + # "Filters, unlike levels and handlers, do not propagate", + # ref https://stackoverflow.com/questions/6850798/why-doesnt-filter- + # attached-to-the-root-logger-propagate-to-descendant-loggers + for handler in root_logger.handlers: + handler.addFilter(precision_filter) + + import socket + root_logger.info(f"the current machine is at" + f" {socket.gethostbyname(socket.gethostname())}") + root_logger.info(f"the current dir is {os.getcwd()}") + root_logger.info(f"the output dir is {cfg.outdir}") + + if cfg.wandb.use: + import sys + sys.stderr = sys.stdout # make both stderr and stdout sent to wandb + # server + init_wandb(cfg) + + +def init_wandb(cfg): + try: + import wandb + # on some linux machines, we may need "thread" init to avoid memory + # leakage + os.environ["WANDB_START_METHOD"] = "thread" + except ImportError: + logger.error("cfg.wandb.use=True but not install the wandb package") + exit() + dataset_name = cfg.data.type + method_name = cfg.federate.method + exp_name = cfg.expname + + tmp_cfg = copy.deepcopy(cfg) + if tmp_cfg.is_frozen(): + tmp_cfg.defrost() + tmp_cfg.clear_aux_info( + ) # in most cases, no need to save the cfg_check_funcs via wandb + tmp_cfg.de_arguments() + import yaml + cfg_yaml = yaml.safe_load(tmp_cfg.dump()) + + wandb.init(project=cfg.wandb.name_project, + entity=cfg.wandb.name_user, + config=cfg_yaml, + group=dataset_name, + job_type=method_name, + name=exp_name, + notes=f"{method_name}, {exp_name}") + + +def logfile_2_wandb_dict(exp_log_f, raw_out=True): + """ + parse the logfiles [exp_print.log, eval_results.log] into + wandb_dict that contains non-nested dicts + + :param exp_log_f: opened exp_log file + :param raw_out: True indicates "exp_print.log", otherwise indicates + "eval_results.log", + the difference is whether contains the logger header such as + "2022-05-02 16:55:02,843 (client:197) INFO:" + + :return: tuple including (all_log_res, exp_stop_normal, last_line, + log_res_best) + """ + log_res_best = {} + exp_stop_normal = False + all_log_res = [] + last_line = None + for line in exp_log_f: + last_line = line + exp_stop_normal, log_res = logline_2_wandb_dict( + exp_stop_normal, line, log_res_best, raw_out) + if "'Role': 'Server #'" in line: + all_log_res.append(log_res) + return all_log_res, exp_stop_normal, last_line, log_res_best + + +def logline_2_wandb_dict(exp_stop_normal, line, log_res_best, raw_out): + log_res = {} + if "INFO:" in line and "Find new best result for" in line: + # Logger type 1, each line for each metric, e.g., + # 2022-03-22 10:48:42,562 (server:459) INFO: Find new best result + # for client_best_individual.test_acc with value 0.5911787974683544 + line = line.split("INFO: ")[1] + parse_res = line.split("with value") + best_key, best_val = parse_res[-2], parse_res[-1] + # client_best_individual.test_acc -> client_best_individual/test_acc + best_key = best_key.replace("Find new best result for", + "").replace(".", "/") + log_res_best[best_key.strip()] = float(best_val.strip()) + + if "Find new best result:" in line: + # each line for all metric of a role, e.g., + # Find new best result: {'Client #1': {'val_loss': + # 132.9812364578247, 'test_total': 36, 'test_avg_loss': + # 3.709533585442437, 'test_correct': 2.0, 'test_loss': + # 133.54320907592773, 'test_acc': 0.05555555555555555, 'val_total': + # 36, 'val_avg_loss': 3.693923234939575, 'val_correct': 4.0, + # 'val_acc': 0.1111111111111111}} + line = line.replace("Find new best result: ", "").replace("\'", "\"") + res = json.loads(s=line) + for best_type_key, val in res.items(): + for inner_key, inner_val in val.items(): + log_res_best[f"best_{best_type_key}/{inner_key}"] = inner_val + + if "'Role'" in line: + if raw_out: + line = line.split("INFO: ")[1] + res = line.replace("\'", "\"") + res = json.loads(s=res) + # pre-process the roles + cur_round = res['Round'] + if "Server" in res['Role']: + if cur_round != "Final" and 'Results_raw' in res: + res.pop('Results_raw') + role = res.pop('Role') + # parse the k-v pairs + for key, val in res.items(): + if not isinstance(val, dict): + log_res[f"{role}, {key}"] = val + else: + if cur_round != "Final": + if key == "Results_raw": + for key_inner, val_inner in res["Results_raw"].items(): + log_res[f"{role}, {key_inner}"] = val_inner + else: + for key_inner, val_inner in val.items(): + assert not isinstance(val_inner, dict), \ + "Un-expected log format" + log_res[f"{role}, {key}/{key_inner}"] = val_inner + else: + exp_stop_normal = True + if key == "Results_raw": + for final_type, final_type_dict in res[ + "Results_raw"].items(): + for inner_key, inner_val in final_type_dict.items( + ): + log_res_best[ + f"{final_type}/{inner_key}"] = inner_val + return exp_stop_normal, log_res diff --git a/fgssl/core/auxiliaries/metric_builder.py b/fgssl/core/auxiliaries/metric_builder.py new file mode 100644 index 0000000..ddbcbed --- /dev/null +++ b/fgssl/core/auxiliaries/metric_builder.py @@ -0,0 +1,21 @@ +import logging +import federatedscope.register as register + +logger = logging.getLogger(__name__) + +try: + from federatedscope.contrib.metrics import * +except ImportError as error: + logger.warning( + f'{error} in `federatedscope.contrib.metrics`, some modules are not ' + f'available.') + + +def get_metric(types): + metrics = dict() + for func in register.metric_dict.values(): + res = func(types) + if res is not None: + name, metric = res + metrics[name] = metric + return metrics diff --git a/fgssl/core/auxiliaries/model_builder.py b/fgssl/core/auxiliaries/model_builder.py new file mode 100644 index 0000000..5221591 --- /dev/null +++ b/fgssl/core/auxiliaries/model_builder.py @@ -0,0 +1,164 @@ +import logging + +import numpy as np + +import federatedscope.register as register + +logger = logging.getLogger(__name__) + +try: + from federatedscope.contrib.model import * +except ImportError as error: + logger.warning( + f'{error} in `federatedscope.contrib.model`, some modules are not ' + f'available.') + + +def get_shape_from_data(data, model_config, backend='torch'): + """ + Extract the input shape from the given data, which can be used to build + the data. Users can also use `data.input_shape` to specify the shape + Arguments: + data (`ClientData`): the data used for local training or evaluation + The expected data format: + 1): {train/val/test: {x:ndarray, y:ndarray}}} + 2): {train/val/test: DataLoader} + Returns: + shape (tuple): the input shape + """ + # Handle some special cases + if model_config.type.lower() in ['vmfnet', 'hmfnet']: + return data['train'].n_col if model_config.type.lower( + ) == 'vmfnet' else data['train'].n_row + elif model_config.type.lower() in [ + 'gcn', 'sage', 'gpr', 'gat', 'gin', 'mpnn' + ] or model_config.type.startswith('gnn_'): + num_label = data['num_label'] if 'num_label' in data else None + num_edge_features = data['data'][ + 'num_edge_features'] if model_config.type == 'mpnn' else None + if model_config.task.startswith('graph'): + # graph-level task + data_representative = next(iter(data['train'])) + return (data_representative.x.shape, num_label, num_edge_features) + else: + # node/link-level task + return (data['data'].x.shape, num_label, num_edge_features) + + if isinstance(data, dict): + keys = list(data.keys()) + if 'test' in keys: + key_representative = 'test' + elif 'val' in keys: + key_representative = 'val' + elif 'train' in keys: + key_representative = 'train' + elif 'data' in keys: + key_representative = 'data' + else: + key_representative = keys[0] + logger.warning(f'We chose the key {key_representative} as the ' + f'representative key to extract data shape.') + + data_representative = data[key_representative] + else: + # Handle the data with non-dict format + data_representative = data + + if isinstance(data_representative, dict): + if 'x' in data_representative: + shape = data_representative['x'].shape + if len(shape) == 1: # (batch, ) = (batch, 1) + return 1 + else: + return shape + elif backend == 'torch': + import torch + if issubclass(type(data_representative), torch.utils.data.DataLoader): + x, _ = next(iter(data_representative)) + return x.shape + else: + try: + x, _ = data_representative + return x.shape + except: + raise TypeError('Unsupported data type.') + elif backend == 'tensorflow': + # TODO: Handle more tensorflow type here + shape = data_representative['x'].shape + if len(shape) == 1: # (batch, ) = (batch, 1) + return 1 + else: + return shape + + +def get_model(model_config, local_data=None, backend='torch'): + """ + Arguments: + local_data (object): the model to be instantiated is + responsible for the given data. + Returns: + model (torch.Module): the instantiated model. + """ + if local_data is not None: + input_shape = get_shape_from_data(local_data, model_config, backend) + else: + input_shape = model_config.input_shape + + if input_shape is None: + logger.warning('The input shape is None. Please specify the ' + '`data.input_shape`(a tuple) or give the ' + 'representative data to `get_model` if necessary') + + for func in register.model_dict.values(): + model = func(model_config, input_shape) + if model is not None: + return model + + if model_config.type.lower() == 'lr': + if backend == 'torch': + from federatedscope.core.lr import LogisticRegression + model = LogisticRegression(in_channels=input_shape[-1], + class_num=model_config.out_channels) + elif backend == 'tensorflow': + from federatedscope.cross_backends import LogisticRegression + model = LogisticRegression(in_channels=input_shape[-1], + class_num=1, + use_bias=model_config.use_bias) + else: + raise ValueError + + elif model_config.type.lower() == 'mlp': + from federatedscope.core.mlp import MLP + model = MLP(channel_list=[input_shape[-1]] + [model_config.hidden] * + (model_config.layer - 1) + [model_config.out_channels], + dropout=model_config.dropout) + + elif model_config.type.lower() == 'quadratic': + from federatedscope.tabular.model import QuadraticModel + model = QuadraticModel(input_shape[-1], 1) + + elif model_config.type.lower() in ['convnet2', 'convnet5', 'vgg11', 'lr']: + from federatedscope.cv.model import get_cnn + model = get_cnn(model_config, input_shape) + elif model_config.type.lower() in ['lstm']: + from federatedscope.nlp.model import get_rnn + model = get_rnn(model_config, input_shape) + elif model_config.type.lower().endswith('transformers'): + from federatedscope.nlp.model import get_transformer + model = get_transformer(model_config, input_shape) + elif model_config.type.lower() in [ + 'gcn', 'sage', 'gpr', 'gat', 'gin', 'mpnn' + ]: + from federatedscope.gfl.model import get_gnn + model = get_gnn(model_config, input_shape) + elif model_config.type.lower() in ['vmfnet', 'hmfnet']: + from federatedscope.mf.model.model_builder import get_mfnet + model = get_mfnet(model_config, input_shape) + else: + raise ValueError('Model {} is not provided'.format(model_config.type)) + + return model + + +def get_trainable_para_names(model): + return set(dict(list(model.named_parameters())).keys()) diff --git a/fgssl/core/auxiliaries/optimizer_builder.py b/fgssl/core/auxiliaries/optimizer_builder.py new file mode 100644 index 0000000..bd6d1bd --- /dev/null +++ b/fgssl/core/auxiliaries/optimizer_builder.py @@ -0,0 +1,48 @@ +import copy +import logging +import federatedscope.register as register + +logger = logging.getLogger(__name__) + +try: + import torch +except ImportError: + torch = None + +try: + from federatedscope.contrib.optimizer import * +except ImportError as error: + logger.warning( + f'{error} in `federatedscope.contrib.optimizer`, some modules are not ' + f'available.') + + +def get_optimizer(model, type, lr, **kwargs): + if torch is None: + return None + # in case of users have not called the cfg.freeze() + tmp_kwargs = copy.deepcopy(kwargs) + if '__help_info__' in tmp_kwargs: + del tmp_kwargs['__help_info__'] + if '__cfg_check_funcs__' in tmp_kwargs: + del tmp_kwargs['__cfg_check_funcs__'] + if 'is_ready_for_run' in tmp_kwargs: + del tmp_kwargs['is_ready_for_run'] + + for func in register.optimizer_dict.values(): + optimizer = func(model, type, lr, **tmp_kwargs) + if optimizer is not None: + return optimizer + + if isinstance(type, str): + if hasattr(torch.optim, type): + if isinstance(model, torch.nn.Module): + return getattr(torch.optim, type)(model.parameters(), lr, + **tmp_kwargs) + else: + return getattr(torch.optim, type)(model, lr, **tmp_kwargs) + else: + raise NotImplementedError( + 'Optimizer {} not implement'.format(type)) + else: + raise TypeError() diff --git a/fgssl/core/auxiliaries/regularizer_builder.py b/fgssl/core/auxiliaries/regularizer_builder.py new file mode 100644 index 0000000..75af98c --- /dev/null +++ b/fgssl/core/auxiliaries/regularizer_builder.py @@ -0,0 +1,30 @@ +from federatedscope.register import regularizer_dict +from federatedscope.core.regularizer.proximal_regularizer import * +try: + from torch.nn import Module +except ImportError: + Module = object + + +def get_regularizer(type): + if type is None or type == '': + return DummyRegularizer() + + for func in regularizer_dict.values(): + regularizer = func(type) + if regularizer is not None: + return regularizer() + + raise NotImplementedError( + "Regularizer {} is not implemented.".format(type)) + + +class DummyRegularizer(Module): + """Dummy regularizer that only returns zero. + + """ + def __init__(self): + super(DummyRegularizer, self).__init__() + + def forward(self, ctx): + return 0. diff --git a/fgssl/core/auxiliaries/sampler_builder.py b/fgssl/core/auxiliaries/sampler_builder.py new file mode 100644 index 0000000..0b7d2ff --- /dev/null +++ b/fgssl/core/auxiliaries/sampler_builder.py @@ -0,0 +1,20 @@ +import logging + +from federatedscope.core.sampler import UniformSampler, GroupSampler + +logger = logging.getLogger(__name__) + + +def get_sampler(sample_strategy='uniform', + client_num=None, + client_info=None, + bins=10): + if sample_strategy == 'uniform': + return UniformSampler(client_num=client_num) + elif sample_strategy == 'group': + return GroupSampler(client_num=client_num, + client_info=client_info, + bins=bins) + else: + raise ValueError( + f"The sample strategy {sample_strategy} has not been provided.") diff --git a/fgssl/core/auxiliaries/scheduler_builder.py b/fgssl/core/auxiliaries/scheduler_builder.py new file mode 100644 index 0000000..afc7a5f --- /dev/null +++ b/fgssl/core/auxiliaries/scheduler_builder.py @@ -0,0 +1,34 @@ +import logging +import federatedscope.register as register + +logger = logging.getLogger(__name__) + +try: + import torch +except ImportError: + torch = None + +try: + from federatedscope.contrib.scheduler import * +except ImportError as error: + logger.warning( + f'{error} in `federatedscope.contrib.scheduler`, some modules are not ' + f'available.') + + +def get_scheduler(optimizer, type, **kwargs): + for func in register.scheduler_dict.values(): + scheduler = func(optimizer, type) + if scheduler is not None: + return scheduler + + if torch is None or type == '': + return None + if isinstance(type, str): + if hasattr(torch.optim.lr_scheduler, type): + return getattr(torch.optim.lr_scheduler, type)(optimizer, **kwargs) + else: + raise NotImplementedError( + 'Scheduler {} not implement'.format(type)) + else: + raise TypeError() diff --git a/fgssl/core/auxiliaries/splitter_builder.py b/fgssl/core/auxiliaries/splitter_builder.py new file mode 100644 index 0000000..fe2a138 --- /dev/null +++ b/fgssl/core/auxiliaries/splitter_builder.py @@ -0,0 +1,49 @@ +import logging +import federatedscope.register as register + +logger = logging.getLogger(__name__) + + +def get_splitter(config): + client_num = config.federate.client_num + if config.data.splitter_args: + kwargs = config.data.splitter_args[0] + else: + kwargs = {} + + for func in register.splitter_dict.values(): + splitter = func(client_num, **kwargs) + if splitter is not None: + return splitter + # Delay import + # generic splitter + if config.data.splitter == 'lda': + from federatedscope.core.splitters.generic import LDASplitter + splitter = LDASplitter(client_num, **kwargs) + # graph splitter + elif config.data.splitter == 'louvain': + from federatedscope.core.splitters.graph import LouvainSplitter + splitter = LouvainSplitter(client_num, **kwargs) + elif config.data.splitter == 'random': + from federatedscope.core.splitters.graph import RandomSplitter + splitter = RandomSplitter(client_num, **kwargs) + elif config.data.splitter == 'rel_type': + from federatedscope.core.splitters.graph import RelTypeSplitter + splitter = RelTypeSplitter(client_num, **kwargs) + elif config.data.splitter == 'scaffold': + from federatedscope.core.splitters.graph import ScaffoldSplitter + splitter = ScaffoldSplitter(client_num, **kwargs) + elif config.data.splitter == 'scaffold_lda': + from federatedscope.core.splitters.graph import ScaffoldLdaSplitter + splitter = ScaffoldLdaSplitter(client_num, **kwargs) + elif config.data.splitter == 'rand_chunk': + from federatedscope.core.splitters.graph import RandChunkSplitter + splitter = RandChunkSplitter(client_num, **kwargs) + elif config.data.splitter == 'iid': + from federatedscope.core.splitters.generic import IIDSplitter + splitter = IIDSplitter(client_num) + else: + logger.warning(f'Splitter {config.data.splitter} not found or not ' + f'used.') + splitter = None + return splitter diff --git a/fgssl/core/auxiliaries/trainer_builder.py b/fgssl/core/auxiliaries/trainer_builder.py new file mode 100644 index 0000000..41d3ffe --- /dev/null +++ b/fgssl/core/auxiliaries/trainer_builder.py @@ -0,0 +1,157 @@ +import logging +import importlib + +import federatedscope.register as register + +logger = logging.getLogger(__name__) + +try: + from federatedscope.contrib.trainer import * +except ImportError as error: + logger.warning( + f'{error} in `federatedscope.contrib.trainer`, some modules are not ' + f'available.') + +TRAINER_CLASS_DICT = { + "cvtrainer": "CVTrainer", + "nlptrainer": "NLPTrainer", + "graphminibatch_trainer": "GraphMiniBatchTrainer", + "linkfullbatch_trainer": "LinkFullBatchTrainer", + "linkminibatch_trainer": "LinkMiniBatchTrainer", + "nodefullbatch_trainer": "NodeFullBatchTrainer", + "nodeminibatch_trainer": "NodeMiniBatchTrainer", + "flitplustrainer": "FLITPlusTrainer", + "flittrainer": "FLITTrainer", + "fedvattrainer": "FedVATTrainer", + "fedfocaltrainer": "FedFocalTrainer", + "mftrainer": "MFTrainer", +} + + +def get_trainer(model=None, + data=None, + device=None, + config=None, + only_for_eval=False, + is_attacker=False, + monitor=None): + if config.trainer.type == 'general': + if config.backend == 'torch': + from federatedscope.core.trainers import GeneralTorchTrainer + trainer = GeneralTorchTrainer(model=model, + data=data, + device=device, + config=config, + only_for_eval=only_for_eval, + monitor=monitor) + elif config.backend == 'tensorflow': + from federatedscope.core.trainers.tf_trainer import \ + GeneralTFTrainer + trainer = GeneralTFTrainer(model=model, + data=data, + device=device, + config=config, + only_for_eval=only_for_eval, + monitor=monitor) + else: + raise ValueError + elif config.trainer.type == 'none': + return None + elif config.trainer.type.lower() in TRAINER_CLASS_DICT: + if config.trainer.type.lower() in ['cvtrainer']: + dict_path = "federatedscope.cv.trainer.trainer" + elif config.trainer.type.lower() in ['nlptrainer']: + dict_path = "federatedscope.nlp.trainer.trainer" + elif config.trainer.type.lower() in [ + 'graphminibatch_trainer', + ]: + dict_path = "federatedscope.gfl.trainer.graphtrainer" + elif config.trainer.type.lower() in [ + 'linkfullbatch_trainer', 'linkminibatch_trainer' + ]: + dict_path = "federatedscope.gfl.trainer.linktrainer" + elif config.trainer.type.lower() in [ + 'nodefullbatch_trainer', 'nodeminibatch_trainer' + ]: + dict_path = "federatedscope.gfl.trainer.nodetrainer" + elif config.trainer.type.lower() in [ + 'flitplustrainer', 'flittrainer', 'fedvattrainer', + 'fedfocaltrainer' + ]: + dict_path = "federatedscope.gfl.flitplus.trainer" + elif config.trainer.type.lower() in ['mftrainer']: + dict_path = "federatedscope.mf.trainer.trainer" + else: + raise ValueError + + trainer_cls = getattr(importlib.import_module(name=dict_path), + TRAINER_CLASS_DICT[config.trainer.type.lower()]) + trainer = trainer_cls(model=model, + data=data, + device=device, + config=config, + only_for_eval=only_for_eval, + monitor=monitor) + else: + # try to find user registered trainer + trainer = None + for func in register.trainer_dict.values(): + trainer_cls = func(config.trainer.type) + if trainer_cls is not None: + trainer = trainer_cls(model=model, + data=data, + device=device, + config=config, + only_for_eval=only_for_eval, + monitor=monitor) + if trainer is None: + raise ValueError('Trainer {} is not provided'.format( + config.trainer.type)) + + # differential privacy plug-in + if config.nbafl.use: + from federatedscope.core.trainers import wrap_nbafl_trainer + trainer = wrap_nbafl_trainer(trainer) + if config.sgdmf.use: + from federatedscope.mf.trainer import wrap_MFTrainer + trainer = wrap_MFTrainer(trainer) + + # personalization plug-in + if config.federate.method.lower() == "pfedme": + from federatedscope.core.trainers import wrap_pFedMeTrainer + # wrap style: instance a (class A) -> instance a (class A) + trainer = wrap_pFedMeTrainer(trainer) + elif config.federate.method.lower() == "ditto": + from federatedscope.core.trainers import wrap_DittoTrainer + # wrap style: instance a (class A) -> instance a (class A) + trainer = wrap_DittoTrainer(trainer) + elif config.federate.method.lower() == "fedem": + from federatedscope.core.trainers import FedEMTrainer + # copy construct style: instance a (class A) -> instance b (class B) + trainer = FedEMTrainer(model_nums=config.model.model_num_per_trainer, + base_trainer=trainer) + + # attacker plug-in + if 'backdoor' in config.attack.attack_method: + from federatedscope.attack.trainer import wrap_benignTrainer + trainer = wrap_benignTrainer(trainer) + + if is_attacker: + if 'backdoor' in config.attack.attack_method: + logger.info('--------This client is a backdoor attacker --------') + else: + logger.info('-------- This client is an privacy attacker --------') + from federatedscope.attack.auxiliary.attack_trainer_builder \ + import wrap_attacker_trainer + trainer = wrap_attacker_trainer(trainer, config) + + elif 'backdoor' in config.attack.attack_method: + logger.info( + '----- This client is a benign client for backdoor attacks -----') + + # fed algorithm plug-in + if config.fedprox.use: + from federatedscope.core.trainers import wrap_fedprox_trainer + trainer = wrap_fedprox_trainer(trainer) + + return trainer diff --git a/fgssl/core/auxiliaries/transform_builder.py b/fgssl/core/auxiliaries/transform_builder.py new file mode 100644 index 0000000..6cd1d81 --- /dev/null +++ b/fgssl/core/auxiliaries/transform_builder.py @@ -0,0 +1,54 @@ +from importlib import import_module +import federatedscope.register as register + + +def get_transform(config, package): + r""" + + Args: + config: `CN` from `federatedscope/core/configs/config.py` + package: one of package from ['torchvision', 'torch_geometric', + 'torchtext', 'torchaudio'] + + Returns: + dict of transform functions. + + """ + transform_funcs = {} + for name in ['transform', 'target_transform', 'pre_transform']: + if config.data[name]: + transform_funcs[name] = config.data[name] + + # Transform are all None, do not import package and return dict with + # None value + if not transform_funcs: + return transform_funcs + + transforms = getattr(import_module(package), 'transforms') + + def convert(trans): + # Recursively converting expressions to functions + if isinstance(trans[0], str): + if len(trans) == 1: + trans.append({}) + transform_type, transform_args = trans + for func in register.transform_dict.values(): + transform_func = func(transform_type, transform_args) + if transform_func is not None: + return transform_func + transform_func = getattr(transforms, + transform_type)(**transform_args) + return transform_func + else: + transform = [convert(x) for x in trans] + if hasattr(transforms, 'Compose'): + return transforms.Compose(transform) + elif hasattr(transforms, 'Sequential'): + return transforms.Sequential(transform) + else: + return transform + + # return composed transform or return list of transform + for key in transform_funcs: + transform_funcs[key] = convert(config.data[key]) + return transform_funcs diff --git a/fgssl/core/auxiliaries/utils.py b/fgssl/core/auxiliaries/utils.py new file mode 100644 index 0000000..190a961 --- /dev/null +++ b/fgssl/core/auxiliaries/utils.py @@ -0,0 +1,305 @@ +import collections +import json +import logging +import math +import os +import random +import signal +import ssl +import urllib.request +from os import path as osp +import pickle + +import numpy as np + +# Blind torch +try: + import torch + import torchvision + import torch.distributions as distributions +except ImportError: + torch = None + torchvision = None + distributions = None + +logger = logging.getLogger(__name__) + + +def setup_seed(seed): + np.random.seed(seed) + random.seed(seed) + if torch is not None: + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + torch.backends.cudnn.deterministic = True + else: + import tensorflow as tf + tf.set_random_seed(seed) + + +def get_dataset(type, root, transform, target_transform, download=True): + if isinstance(type, str): + if hasattr(torchvision.datasets, type): + return getattr(torchvision.datasets, + type)(root=root, + transform=transform, + target_transform=target_transform, + download=download) + else: + raise NotImplementedError('Dataset {} not implement'.format(type)) + else: + raise TypeError() + + +def save_local_data(dir_path, + train_data=None, + train_targets=None, + test_data=None, + test_targets=None, + val_data=None, + val_targets=None): + r""" + https://github.com/omarfoq/FedEM/blob/main/data/femnist/generate_data.py + + save (`train_data`, `train_targets`) in {dir_path}/train.pt, + (`val_data`, `val_targets`) in {dir_path}/val.pt + and (`test_data`, `test_targets`) in {dir_path}/test.pt + :param dir_path: + :param train_data: + :param train_targets: + :param test_data: + :param test_targets: + :param val_data: + :param val_targets + """ + if (train_data is not None) and (train_targets is not None): + torch.save((train_data, train_targets), osp.join(dir_path, "train.pt")) + + if (test_data is not None) and (test_targets is not None): + torch.save((test_data, test_targets), osp.join(dir_path, "test.pt")) + + if (val_data is not None) and (val_targets is not None): + torch.save((val_data, val_targets), osp.join(dir_path, "val.pt")) + + +def filter_by_specified_keywords(param_name, filter_keywords): + ''' + Arguments: + param_name (str): parameter name. + Returns: + preserve (bool): whether to preserve this parameter. + ''' + preserve = True + for kw in filter_keywords: + if kw in param_name: + preserve = False + break + return preserve + + +def get_random(type, sample_shape, params, device): + if not hasattr(distributions, type): + raise NotImplementedError("Distribution {} is not implemented, " + "please refer to ```torch.distributions```" + "(https://pytorch.org/docs/stable/ " + "distributions.html).".format(type)) + generator = getattr(distributions, type)(**params) + return generator.sample(sample_shape=sample_shape).to(device) + + +def batch_iter(data, batch_size=64, shuffled=True): + assert 'x' in data and 'y' in data + data_x = data['x'] + data_y = data['y'] + data_size = len(data_y) + num_batches_per_epoch = math.ceil(data_size / batch_size) + + while True: + shuffled_index = np.random.permutation( + np.arange(data_size)) if shuffled else np.arange(data_size) + for batch in range(num_batches_per_epoch): + start_index = batch * batch_size + end_index = min(data_size, (batch + 1) * batch_size) + sample_index = shuffled_index[start_index:end_index] + yield {'x': data_x[sample_index], 'y': data_y[sample_index]} + + +def merge_dict(dict1, dict2): + # Merge results for history + for key, value in dict2.items(): + if key not in dict1: + if isinstance(value, dict): + dict1[key] = merge_dict({}, value) + else: + dict1[key] = [value] + else: + if isinstance(value, dict): + merge_dict(dict1[key], value) + else: + dict1[key].append(value) + return dict1 + + +def download_url(url: str, folder='folder'): + r"""Downloads the content of an url to a folder. + + Modified from `https://github.com/pyg-team/pytorch_geometric/blob/master + /torch_geometric/data/download.py` + + Args: + url (string): The url of target file. + folder (string): The target folder. + + Returns: + path (string): File path of downloaded files. + """ + + file = url.rpartition('/')[2] + file = file if file[0] == '?' else file.split('?')[0] + path = osp.join(folder, file) + if osp.exists(path): + logger.info(f'File {file} exists, use existing file.') + return path + + logger.info(f'Downloading {url}') + os.makedirs(folder, exist_ok=True) + ctx = ssl._create_unverified_context() + data = urllib.request.urlopen(url, context=ctx) + with open(path, 'wb') as f: + f.write(data.read()) + + return path + + +def move_to(obj, device): + import torch + if torch.is_tensor(obj): + return obj.to(device) + elif isinstance(obj, dict): + res = {} + for k, v in obj.items(): + res[k] = move_to(v, device) + return res + elif isinstance(obj, list): + res = [] + for v in obj: + res.append(move_to(v, device)) + return res + else: + raise TypeError("Invalid type for move_to") + + +def param2tensor(param): + import torch + if isinstance(param, list): + param = torch.FloatTensor(param) + elif isinstance(param, int): + param = torch.tensor(param, dtype=torch.long) + elif isinstance(param, float): + param = torch.tensor(param, dtype=torch.float) + return param + + +class Timeout(object): + def __init__(self, seconds, max_failure=5): + self.seconds = seconds + self.max_failure = max_failure + + def __enter__(self): + def signal_handler(signum, frame): + raise TimeoutError() + + if self.seconds > 0: + signal.signal(signal.SIGALRM, signal_handler) + signal.alarm(self.seconds) + return self + + def __exit__(self, exc_type, exc_value, traceback): + signal.alarm(0) + + def reset(self): + signal.alarm(self.seconds) + + def block(self): + signal.alarm(0) + + def exceed_max_failure(self, num_failure): + return num_failure > self.max_failure + + +def format_log_hooks(hooks_set): + def format_dict(target_dict): + print_dict = collections.defaultdict(list) + for k, v in target_dict.items(): + for element in v: + print_dict[k].append(element.__name__) + return print_dict + + if isinstance(hooks_set, list): + print_obj = [format_dict(_) for _ in hooks_set] + elif isinstance(hooks_set, dict): + print_obj = format_dict(hooks_set) + return json.dumps(print_obj, indent=2).replace('\n', '\n\t') + + +def get_resource_info(filename): + if filename is None or not os.path.exists(filename): + logger.info('The device information file is not provided') + return None + + # Users can develop this loading function according to resource_info_file + # As an example, we use the device_info provided by FedScale (FedScale: + # Benchmarking Model and System Performance of Federated Learning + # at Scale), which can be downloaded from + # https://github.com/SymbioticLab/FedScale/blob/master/benchmark/dataset/ + # data/device_info/client_device_capacity The expected format is + # { INDEX:{'computation': FLOAT_VALUE_1, 'communication': FLOAT_VALUE_2}} + with open(filename, 'br') as f: + device_info = pickle.load(f) + return device_info + + +def calculate_time_cost(instance_number, + comm_size, + comp_speed=None, + comm_bandwidth=None, + augmentation_factor=3.0): + # Served as an example, this cost model is adapted from FedScale at + # https://github.com/SymbioticLab/FedScale/blob/master/fedscale/core/ + # internal/client.py#L35 (Apache License Version 2.0) + # Users can modify this function according to customized cost model + if comp_speed is not None and comm_bandwidth is not None: + comp_cost = augmentation_factor * instance_number * comp_speed + comm_cost = 2.0 * comm_size / comm_bandwidth + else: + comp_cost = 0 + comm_cost = 0 + + return comp_cost, comm_cost + + +def calculate_batch_epoch_num(steps, batch_or_epoch, num_data, batch_size, + drop_last): + num_batch_per_epoch = num_data // batch_size + int( + not drop_last and bool(num_data % batch_size)) + if num_batch_per_epoch == 0: + raise RuntimeError( + "The number of batch is 0, please check 'batch_size' or set " + "'drop_last' as False") + elif batch_or_epoch == "epoch": + num_epoch = steps + num_batch_last_epoch = num_batch_per_epoch + num_total_batch = steps * num_batch_per_epoch + else: + num_epoch = math.ceil(steps / num_batch_per_epoch) + num_batch_last_epoch = steps % num_batch_per_epoch or \ + num_batch_per_epoch + num_total_batch = steps + return num_batch_per_epoch, num_batch_last_epoch, num_epoch, \ + num_total_batch + + +def merge_param_dict(raw_param, filtered_param): + for key in filtered_param.keys(): + raw_param[key] = filtered_param[key] + return raw_param diff --git a/fgssl/core/auxiliaries/worker_builder.py b/fgssl/core/auxiliaries/worker_builder.py new file mode 100644 index 0000000..f0f94e3 --- /dev/null +++ b/fgssl/core/auxiliaries/worker_builder.py @@ -0,0 +1,109 @@ +import logging + +from federatedscope.core.configs import constants +from federatedscope.core.workers import Server, Client +import federatedscope.register as register + +logger = logging.getLogger(__name__) + +try: + from federatedscope.contrib.worker import * +except ImportError as error: + logger.warning( + f'{error} in `federatedscope.contrib.worker`, some modules are not ' + f'available.') + + +def get_client_cls(cfg): + for func in register.worker_dict.values(): + worker_class = func(cfg.federate.method.lower()) + if worker_class is not None: + return worker_class['client'] + + if cfg.hpo.fedex.use: + from federatedscope.autotune.fedex import FedExClient + return FedExClient + + if cfg.vertical.use: + from federatedscope.vertical_fl.worker import vFLClient + return vFLClient + + if cfg.federate.method.lower() in constants.CLIENTS_TYPE: + client_type = constants.CLIENTS_TYPE[cfg.federate.method.lower()] + else: + client_type = "normal" + logger.warning( + 'Clients for method {} is not implemented. Will use default one'. + format(cfg.federate.method)) + + if client_type == 'fedsageplus': + from federatedscope.gfl.fedsageplus.worker import FedSagePlusClient + client_class = FedSagePlusClient + elif client_type == 'gcflplus': + from federatedscope.gfl.gcflplus.worker import GCFLPlusClient + client_class = GCFLPlusClient + else: + client_class = Client + + # add attack related method to client_class + + if cfg.attack.attack_method.lower() in constants.CLIENTS_TYPE: + client_atk_type = constants.CLIENTS_TYPE[ + cfg.attack.attack_method.lower()] + else: + client_atk_type = None + + if client_atk_type == 'gradascent': + from federatedscope.attack.worker_as_attacker.active_client import \ + add_atk_method_to_Client_GradAscent + logger.info("=========== add method to current client class ") + client_class = add_atk_method_to_Client_GradAscent(client_class) + return client_class + + +def get_server_cls(cfg): + for func in register.worker_dict.values(): + worker_class = func(cfg.federate.method.lower()) + if worker_class is not None: + return worker_class['server'] + + if cfg.hpo.fedex.use: + from federatedscope.autotune.fedex import FedExServer + return FedExServer + + if cfg.attack.attack_method.lower() in ['dlg', 'ig']: + from federatedscope.attack.worker_as_attacker.server_attacker import\ + PassiveServer + return PassiveServer + elif cfg.attack.attack_method.lower() in ['passivepia']: + from federatedscope.attack.worker_as_attacker.server_attacker import\ + PassivePIAServer + return PassivePIAServer + + elif cfg.attack.attack_method.lower() in ['backdoor']: + from federatedscope.attack.worker_as_attacker.server_attacker \ + import BackdoorServer + return BackdoorServer + + if cfg.vertical.use: + from federatedscope.vertical_fl.worker import vFLServer + return vFLServer + + if cfg.federate.method.lower() in constants.SERVER_TYPE: + server_type = constants.SERVER_TYPE[cfg.federate.method.lower()] + else: + server_type = "normal" + logger.warning( + 'Server for method {} is not implemented. Will use default one'. + format(cfg.federate.method)) + + if server_type == 'fedsageplus': + from federatedscope.gfl.fedsageplus.worker import FedSagePlusServer + server_class = FedSagePlusServer + elif server_type == 'gcflplus': + from federatedscope.gfl.gcflplus.worker import GCFLPlusServer + server_class = GCFLPlusServer + else: + server_class = Server + + return server_class diff --git a/fgssl/core/cmd_args.py b/fgssl/core/cmd_args.py new file mode 100644 index 0000000..61b886f --- /dev/null +++ b/fgssl/core/cmd_args.py @@ -0,0 +1,47 @@ +import argparse +import sys +from federatedscope.core.configs.config import global_cfg + + +def parse_args(args=None): + parser = argparse.ArgumentParser(description='FederatedScope', + add_help=False) + parser.add_argument('--cfg', + dest='cfg_file', + help='Config file path', + required=False, + type=str) + parser.add_argument('--client_cfg', + dest='client_cfg_file', + help='Config file path for clients', + required=False, + default=None, + type=str) + parser.add_argument( + '--help', + nargs="?", + const="all", + default="", + ) + parser.add_argument('opts', + help='See federatedscope/core/configs for all options', + default=None, + nargs=argparse.REMAINDER) + parse_res = parser.parse_args(args) + init_cfg = global_cfg.clone() + # when users type only "main.py" or "main.py help" + if len(sys.argv) == 1 or parse_res.help == "all": + parser.print_help() + init_cfg.print_help() + sys.exit(1) + elif hasattr(parse_res, "help") and isinstance( + parse_res.help, str) and parse_res.help != "": + init_cfg.print_help(parse_res.help) + sys.exit(1) + elif hasattr(parse_res, "help") and isinstance( + parse_res.help, list) and len(parse_res.help) != 0: + for query in parse_res.help: + init_cfg.print_help(query) + sys.exit(1) + + return parse_res diff --git a/fgssl/core/communication.py b/fgssl/core/communication.py new file mode 100644 index 0000000..dbeb3e2 --- /dev/null +++ b/fgssl/core/communication.py @@ -0,0 +1,147 @@ +import grpc +from concurrent import futures + +from federatedscope.core.configs.config import global_cfg +from federatedscope.core.proto import gRPC_comm_manager_pb2, \ + gRPC_comm_manager_pb2_grpc +from federatedscope.core.gRPC_server import gRPCComServeFunc +from federatedscope.core.message import Message + + +class StandaloneCommManager(object): + """ + The communicator used for standalone mode + """ + def __init__(self, comm_queue, monitor=None): + self.comm_queue = comm_queue + self.neighbors = dict() + self.monitor = monitor # used to track the communication related + # metrics + + def receive(self): + # we don't need receive() in standalone + pass + + def add_neighbors(self, neighbor_id, address=None): + self.neighbors[neighbor_id] = address + + def get_neighbors(self, neighbor_id=None): + address = dict() + if neighbor_id: + if isinstance(neighbor_id, list): + for each_neighbor in neighbor_id: + address[each_neighbor] = self.get_neighbors(each_neighbor) + return address + else: + return self.neighbors[neighbor_id] + else: + # Get all neighbors + return self.neighbors + + def send(self, message): + self.comm_queue.append(message) + download_bytes, upload_bytes = message.count_bytes() + self.monitor.track_upload_bytes(upload_bytes) + + +class gRPCCommManager(object): + """ + The implementation of gRPCCommManager is referred to the tutorial on + https://grpc.io/docs/languages/python/ + """ + def __init__(self, host='0.0.0.0', port='50050', client_num=2): + self.host = host + self.port = port + options = [ + ("grpc.max_send_message_length", + global_cfg.distribute.grpc_max_send_message_length), + ("grpc.max_receive_message_length", + global_cfg.distribute.grpc_max_receive_message_length), + ("grpc.enable_http_proxy", + global_cfg.distribute.grpc_enable_http_proxy), + ] + self.server_funcs = gRPCComServeFunc() + self.grpc_server = self.serve(max_workers=client_num, + host=host, + port=port, + options=options) + self.neighbors = dict() + self.monitor = None # used to track the communication related metrics + + def serve(self, max_workers, host, port, options): + """ + This function is referred to + https://grpc.io/docs/languages/python/basics/#starting-the-server + """ + server = grpc.server( + futures.ThreadPoolExecutor(max_workers=max_workers), + options=options) + gRPC_comm_manager_pb2_grpc.add_gRPCComServeFuncServicer_to_server( + self.server_funcs, server) + server.add_insecure_port("{}:{}".format(host, port)) + server.start() + + return server + + def add_neighbors(self, neighbor_id, address): + if isinstance(address, dict): + self.neighbors[neighbor_id] = '{}:{}'.format( + address['host'], address['port']) + elif isinstance(address, str): + self.neighbors[neighbor_id] = address + else: + raise TypeError(f"The type of address ({type(address)}) is not " + "supported yet") + + def get_neighbors(self, neighbor_id=None): + address = dict() + if neighbor_id: + if isinstance(neighbor_id, list): + for each_neighbor in neighbor_id: + address[each_neighbor] = self.get_neighbors(each_neighbor) + return address + else: + return self.neighbors[neighbor_id] + else: + # Get all neighbors + return self.neighbors + + def _send(self, receiver_address, message): + def _create_stub(receiver_address): + """ + This part is referred to + https://grpc.io/docs/languages/python/basics/#creating-a-stub + """ + channel = grpc.insecure_channel(receiver_address, + options=(('grpc.enable_http_proxy', + 0), )) + stub = gRPC_comm_manager_pb2_grpc.gRPCComServeFuncStub(channel) + return stub, channel + + stub, channel = _create_stub(receiver_address) + request = message.transform(to_list=True) + try: + stub.sendMessage(request) + except grpc._channel._InactiveRpcError: + pass + channel.close() + + def send(self, message): + receiver = message.receiver + if receiver is not None: + if not isinstance(receiver, list): + receiver = [receiver] + for each_receiver in receiver: + if each_receiver in self.neighbors: + receiver_address = self.neighbors[each_receiver] + self._send(receiver_address, message) + else: + for each_receiver in self.neighbors: + receiver_address = self.neighbors[each_receiver] + self._send(receiver_address, message) + + def receive(self): + received_msg = self.server_funcs.receive() + message = Message() + message.parse(received_msg.msg) + return message diff --git a/fgssl/core/configs/README.md b/fgssl/core/configs/README.md new file mode 100644 index 0000000..8150482 --- /dev/null +++ b/fgssl/core/configs/README.md @@ -0,0 +1,397 @@ +## Configurations +We summarize all the customizable configurations: +- [cfg_data.py](#data) +- [cfg_model.py](#model) +- [cfg_fl_algo.py](#federated-algorithms) +- [cfg_training.py](#federated-training) +- [cfg_fl_setting.py](#fl-setting) +- [cfg_evaluation.py](#evaluation) +- [cfg_asyn.py](#asynchronous-training-strategies) +- [cfg_differential_privacy.py](#differential-privacy) +- [cfg_hpo.py](#auto-tuning-components) +- [cfg_attack.py](#attack) + +### Data +The configurations related to the data/dataset are defined in `cfg_data.py`. + +| Name | (Type) Default Value | Description | Note | +|:--------------------------------------------:|:-----:|:---------- |:--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| +| `data.root` | (string) 'data' | The folder where the data file located. `data.root` would be used together with `data.type` to load the dataset. | - | +| `data.type` | (string) 'toy' | Dataset name | CV: 'femnist', 'celeba' ; NLP: 'shakespeare', 'subreddit', 'twitter'; Graph: 'cora', 'citeseer', 'pubmed', 'dblp_conf', 'dblp_org', 'csbm', 'epinions', 'ciao', 'fb15k-237', 'wn18', 'fb15k' , 'MUTAG', 'BZR', 'COX2', 'DHFR', 'PTC_MR', 'AIDS', 'NCI1', 'ENZYMES', 'DD', 'PROTEINS', 'COLLAB', 'IMDB-BINARY', 'IMDB-MULTI', 'REDDIT-BINARY', 'IMDB-BINARY', 'IMDB-MULTI', 'HIV', 'ESOL', 'FREESOLV', 'LIPO', 'PCBA', 'MUV', 'BACE', 'BBBP', 'TOX21', 'TOXCAST', 'SIDER', 'CLINTOX', 'graph_multi_domain_mol', 'graph_multi_domain_small', 'graph_multi_domain_mix', 'graph_multi_domain_biochem'; MF: 'vflmovielens1m', 'vflmovielens10m', 'hflmovielens1m', 'hflmovielens10m', 'vflnetflix', 'hflnetflix'; Tabular: 'toy', 'synthetic'; External dataset: 'DNAME@torchvision', 'DNAME@torchtext', 'DNAME@huggingface_datasets', 'DNAME@openml'. | +| `data.args` | (list) [] | Args for the external dataset | Used for external dataset, eg. `[{'download': False}]` | +| `data.save_data` | (bool) False | Whether to save the generated toy data | - | +| `data.splitter` | (string) '' | Splitter name for standalone dataset | Generic splitter: 'lda'; Graph splitter: 'louvain', 'random', 'rel_type', 'graph_type', 'scaffold', 'scaffold_lda', 'rand_chunk' | +| `data.splitter_args` | (list) [] | Args for splitter. | Used for splitter, eg. `[{'alpha': 0.5}]` | +| `data.transform` | (list) [] | Transform for x of data | Used in `get_item` in torch.dataset, eg. `[['ToTensor'], ['Normalize', {'mean': [0.1307], 'std': [0.3081]}]]` | +| `data.target_transform` | (list) [] | Transform for y of data | Use as `data.transform` | +| `data.pre_transform` | (list) [] | Pre_transform for `torch_geometric` dataset | Use as `data.transform` | +| `dataloader.batch_size` | (int) 64 | batch_size for DataLoader | - | +| `dataloader.drop_last` | (bool) False | Whether drop last batch (if the number of last batch is smaller than batch_size) in DataLoader | - | +| `dataloader.sizes` | (list) [10, 5] | Sample size for graph DataLoader | The length of `dataloader.sizes` must meet the layer of GNN models. | +| `dataloader.shuffle` | (bool) True | Shuffle train DataLoader | - | +| `data.server_holds_all` | (bool) False | Only use in global mode, whether the server (workers with idx 0) holds all data, useful in global training/evaluation case | - | +| `data.subsample` | (float) 1.0 |  Only used in LEAF datasets, subsample clients from all clients | - | +| `data.splits` | (list) [0.8, 0.1, 0.1] | Train, valid, test splits | - | +| `data.`
`consistent_label_distribution` | (bool) False | Make label distribution of train/val/test set over clients keep consistent during splitting | - | +| `data.cSBM_phi` | (list) [0.5, 0.5, 0.5] | Phi for cSBM graph dataset | - | +| `data.loader` | (string) '' | Graph sample name, used in minibatch trainer | 'graphsaint-rw': use `GraphSAINTRandomWalkSampler` as DataLoader; 'neighbor': use `NeighborSampler` as DataLoader. | +| `dataloader.num_workers` | (int) 0 | num_workers in DataLoader | - | +| `dataloader.walk_length` | (int) 2 | The length of each random walk in graphsaint. | - | +| `dataloader.num_steps` | (int) 30 | The number of iterations per epoch in graphsaint. | - | +| `data.quadratic.dim` | (int) 1 | Dim of synthetic quadratic  dataset | - | +| `data.quadratic.min_curv` | (float) 0.02 | Min_curve of synthetic quadratic dataset | - | +| `data.quadratic.max_curv` | (float) 12.5 | Max_cur of synthetic quadratic dataset | - | + + +### Model + +The configurations related to the model are defined in `cfg_model.py`. +| [General](#model-general) | [Criterion](#criterion) | [Regularization](#regularizer) | + +#### Model-General +| Name | (Type) Default Value | Description | Note | +|:--------------------------:|:--------------------:|:------------------------------------------------:|:--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------:| +| `model.`
`model_num_per_trainer` | (int) 1 | Number of model per trainer | some methods may leverage more | +| `model.type` | (string) 'lr' | The model name used in FL | CV: 'convnet2', 'convnet5', 'vgg11', 'lr'; NLP: 'LSTM', 'MODEL@transformers'; Graph: 'gcn', 'sage', 'gpr', 'gat', 'gin', 'mpnn'; Tabular: 'mlp', 'lr', 'quadratic'; MF: 'vmfnet', 'hmfnet' | +| `model.use_bias` | (bool) True | Whether use bias in lr model | - | +| `model.task` | (string) 'node' | The task type of model, the default is `Classification` | NLP: 'PreTraining', 'QuestionAnswering', 'SequenceClassification', 'TokenClassification', 'Auto', 'WithLMHead'; Graph: 'NodeClassification', 'NodeRegression', 'LinkClassification', 'LinkRegression', 'GraphClassification', 'GraphRegression', | +| `model.hidden` | (int) 256 | Hidden layer dimension | - | +| `model.dropout` | (float) 0.5 | Dropout ratio | - | +| `model.in_channels` | (int) 0 | Input channels dimension | If 0, model will be built by `data.shape` | +| `model.out_channels` | (int) 1 | Output channels dimension | - | +| `model.layer` | (int) 2 | Model layer | - | +| `model.graph_pooling` | (string) 'mean' | Graph pooling method in graph-level task | 'add', 'mean' or 'max' | +| `model.embed_size` | (int) 8 | `embed_size` in LSTM | - | +| `model.num_item` | (int) 0 | Number of items in MF. | It will be overwritten by the real value of the dataset. | +| `model.num_user` | (int) 0 | Number of users in MF. | It will be overwritten by the real value of the dataset. | + +#### Criterion + +| Name | (Type) Default Value | Description | Note | +|:--------------------------:|:--------------------:|:------------------------------------------------:|:-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------:| +| `criterion.type` | (string) 'MSELoss' | Criterion type | Chosen from https://pytorch.org/docs/stable/nn.html#loss-functions , eg. 'CrossEntropyLoss', 'L1Loss', etc. | + +#### Regularizer + +| Name | (Type) Default Value | Description | Note | +|:--------------------------:|:--------------------:|:------------------------------------------------:|:--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------:| +| `regularizer.type` | (string) ' ' | The type of the regularizer | Chosen from [`proximal_regularizer`] | +| `regularizer.mu` | (float) 0 | The factor that controls the loss of the regularization term | - | + + +### Federated Algorithms +The configurations related to specific federated algorithms, which are defined in `cfg_fl_algo.py`. + +| [FedOPT](#fedopt-for-fedopt-algorithm) | [FedProx](#fedprox-for-fedprox-algorithm) | [personalization](#personalization-for-personalization-algorithms) | [fedsageplus](#fedsageplus-for-fedsageplus-algorithm) | [gcflplus](#gcflplus-for-gcflplus-algorithm) | [flitplus](#flitplus-for-flitplus-algorithm) | + +#### `fedopt`: for FedOpt algorithm +| Name | (Type) Default Value | Description | Note | +|:----:|:-----:|:---------- |:---- | +| `fedopt.use` | (bool) False | Whether to run FL courses with FedOpt algorithm. | If False, all the related configurations (cfg.fedopt.xxx) would not take effect. | +| `fedopt.optimizer.type` | (string) 'SGD' | The type of optimizer used for FedOpt algorithm. | Currently we support all optimizers build in PyTorch (The modules under torch.optim). | +| `fedopt.optimizer.lr` | (float) 0.1 | The learning rate used in for FedOpt optimizer. | - | +#### `fedprox`: for FedProx algorithm +| Name | (Type) Default Value | Description | Note | +|:----:|:-----:|:---------- |:---- | +| `fedprox.use` | (bool) False | Whether to run FL courses with FedProx algorithm. | If False, all the related configurations (cfg.fedprox.xxx) would not take effect. | +| `fedprox.mu` | (float) 0.0 | The hyper-parameter $\mu$ used in FedProx algorithm. | - | +#### `personalization`: for personalization algorithms +| Name | (Type) Default Value | Description | Note | +|:----:|:-----:|:---------- |:---- | +| `personalization.local_param` | (list of str) [] | The client-distinct local param names, e.g., ['pre', 'bn'] | - | +| `personalization.`
`share_non_trainable_para` | (bool) False | Whether transmit non-trainable parameters between FL participants | - | +| `personalization.`
`local_update_steps` | (int) -1 | The local training steps for personalized models | By default, -1 indicates that the local model steps will be set to be the same as the valid `train.local_update_steps` | +| `personalization.regular_weight` | (float) 0.1 | The regularization factor used for model para regularization methods such as Ditto and pFedMe. | The smaller the regular_weight is, the stronger emphasising on personalized model. | +| `personalization.lr` | (float) 0.0 | The personalized learning rate used in personalized FL algorithms. | The default value 0.0 indicates that the value will be set to be the same as `train.optimizer.lr` in case of users have not specify a valid `personalization.lr` | +| `personalization.K` | (int) 5 | The local approximation steps for pFedMe. | - | +| `personalization.beta` | (float) 5 | The average moving parameter for pFedMe. | - | +#### `fedsageplus`: for fedsageplus algorithm +| Name | (Type) Default Value | Description | Note | +|:----:|:-----:|:---------- |:---- | +| `fedsageplus.num_pred` | (int) 5 | Number of nodes generated by the generator | - | +| `fedsageplus.gen_hidden` | (int) 128 | Hidden layer dimension of generator | - | +| `fedsageplus.hide_portion` | (float) 0.5 | Hide graph portion | - | +| `fedsageplus.fedgen_epoch` | (int) 200 | Federated training round for generator | - | +| `fedsageplus.loc_epoch` | (int) 1 | Local pre-train round for generator | - | +| `fedsageplus.a` | (float) 1.0 | Coefficient for criterion number of missing node | - | +| `fedsageplus.b` | (float) 1.0 | Coefficient for criterion feature | - | +| `fedsageplus.c` | (float) 1.0 | Coefficient for criterion classification | - | +#### `gcflplus`: for gcflplus algorithm +| Name | (Type) Default Value | Description | Note | +|:----:|:-----:|:---------- |:---- | +| `gcflplus.EPS_1` | (float) 0.05 | Bound for mean_norm | - | +| `gcflplus.EPS_2` | (float) 0.1 | Bound for max_norm | - | +| `gcflplus.seq_length` | (int) 5 | Length of the gradient sequence | - | +| `gcflplus.standardize` | (bool) False | Whether standardized dtw_distances | - | +#### `flitplus`: for flitplus algorithm +| Name | (Type) Default Value | Description | Note | +|:----:|:-----:|:---------- |:---- | +| `flitplus.tmpFed` | (float) 0.5 | gamma in focal loss (Eq.4) | - | +| `flitplus.lambdavat` | (float) 0.5 | lambda in phi (Eq.10) | - | +| `flitplus.factor_ema` | (float) 0.8 | beta in omega (Eq.12) | - | +| `flitplus.weightReg` | (float) 1.0 | balance lossLocalLabel and lossLocalVAT | - | + + +### Federated training +The configurations related to federated training are defined in `cfg_training.py`. +Considering it's infeasible to list all the potential arguments for optimizers and schedulers, we allow the users to add new parameters directly under the corresponding namespace. +For example, we haven't defined the argument `train.optimizer.weight_decay` in `cfg_training.py`, but the users are allowed directly use it. +If the optimizer doesn't require the argument named `weight_decay`, an error will be raised. + +| [Local Training](#local-training) | [Finetune](#fine-tuning) | [Grad Clipping](#grad-clipping) | [Early Stop](#early-stop) | + +#### Local training +The following configurations are related to the local training. + +| Name | (Type) Default Value | Description | Note | +|:--------------------------:|:--------------------:|:------------------------------------------------:|:------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------:| +| `train.local_update_steps` | (int) 1 | The number of local training steps. | - | +| `train.batch_or_epoch` | (string) 'batch' | The type of local training. | `train.batch_or_epoch` specifies the unit that `train.local_update_steps` adopts. All new parameters will be used as arguments for the chosen optimizer. | +| `train.optimizer` | - | - | You can add new parameters under `train.optimizer` according to the optimizer, e.g., you can set momentum by `cfg.train.optimizer.momentum`. | +| `train.optimizer.type` | (string) 'SGD' | The type of optimizer used in local training. | Currently we support all optimizers build in PyTorch (The modules under `torch.optim`). | +| `train.optimizer.lr` | (float) 0.1 | The learning rate used in the local training. | - | +| `train.scheduler` | - | - | Similar with `train.optimizer`, you can add new parameters as you need, e.g., `train.scheduler.step_size=10`. All new parameters will be used as arguments for the chosen scheduler. | +| `train.scheduler.type` | (string) '' | The type of the scheduler used in local training | Currently we support all schedulers build in PyTorch (The modules under `torch.optim.lr_scheduler`). | + +#### Fine tuning +The following configurations are related to the fine tuning. + +| Name | (Type) Default Value | Description | Note | +|:--------------------------:|:--------------------:|:------------------------------------------------:|:--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------:| +| `finetune.before_eval` | (bool) False | Indicator of fintune before evaluation | If `True`, the clients will fine tune its model before each evaluation. Note the fine tuning is only conducted before evaluation and won't influence the upload weights in each round. | +| `finetune.local_update_steps` | (int) 1 | The number of local fine tune steps | - | +| `finetune.batch_or_epoch` | (string) `batch` | The type of local fine tuning. | Similar with `train.batch_or_epoch`, `finetune.batch_or_epoch` specifies the unit of `finetune.local_update_steps` | +| `finetune.optimizer` | - | - | You can add new parameters under `finetune.optimizer` according to the type of optimizer. All new parameters will be used as arguments for the chosen optimizer. | +| `finetune.optimizer.type` | (string) 'SGD' | The type of the optimizer used in fine tuning. | Currently we support all optimizers build in PyTorch (The modules under `torch.optim`). | +| `finetune.optimizer.lr` | (float) 0.1 | The learning rate used in local fine tuning | - | +| `finetune.scheduler` | - | - | Similar with `train.scheduler`, you can add new parameters as you need, and all new parameters will be used as arguments for the chosen scheduler. | + +#### Grad Clipping +The following configurations are related to the grad clipping. + +| Name | (Type) Default Value | Description | Note | +|:--------------------------:|:--------------------:|:------------------------------------------------:|:--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------:| +| `grad.grad_clip` | (float) -1.0 | The threshold used in gradient clipping. | `grad.grad_clip < 0` means we don't clip the gradient. | + +#### Early Stop + +| Name | (Type) Default Value | Description | Note | +|:----------------------------------------:|:--------------------:|:------------------------------------------------:|:-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------:| +| `early_stop.patience` | (int) 5 | How long to wait after last time the monitored metric improved. | Note that the actual_checking_round = `early_step.patience` * `eval.freq`. To disable the early stop, set the `early_stop.patience` <=0 | +| `early_stop.delta` | (float) 0. | Minimum change in the monitored metric to indicate a improvement. | - | +| `early_stop.improve_indicaator_mode` | (string) 'best' | Early stop when there is no improvement within the last `early_step.patience` rounds, in ['mean', 'best'] | Chosen from 'mean' or 'best' | +| `early_step.the_smaller_the_better` | (bool) True | The optimized direction of the chosen metric | - | + + +### FL Setting +The configurations related to FL settings are defined in `cfg_fl_setting.py`. + +| [General](#federate-general-fl-setting) | [Distribute](#distribute-for-distribute-mode) | [Vertical](#vertical-for-vertical-federated-learning) | + +#### `federate`: general fl setting +| Name | (Type) Default Value | Description | Note | +|:----:|:-----:|:---------- |:---- | +| `federate.client_num` | (int) 0 | The number of clients that involves in the FL courses. | It can set to 0 to automatically specify by the partition of dataset. | +| `federate.sample_client_num` | (int) -1 | The number of sampled clients in each training round. | - | +| `federate.sample_client_rate` | (float) -1.0 | The ratio of sampled clients in each training round. | - | +| `federate.unseen_clients_rate` | (float) 0.0 | The ratio of clients served as unseen clients, which would not be used for training and only for evaluation. | - | +| `federate.total_round_num` | (int) 50 | The maximum training round number of the FL course. | - | +| `federate.mode` | (string) 'standalone'
Choices: {'standalone', 'distributed'} | The running mode of the FL course. | - | +| `federate.share_local_model` | (bool) False | If `True`, only one model object is created in the FL course and shared among clients for efficient simulation. | - | +| `federate.data_weighted_aggr` | (bool) False | If `True`, the weight of aggregator is the number of training samples in dataset. | - | +| `federate.online_aggr` | (bool) False | If `True`, an online aggregation mechanism would be applied for efficient simulation. | - | +| `federate.make_global_eval` | (bool) False | If `True`, the evaluation would be performed on the server's test data, otherwise each client would perform evaluation on local test set and the results would be merged. | - | +| `federate.use_diff` | (bool) False | If `True`, the clients would return the variation in local training (i.e., $\delta$) instead of the updated models to the server for federated aggregation. | - | +| `federate.merge_test_data` | (bool) False | If `True`, clients' test data would be merged and perform global evaluation for efficient simulation. | - | +| `federate.method` | (string) 'FedAvg' | The method used for federated aggregation. | We support existing federated aggregation algorithms (such as 'FedAvg/FedOpt'), 'global' (centralized training), 'local' (isolated training), personalized algorithms ('Ditto/pFedMe/FedEM'), and allow developer to customize. | +| `federate.ignore_weight` | (bool) False | If `True`, the model updates would be averaged in federated aggregation. | - | +| `federate.use_ss` | (bool) False | If `True`, additively secret sharing would be applied in the FL course. | Only used in vanilla FedAvg in this version. | +| `federate.restore_from` | (string) '' | The checkpoint file to restore the model. | - | +| `federate.save_to` | (string) '' | The path to save the model. | - | +| `federate.join_in_info` | (list of string) [] | The information requirements (from server) for joining in the FL course. | We support 'num_sample/client_resource' and allow user customization. +| `federate.sampler` | (string) 'uniform'
Choices: {'uniform', 'group'} | The sample strategy of server used for client selection in a training round. | - | +| `federate.`
`resource_info_file` | (string) '' | the device information file to record computation and communication ability | - | +#### `distribute`: for distribute mode +| Name | (Type) Default Value | Description | Note | +|:----:|:-----:|:---------- |:---- | +| `distribute.use` | (bool) False | Whether to run FL courses with distribute mode. | If `False`, all the related configurations (`cfg.distribute.xxx`) would not take effect. | +| `distribute.server_host` | (string) '0.0.0.0' | The host of server's ip address for communication | - | +| `distribute.server_port` | (string) 50050 | The port of server's ip address for communication | - | +| `distribute.client_host` | (string) '0.0.0.0' | The host of client's ip address for communication | - | +| `distribute.client_port` | (string) 50050 | The port of client's ip address for communication | - | +| `distribute.role` | (string) 'client'
Choices: {'server', 'client'} | The role of the worker | - | +| `distribute.data_file` | (string) 'data' | The path to the data dile | - | +| `distribute.data_idx` | (int) -1 | It is used to specify the data index in distributed mode when adopting a centralized dataset for simulation (formatted as {data_idx: data/dataloader}). | `data_idx=-1` means that the entire dataset is owned by the participant. And we randomly sample the index in simulation for other invalid values excepted for -1. +| `distribute.`
`grpc_max_send_message_length` | (int) 100 * 1024 * 1024 | The maximum length of sent messages | - | +| `distribute.`
`grpc_max_receive_message_length` | (int) 100 * 1024 * 1024 | The maximum length of received messages | - | +| `distribute.grpc_enable_http_proxy` | (bool) False | Whether to enable http proxy | - | +#### `vertical`: for vertical federated learning +| Name | (Type) Default Value | Description | Note | +|:----:|:-----:|:---------- |:---- | +| `vertical.use` | (bool) False | Whether to run vertical FL. | If `False`, all the related configurations (`cfg.vertical.xxx`) would not take effect. | +| `vertical.encryption` | (string) `paillier` | The encryption algorithms used in vertical FL. | - | +| `vertical.dims` | (list of int) [5,10] | The dimensions of the input features for participants. | - | +| `vertical.key_size` | (int) 3072 | The length (bit) of the public keys. | - | + + +### Evaluation +The configurations related to monitoring and evaluation, which are adefined in `cfg_evaluation.py`. + +| [General](#evaluation-general) | [WandB](#wandb-for-wandb-tracking-and-visualization) | + +#### Evaluation General +| Name | (Type) Default Value | Description | Note | +|:----:|:-----:|:---------- |:---- | +| `eval.freq` | (int) 1 | The frequency we conduct evaluation. | - | +| `eval.metrics` | (list of str) [] | The names of adopted evaluation metrics. | By default, we calculate the ['loss', 'avg_loss', 'total'], all the supported metric can be find in `core/monitors/metric_calculator.py` | +| `eval.split` | (list of str) ['test', 'val'] | The data splits' names we conduct evaluation. | - | +| `eval.report` | (list of str) ['weighted_avg', 'avg', 'fairness', 'raw'] | The results reported forms to loggers | By default, we report comprehensive results, - `weighted_avg` and `avg` indicate the weighted average and uniform average over all evaluated clients; - `fairness` indicates report fairness-related results such as individual performance and std across all evaluated clients; - `raw` indicates that we save and compress all clients' individual results without summarization, and users can flexibly post-process the saved results further.| +| `eval.`
`best_res_update_round_wise_key` | (str) 'val_loss' | The metric name we used to as the primary key to check the performance improvement at each evaluation round. | - | +| `eval.monitoring` | (list of str) [] | Extended monitoring methods or metric, e.g., 'dissim' for B-local dissimilarity | - | +| `eval.count_flops` | (bool) True | Whether to count the flops during the FL courses. | - | +#### `wandb`: for wandb tracking and visualization +| Name | (Type) Default Value | Description | Note | +|:----:|:-----:|:---------- |:---- | +| `wandb.use` | (bool) False | Whether to use wandb to track and visualize the FL dynamics and results. | If `False`, all the related configurations (`wandb.xxx`) would not take effect. | +| `wandb.name_user` | (str) '' | the user name used in wandb management | - | +| `wandb.name_project` | (str) '' | the project name used in wandb management | - | +| `wandb.online_track` | (bool) True | whether to track the results in an online manner, i.e., log results at every evaluation round | - | +| `wandb.client_train_info` | (bool) True | whether to track the training info of clients | - | + + +### Asynchronous Training Strategies +The configurations related to applying asynchronous training strategies in FL are defined in `cfg_asyn.py`. + +| Name | (Type) Default Value | Description | Note | +|:----:|:-----:|:---------- |:---- | +| `asyn.use` | (bool) False | Whether to use asynchronous training strategies. | If `False`, all the related configurations (`cfg.asyn.xxx`) would not take effect. | +| `asyn.time_budget` | (int/float) 0 | The predefined time budget (seconds) for each training round. | `time_budget`<=0 means the time budget is not applied. | +| `asyn.min_received_num` | (int) 2 | The minimal number of received feedback for the server to trigger federated aggregation. | - | +| `asyn.min_received_rate` | (float) -1.0 | The minimal ratio of received feedback w.r.t. the sampled clients for the server to trigger federated aggregation. | - | +| `asyn.staleness_toleration` | (int) 0 | The threshold of the tolerable staleness in federated aggregation. | - | +| `asyn.`
`staleness_discount_factor` | (float) 1.0 | The discount factor for the staled feedback in federated aggregation. | - | +| `asyn.aggregator` | (string) 'goal_achieved'
Choices: {'goal_achieved', 'time_up'} | The condition for federated aggregation. | 'goal_achieved': perform aggregation when the defined number of feedback has been received; 'time_up': perform aggregation when the allocated time budget has been run out. | +| `asyn.broadcast_manner` | (string) 'after_aggregating'
Choices: {'after_aggregating', 'after_receiving'} | The broadcasting manner of server. | 'after_aggregating': broadcast the up-to-date global model after performing federated aggregation; 'after_receiving': broadcast the up-to-date global model after receiving the model update from clients. | +| `asyn.overselection` | (bool) False | Whether to use the overselection technique | - | + + +### Differential Privacy +| [NbAFL](#nbafl) | [SGDMF](#sgdmf) | + +#### NbAFL +The configurations related to NbAFL method. + +| Name | (Type) Default Value | Description | Note | +|:----:|:--------------------:|:-------------------------------------------|:-----| +| `nbafl.use` | (bool) False | The indicator of the NbAFL method. | - | +| `nbafl.mu` | (float) 0. | The argument $\mu$ in NbAFL. | - | +| `nbafl.epsilon` | (float) 100. | The $\epsilon$-DP guarantee used in NbAFL. | - | +| `nbafl.w_clip` | (float) 1. | The threshold used for weight clipping. | - | +| `nbafl.constant` | (float) 30. | The constant used in NbAFL. | - | + +#### SGDMF +The configurations related to SGDMF method (only used in matrix factorization tasks). + +| Name | (Type) Default Value | Description | Note | +|:------------------:|:--------------------:|:-----------------------------------|:--------------------------------------------------------| +| `sgdmf.use` | (bool) False | The indicator of the SGDMF method. | | +| `sgdmf.R` | (float) 5. | The upper bound of rating. | - | +| `sgdmf.epsilon` | (float) 4. | The $\epsilon$ used in DP. | - | +| `sgdmf.delta` | (float) 0.5 | The $\delta$ used in DP. | - | +| `sgdmf.constant` | (float) 1. | The constant in SGDMF | - | +| `dagaloader.theta` | (int) -1 | - | -1 means per-rating privacy, otherwise per-user privacy | + + +### Auto-tuning Components + +These arguments are exposed for customizing our provided auto-tuning components. + +| [General](#auto-tunning-general) | [SHA](#successive-halving-algorithm-sha) | [FedEx](#fedex) | [Wrappers for FedEx](#wrappers-for-fedex) | + +#### Auto-tunning General + +| Name | (Type) Default Value | Description | Note | +|:----:|:--------------------:|:-------------------------------------------|:-----| +| `hpo.working_folder` | (string) 'hpo' | Save model checkpoints and search space configurations to this folder. | Trials in the next stage of an iterative HPO algorithm can restore from the checkpoints of their corresponding last trials. | +| `hpo.ss` | (string) 'hpo' | File path of the .yaml that specifying the search space. | - | +| `hpo.num_workers` | (int) 0 | The number of threads to concurrently attempt different hyperparameter configurations. | Multi-threading is banned in current version. | +| `hpo.init_cand_num` | (int) 16 | The number of initial hyperparameter configurations sampled from the search space. | - | +| `hpo.larger_better` | (bool) False | The indicator of whether the larger metric is better. | - | +| `hpo.scheduler` | (string) 'rs'
Choices: {'rs', 'sha', 'wrap_sha'} | Which algorithm to use. | - | +| `hpo.metric` | (string) 'client_summarized_weighted_avg.val_loss' | Metric to be optimized. | - | + +#### Successive Halving Algorithm (SHA) + +| Name | (Type) Default Value | Description | Note | +|:----:|:--------------------:|:-------------------------------------------|:-----| +| `hpo.sha.elim_rate` | (int) 3 | Reserve only top 1/`hpo.sha.elim_rate` hyperparameter configurations in each state. | - | +| `hpo.sha.budgets` | (list of int) [] | Budgets for each SHA stage. | - | + + +#### FedEx + +| Name | (Type) Default Value | Description | Note | +|:----:|:--------------------:|:-------------------------------------------|:-----| +| `hpo.fedex.use` | (bool) False | Whether to use FedEx. | - | +| `hpo.fedex.ss` | (striing) '' | Path of the .yaml specifying the search space to be explored. | - | +| `hpo.fedex.flatten_ss` | (bool) True | Whether the search space has been flattened. | - | +| `hpo.fedex.eta0` | (float) -1.0 | Initial learning rate. | -1.0 means automatically determine the learning rate based on the size of search space. | +| `hpo.fedex.sched` | (string) 'auto'
Choices: {'auto', 'adaptive', 'aggressive', 'constant', 'scale' } | The strategy to update step sizes | - | +| `hpo.fedex.cutoff` | (float) 0.0 | The entropy level below which to stop updating the config. | - | +| `hpo.fedex.gamma` | (float) 0.0 | The discount factor; 0.0 is most recent, 1.0 is mean. | - | +| `hpo.fedex.diff` | (bool) False | Whether to use the difference of validation losses before and after the local update as the reward signal. | - | + +#### Wrappers for FedEx + +| Name | (Type) Default Value | Description | Note | +|:----:|:--------------------:|:-------------------------------------------|:-----| +| `hpo.table.eps` | (float) 0.1 | The probability to make local perturbation. | Larger values lead to drastically different arms of the bandit FedEx attempts to solve. | +| `hpo.table.num` | (int) 27 | The number of arms of the bandit FedEx attempts to solve. | - | +| `hpo.table.idx` | (int) 0 | The key (i.e., name) of the hyperparameter wrapper considers. | No need to change this argument. | + + +### Attack + +The configurations related to the data/dataset are defined in `cfg_attack.py`. + +| [Privacy Attack](#for-privacy-attack) | [Back-door Attack](#for-back-door-attack) | + + +#### For Privacy Attack +| Name | (Type) Default Value | Description | Note | +|:----:|:-----:|:---------- |:---- | +`attack.attack_method` | (str) '' | Attack method name | Choices: {'gan_attack', 'GradAscent', 'PassivePIA', 'DLG', 'IG', 'backdoor'} | +`attack.target_label_ind` | (int) -1 | The target label to attack | Used in class representative attack (GAN based method) and back-door attack; defult -1 means no label to target| +`attack.attacker_id` | (int) -1 | The id of the attack client | Default -1 means no client as attacker; Used in both privacy attack and back-door attack when client is the attacker | +`attack.reconstruct_lr `| (float) 0.01 | The learning rate of the optimization based training data/label inference attack|-| +`attack.reconstruct_optim` | (str) 'Adam' | The learning rate of the optimization based training data/label inference attack|Choices: {'Adam', 'SGD', 'LBFGS'}| +`attack.info_diff_type` | (str) 'l2' | The distance to compare the ground-truth info (gradients or model updates) and the info generated by the dummy data. | Options: 'l2', 'l1', 'sim' representing L2, L1 and cosin similarity | +`attack.max_ite` | (int) 400 | The maximum iteration of the optimization based training data/label inference attack |-| +`attack.alpha_TV` | (float) 0.001 | The hyperparameter of the total variance term | Used in the mehtod invert gradint | +`attack.inject_round` | (int) 0 | The round to start performing the attack actions |-| +`attack.classifier_PIA` | (str) 'randomforest' | The property inference classifier name |-| + `attack.mia_simulate_in_round`|(int) 20 | The round to add the target data into training batch| Used When simulate the case that the target data are in the training set| + `attack. mia_is_simulate_in` | (bool) False | whether simulate the case that the target data are in the training set|| + +#### For Back-door Attack +| Name | (Type) Default Value | Description | Note | +|:----:|:-----:|:---------- |:---- | +`attack.edge_path` |(str) 'edge_data/' | The folder where the ood data used by edge-case backdoor attacks located |-| +`attack.trigger_path` |(str) 'trigger/'|The folder where the trigger pictures used by pixel-wise backdoor attacks located |-| +`attack.setting` | (str) 'fix'| The setting about how to select the attack client. |Choices:{'fix', 'single', and 'all'}, 'single' setting means the attack client can be only selected in the predefined round (cfg.attack.insert_round). 'all' setting means the attack client can be selected in all round. 'fix' setting means that the attack client can be selected every freq round. freq has beed defined in the cfg.attack.freq keyword.| +`attack.freq` | (int) 10 |This keyword is used in the 'fix' setting. The attack client can be selected every freq round.|-| +`attack.insert_round` |(int) 100000 |This keyword is used in the 'single' setting. The attack client can be only selected in the insert_round round.|-| +`attack.mean` |(list) [0.1307] |The mean value which is used in the normalization procedure of poisoning data. |Notice: The length of this list must be same as the number of channels of used dataset.| +`attack.std` |(list) [0.3081] |The std value which is used in the normalization procedure of poisoning data.|Notice: The length of this list must be same as the number of channels of used dataset.| +`attack.trigger_type`|(str) 'edge'|This keyword represents the type of used triggers|Choices: {'edge', 'gridTrigger', 'hkTrigger', 'sigTrigger', 'wanetTrigger', 'fourCornerTrigger'}| +`attack.label_type` |(str) 'dirty'| This keyword represents the type of used attack.|It contains 'dirty'-label and 'clean'-label attacks. Now, we only support 'dirty'-label attack. | +`attack.edge_num` |(int) 100 | This keyword represents the number of used good samples for edge-case attack.|-| +`attack.poison_ratio` |(float) 0.5|This keyword represents the percentage of samples with pixel-wise triggers in the local dataset of attack client|-| +`attack.scale_poisoning` |(bool) False| This keyword represents whether to use the model scaling attack for attack client. |-| +`attack.scale_para` |(float) 1.0 |This keyword represents the value to amplify the model update when conducting the model scaling attack.|-| +`attack.pgd_poisoning` |(bool) False|This keyword represents whether to use the pgd to train the local model for attack client. |-| +`attack.pgd_lr` | (float) 0.1 |This keyword represents learning rate of pgd training for attack client.|-| +`attack.pgd_eps`|(int) 2 | This keyword represents perturbation budget of pgd training for attack client.|-| +`attack.self_opt` |(bool) False |This keyword represents whether to use his own training procedure for attack client.|-| +`attack.self_lr` |(float) 0.05|This keyword represents learning rate of his own training procedure for attack client.|-| +`attack.self_epoch` |(int) 6 |This keyword represents epoch number of his own training procedure for attack client.|-| diff --git a/fgssl/core/configs/__init__.py b/fgssl/core/configs/__init__.py new file mode 100644 index 0000000..5c3259a --- /dev/null +++ b/fgssl/core/configs/__init__.py @@ -0,0 +1,29 @@ +import copy +from os.path import dirname, basename, isfile, join +import glob + +modules = glob.glob(join(dirname(__file__), "*.py")) +__all__ = [ + basename(f)[:-3] for f in modules + if isfile(f) and not f.endswith('__init__.py') +] + +# to ensure the sub-configs registered before set up the global config +all_sub_configs = copy.copy(__all__) +if "config" in all_sub_configs: + all_sub_configs.remove('config') + +from federatedscope.core.configs.config import CN, init_global_cfg +__all__ = __all__ + \ + [ + 'CN', + 'init_global_cfg' + ] + +# reorder the config to ensure the base config will be registered first +base_configs = [ + 'cfg_data', 'cfg_fl_setting', 'cfg_model', 'cfg_training', 'cfg_evaluation' +] +for base_config in base_configs: + all_sub_configs.pop(all_sub_configs.index(base_config)) + all_sub_configs.insert(0, base_config) diff --git a/fgssl/core/configs/cfg_asyn.py b/fgssl/core/configs/cfg_asyn.py new file mode 100644 index 0000000..2b3e7aa --- /dev/null +++ b/fgssl/core/configs/cfg_asyn.py @@ -0,0 +1,87 @@ +import logging + +from federatedscope.core.configs.config import CN +from federatedscope.register import register_config + + +def extend_asyn_cfg(cfg): + # ---------------------------------------------------------------------- # + # Asynchronous related options + # ---------------------------------------------------------------------- # + cfg.asyn = CN() + + cfg.asyn.use = False + cfg.asyn.time_budget = 0 + cfg.asyn.min_received_num = 2 + cfg.asyn.min_received_rate = -1.0 + cfg.asyn.staleness_toleration = 0 + cfg.asyn.staleness_discount_factor = 1.0 + cfg.asyn.aggregator = 'goal_achieved' # ['goal_achieved', 'time_up'] + # 'goal_achieved': perform aggregation when the defined number of feedback + # has been received; 'time_up': perform aggregation when the allocated + # time budget has been run out + cfg.asyn.broadcast_manner = 'after_aggregating' # ['after_aggregating', + # 'after_receiving'] 'after_aggregating': broadcast the up-to-date global + # model after performing federated aggregation; + # 'after_receiving': broadcast the up-to-date global model after receiving + # the model update from clients + cfg.asyn.overselection = False + + # --------------- register corresponding check function ---------- + cfg.register_cfg_check_fun(assert_asyn_cfg) + + +def assert_asyn_cfg(cfg): + if not cfg.asyn.use: + return True + # to ensure a valid time budget + assert isinstance(cfg.asyn.time_budget, int) or isinstance( + cfg.asyn.time_budget, float + ), "The time budget (seconds) must be an int or a float value, " \ + "but {} is got".format( + type(cfg.asyn.time_budget)) + + # min received num pre-process + min_received_num_valid = (0 < cfg.asyn.min_received_num <= + cfg.federate.sample_client_num) + min_received_rate_valid = (0 < cfg.asyn.min_received_rate <= 1) + # (a) sampling case + if min_received_rate_valid: + # (a.1) use min_received_rate + old_min_received_num = cfg.asyn.min_received_num + cfg.asyn.min_received_num = max( + 1, + int(cfg.asyn.min_received_rate * cfg.federate.sample_client_num)) + if min_received_num_valid: + logging.warning( + f"Users specify both valid min_received_rate as" + f" {cfg.asyn.min_received_rate} " + f"and min_received_num as {old_min_received_num}.\n" + f"\t\tWe will use the min_received_rate value to calculate " + f"the actual number of participated clients as" + f" {cfg.asyn.min_received_num}.") + # (a.2) use min_received_num, commented since the below two lines do not + # change anything elif min_received_rate: + # cfg.asyn.min_received_num = cfg.asyn.min_received_num + if not (min_received_num_valid or min_received_rate_valid): + # (b) non-sampling case, use all clients + cfg.asyn.min_received_num = cfg.federate.sample_client_num + + # to ensure a valid staleness toleation + assert cfg.asyn.staleness_toleration >= 0 and isinstance( + cfg.asyn.staleness_toleration, int + ), f"Please provide a valid staleness toleration value, " \ + f"expect an integer value that is larger or equal to 0, " \ + f"but got {cfg.asyn.staleness_toleration}." + + assert cfg.asyn.aggregator in ["goal_achieved", "time_up"], \ + f"Please specify the cfg.asyn.aggregator as string 'goal_achieved' " \ + f"or 'time_up'. But got {cfg.asyn.aggregator}." + assert cfg.asyn.broadcast_manner in ["after_aggregating", + "after_receiving"], \ + f"Please specify the cfg.asyn.broadcast_manner as the string " \ + f"'after_aggregating' or 'after_receiving'. " \ + f"But got {cfg.asyn.broadcast_manner}." + + +register_config("asyn", extend_asyn_cfg) diff --git a/fgssl/core/configs/cfg_attack.py b/fgssl/core/configs/cfg_attack.py new file mode 100644 index 0000000..c90e129 --- /dev/null +++ b/fgssl/core/configs/cfg_attack.py @@ -0,0 +1,66 @@ +from federatedscope.core.configs.config import CN +from federatedscope.register import register_config + + +def extend_attack_cfg(cfg): + + # ---------------------------------------------------------------------- # + # attack + # ---------------------------------------------------------------------- # + cfg.attack = CN() + cfg.attack.attack_method = '' + # for gan_attack + cfg.attack.target_label_ind = -1 + cfg.attack.attacker_id = -1 + + # for backdoor attack + + cfg.attack.edge_path = 'edge_data/' + cfg.attack.trigger_path = 'trigger/' + cfg.attack.setting = 'fix' + cfg.attack.freq = 10 + cfg.attack.insert_round = 100000 + cfg.attack.mean = [0.1307] + cfg.attack.std = [0.3081] + cfg.attack.trigger_type = 'edge' + cfg.attack.label_type = 'dirty' + # dirty, clean_label, dirty-label attack is all2one attack. + cfg.attack.edge_num = 100 + cfg.attack.poison_ratio = 0.5 + cfg.attack.scale_poisoning = False + cfg.attack.scale_para = 1.0 + cfg.attack.pgd_poisoning = False + cfg.attack.pgd_lr = 0.1 + cfg.attack.pgd_eps = 2 + cfg.attack.self_opt = False + cfg.attack.self_lr = 0.05 + cfg.attack.self_epoch = 6 + # Note: the mean and std should be the list type. + + # for reconstruct_opt + cfg.attack.reconstruct_lr = 0.01 + cfg.attack.reconstruct_optim = 'Adam' + cfg.attack.info_diff_type = 'l2' + cfg.attack.max_ite = 400 + cfg.attack.alpha_TV = 0.001 + + # for active PIA attack + cfg.attack.alpha_prop_loss = 0 + + # for passive PIA attack + cfg.attack.classifier_PIA = 'randomforest' + + # for gradient Ascent --- MIA attack + cfg.attack.inject_round = 0 + cfg.attack.mia_simulate_in_round = 20 + cfg.attack.mia_is_simulate_in = False + + # --------------- register corresponding check function ---------- + cfg.register_cfg_check_fun(assert_attack_cfg) + + +def assert_attack_cfg(cfg): + pass + + +register_config("attack", extend_attack_cfg) diff --git a/fgssl/core/configs/cfg_data.py b/fgssl/core/configs/cfg_data.py new file mode 100644 index 0000000..b4d6caf --- /dev/null +++ b/fgssl/core/configs/cfg_data.py @@ -0,0 +1,126 @@ +import logging + +from federatedscope.core.configs.config import CN +from federatedscope.register import register_config + +logger = logging.getLogger(__name__) + + +def extend_data_cfg(cfg): + # ---------------------------------------------------------------------- # + # Dataset related options + # ---------------------------------------------------------------------- # + cfg.data = CN() + + cfg.data.root = 'data' + cfg.data.type = 'toy' + cfg.data.fgcl = False + cfg.data.save_data = False # whether to save the generated toy data + cfg.data.args = [] # args for external dataset, eg. [{'download': True}] + cfg.data.splitter = '' + cfg.data.splitter_args = [] # args for splitter, eg. [{'alpha': 0.5}] + cfg.data.transform = [ + ] # transform for x, eg. [['ToTensor'], ['Normalize', {'mean': [ + # 0.1307], 'std': [0.3081]}]] + cfg.data.target_transform = [] # target_transform for y, use as above + cfg.data.pre_transform = [ + ] # pre_transform for `torch_geometric` dataset, use as above + cfg.data.server_holds_all = False # whether the server (workers with + # idx 0) holds all data, useful in global training/evaluation case + cfg.data.subsample = 1.0 + cfg.data.splits = [0.8, 0.1, 0.1] # Train, valid, test splits + cfg.data.consistent_label_distribution = False # If True, the label + # distributions of train/val/test set over clients will be kept + # consistent during splitting + cfg.data.cSBM_phi = [0.5, 0.5, 0.5] + + # DataLoader related args + cfg.dataloader = CN() + cfg.dataloader.type = 'base' + cfg.dataloader.batch_size = 64 + cfg.dataloader.shuffle = True + cfg.dataloader.num_workers = 0 + cfg.dataloader.drop_last = False + cfg.dataloader.pin_memory = True + # GFL: graphsaint DataLoader + cfg.dataloader.walk_length = 2 + cfg.dataloader.num_steps = 30 + # GFL: neighbor sampler DataLoader + cfg.dataloader.sizes = [10, 5] + # DP: -1 means per-rating privacy, otherwise per-user privacy + cfg.dataloader.theta = -1 + + # quadratic + cfg.data.quadratic = CN() + cfg.data.quadratic.dim = 1 + cfg.data.quadratic.min_curv = 0.02 + cfg.data.quadratic.max_curv = 12.5 + + # --------------- outdated configs --------------- + # TODO: delete this code block + cfg.data.loader = '' + cfg.data.batch_size = 64 + cfg.data.shuffle = True + cfg.data.num_workers = 0 + cfg.data.drop_last = False + cfg.data.walk_length = 2 + cfg.data.num_steps = 30 + cfg.data.sizes = [10, 5] + + # --------------- register corresponding check function ---------- + cfg.register_cfg_check_fun(assert_data_cfg) + + +def assert_data_cfg(cfg): + if cfg.dataloader.type == 'graphsaint-rw': + assert cfg.model.layer == cfg.dataloader.walk_length, 'Sample ' \ + 'size ' \ + 'mismatch' + if cfg.dataloader.type == 'neighbor': + assert cfg.model.layer == len( + cfg.dataloader.sizes), 'Sample size mismatch' + if '@' in cfg.data.type: + assert cfg.federate.client_num > 0, '`federate.client_num` should ' \ + 'be greater than 0 when using ' \ + 'external data' + assert cfg.data.splitter, '`data.splitter` should not be empty when ' \ + 'using external data' + # -------------------------------------------------------------------- + # For compatibility with older versions of FS + # TODO: delete this code block + if cfg.data.loader != '': + logger.warning('config `cfg.data.loader` will be remove in the ' + 'future, use `cfg.dataloader.type` instead.') + cfg.dataloader.type = cfg.data.loader + if cfg.data.batch_size != 64: + logger.warning('config `cfg.data.batch_size` will be remove in the ' + 'future, use `cfg.dataloader.batch_size` instead.') + cfg.dataloader.batch_size = cfg.data.batch_size + if not cfg.data.shuffle: + logger.warning('config `cfg.data.shuffle` will be remove in the ' + 'future, use `cfg.dataloader.shuffle` instead.') + cfg.dataloader.shuffle = cfg.data.shuffle + if cfg.data.num_workers != 0: + logger.warning('config `cfg.data.num_workers` will be remove in the ' + 'future, use `cfg.dataloader.num_workers` instead.') + cfg.dataloader.num_workers = cfg.data.num_workers + if cfg.data.drop_last: + logger.warning('config `cfg.data.drop_last` will be remove in the ' + 'future, use `cfg.dataloader.drop_last` instead.') + cfg.dataloader.drop_last = cfg.data.drop_last + if cfg.data.walk_length != 2: + logger.warning('config `cfg.data.walk_length` will be remove in the ' + 'future, use `cfg.dataloader.walk_length` instead.') + cfg.dataloader.walk_length = cfg.data.walk_length + if cfg.data.num_steps != 30: + logger.warning('config `cfg.data.num_steps` will be remove in the ' + 'future, use `cfg.dataloader.num_steps` instead.') + cfg.dataloader.num_steps = cfg.data.num_steps + if cfg.data.sizes != [10, 5]: + logger.warning('config `cfg.data.sizes` will be remove in the ' + 'future, use `cfg.dataloader.sizes` instead.') + cfg.dataloader.sizes = cfg.data.sizes + # -------------------------------------------------------------------- + + +register_config("data", extend_data_cfg) diff --git a/fgssl/core/configs/cfg_differential_privacy.py b/fgssl/core/configs/cfg_differential_privacy.py new file mode 100644 index 0000000..7a6ae3e --- /dev/null +++ b/fgssl/core/configs/cfg_differential_privacy.py @@ -0,0 +1,37 @@ +from federatedscope.core.configs.config import CN +from federatedscope.register import register_config + + +def extend_dp_cfg(cfg): + # ---------------------------------------------------------------------- # + # nbafl(dp) related options + # ---------------------------------------------------------------------- # + cfg.nbafl = CN() + + # Params + cfg.nbafl.use = False + cfg.nbafl.mu = 0. + cfg.nbafl.epsilon = 100. + cfg.nbafl.w_clip = 1. + cfg.nbafl.constant = 30. + + # ---------------------------------------------------------------------- # + # VFL-SGDMF(dp) related options + # ---------------------------------------------------------------------- # + cfg.sgdmf = CN() + + cfg.sgdmf.use = False # if use sgdmf algorithm + cfg.sgdmf.R = 5. # The upper bound of rating + cfg.sgdmf.epsilon = 4. # \epsilon in dp + cfg.sgdmf.delta = 0.5 # \delta in dp + cfg.sgdmf.constant = 1. # constant + + # --------------- register corresponding check function ---------- + cfg.register_cfg_check_fun(assert_dp_cfg) + + +def assert_dp_cfg(cfg): + pass + + +register_config("dp", extend_dp_cfg) diff --git a/fgssl/core/configs/cfg_evaluation.py b/fgssl/core/configs/cfg_evaluation.py new file mode 100644 index 0000000..09b9cdd --- /dev/null +++ b/fgssl/core/configs/cfg_evaluation.py @@ -0,0 +1,43 @@ +from federatedscope.core.configs.config import CN +from federatedscope.register import register_config + + +def extend_evaluation_cfg(cfg): + + # ---------------------------------------------------------------------- # + # Evaluation related options + # ---------------------------------------------------------------------- # + cfg.eval = CN( + new_allowed=True) # allow user to add their settings under `cfg.eval` + + cfg.eval.freq = 1 + cfg.eval.metrics = [] + cfg.eval.split = ['test', 'val'] + cfg.eval.report = ['weighted_avg', 'avg', 'fairness', + 'raw'] # by default, we report comprehensive results + cfg.eval.best_res_update_round_wise_key = "val_loss" + + # Monitoring, e.g., 'dissim' for B-local dissimilarity + cfg.eval.monitoring = [] + + cfg.eval.count_flops = True + + # ---------------------------------------------------------------------- # + # wandb related options + # ---------------------------------------------------------------------- # + cfg.wandb = CN() + cfg.wandb.use = False + cfg.wandb.name_user = '' + cfg.wandb.name_project = '' + cfg.wandb.online_track = True + cfg.wandb.client_train_info = False + + # --------------- register corresponding check function ---------- + cfg.register_cfg_check_fun(assert_evaluation_cfg) + + +def assert_evaluation_cfg(cfg): + pass + + +register_config("eval", extend_evaluation_cfg) diff --git a/fgssl/core/configs/cfg_fl_algo.py b/fgssl/core/configs/cfg_fl_algo.py new file mode 100644 index 0000000..8dbdd2a --- /dev/null +++ b/fgssl/core/configs/cfg_fl_algo.py @@ -0,0 +1,118 @@ +from federatedscope.core.configs.config import CN +from federatedscope.core.configs.yacs_config import Argument +from federatedscope.register import register_config + + +def extend_fl_algo_cfg(cfg): + # ---------------------------------------------------------------------- # + # fedopt related options, a general fl algorithm + # ---------------------------------------------------------------------- # + cfg.fedopt = CN() + + cfg.fedopt.use = False + + cfg.fedopt.optimizer = CN(new_allowed=True) + cfg.fedopt.optimizer.type = Argument( + 'SGD', description="optimizer type for FedOPT") + cfg.fedopt.optimizer.lr = Argument( + 0.01, description="learning rate for FedOPT optimizer") + + # ---------------------------------------------------------------------- # + # fedprox related options, a general fl algorithm + # ---------------------------------------------------------------------- # + cfg.fedprox = CN() + + cfg.fedprox.use = False + cfg.fedprox.mu = 0. + + # ---------------------------------------------------------------------- # + # Personalization related options, pFL + # ---------------------------------------------------------------------- # + cfg.personalization = CN() + + # client-distinct param names, e.g., ['pre', 'post'] + cfg.personalization.local_param = [] + cfg.personalization.share_non_trainable_para = False + cfg.personalization.local_update_steps = -1 + # @regular_weight: + # The smaller the regular_weight is, the stronger emphasising on + # personalized model + # For Ditto, the default value=0.1, the search space is [0.05, 0.1, 0.2, + # 1, 2] + # For pFedMe, the default value=15 + cfg.personalization.regular_weight = 0.1 + + # @lr: + # 1) For pFedME, the personalized learning rate to calculate theta + # approximately using K steps + # 2) 0.0 indicates use the value according to optimizer.lr in case of + # users have not specify a valid lr + cfg.personalization.lr = 0.0 + + cfg.personalization.K = 5 # the local approximation steps for pFedMe + cfg.personalization.beta = 1.0 # the average moving parameter for pFedMe + + # ---------------------------------------------------------------------- # + # FedSage+ related options, gfl + # ---------------------------------------------------------------------- # + cfg.fedsageplus = CN() + + # Number of nodes generated by the generator + cfg.fedsageplus.num_pred = 5 + # Hidden layer dimension of generator + cfg.fedsageplus.gen_hidden = 128 + # Hide graph portion + cfg.fedsageplus.hide_portion = 0.5 + # Federated training round for generator + cfg.fedsageplus.fedgen_epoch = 200 + # Local pre-train round for generator + cfg.fedsageplus.loc_epoch = 1 + # Coefficient for criterion number of missing node + cfg.fedsageplus.a = 1.0 + # Coefficient for criterion feature + cfg.fedsageplus.b = 1.0 + # Coefficient for criterion classification + cfg.fedsageplus.c = 1.0 + + # ---------------------------------------------------------------------- # + # GCFL+ related options, gfl + # ---------------------------------------------------------------------- # + cfg.gcflplus = CN() + + # Bound for mean_norm + cfg.gcflplus.EPS_1 = 0.05 + # Bound for max_norm + cfg.gcflplus.EPS_2 = 0.1 + # Length of the gradient sequence + cfg.gcflplus.seq_length = 5 + # Whether standardized dtw_distances + cfg.gcflplus.standardize = False + + # ---------------------------------------------------------------------- # + # FLIT+ related options, gfl + # ---------------------------------------------------------------------- # + cfg.flitplus = CN() + + cfg.flitplus.tmpFed = 0.5 # gamma in focal loss (Eq.4) + cfg.flitplus.lambdavat = 0.5 # lambda in phi (Eq.10) + cfg.flitplus.factor_ema = 0.8 # beta in omega (Eq.12) + cfg.flitplus.weightReg = 1.0 # balance lossLocalLabel and lossLocalVAT + + # --------------- register corresponding check function ---------- + cfg.register_cfg_check_fun(assert_fl_algo_cfg) + + +def assert_fl_algo_cfg(cfg): + if cfg.personalization.local_update_steps == -1: + # By default, use the same step to normal mode + cfg.personalization.local_update_steps = \ + cfg.train.local_update_steps + cfg.personalization.local_update_steps = \ + cfg.train.local_update_steps + + if cfg.personalization.lr <= 0.0: + # By default, use the same lr to normal mode + cfg.personalization.lr = cfg.train.optimizer.lr + + +register_config("fl_algo", extend_fl_algo_cfg) diff --git a/fgssl/core/configs/cfg_fl_setting.py b/fgssl/core/configs/cfg_fl_setting.py new file mode 100644 index 0000000..4f374c2 --- /dev/null +++ b/fgssl/core/configs/cfg_fl_setting.py @@ -0,0 +1,183 @@ +import logging + +from federatedscope.core.configs.config import CN +from federatedscope.register import register_config + +logger = logging.getLogger(__name__) + + +def extend_fl_setting_cfg(cfg): + # ---------------------------------------------------------------------- # + # Federate learning related options + # ---------------------------------------------------------------------- # + cfg.federate = CN() + + cfg.federate.client_num = 0 + cfg.federate.sample_client_num = -1 + cfg.federate.sample_client_rate = -1.0 + cfg.federate.unseen_clients_rate = 0.0 + cfg.federate.total_round_num = 50 + cfg.federate.mode = 'standalone' + cfg.federate.share_local_model = False + cfg.federate.data_weighted_aggr = False # If True, the weight of aggr is + # the number of training samples in dataset. + cfg.federate.online_aggr = False + cfg.federate.make_global_eval = False + cfg.federate.use_diff = False + cfg.federate.merge_test_data = False # For efficient simulation, users + # can choose to merge the test data and perform global evaluation, + # instead of perform test at each client + + # the method name is used to internally determine composition of + # different aggregators, messages, handlers, etc., + cfg.federate.method = "FedAvg" + cfg.federate.ignore_weight = False + cfg.federate.use_ss = False # Whether to apply Secret Sharing + cfg.federate.restore_from = '' + cfg.federate.save_to = '' + cfg.federate.join_in_info = [ + ] # The information requirements (from server) for join_in + cfg.federate.sampler = 'uniform' # the strategy for sampling client + # in each training round, ['uniform', 'group'] + cfg.federate.resource_info_file = "" # the device information file to + # record computation and communication ability + + # ---------------------------------------------------------------------- # + # Distribute training related options + # ---------------------------------------------------------------------- # + cfg.distribute = CN() + + cfg.distribute.use = False + cfg.distribute.server_host = '0.0.0.0' + cfg.distribute.server_port = 50050 + cfg.distribute.client_host = '0.0.0.0' + cfg.distribute.client_port = 50050 + cfg.distribute.role = 'client' + cfg.distribute.data_file = 'data' + cfg.distribute.data_idx = -1 # data_idx is used to specify the data + # index in distributed mode when adopting a centralized dataset for + # simulation (formatted as {data_idx: data/dataloader}). + # data_idx = -1 means that the whole dataset is owned by the participant. + # when data_idx is other invalid values excepted for -1, we randomly + # sample the data_idx for simulation + cfg.distribute.grpc_max_send_message_length = 100 * 1024 * 1024 + cfg.distribute.grpc_max_receive_message_length = 100 * 1024 * 1024 + cfg.distribute.grpc_enable_http_proxy = False + + # ---------------------------------------------------------------------- # + # Vertical FL related options (for demo) + # ---------------------------------------------------------------------- # + cfg.vertical = CN() + cfg.vertical.use = False + cfg.vertical.encryption = 'paillier' + cfg.vertical.dims = [5, 10] + cfg.vertical.key_size = 3072 + + # --------------- register corresponding check function ---------- + cfg.register_cfg_check_fun(assert_fl_setting_cfg) + + +def assert_fl_setting_cfg(cfg): + assert cfg.federate.mode in ["standalone", "distributed"], \ + f"Please specify the cfg.federate.mode as the string standalone or " \ + f"distributed. But got {cfg.federate.mode}." + + # ============= client num related ============== + assert not (cfg.federate.client_num == 0 + and cfg.federate.mode == 'distributed' + ), "Please configure the cfg.federate. in distributed mode. " + + assert 0 <= cfg.federate.unseen_clients_rate < 1, \ + "You specified in-valid cfg.federate.unseen_clients_rate" + if 0 < cfg.federate.unseen_clients_rate < 1 and cfg.federate.method in [ + "local", "global" + ]: + logger.warning( + "In local/global training mode, the unseen_clients_rate is " + "in-valid, plz check your config") + unseen_clients_rate = 0.0 + cfg.federate.unseen_clients_rate = unseen_clients_rate + else: + unseen_clients_rate = cfg.federate.unseen_clients_rate + participated_client_num = max( + 1, int((1 - unseen_clients_rate) * cfg.federate.client_num)) + + # sample client num pre-process + sample_client_num_valid = ( + 0 < cfg.federate.sample_client_num <= + cfg.federate.client_num) and cfg.federate.client_num != 0 + sample_client_rate_valid = (0 < cfg.federate.sample_client_rate <= 1) + + sample_cfg_valid = sample_client_rate_valid or sample_client_num_valid + non_sample_case = cfg.federate.method in ["local", "global"] + if non_sample_case and sample_cfg_valid: + logger.warning("In local/global training mode, " + "the sampling related configs are in-valid, " + "we will use all clients. ") + + if cfg.federate.method == "global": + logger.info( + "In global training mode, we will put all data in a proxy client. " + ) + if cfg.federate.make_global_eval: + cfg.federate.make_global_eval = False + logger.warning( + "In global training mode, we will conduct global evaluation " + "in a proxy client rather than the server. The configuration " + "cfg.federate.make_global_eval will be False.") + + if non_sample_case or not sample_cfg_valid: + # (a) use all clients + # in standalone mode, federate.client_num may be modified from 0 to + # num_of_all_clients after loading the data + if cfg.federate.client_num != 0: + cfg.federate.sample_client_num = participated_client_num + else: + # (b) sampling case + if sample_client_rate_valid: + # (b.1) use sample_client_rate + old_sample_client_num = cfg.federate.sample_client_num + cfg.federate.sample_client_num = max( + 1, + int(cfg.federate.sample_client_rate * participated_client_num)) + if sample_client_num_valid: + logger.warning( + f"Users specify both valid sample_client_rate as" + f" {cfg.federate.sample_client_rate} " + f"and sample_client_num as {old_sample_client_num}.\n" + f"\t\tWe will use the sample_client_rate value to " + f"calculate " + f"the actual number of participated clients as" + f" {cfg.federate.sample_client_num}.") + # (b.2) use sample_client_num, commented since the below two + # lines do not change anything + # elif sample_client_num_valid: + # cfg.federate.sample_client_num = \ + # cfg.federate.sample_client_num + + if cfg.federate.use_ss: + assert cfg.federate.client_num == cfg.federate.sample_client_num, \ + "Currently, we support secret sharing only in " \ + "all-client-participation case" + + assert cfg.federate.method != "local", \ + "Secret sharing is not supported in local training mode" + + # ============= aggregator related ================ + assert (not cfg.federate.online_aggr) or ( + not cfg.federate.use_ss + ), "Have not supported to use online aggregator and secrete sharing at " \ + "the same time" + + assert not cfg.federate.merge_test_data or ( + cfg.federate.merge_test_data and cfg.federate.mode == 'standalone' + ), "The operation of merging test data can only used in standalone for " \ + "efficient simulation, please change 'federate.merge_test_data' to " \ + "False or change 'federate.mode' to 'distributed'." + if cfg.federate.merge_test_data and not cfg.federate.make_global_eval: + cfg.federate.make_global_eval = True + logger.warning('Set cfg.federate.make_global_eval=True since ' + 'cfg.federate.merge_test_data=True') + + +register_config("fl_setting", extend_fl_setting_cfg) diff --git a/fgssl/core/configs/cfg_hpo.py b/fgssl/core/configs/cfg_hpo.py new file mode 100644 index 0000000..8ed8e7e --- /dev/null +++ b/fgssl/core/configs/cfg_hpo.py @@ -0,0 +1,85 @@ +from federatedscope.core.configs.config import CN +from federatedscope.register import register_config + + +def extend_hpo_cfg(cfg): + + # ---------------------------------------------------------------------- # + # hpo related options + # ---------------------------------------------------------------------- # + cfg.hpo = CN() + cfg.hpo.working_folder = 'hpo' + cfg.hpo.ss = '' + cfg.hpo.num_workers = 0 + cfg.hpo.init_cand_num = 16 + cfg.hpo.larger_better = False + cfg.hpo.scheduler = 'rs' + cfg.hpo.metric = 'client_summarized_weighted_avg.val_loss' + + # SHA + cfg.hpo.sha = CN() + cfg.hpo.sha.elim_rate = 3 + cfg.hpo.sha.budgets = [] + cfg.hpo.sha.iter = 0 + + # PBT + cfg.hpo.pbt = CN() + cfg.hpo.pbt.max_stage = 5 + cfg.hpo.pbt.perf_threshold = 0.1 + + # FedEx + cfg.hpo.fedex = CN() + cfg.hpo.fedex.use = False + cfg.hpo.fedex.ss = '' + cfg.hpo.fedex.flatten_ss = True + # If <= .0, use 'auto' + cfg.hpo.fedex.eta0 = -1.0 + cfg.hpo.fedex.sched = 'auto' + # cutoff: entropy level below which to stop updating the config + # probability and use MLE + cfg.hpo.fedex.cutoff = .0 + # discount factor; 0.0 is most recent, 1.0 is mean + cfg.hpo.fedex.gamma = .0 + cfg.hpo.fedex.diff = False + + # Table + cfg.hpo.table = CN() + cfg.hpo.table.eps = 0.1 + cfg.hpo.table.num = 27 + cfg.hpo.table.idx = 0 + + +def assert_hpo_cfg(cfg): + # HPO related + # assert cfg.hpo.init_strategy in [ + # 'full', 'grid', 'random' + # ], "initialization strategy for HPO should be \"full\", \"grid\", + # or \"random\", but the given choice is {}".format( + # cfg.hpo.init_strategy) + assert cfg.hpo.scheduler in ['rs', 'sha', + 'pbt'], "No HPO scheduler named {}".format( + cfg.hpo.scheduler) + assert cfg.hpo.num_workers >= 0, "#worker should be non-negative but " \ + "given {}".format(cfg.hpo.num_workers) + assert len(cfg.hpo.sha.budgets) > 0, \ + "Either do NOT specify the budgets or specify the budget for each " \ + "SHA iteration, but the given budgets is {}".format( + cfg.hpo.sha.budgets) + + assert not (cfg.hpo.fedex.use and cfg.federate.use_ss + ), "Cannot use secret sharing and FedEx at the same time" + assert cfg.train.optimizer.type == 'SGD' or not cfg.hpo.fedex.use, \ + "SGD is required if FedEx is considered" + assert cfg.hpo.fedex.sched in [ + 'adaptive', 'aggressive', 'auto', 'constant', 'scale' + ], "schedule of FedEx must be choice from {}".format( + ['adaptive', 'aggressive', 'auto', 'constant', 'scale']) + assert cfg.hpo.fedex.gamma >= .0 and cfg.hpo.fedex.gamma <= 1.0, \ + "{} must be in [0, 1]".format(cfg.hpo.fedex.gamma) + assert cfg.hpo.fedex.use == cfg.federate.use_diff, "Once FedEx is " \ + "adopted, " \ + "federate.use_diff " \ + "must be True." + + +register_config("hpo", extend_hpo_cfg) diff --git a/fgssl/core/configs/cfg_model.py b/fgssl/core/configs/cfg_model.py new file mode 100644 index 0000000..b11b082 --- /dev/null +++ b/fgssl/core/configs/cfg_model.py @@ -0,0 +1,50 @@ +from federatedscope.core.configs.config import CN +from federatedscope.register import register_config + + +def extend_model_cfg(cfg): + # ---------------------------------------------------------------------- # + # Model related options + # ---------------------------------------------------------------------- # + cfg.model = CN() + + cfg.model.model_num_per_trainer = 1 # some methods may leverage more + # than one model in each trainer + cfg.model.type = 'lr' + cfg.model.use_bias = True + cfg.model.task = 'node' + cfg.model.hidden = 256 + cfg.model.dropout = 0.5 + cfg.model.in_channels = 0 # If 0, model will be built by data.shape + cfg.model.out_channels = 1 + cfg.model.layer = 2 # In GPR-GNN, K = layer + cfg.model.graph_pooling = 'mean' + cfg.model.embed_size = 8 + cfg.model.num_item = 0 + cfg.model.num_user = 0 + cfg.model.input_shape = () # A tuple, e.g., (in_channel, h, w) + + # ---------------------------------------------------------------------- # + # Criterion related options + # ---------------------------------------------------------------------- # + cfg.criterion = CN() + + cfg.criterion.type = 'MSELoss' + + # ---------------------------------------------------------------------- # + # regularizer related options + # ---------------------------------------------------------------------- # + cfg.regularizer = CN() + + cfg.regularizer.type = '' + cfg.regularizer.mu = 0. + + # --------------- register corresponding check function ---------- + cfg.register_cfg_check_fun(assert_model_cfg) + + +def assert_model_cfg(cfg): + pass + + +register_config("model", extend_model_cfg) diff --git a/fgssl/core/configs/cfg_training.py b/fgssl/core/configs/cfg_training.py new file mode 100644 index 0000000..e4089a1 --- /dev/null +++ b/fgssl/core/configs/cfg_training.py @@ -0,0 +1,104 @@ +from federatedscope.core.configs.config import CN +from federatedscope.register import register_config + + +def extend_training_cfg(cfg): + # ---------------------------------------------------------------------- # + # Trainer related options + # ---------------------------------------------------------------------- # + cfg.trainer = CN() + + cfg.trainer.type = 'general' + + # ---------------------------------------------------------------------- # + # Training related options + # ---------------------------------------------------------------------- # + cfg.train = CN() + + cfg.train.local_update_steps = 1 + cfg.train.batch_or_epoch = 'batch' + + cfg.train.optimizer = CN(new_allowed=True) + cfg.train.optimizer.type = 'SGD' + cfg.train.optimizer.lr = 0.1 + + # you can add new arguments 'aa' by `cfg.train.scheduler.aa = 'bb'` + cfg.train.scheduler = CN(new_allowed=True) + cfg.train.scheduler.type = '' + + # ---------------------------------------------------------------------- # + # Finetune related options + # ---------------------------------------------------------------------- # + cfg.finetune = CN() + + cfg.finetune.before_eval = False + cfg.finetune.local_update_steps = 1 + cfg.finetune.batch_or_epoch = 'epoch' + cfg.finetune.freeze_param = "" + + cfg.finetune.optimizer = CN(new_allowed=True) + cfg.finetune.optimizer.type = 'SGD' + cfg.finetune.optimizer.lr = 0.1 + + cfg.finetune.scheduler = CN(new_allowed=True) + cfg.finetune.scheduler.type = '' + + # ---------------------------------------------------------------------- # + # Gradient related options + # ---------------------------------------------------------------------- # + cfg.grad = CN() + cfg.grad.grad_clip = -1.0 # negative numbers indicate we do not clip grad + + # ---------------------------------------------------------------------- # + # Early stopping related options + # ---------------------------------------------------------------------- # + cfg.early_stop = CN() + + # patience (int): How long to wait after last time the monitored metric + # improved. + # Note that the actual_checking_round = patience * cfg.eval.freq + # To disable the early stop, set the early_stop.patience a integer <=0 + cfg.early_stop.patience = 5 + # delta (float): Minimum change in the monitored metric to indicate an + # improvement. + cfg.early_stop.delta = 0.0 + # Early stop when no improve to last `patience` round, in ['mean', 'best'] + cfg.early_stop.improve_indicator_mode = 'best' + cfg.early_stop.the_smaller_the_better = True + + # --------------- register corresponding check function ---------- + cfg.register_cfg_check_fun(assert_training_cfg) + + +def assert_training_cfg(cfg): + if cfg.train.batch_or_epoch not in ['batch', 'epoch']: + raise ValueError( + "Value of 'cfg.train.batch_or_epoch' must be chosen from [" + "'batch', 'epoch'].") + + if cfg.finetune.batch_or_epoch not in ['batch', 'epoch']: + raise ValueError( + "Value of 'cfg.finetune.batch_or_epoch' must be chosen from [" + "'batch', 'epoch'].") + + # TODO: should not be here? + if cfg.backend not in ['torch', 'tensorflow']: + raise ValueError( + "Value of 'cfg.backend' must be chosen from ['torch', " + "'tensorflow'].") + if cfg.backend == 'tensorflow' and cfg.federate.mode == 'standalone': + raise ValueError( + "We only support run with distribued mode when backend is " + "tensorflow") + if cfg.backend == 'tensorflow' and cfg.use_gpu is True: + raise ValueError( + "We only support run with cpu when backend is tensorflow") + + if cfg.finetune.before_eval is False and cfg.finetune.local_update_steps\ + <= 0: + raise ValueError( + f"When adopting fine-tuning, please set a valid local fine-tune " + f"steps, got {cfg.finetune.local_update_steps}") + + +register_config("fl_training", extend_training_cfg) diff --git a/fgssl/core/configs/config.py b/fgssl/core/configs/config.py new file mode 100644 index 0000000..48b71b4 --- /dev/null +++ b/fgssl/core/configs/config.py @@ -0,0 +1,293 @@ +import copy +import logging +import os + +from pathlib import Path + +import federatedscope.register as register +from federatedscope.core.configs.yacs_config import CfgNode, _merge_a_into_b, \ + Argument + +logger = logging.getLogger(__name__) + + +def set_help_info(cn_node, help_info_dict, prefix=""): + for k, v in cn_node.items(): + if isinstance(v, Argument) and k not in help_info_dict: + help_info_dict[prefix + k] = v.description + elif isinstance(v, CN): + set_help_info(v, + help_info_dict, + prefix=f"{k}." if prefix == "" else f"{prefix}{k}.") + + +class CN(CfgNode): + """ + An extended configuration system based on [yacs]( + https://github.com/rbgirshick/yacs). + The two-level tree structure consists of several internal dict-like + containers to allow simple key-value access and management. + + """ + def __init__(self, init_dict=None, key_list=None, new_allowed=False): + init_dict = super().__init__(init_dict, key_list, new_allowed) + self.__cfg_check_funcs__ = list() # to check the config values + # validity + self.__help_info__ = dict() # build the help dict + + self.is_ready_for_run = False # whether this CfgNode has checked its + # validity, completeness and clean some un-useful info + + if init_dict: + for k, v in init_dict.items(): + if isinstance(v, Argument): + self.__help_info__[k] = v.description + elif isinstance(v, CN) and "help_info" in v: + for name, des in v.__help_info__.items(): + self.__help_info__[name] = des + + def __getattr__(self, name): + if name in self: + return self[name] + else: + raise AttributeError(name) + + def __delattr__(self, name): + if name in self: + del self[name] + else: + raise AttributeError(name) + + def clear_aux_info(self): + if hasattr(self, "__cfg_check_funcs__"): + delattr(self, "__cfg_check_funcs__") + if hasattr(self, "__help_info__"): + delattr(self, "__help_info__") + if hasattr(self, "is_ready_for_run"): + delattr(self, "is_ready_for_run") + for v in self.values(): + if isinstance(v, CN): + v.clear_aux_info() + + def print_help(self, arg_name=""): + """ + print help info for a specific given `arg_name` or + for all arguments if not given `arg_name` + :param arg_name: + :return: + """ + if arg_name != "" and arg_name in self.__help_info__: + print(f" --{arg_name} \t {self.__help_info__[arg_name]}") + else: + for k, v in self.__help_info__.items(): + print(f" --{k} \t {v}") + + def register_cfg_check_fun(self, cfg_check_fun): + self.__cfg_check_funcs__.append(cfg_check_fun) + + def merge_from_file(self, cfg_filename, check_cfg=True): + """ + load configs from a yaml file, another cfg instance or a list + stores the keys and values. + + :param cfg_filename (string): + :return: + """ + cfg_check_funcs = copy.copy(self.__cfg_check_funcs__) + with open(cfg_filename, "r") as f: + cfg = self.load_cfg(f) + self.merge_from_other_cfg(cfg) + self.__cfg_check_funcs__.clear() + self.__cfg_check_funcs__.extend(cfg_check_funcs) + self.assert_cfg(check_cfg) + set_help_info(self, self.__help_info__) + + def merge_from_other_cfg(self, cfg_other, check_cfg=True): + """ + load configs from another cfg instance + + :param cfg_other (CN): + :return: + """ + + cfg_check_funcs = copy.copy(self.__cfg_check_funcs__) + _merge_a_into_b(cfg_other, self, self, []) + self.__cfg_check_funcs__.clear() + self.__cfg_check_funcs__.extend(cfg_check_funcs) + self.assert_cfg(check_cfg) + set_help_info(self, self.__help_info__) + + def merge_from_list(self, cfg_list, check_cfg=True): + """ + load configs from a list stores the keys and values. + modified `merge_from_list` in `yacs.config.py` to allow adding + new keys if `is_new_allowed()` returns True + + :param cfg_list (list): + :return: + """ + cfg_check_funcs = copy.copy(self.__cfg_check_funcs__) + super().merge_from_list(cfg_list) + self.__cfg_check_funcs__.clear() + self.__cfg_check_funcs__.extend(cfg_check_funcs) + self.assert_cfg(check_cfg) + set_help_info(self, self.__help_info__) + + def assert_cfg(self, check_cfg=True): + """ + check the validness of the configuration instance + + :return: + """ + if check_cfg: + for check_func in self.__cfg_check_funcs__: + check_func(self) + + def clean_unused_sub_cfgs(self): + """ + Clean the un-used secondary-level CfgNode, whose `.use` + attribute is `True` + + :return: + """ + for v in self.values(): + if isinstance(v, CfgNode) or isinstance(v, CN): + # sub-config + if hasattr(v, "use") and v.use is False: + for k in copy.deepcopy(v).keys(): + # delete the un-used attributes + if k == "use": + continue + else: + del v[k] + + def check_required_args(self): + for k, v in self.items(): + if isinstance(v, CN): + v.check_required_args() + if isinstance(v, Argument) and v.required and v.value is None: + logger.warning(f"You have not set the required argument {k}") + + def de_arguments(self): + """ + some config values are managed via `Argument` class, this function + is used to make these values clean without the `Argument` class, + such that the potential type-specific methods work correctly, + e.g., len(cfg.federate.method) for a string config + :return: + """ + for k, v in copy.deepcopy(self).items(): + if isinstance(v, CN): + self[k].de_arguments() + if isinstance(v, Argument): + self[k] = v.value + + def ready_for_run(self, check_cfg=True): + self.assert_cfg(check_cfg) + self.clean_unused_sub_cfgs() + self.check_required_args() + self.de_arguments() + self.is_ready_for_run = True + + def freeze(self, inform=True, save=True, check_cfg=True): + """ + 1) make the cfg attributes immutable; + 2) if save=True, save the frozen cfg_check_funcs into + "self.outdir/config.yaml" for better reproducibility; + 3) if self.wandb.use=True, update the frozen config + + :return: + """ + self.ready_for_run(check_cfg) + super(CN, self).freeze() + + if save: # save the final cfg + Path(self.outdir).mkdir(parents=True, exist_ok=True) + with open(os.path.join(self.outdir, "config.yaml"), + 'w') as outfile: + from contextlib import redirect_stdout + with redirect_stdout(outfile): + tmp_cfg = copy.deepcopy(self) + tmp_cfg.clear_aux_info() + print(tmp_cfg.dump()) + if self.wandb.use: + # update the frozen config + try: + import wandb + import yaml + cfg_yaml = yaml.safe_load(tmp_cfg.dump()) + wandb.config.update(cfg_yaml, allow_val_change=True) + except ImportError: + logger.error( + "cfg.wandb.use=True but not install the wandb " + "package") + exit() + + if inform: + logger.info("the used configs are: \n" + str(tmp_cfg)) + + +# to ensure the sub-configs registered before set up the global config +from federatedscope.core.configs import all_sub_configs + +for sub_config in all_sub_configs: + __import__("federatedscope.core.configs." + sub_config) + +from federatedscope.contrib.configs import all_sub_configs_contrib + +for sub_config in all_sub_configs_contrib: + __import__("federatedscope.contrib.configs." + sub_config) + +# Global config object +global_cfg = CN() + + +def init_global_cfg(cfg): + r''' + This function sets the default config value. + 1) Note that for an experiment, only part of the arguments will be used + The remaining unused arguments won't affect anything. + So feel free to register any argument in graphgym.contrib.config + 2) We support more than one levels of configs, e.g., cfg.dataset.name + + :return: configuration use by the experiment. + ''' + + # ---------------------------------------------------------------------- # + # Basic options, first level configs + # ---------------------------------------------------------------------- # + + cfg.backend = 'torch' + + # Whether to use GPU + cfg.use_gpu = False + + # Whether to print verbose logging info + cfg.verbose = 1 + + # How many decimal places we print out using logger + cfg.print_decimal_digits = 6 + + # Specify the device + cfg.device = -1 + + # Random seed + cfg.seed = 0 + + # Path of configuration file + cfg.cfg_file = '' + + # The dir used to save log, exp_config, models, etc,. + cfg.outdir = 'exp' + cfg.expname = '' # detailed exp name to distinguish different sub-exp + cfg.expname_tag = '' # detailed exp tag to distinguish different + # sub-exp with the same expname + + # extend user customized configs + for func in register.config_dict.values(): + func(cfg) + + set_help_info(cfg, cfg.__help_info__) + + +init_global_cfg(global_cfg) diff --git a/fgssl/core/configs/constants.py b/fgssl/core/configs/constants.py new file mode 100644 index 0000000..93c8ace --- /dev/null +++ b/fgssl/core/configs/constants.py @@ -0,0 +1,46 @@ +"""Configuration file for composition of different aggregators, messages, +handlers, etc. + + - The method `local` indicates that the clients only locally + train their model without sharing any training related information + - The method `global` indicates that the only one client locally trains + using all data + +""" + +AGGREGATOR_TYPE = { + "local": "no_communication", # the clients locally train their model + # without sharing any training related info + "global": "no_communication", # only one client locally train all data, + # i.e., totally global training + "fedavg": "clients_avg", # FedAvg + "pfedme": "server_clients_interpolation", # pFedMe, + server-clients + # interpolation + "ditto": "clients_avg", # Ditto + "fedsageplus": "clients_avg", + "gcflplus": "clients_avg", + "fedopt": "fedopt" +} + +CLIENTS_TYPE = { + "local": "normal", + "fedavg": "normal", # FedAvg + "pfedme": "normal_loss_regular", # pFedMe, + regularization-based local + # loss + "ditto": "normal", # Ditto, + local training for distinct personalized + # models + "fedsageplus": "fedsageplus", # FedSage+ for graph data + "gcflplus": "gcflplus", # GCFL+ for graph data + "gradascent": "gradascent", + "fgcl": "fgcl" +} + +SERVER_TYPE = { + "local": "normal", + "fedavg": "normal", # FedAvg + "pfedme": "normal", # pFedMe, + regularization-based local loss + "ditto": "normal", # Ditto, + local training for distinct personalized + # models + "fedsageplus": "fedsageplus", # FedSage+ for graph data + "gcflplus": "gcflplus", # GCFL+ for graph data +} diff --git a/fgssl/core/configs/yacs_config.py b/fgssl/core/configs/yacs_config.py new file mode 100644 index 0000000..cd1c6ae --- /dev/null +++ b/fgssl/core/configs/yacs_config.py @@ -0,0 +1,605 @@ +# An extended configuration system based on [yacs], +# modified by Alibaba Group +# +# +# Yacs' license is as follows: +# +# Copyright (c) 2018-present, Facebook, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +############################################################################## +""" + An extended configuration system based on [yacs]( + https://github.com/rbgirshick/yacs). + + We enhance yacs with more functionalities to support FederatedScope + configureation, such as more flexible type conversion, help info, + allowing required attributes, etc., + +""" + +import copy +import io +import logging +import os +import sys +from ast import literal_eval + +import yaml + +logger = logging.getLogger(__name__) + +# Flag for py2 and py3 compatibility to use when separate code paths are +# necessary. When _PY2 is False, we assume Python 3 is in use +_PY2 = sys.version_info.major == 2 + +# Filename extensions for loading configs from files +_YAML_EXTS = {"", ".yaml", ".yml"} +_PY_EXTS = {".py"} +_FILE_TYPES = (io.IOBase, ) + +# py2 and py3 compatibility for checking file object type +# We simply use this to infer py2 vs py3 + +# Utilities for importing modules from file paths +if _PY2: + # imp is available in both py2 and py3 for now, but is deprecated in py3 + import imp +else: + import importlib.util + +logger = logging.getLogger(__name__) + + +class Argument: + def __init__(self, + default_value=None, + description='', + required_type=None, + required=False): + self.value = default_value + self.description = description + self.required = required + if required_type is not None: + assert required_type in _VALID_TYPES, "" + self.type = required_type + else: + assert type(default_value) in _VALID_TYPES, "" + self.type = type(default_value) + + def __str__(self): + return 'Required({})'.format(self.type.__name__) \ + if self.required else str(self.value) + + def __repr__(self): + return 'Required({})'.format(self.type.__name__) \ + if self.required else str(self.value) + + +# CfgNodes can only contain a limited set of valid types +_VALID_TYPES = {tuple, list, str, int, float, bool, type(None), Argument, dict} +# py2 allow for str and unicode +if _PY2: + _VALID_TYPES = _VALID_TYPES.union({unicode}) # noqa: F821 + +# allow int <-> float conversation +casts = [(tuple, list), (list, tuple), (int, float), (float, int)] + + +class CfgNode(dict): + """ + CfgNode represents an internal node in the configuration tree. It's a + simple dict-like container that allows for attribute-based access to keys. + """ + + IMMUTABLE = "__immutable__" + DEPRECATED_KEYS = "__deprecated_keys__" + RENAMED_KEYS = "__renamed_keys__" + NEW_ALLOWED = "__new_allowed__" + + def __init__(self, init_dict=None, key_list=None, new_allowed=False): + """ + Args: + init_dict (dict): the possibly-nested dictionary to initailize the + CfgNode. + key_list (list[str]): a list of names which index this CfgNode from + the root. + Currently only used for logging purposes. + new_allowed (bool): whether adding new key is allowed when merging + with other configs. + """ + # Recursively convert nested dictionaries in init_dict into CfgNodes + init_dict = {} if init_dict is None else init_dict + key_list = [] if key_list is None else key_list + init_dict = self._create_config_tree_from_dict(init_dict, key_list) + super(CfgNode, self).__init__(init_dict) + # Manage if the CfgNode is frozen or not + self.__dict__[CfgNode.IMMUTABLE] = False + # Deprecated options + # If an option is removed from the code and you don't want to break + # existing yaml configs, you can add the full config key as a string + # to the set below. + self.__dict__[CfgNode.DEPRECATED_KEYS] = set() + # Renamed options + # If you rename a config option, record the mapping from the old name + # to the new name in the dictionary below. Optionally, + # if the type also changed, you can + # make the value a tuple that specifies first the renamed key and then + # instructions for how to edit the config file. + self.__dict__[CfgNode.RENAMED_KEYS] = { + # 'EXAMPLE.OLD.KEY': 'EXAMPLE.NEW.KEY', # Dummy example to follow + # 'EXAMPLE.OLD.KEY': ( + # # A more complex example to follow + # 'EXAMPLE.NEW.KEY', + # "Also convert to a tuple, e.g., 'foo' -> ('foo',) or " + # + "'foo:bar' -> ('foo', 'bar')" + # ), + } + + # Allow new attributes after initialisation + self.__dict__[CfgNode.NEW_ALLOWED] = new_allowed + # return init_dict + + @classmethod + def _create_config_tree_from_dict(cls, dic, key_list): + """ + Create a configuration tree using the given dict. + Any dict-like objects inside dict will be treated as a new CfgNode. + + Args: + dic (dict): + key_list (list[str]): a list of names which index this CfgNode + from the root. Currently only used for logging purposes. + """ + dic = copy.deepcopy(dic) + for k, v in dic.items(): + if isinstance(v, dict): + # Convert dict to CfgNode + dic[k] = cls(v, key_list=key_list + [k]) + else: + # Check for valid leaf type or nested CfgNode + _assert_with_logging( + _valid_type(v, allow_cfg_node=False), + "Key {} with value {} is not a valid type; valid types: {}" + .format(".".join(key_list + [str(k)]), type(v), + _VALID_TYPES), + ) + return dic + + def __getattr__(self, name): + if name in self: + return self[name] + else: + raise AttributeError(name) + + def __setattr__(self, name, value): + if self.is_frozen(): + raise AttributeError( + "Attempted to set {} to {}, but CfgNode is immutable".format( + name, value)) + + _assert_with_logging( + name not in self.__dict__, + "Invalid attempt to modify internal CfgNode state: {}".format( + name), + ) + _assert_with_logging( + _valid_type(value, allow_cfg_node=True), + "Invalid type {} for key {}; valid types = {}".format( + type(value), name, _VALID_TYPES), + ) + + self[name] = value + + def __str__(self): + def _indent(s_, num_spaces): + s = s_.split("\n") + if len(s) == 1: + return s_ + first = s.pop(0) + s = [(num_spaces * " ") + line for line in s] + s = "\n".join(s) + s = first + "\n" + s + return s + + r = "" + s = [] + for k, v in sorted(self.items()): + seperator = "\n" if isinstance(v, CfgNode) else " " + attr_str = "{}:{}{}".format(str(k), seperator, str(v)) + attr_str = _indent(attr_str, 2) + s.append(attr_str) + r += "\n".join(s) + return r + + def __repr__(self): + return "{}({})".format(self.__class__.__name__, + super(CfgNode, self).__repr__()) + + def dump(self, **kwargs): + """Dump to a string.""" + def convert_to_dict(cfg_node, key_list): + if not isinstance(cfg_node, CfgNode): + _assert_with_logging( + _valid_type(cfg_node), + "Key {} with value {} is not a valid type; valid types: {}" + .format(".".join(key_list), type(cfg_node), _VALID_TYPES), + ) + return cfg_node + else: + cfg_dict = dict(cfg_node) + for k, v in cfg_dict.items(): + cfg_dict[k] = convert_to_dict(v, key_list + [k]) + return cfg_dict + + self_as_dict = convert_to_dict(self, []) + return yaml.safe_dump(self_as_dict, **kwargs) + + def merge_from_file(self, cfg_filename): + """Load a yaml config file and merge it this CfgNode.""" + with open(cfg_filename, "r") as f: + cfg = self.load_cfg(f) + self.merge_from_other_cfg(cfg) + + def merge_from_other_cfg(self, cfg_other): + """Merge `cfg_other` into this CfgNode.""" + _merge_a_into_b(cfg_other, self, self, []) + + def merge_from_list(self, cfg_list): + """Merge config (keys, values) in a list (e.g., from command line) into + this CfgNode. For example, `cfg_list = ['FOO.BAR', 0.5]`. + """ + _assert_with_logging( + len(cfg_list) % 2 == 0, + "Override list has odd length: {}; it must be a list of pairs". + format(cfg_list), + ) + root = self + for full_key, v in zip(cfg_list[0::2], cfg_list[1::2]): + if root.key_is_deprecated(full_key): + continue + if root.key_is_renamed(full_key): + root.raise_key_rename_error(full_key) + key_list = full_key.split(".") + d = self + for subkey in key_list[:-1]: + _assert_with_logging(subkey in d, + "Non-existent key: {}".format(full_key)) + d = d[subkey] + subkey = key_list[-1] + _assert_with_logging(subkey in d or d.is_new_allowed(), + "Non-existent key: {}".format(full_key)) + value = self._decode_cfg_value(v) + value = _check_and_coerce_cfg_value_type(value, d[subkey], subkey, + full_key) + d[subkey] = value + + def freeze(self): + """Make this CfgNode and all of its children immutable.""" + self._immutable(True) + + def defrost(self): + """Make this CfgNode and all of its children mutable.""" + self._immutable(False) + + def is_frozen(self): + """Return mutability.""" + return self.__dict__[CfgNode.IMMUTABLE] + + def _immutable(self, is_immutable): + """Set immutability to is_immutable and recursively apply the setting + to all nested CfgNodes. + """ + self.__dict__[CfgNode.IMMUTABLE] = is_immutable + # Recursively set immutable state + for v in self.__dict__.values(): + if isinstance(v, CfgNode): + v._immutable(is_immutable) + for v in self.values(): + if isinstance(v, CfgNode): + v._immutable(is_immutable) + + def clone(self): + """Recursively copy this CfgNode.""" + return copy.deepcopy(self) + + def register_deprecated_key(self, key): + """Register key (e.g. `FOO.BAR`) a deprecated option. + When merging deprecated keys a warning is generated and the key is + ignored. + """ + _assert_with_logging( + key not in self.__dict__[CfgNode.DEPRECATED_KEYS], + "key {} is already registered as a deprecated key".format(key), + ) + self.__dict__[CfgNode.DEPRECATED_KEYS].add(key) + + def register_renamed_key(self, old_name, new_name, message=None): + """Register a key as having been renamed from `old_name` to `new_name`. + When merging a renamed key, an exception is thrown alerting to user to + the fact that the key has been renamed. + """ + _assert_with_logging( + old_name not in self.__dict__[CfgNode.RENAMED_KEYS], + "key {} is already registered as a renamed cfg key".format( + old_name), + ) + value = new_name + if message: + value = (new_name, message) + self.__dict__[CfgNode.RENAMED_KEYS][old_name] = value + + def key_is_deprecated(self, full_key): + """Test if a key is deprecated.""" + if full_key in self.__dict__[CfgNode.DEPRECATED_KEYS]: + logger.warning( + "Deprecated config key (ignoring): {}".format(full_key)) + return True + return False + + def key_is_renamed(self, full_key): + """Test if a key is renamed.""" + return full_key in self.__dict__[CfgNode.RENAMED_KEYS] + + def raise_key_rename_error(self, full_key): + new_key = self.__dict__[CfgNode.RENAMED_KEYS][full_key] + if isinstance(new_key, tuple): + msg = " Note: " + new_key[1] + new_key = new_key[0] + else: + msg = "" + raise KeyError( + "Key {} was renamed to {}; please update your config.{}".format( + full_key, new_key, msg)) + + def is_new_allowed(self): + return self.__dict__[CfgNode.NEW_ALLOWED] + + def set_new_allowed(self, is_new_allowed): + """ + Set this config (and recursively its subconfigs) to allow merging + new keys from other configs. + """ + self.__dict__[CfgNode.NEW_ALLOWED] = is_new_allowed + # Recursively set new_allowed state + for v in self.__dict__.values(): + if isinstance(v, CfgNode): + v.set_new_allowed(is_new_allowed) + for v in self.values(): + if isinstance(v, CfgNode): + v.set_new_allowed(is_new_allowed) + + @classmethod + def load_cfg(cls, cfg_file_obj_or_str): + """ + Load a cfg. + Args: + cfg_file_obj_or_str (str or file): + Supports loading from: + - A file object backed by a YAML file + - A file object backed by a Python source file that exports an + attribute "cfg" that is either a dict or a CfgNode + - A string that can be parsed as valid YAML + """ + _assert_with_logging( + isinstance(cfg_file_obj_or_str, _FILE_TYPES + (str, )), + "Expected first argument to be of type {} or {}, but it was {}". + format(_FILE_TYPES, str, type(cfg_file_obj_or_str)), + ) + if isinstance(cfg_file_obj_or_str, str): + return cls._load_cfg_from_yaml_str(cfg_file_obj_or_str) + elif isinstance(cfg_file_obj_or_str, _FILE_TYPES): + return cls._load_cfg_from_file(cfg_file_obj_or_str) + else: + raise NotImplementedError( + "Impossible to reach here (unless there's a bug)") + + @classmethod + def _load_cfg_from_file(cls, file_obj): + """Load a config from a YAML file or a Python source file.""" + _, file_extension = os.path.splitext(file_obj.name) + if file_extension in _YAML_EXTS: + return cls._load_cfg_from_yaml_str(file_obj.read()) + elif file_extension in _PY_EXTS: + return cls._load_cfg_py_source(file_obj.name) + else: + raise Exception( + "Attempt to load from an unsupported file type {}; " + "only {} are supported".format(file_obj, + _YAML_EXTS.union(_PY_EXTS))) + + @classmethod + def _load_cfg_from_yaml_str(cls, str_obj): + """Load a config from a YAML string encoding.""" + cfg_as_dict = yaml.safe_load(str_obj) + return cls(cfg_as_dict) + + @classmethod + def _load_cfg_py_source(cls, filename): + """Load a config from a Python source file.""" + module = _load_module_from_file("yacs.config.override", filename) + _assert_with_logging( + hasattr(module, "cfg"), + "Python module from file {} must have 'cfg' attr".format(filename), + ) + VALID_ATTR_TYPES = {dict, CfgNode} + _assert_with_logging( + type(module.cfg) in VALID_ATTR_TYPES, + "Imported module 'cfg' attr must be in {} but is {} instead". + format(VALID_ATTR_TYPES, type(module.cfg)), + ) + return cls(module.cfg) + + @classmethod + def _decode_cfg_value(cls, value): + """ + Decodes a raw config value (e.g., from a yaml config files or command + line argument) into a Python object. + + If the value is a dict, it will be interpreted as a new CfgNode. + If the value is a str, it will be evaluated as literals. + Otherwise it is returned as-is. + """ + # Configs parsed from raw yaml will contain dictionary keys that need + # to be converted to CfgNode objects + if isinstance(value, dict): + return cls(value) + # All remaining processing is only applied to strings + if not isinstance(value, str): + return value + # Try to interpret `value` as a: + # string, number, tuple, list, dict, boolean, or None + try: + value = literal_eval(value) + # The following two excepts allow v to pass through when it represents + # a string. + # + # Longer explanation: + # The type of v is always a string (before calling literal_eval), but + # sometimes it *represents* a string and other times a data structure, + # like a list. In the case that v represents a string, what we got + # back from the yaml parser is 'foo' *without quotes* (so, + # not '"foo"'). literal_eval is ok with '"foo"', but will raise a + # ValueError if given 'foo'. In other cases, like paths (v = + # 'foo/bar' and not v = '"foo/bar"'), literal_eval + # will raise a SyntaxError. + except ValueError: + pass + except SyntaxError: + pass + return value + + +load_cfg = (CfgNode.load_cfg + ) # keep this function in global scope for backward compatibility + + +def _valid_type(value, allow_cfg_node=False): + return (type(value) in _VALID_TYPES) or (allow_cfg_node + and isinstance(value, CfgNode)) + + +def _merge_a_into_b(a, b, root, key_list): + """ + [Modified from yacs, to allow int <-> float conversation] + + Merge config dictionary a into config dictionary b, clobbering the + options in b whenever they are also specified in a. + """ + _assert_with_logging( + isinstance(a, CfgNode), + "`a` (cur type {}) must be an instance of {}".format(type(a), CfgNode), + ) + _assert_with_logging( + isinstance(b, CfgNode), + "`b` (cur type {}) must be an instance of {}".format(type(b), CfgNode), + ) + + for k, v_ in a.items(): + full_key = ".".join(key_list + [k]) + + v = copy.deepcopy(v_) + v = b._decode_cfg_value(v) + + if k in ['help_info', "__help_info__"] and k in b: + for name, info in v_.items(): + b[k][name] = info + elif k in b: + v = _check_and_coerce_cfg_value_type(v, b[k], k, full_key) + # Recursively merge dicts + if isinstance(v, CfgNode): + try: + _merge_a_into_b(v, b[k], root, key_list + [k]) + except BaseException: + raise + else: + b[k] = v + elif b.is_new_allowed(): + b[k] = v + else: + if root.key_is_deprecated(full_key): + continue + elif root.key_is_renamed(full_key): + root.raise_key_rename_error(full_key) + else: + raise KeyError("Non-existent config key: {}".format(full_key)) + + +def _check_and_coerce_cfg_value_type(replacement, original, key, full_key): + """ + [Modified from yacs, to allow int <-> float conversation] + + Checks that `replacement`, which is intended to replace `original` is of + the right type. The type is correct if it matches exactly or is one of a + few cases in which the type can be easily coerced. + """ + original_type = type(original) + replacement_type = type(replacement) + if original_type is Argument: + original_type = original.type + original = original.value + if replacement_type is Argument: + replacement_type = replacement.type + replacement = replacement.value + + # The types must match (with some exceptions) + if replacement_type == original_type: + return replacement + + # If either of them is None, + # allow type conversion to one of the valid types + if (replacement_type is None and original_type in _VALID_TYPES) or ( + original_type is None and replacement_type in _VALID_TYPES): + return replacement + + # Cast replacement from from_type to to_type + # if the replacement and original types match from_type and to_type + def conditional_cast(from_type, to_type): + if replacement_type == from_type and original_type == to_type: + return True, to_type(replacement) + else: + return False, None + + # Conditionally casts + # list <-> tuple + # For py2: allow converting from str (bytes) to a unicode string + try: + casts.append((str, unicode)) # noqa: F821 + except Exception: + pass + + for (from_type, to_type) in casts: + converted, converted_value = conditional_cast(from_type, to_type) + if converted: + return converted_value + + raise ValueError( + "Type mismatch ({} vs. {}) with values ({} vs. {}) for config " + "key: {}".format(original_type, replacement_type, original, + replacement, full_key)) + + +def _assert_with_logging(cond, msg): + if not cond: + logger.debug(msg) + assert cond, msg + + +def _load_module_from_file(name, filename): + if _PY2: + module = imp.load_source(name, filename) + else: + spec = importlib.util.spec_from_file_location(name, filename) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + return module diff --git a/fgssl/core/data/README.md b/fgssl/core/data/README.md new file mode 100644 index 0000000..e63e062 --- /dev/null +++ b/fgssl/core/data/README.md @@ -0,0 +1,155 @@ +# DataZoo + +FederatedScope provides a rich collection of federated datasets for researchers, including images, texts, graphs, recommendation systems, and speeches, as well as utility classes `BaseDataTranslator` for building your own FS datasets. + +## Built-in FS data + +All datasets can be accessed from [`federatedscope.core.auxiliaries.data_builder.get_data`](https://github.com/alibaba/FederatedScope/blob/master/federatedscope/core/auxiliaries/data_builder.py), which are built to [`federatedscope.core.data.StandaloneDataDict`](https://github.com/alibaba/FederatedScope/tree/master/federatedscope/core/data/base_data.py) (for more details, see [[DataZoo advanced]](#advanced)). By setting `cfg.data.type = DATASET_NAME`, FS would download and pre-process a specific dataset to be passed to `FedRunner`. For example: + +```python +# Source: federatedscope/main.py + +data, cfg = get_data(cfg) +runner = FedRunner(data=data, + server_class=get_server_cls(cfg), + client_class=get_client_cls(cfg), + config=cfg.clone()) +``` + +We provide a **look-up table** for you to get started with our DataZoo: + +| `cfg.data.type` | Domain | +| ------------------------------------------------------------ | ------------------- | +| [FEMNIST](https://github.com/alibaba/FederatedScope/blob/master/federatedscope/cv/dataset/leaf_cv.py) | CV | +| [Celeba](https://github.com/alibaba/FederatedScope/blob/master/federatedscope/cv/dataset/leaf_cv.py) | CV | +| [{DNAME}@torchvision](https://github.com/alibaba/FederatedScope/blob/master/federatedscope/core/auxiliaries/data_builder.py) | CV | +| [Shakespeare](https://github.com/alibaba/FederatedScope/blob/master/federatedscope/nlp/dataset/leaf_nlp.py) | NLP | +| [SubReddit](https://github.com/alibaba/FederatedScope/blob/master/federatedscope/nlp/dataset/leaf_nlp.py) | NLP | +| [Twitter (Sentiment140)](https://github.com/alibaba/FederatedScope/blob/master/federatedscope/nlp/dataset/leaf_twitter.py) | NLP | +| [{DNAME}@torchtext](https://github.com/alibaba/FederatedScope/blob/master/federatedscope/core/auxiliaries/data_builder.py) | NLP | +| [{DNAME}@huggingface_datasets](https://github.com/alibaba/FederatedScope/blob/master/federatedscope/core/auxiliaries/data_builder.py) | NLP | +| [Cora](https://github.com/alibaba/FederatedScope/blob/master/federatedscope/gfl/dataloader/dataloader_node.py) | Graph (node-level) | +| [CiteSeer](https://github.com/alibaba/FederatedScope/blob/master/federatedscope/gfl/dataloader/dataloader_node.py) | Graph (node-level) | +| [PubMed](https://github.com/alibaba/FederatedScope/blob/master/federatedscope/gfl/dataloader/dataloader_node.py) | Graph (node-level) | +| [DBLP_conf](https://github.com/alibaba/FederatedScope/blob/master/federatedscope/gfl/dataset/dblp_new.py) | Graph (node-level) | +| [DBLP_org](https://github.com/alibaba/FederatedScope/blob/master/federatedscope/gfl/dataset/dblp_new.py) | Graph (node-level) | +| [csbm](https://github.com/alibaba/FederatedScope/blob/master/federatedscope/gfl/dataset/cSBM_dataset.py) | Graph (node-level) | +| [Epinions](https://github.com/alibaba/FederatedScope/blob/master/federatedscope/gfl/dataset/recsys.py) | Graph (link-level) | +| [Ciao](https://github.com/alibaba/FederatedScope/blob/master/federatedscope/gfl/dataset/recsys.py) | Graph (link-level) | +| [FB15k](https://github.com/alibaba/FederatedScope/blob/master/federatedscope/gfl/dataloader/dataloader_link.py) | Graph (link-level) | +| [FB15k-237](https://github.com/alibaba/FederatedScope/blob/master/federatedscope/gfl/dataloader/dataloader_link.py) | Graph (link-level) | +| [WN18](https://github.com/alibaba/FederatedScope/blob/master/federatedscope/gfl/dataloader/dataloader_link.py) | Graph (link-level) | +| [MUTAG](https://github.com/alibaba/FederatedScope/blob/master/federatedscope/gfl/dataloader/dataloader_graph.py) | Graph (graph-level) | +| [BZR](https://github.com/alibaba/FederatedScope/blob/master/federatedscope/gfl/dataloader/dataloader_graph.py) | Graph (graph-level) | +| [COX2](https://github.com/alibaba/FederatedScope/blob/master/federatedscope/gfl/dataloader/dataloader_graph.py) | Graph (graph-level) | +| [DHFR](https://github.com/alibaba/FederatedScope/blob/master/federatedscope/gfl/dataloader/dataloader_graph.py) | Graph (graph-level) | +| [PTC_MR](https://github.com/alibaba/FederatedScope/blob/master/federatedscope/gfl/dataloader/dataloader_graph.py) | Graph (graph-level) | +| [AIDS](https://github.com/alibaba/FederatedScope/blob/master/federatedscope/gfl/dataloader/dataloader_graph.py) | Graph (graph-level) | +| [NCI1](https://github.com/alibaba/FederatedScope/blob/master/federatedscope/gfl/dataloader/dataloader_graph.py) | Graph (graph-level) | +| [ENZYMES](https://github.com/alibaba/FederatedScope/blob/master/federatedscope/gfl/dataloader/dataloader_graph.py) | Graph (graph-level) | +| [DD](https://github.com/alibaba/FederatedScope/blob/master/federatedscope/gfl/dataloader/dataloader_graph.py) | Graph (graph-level) | +| [PROTEINS](https://github.com/alibaba/FederatedScope/blob/master/federatedscope/gfl/dataloader/dataloader_graph.py) | Graph (graph-level) | +| [COLLAB](https://github.com/alibaba/FederatedScope/blob/master/federatedscope/gfl/dataloader/dataloader_graph.py) | Graph (graph-level) | +| [IMDB-BINARY](https://github.com/alibaba/FederatedScope/blob/master/federatedscope/gfl/dataloader/dataloader_graph.py) | Graph (graph-level) | +| [IMDB-MULTI](https://github.com/alibaba/FederatedScope/blob/master/federatedscope/gfl/dataloader/dataloader_graph.py) | Graph (graph-level) | +| [REDDIT-BINARY](https://github.com/alibaba/FederatedScope/blob/master/federatedscope/gfl/dataloader/dataloader_graph.py) | Graph (graph-level) | +| [HIV](https://github.com/alibaba/FederatedScope/blob/master/federatedscope/gfl/dataloader/dataloader_graph.py) | Graph (graph-level) | +| [ESOL](https://github.com/alibaba/FederatedScope/blob/master/federatedscope/gfl/dataloader/dataloader_graph.py) | Graph (graph-level) | +| [FREESOLV](https://github.com/alibaba/FederatedScope/blob/master/federatedscope/gfl/dataloader/dataloader_graph.py) | Graph (graph-level) | +| [LIPO](https://github.com/alibaba/FederatedScope/blob/master/federatedscope/gfl/dataloader/dataloader_graph.py) | Graph (graph-level) | +| [PCBA](https://github.com/alibaba/FederatedScope/blob/master/federatedscope/gfl/dataloader/dataloader_graph.py) | Graph (graph-level) | +| [MUV](https://github.com/alibaba/FederatedScope/blob/master/federatedscope/gfl/dataloader/dataloader_graph.py) | Graph (graph-level) | +| [BACE](https://github.com/alibaba/FederatedScope/blob/master/federatedscope/gfl/dataloader/dataloader_graph.py) | Graph (graph-level) | +| [BBBP](https://github.com/alibaba/FederatedScope/blob/master/federatedscope/gfl/dataloader/dataloader_graph.py) | Graph (graph-level) | +| [TOX21](https://github.com/alibaba/FederatedScope/blob/master/federatedscope/gfl/dataloader/dataloader_graph.py) | Graph (graph-level) | +| [TOXCAST](https://github.com/alibaba/FederatedScope/blob/master/federatedscope/gfl/dataloader/dataloader_graph.py) | Graph (graph-level) | +| [SIDER](https://github.com/alibaba/FederatedScope/blob/master/federatedscope/gfl/dataloader/dataloader_graph.py) | Graph (graph-level) | +| [CLINTOX](https://github.com/alibaba/FederatedScope/blob/master/federatedscope/gfl/dataloader/dataloader_graph.py) | Graph (graph-level) | +| [graph_multi_domain_mol](https://github.com/alibaba/FederatedScope/blob/master/federatedscope/gfl/dataloader/dataloader_graph.py) | Graph (graph-level) | +| [graph_multi_domain_small](https://github.com/alibaba/FederatedScope/blob/master/federatedscope/gfl/dataloader/dataloader_graph.py) | Graph (graph-level) | +| [graph_multi_domain_biochem](https://github.com/alibaba/FederatedScope/blob/master/federatedscope/gfl/dataloader/dataloader_graph.py) | Graph (graph-level) | +| [cikmcup](https://github.com/alibaba/FederatedScope/blob/master/federatedscope/gfl/dataset/cikm_cup.py) | Graph (graph-level) | +| [toy](https://github.com/alibaba/FederatedScope/blob/master/federatedscope/core/auxiliaries/data_builder.py) | Tabular | +| [synthetic](https://github.com/alibaba/FederatedScope/blob/master/federatedscope/nlp/dataset/leaf_synthetic.py) | Tabular | +| [quadratic](https://github.com/alibaba/FederatedScope/blob/master/federatedscope/tabular/dataloader/quadratic.py) | Tabular | +| [{DNAME}openml](https://github.com/alibaba/FederatedScope/blob/master/federatedscope/core/auxiliaries/data_builder.py) | Tabular | +| [vertical_fl_data](https://github.com/alibaba/FederatedScope/blob/master/federatedscope/vertical_fl/dataloader/dataloader.py) | Tabular(vertical) | +| [VFLMovieLens1M](https://github.com/alibaba/FederatedScope/blob/master/federatedscope/mf/dataset/movielens.py) | Recommendation | +| [VFLMovieLens10M](https://github.com/alibaba/FederatedScope/blob/master/federatedscope/mf/dataset/movielens.py) | Recommendation | +| [HFLMovieLens1M](https://github.com/alibaba/FederatedScope/blob/master/federatedscope/mf/dataset/movielens.py) | Recommendation | +| [HFLMovieLens10M](https://github.com/alibaba/FederatedScope/blob/master/federatedscope/mf/dataset/movielens.py) | Recommendation | +| [VFLNetflix](https://github.com/alibaba/FederatedScope/blob/master/federatedscope/mf/dataset/netflix.py) | Recommendation | +| [HFLNetflix](https://github.com/alibaba/FederatedScope/blob/master/federatedscope/mf/dataset/netflix.py) | Recommendation | + +## DataZoo Advanced + +In this section, we will introduce key concepts and tools to help you understand how FS data works and how to use it to build your own data in FS. + +Concepts: + +* [`federatedscope.core.data.ClientData`](https://github.com/alibaba/FederatedScope/blob/master/federatedscope/core/data/base_data.py) + + * `ClientData` is a subclass of `dict`. In federated learning, each client (server) owns a `ClientData` for training, validating, or testing. Thus, each `ClientData` has one or more of `train`, `val`, and `test` as keys, and `DataLoader` accordingly. + + * The `DataLoader` of each key is created by `setup()` method, which specifies the arguments of `DataLoader`, such as `batch_size`, `shuffle` of `cfg`. + + Example: + + ```python + # Instantiate client_data for each Client + client_data = ClientData(DataLoader, + cfg, + train=train_data, + val=None, + test=test_data) + # other_cfg with different batch size + client_data.setup(other_cfg) + print(client_data) + + >> {'train': DataLoader(train_data), 'test': DataLoader(test_data)} + ``` + +* [`federatedscope.core.data.StandaloneDataDict`](https://github.com/alibaba/FederatedScope/blob/master/federatedscope/core/data/base_data.py) + * `StandaloneDataDict` is a subclass of `dict`. As the name implies, `StandaloneDataDict` consists of all `ClientData` with client index as key (`0`, `1`, `2`, ...) in standalone mode. The key `0` is the data of the server for global evaluation or other usages. + * The method `preprocess()` in `StandaloneDataDict` makes changes to inner `ClientData` when `cfg` changes, such as in global mode, we set `cfg.federate.method == "global"`, and `StandaloneDataDict` will merge all `ClientData` to one client to perform global training. + +Tools + +* [`federatedscope.core.data.BaseDataTranslator`](https://github.com/alibaba/FederatedScope/blob/master/federatedscope/core/data/base_translator.py) + + * `BaseDataTranslator` converts [`torch.utils.data.Dataset`](https://pytorch.org/docs/stable/data.html#torch.utils.data.Dataset) or `dict` of data split to `StandaloneDataDict` according to `cfg`. After translating, it can be directly passed to `FedRunner` to launch a FL course. + + * `BaseDataTranslator` will split data to `train`, `val,` and `test` by `cfg.data.splits` (**ML split**). And using `Splitter` to split each data split to each client (**FL split**). In order to use `BaseDataTranslator`, `cfg.data.splitter`, `cfg.federate.client_num,` and other arguments of `Splitter` must be specified. + + Example: + + ```python + cfg.data.splitter = 'lda' + cfg.federate.client_num = 5 + cfg.data.splitter_args = [{'alpha': 0.2}] + + translator = BaseDataTranslator(global_cfg, DataLoader) + raw_data = CIFAR10() + fs_data = translator(raw_data) + + runner = FedRunner(data=fs_data, + server_class=get_server_cls(cfg), + client_class=get_client_cls(cfg), + config=cfg.clone()) + ``` + +* [`federatedscope.core.splitters`](federatedscope.core.splitters) + + * To generate simulated federation datasets, we provide `splitter` who are responsible for dispersing a given standalone dataset into multiple clients, with configurable statistical heterogeneity among them. + + We provide a **look-up table** for you to get started with our `Splitter`: + + | `cfg.data.splitter` | Domain | Arguments | + | :------------------ | ------------------- | :----------------------------------------------- | + | LDA | Generic | `alpha` | + | Louvain | Graph (node-level) | `delta` | + | Random | Graph (node-level) | `sampling_rate`, `overlapping_rate`, `drop_edge` | + | rel_type | Graph (link-level) | `alpha` | + | Scaffold | Molecular | - | + | Scaffold_lda | Molecular | `alpha` | + | Rand_chunk | Graph (graph-level) | - | \ No newline at end of file diff --git a/fgssl/core/data/__init__.py b/fgssl/core/data/__init__.py new file mode 100644 index 0000000..be0c014 --- /dev/null +++ b/fgssl/core/data/__init__.py @@ -0,0 +1,8 @@ +from federatedscope.core.data.base_data import StandaloneDataDict, ClientData +from federatedscope.core.data.base_translator import BaseDataTranslator +from federatedscope.core.data.dummy_translator import DummyDataTranslator + +__all__ = [ + 'StandaloneDataDict', 'ClientData', 'BaseDataTranslator', + 'DummyDataTranslator' +] diff --git a/fgssl/core/data/base_data.py b/fgssl/core/data/base_data.py new file mode 100644 index 0000000..c219f50 --- /dev/null +++ b/fgssl/core/data/base_data.py @@ -0,0 +1,174 @@ +import logging +from federatedscope.core.data.utils import merge_data +from federatedscope.core.auxiliaries.dataloader_builder import get_dataloader + +logger = logging.getLogger(__name__) + + +class StandaloneDataDict(dict): + """ + `StandaloneDataDict` maintain several `ClientData`. + """ + def __init__(self, datadict, global_cfg): + """ + + Args: + datadict: `Dict` with `client_id` as key, `ClientData` as value. + global_cfg: global CfgNode + """ + self.global_cfg = global_cfg + self.client_cfgs = None + datadict = self.preprocess(datadict) + super(StandaloneDataDict, self).__init__(datadict) + + def resetup(self, global_cfg, client_cfgs=None): + """ + Resetup new configs for `ClientData`, which might be used in HPO. + + Args: + global_cfg: enable new config for `ClientData` + client_cfgs: enable new client-specific config for `ClientData` + """ + self.global_cfg, self.client_cfgs = global_cfg, client_cfgs + for client_id, client_data in self.items(): + if isinstance(client_data, ClientData): + if client_cfgs is not None: + client_cfg = global_cfg.clone() + client_cfg.merge_from_other_cfg( + client_cfgs.get(f'client_{client_id}')) + else: + client_cfg = global_cfg + client_data.setup(client_cfg) + else: + logger.warning('`client_data` is not subclass of ' + '`ClientData`, and cannot re-setup ' + 'DataLoader with new configs.') + + def preprocess(self, datadict): + """ + Preprocess for StandaloneDataDict for: + 1. Global evaluation (merge test data). + 2. Global mode (train with centralized setting, merge all data). + + Args: + datadict: dict with `client_id` as key, `ClientData` as value. + """ + if self.global_cfg.federate.merge_test_data: + server_data = merge_data( + all_data=datadict, + merged_max_data_id=self.global_cfg.federate.client_num, + specified_dataset_name=['test']) + # `0` indicate Server + datadict[0] = server_data + + if self.global_cfg.federate.method == "global": + if self.global_cfg.federate.client_num != 1: + if self.global_cfg.data.server_holds_all: + assert datadict[0] is not None \ + and len(datadict[0]) != 0, \ + "You specified cfg.data.server_holds_all=True " \ + "but data[0] is None. Please check whether you " \ + "pre-process the data[0] correctly" + datadict[1] = datadict[0] + else: + logger.info(f"Will merge data from clients whose ids in " + f"[1, {self.global_cfg.federate.client_num}]") + datadict[1] = merge_data( + all_data=datadict, + merged_max_data_id=self.global_cfg.federate.client_num) + datadict = self.attack(datadict) + return datadict + + def attack(self, datadict): + """ + Apply attack to `StandaloneDataDict`. + + """ + if 'backdoor' in self.global_cfg.attack.attack_method and 'edge' in \ + self.global_cfg.attack.trigger_type: + import os + import torch + from federatedscope.attack.auxiliary import \ + create_ardis_poisoned_dataset, create_ardis_test_dataset + if not os.path.exists(self.global_cfg.attack.edge_path): + os.makedirs(self.global_cfg.attack.edge_path) + poisoned_edgeset = create_ardis_poisoned_dataset( + data_path=self.global_cfg.attack.edge_path) + + ardis_test_dataset = create_ardis_test_dataset( + self.global_cfg.attack.edge_path) + + logger.info("Writing poison_data to: {}".format( + self.global_cfg.attack.edge_path)) + + with open( + self.global_cfg.attack.edge_path + + "poisoned_edgeset_training", "wb") as saved_data_file: + torch.save(poisoned_edgeset, saved_data_file) + + with open( + self.global_cfg.attack.edge_path + + "ardis_test_dataset.pt", "wb") as ardis_data_file: + torch.save(ardis_test_dataset, ardis_data_file) + logger.warning( + 'please notice: downloading the poisoned dataset \ + on cifar-10 from \ + https://github.com/ksreenivasan/OOD_Federated_Learning' + ) + + if 'backdoor' in self.global_cfg.attack.attack_method: + from federatedscope.attack.auxiliary import poisoning + poisoning(datadict, self.global_cfg) + return datadict + + +class ClientData(dict): + """ + `ClientData` converts dataset to train/val/test DataLoader. + Key `data` in `ClientData` is the raw dataset. + """ + def __init__(self, client_cfg, train=None, val=None, test=None, **kwargs): + """ + + Args: + loader: Dataloader class or data dict which have been built + client_cfg: client-specific CfgNode + data: raw dataset, which will stay raw + train: train dataset, which will be converted to DataLoader + val: valid dataset, which will be converted to DataLoader + test: test dataset, which will be converted to DataLoader + """ + self.client_cfg = None + self.train = train + self.val = val + self.test = test + self.setup(client_cfg) + if kwargs is not None: + for key in kwargs: + self[key] = kwargs[key] + super(ClientData, self).__init__() + + def setup(self, new_client_cfg=None): + """ + + Args: + new_client_cfg: new client-specific CfgNode + + Returns: + Status: indicate whether the client_cfg is updated + """ + # if `batch_size` or `shuffle` change, reinstantiate DataLoader + if self.client_cfg is not None: + if dict(self.client_cfg.dataloader) == dict( + new_client_cfg.dataloader): + return False + + self.client_cfg = new_client_cfg + if self.train is not None: + self['train'] = get_dataloader(self.train, self.client_cfg, + 'train') + if self.val is not None: + self['val'] = get_dataloader(self.val, self.client_cfg, 'val') + if self.test is not None: + self['test'] = get_dataloader(self.test, self.client_cfg, 'test') + return True diff --git a/fgssl/core/data/base_translator.py b/fgssl/core/data/base_translator.py new file mode 100644 index 0000000..4fc82a0 --- /dev/null +++ b/fgssl/core/data/base_translator.py @@ -0,0 +1,129 @@ +import logging +import numpy as np + +from federatedscope.core.auxiliaries.splitter_builder import get_splitter +from federatedscope.core.data import ClientData, StandaloneDataDict + +logger = logging.getLogger(__name__) + + +class BaseDataTranslator: + """ + Perform process: + Dataset -> ML split -> FL split -> Data (passed to FedRunner) + + """ + def __init__(self, global_cfg, client_cfgs=None): + """ + Convert data to `StandaloneDataDict`. + + Args: + global_cfg: global CfgNode + client_cfgs: client cfg `Dict` + """ + self.global_cfg = global_cfg + self.client_cfgs = client_cfgs + self.splitter = get_splitter(global_cfg) + + def __call__(self, dataset): + """ + + Args: + dataset: `torch.utils.data.Dataset`, `List` of (feature, label) + or split dataset tuple of (train, val, test) or Tuple of + split dataset with [train, val, test] + + Returns: + datadict: instance of `StandaloneDataDict`, which is a subclass of + `dict`. + """ + datadict = self.split(dataset) + datadict = StandaloneDataDict(datadict, self.global_cfg) + + return datadict + + def split(self, dataset): + """ + Perform ML split and FL split. + + Returns: + dict of `ClientData` with client_idx as key. + + """ + train, val, test = self.split_train_val_test(dataset) + datadict = self.split_to_client(train, val, test) + return datadict + + def split_train_val_test(self, dataset): + """ + Split dataset to train, val, test if not provided. + + Returns: + split_data (List): List of split dataset, [train, val, test] + + """ + splits = self.global_cfg.data.splits + if isinstance(dataset, tuple): + # No need to split train/val/test for tuple dataset. + error_msg = 'If dataset is tuple, it must contains ' \ + 'train, valid and test split.' + assert len(dataset) == len(['train', 'val', 'test']), error_msg + return [dataset[0], dataset[1], dataset[2]] + + index = np.random.permutation(np.arange(len(dataset))) + train_size = int(splits[0] * len(dataset)) + val_size = int(splits[1] * len(dataset)) + + train_dataset = [dataset[x] for x in index[:train_size]] + val_dataset = [ + dataset[x] for x in index[train_size:train_size + val_size] + ] + test_dataset = [dataset[x] for x in index[train_size + val_size:]] + return train_dataset, val_dataset, test_dataset + + def split_to_client(self, train, val, test): + """ + Split dataset to clients and build `ClientData`. + + Returns: + data_dict (dict): dict of `ClientData` with client_idx as key. + + """ + + # Initialization + client_num = self.global_cfg.federate.client_num + split_train, split_val, split_test = [[None] * client_num] * 3 + train_label_distribution = None + + # Split train/val/test to client + if len(train) > 0: + split_train = self.splitter(train) + if self.global_cfg.data.consistent_label_distribution: + try: + train_label_distribution = [[j[1] for j in x] + for x in split_train] + except: + logger.warning( + 'Cannot access train label distribution for ' + 'splitter.') + if len(val) > 0: + split_val = self.splitter(val, prior=train_label_distribution) + if len(test) > 0: + split_test = self.splitter(test, prior=train_label_distribution) + + # Build data dict with `ClientData`, key `0` for server. + data_dict = { + 0: ClientData(self.global_cfg, train=train, val=val, test=test) + } + for client_id in range(1, client_num + 1): + if self.client_cfgs is not None: + client_cfg = self.global_cfg.clone() + client_cfg.merge_from_other_cfg( + self.client_cfgs.get(f'client_{client_id}')) + else: + client_cfg = self.global_cfg + data_dict[client_id] = ClientData(client_cfg, + train=split_train[client_id - 1], + val=split_val[client_id - 1], + test=split_test[client_id - 1]) + return data_dict diff --git a/fgssl/core/data/dummy_translator.py b/fgssl/core/data/dummy_translator.py new file mode 100644 index 0000000..640a80e --- /dev/null +++ b/fgssl/core/data/dummy_translator.py @@ -0,0 +1,38 @@ +from federatedscope.core.data.base_translator import BaseDataTranslator +from federatedscope.core.data.base_data import ClientData + + +class DummyDataTranslator(BaseDataTranslator): + """ + DummyDataTranslator convert FL dataset to DataLoader. + Do not perform FL split. + """ + def split(self, dataset): + if not isinstance(dataset, dict): + raise TypeError(f'Not support data type {type(dataset)}') + datadict = {} + for client_id in dataset.keys(): + if self.client_cfgs is not None: + client_cfg = self.global_cfg.clone() + client_cfg.merge_from_other_cfg( + self.client_cfgs.get(f'client_{client_id}')) + else: + client_cfg = self.global_cfg + + if isinstance(dataset[client_id], dict): + datadict[client_id] = ClientData(client_cfg, + **dataset[client_id]) + else: + # Do not have train/val/test + train, val, test = self.split_train_val_test( + dataset[client_id]) + tmp_dict = dict(train=train, val=val, test=test) + # Only for graph-level task, get number of graph labels + if client_cfg.model.task.startswith('graph') and \ + client_cfg.model.out_channels == 0: + s = set() + for g in dataset[client_id]: + s.add(g.y.item()) + tmp_dict['num_label'] = len(s) + datadict[client_id] = ClientData(client_cfg, **tmp_dict) + return datadict diff --git a/fgssl/core/data/utils.py b/fgssl/core/data/utils.py new file mode 100644 index 0000000..1b46cac --- /dev/null +++ b/fgssl/core/data/utils.py @@ -0,0 +1,617 @@ +import copy +import inspect +import logging +import os +import re +from collections import defaultdict + +import numpy as np +from random import shuffle + +logger = logging.getLogger(__name__) + + +class RegexInverseMap: + def __init__(self, n_dic, val): + self._items = {} + for key, values in n_dic.items(): + for value in values: + self._items[value] = key + self.__val = val + + def __getitem__(self, key): + for regex in self._items.keys(): + if re.compile(regex).match(key): + return self._items[regex] + return self.__val + + def __repr__(self): + return str(self._items.items()) + + +def load_dataset(config): + if config.data.type.lower() == 'toy': + from federatedscope.tabular.dataloader.toy import load_toy_data + dataset, modified_config = load_toy_data(config) + elif config.data.type.lower() == 'quadratic': + from federatedscope.tabular.dataloader import load_quadratic_dataset + dataset, modified_config = load_quadratic_dataset(config) + elif config.data.type.lower() in ['femnist', 'celeba']: + from federatedscope.cv.dataloader import load_cv_dataset + dataset, modified_config = load_cv_dataset(config) + elif config.data.type.lower() in [ + 'shakespeare', 'twitter', 'subreddit', 'synthetic' + ]: + from federatedscope.nlp.dataloader import load_nlp_dataset + dataset, modified_config = load_nlp_dataset(config) + elif config.data.type.lower() in [ + 'cora', + 'citeseer', + 'pubmed', + 'dblp_conf', + 'dblp_org', + ] or config.data.type.lower().startswith('csbm'): + from federatedscope.gfl.dataloader import load_nodelevel_dataset + dataset, modified_config = load_nodelevel_dataset(config) + elif config.data.type.lower() in ['ciao', 'epinions', 'fb15k-237', 'wn18']: + from federatedscope.gfl.dataloader import load_linklevel_dataset + dataset, modified_config = load_linklevel_dataset(config) + elif config.data.type.lower() in [ + 'hiv', 'proteins', 'imdb-binary', 'bbbp', 'tox21', 'bace', 'sider', + 'clintox', 'esol', 'freesolv', 'lipo', 'cikmcup' + ] or config.data.type.startswith('graph_multi_domain'): + from federatedscope.gfl.dataloader import load_graphlevel_dataset + dataset, modified_config = load_graphlevel_dataset(config) + elif config.data.type.lower() == 'vertical_fl_data': + from federatedscope.vertical_fl.dataloader import load_vertical_data + dataset, modified_config = load_vertical_data(config, generate=True) + elif 'movielens' in config.data.type.lower( + ) or 'netflix' in config.data.type.lower(): + from federatedscope.mf.dataloader import load_mf_dataset + dataset, modified_config = load_mf_dataset(config) + elif '@' in config.data.type.lower(): + from federatedscope.core.data.utils import load_external_data + dataset, modified_config = load_external_data(config) + elif config.data.type is None or config.data.type == "": + # The participant (only for server in this version) does not own data + dataset = None + modified_config = config + else: + raise ValueError('Dataset {} not found.'.format(config.data.type)) + return dataset, modified_config + + +def load_external_data(config=None): + r""" Based on the configuration file, this function imports external + datasets and applies train/valid/test splits and split by some specific + `splitter` into the standard FederatedScope input data format. + + Args: + config: `CN` from `federatedscope/core/configs/config.py` + + Returns: + data_local_dict: dict of split dataloader. + Format: + { + 'client_id': { + 'train': DataLoader(), + 'test': DataLoader(), + 'val': DataLoader() + } + } + modified_config: `CN` from `federatedscope/core/configs/config.py`, + which might be modified in the function. + + """ + + import torch + from importlib import import_module + from torch.utils.data import DataLoader + from federatedscope.core.auxiliaries.transform_builder import get_transform + + def load_torchvision_data(name, splits=None, config=None): + dataset_func = getattr(import_module('torchvision.datasets'), name) + transform_funcs = get_transform(config, 'torchvision') + if config.data.args: + raw_args = config.data.args[0] + else: + raw_args = {} + if 'download' not in raw_args.keys(): + raw_args.update({'download': True}) + filtered_args = filter_dict(dataset_func.__init__, raw_args) + func_args = get_func_args(dataset_func.__init__) + + # Perform split on different dataset + if 'train' in func_args: + # Split train to (train, val) + dataset_train = dataset_func(root=config.data.root, + train=True, + **filtered_args, + **transform_funcs) + dataset_val = None + dataset_test = dataset_func(root=config.data.root, + train=False, + **filtered_args, + **transform_funcs) + if splits: + train_size = int(splits[0] * len(dataset_train)) + val_size = len(dataset_train) - train_size + lengths = [train_size, val_size] + dataset_train, dataset_val = \ + torch.utils.data.dataset.random_split(dataset_train, + lengths) + + elif 'split' in func_args: + # Use raw split + dataset_train = dataset_func(root=config.data.root, + split='train', + **filtered_args, + **transform_funcs) + dataset_val = dataset_func(root=config.data.root, + split='valid', + **filtered_args, + **transform_funcs) + dataset_test = dataset_func(root=config.data.root, + split='test', + **filtered_args, + **transform_funcs) + elif 'classes' in func_args: + # Use raw split + dataset_train = dataset_func(root=config.data.root, + classes='train', + **filtered_args, + **transform_funcs) + dataset_val = dataset_func(root=config.data.root, + classes='valid', + **filtered_args, + **transform_funcs) + dataset_test = dataset_func(root=config.data.root, + classes='test', + **filtered_args, + **transform_funcs) + else: + # Use config.data.splits + dataset = dataset_func(root=config.data.root, + **filtered_args, + **transform_funcs) + train_size = int(splits[0] * len(dataset)) + val_size = int(splits[1] * len(dataset)) + test_size = len(dataset) - train_size - val_size + lengths = [train_size, val_size, test_size] + dataset_train, dataset_val, dataset_test = \ + torch.utils.data.dataset.random_split(dataset, lengths) + + data_split_dict = { + 'train': dataset_train, + 'val': dataset_val, + 'test': dataset_test + } + + return data_split_dict + + def load_torchtext_data(name, splits=None, config=None): + from torch.nn.utils.rnn import pad_sequence + from federatedscope.nlp.dataset.utils import label_to_index + + dataset_func = getattr(import_module('torchtext.datasets'), name) + if config.data.args: + raw_args = config.data.args[0] + else: + raw_args = {} + assert 'max_len' in raw_args, "Miss key 'max_len' in " \ + "`config.data.args`." + filtered_args = filter_dict(dataset_func.__init__, raw_args) + dataset = dataset_func(root=config.data.root, **filtered_args) + + # torchtext.transforms requires >= 0.12.0 and torch = 1.11.0, + # so we do not use `get_transform` in torchtext. + + # Merge all data and tokenize + x_list = [] + y_list = [] + for data_iter in dataset: + data, targets = [], [] + for i, item in enumerate(data_iter): + data.append(item[1]) + targets.append(item[0]) + x_list.append(data) + y_list.append(targets) + + x_all, y_all = [], [] + for i in range(len(x_list)): + x_all += x_list[i] + y_all += y_list[i] + + if config.model.type.endswith('transformers'): + from transformers import AutoTokenizer + cache_path = os.path.join(os.getcwd(), "huggingface") + try: + tokenizer = AutoTokenizer.from_pretrained( + config.model.type.split('@')[0], + local_files_only=True, + cache_dir=cache_path) + except Exception as e: + logging.error(f"When loading cached file form " + f"{cache_path}, we faced the exception: \n " + f"{str(e)}") + + x_all = tokenizer(x_all, + return_tensors='pt', + padding=True, + truncation=True, + max_length=raw_args['max_len']) + data = [{key: value[i] + for key, value in x_all.items()} + for i in range(len(next(iter(x_all.values()))))] + if 'classification' in config.model.task.lower(): + targets = label_to_index(y_all) + else: + y_all = tokenizer(y_all, + return_tensors='pt', + padding=True, + truncation=True, + max_length=raw_args['max_len']) + targets = [{key: value[i] + for key, value in y_all.items()} + for i in range(len(next(iter(y_all.values()))))] + else: + from torchtext.data import get_tokenizer + tokenizer = get_tokenizer("basic_english") + if len(config.data.transform) == 0: + raise ValueError( + "`transform` must be one pretrained Word Embeddings from \ + ['GloVe', 'FastText', 'CharNGram']") + if len(config.data.transform) == 1: + config.data.transform.append({}) + vocab = getattr(import_module('torchtext.vocab'), + config.data.transform[0])( + dim=config.model.in_channels, + **config.data.transform[1]) + + if 'classification' in config.model.task.lower(): + data = [ + vocab.get_vecs_by_tokens(tokenizer(x), + lower_case_backup=True) + for x in x_all + ] + targets = label_to_index(y_all) + else: + data = [ + vocab.get_vecs_by_tokens(tokenizer(x), + lower_case_backup=True) + for x in x_all + ] + targets = [ + vocab.get_vecs_by_tokens(tokenizer(y), + lower_case_backup=True) + for y in y_all + ] + targets = pad_sequence(targets).transpose( + 0, 1)[:, :raw_args['max_len'], :] + data = pad_sequence(data).transpose(0, + 1)[:, :raw_args['max_len'], :] + # Split data to raw + num_items = [len(ds) for ds in x_list] + data_list, cnt = [], 0 + for num in num_items: + data_list.append([ + (x, y) + for x, y in zip(data[cnt:cnt + num], targets[cnt:cnt + num]) + ]) + cnt += num + + if len(data_list) == 3: + # Use raw splits + data_split_dict = { + 'train': data_list[0], + 'val': data_list[1], + 'test': data_list[2] + } + elif len(data_list) == 2: + # Split train to (train, val) + data_split_dict = { + 'train': data_list[0], + 'val': None, + 'test': data_list[1] + } + if splits: + train_size = int(splits[0] * len(data_split_dict['train'])) + val_size = len(data_split_dict['train']) - train_size + lengths = [train_size, val_size] + data_split_dict['train'], data_split_dict[ + 'val'] = torch.utils.data.dataset.random_split( + data_split_dict['train'], lengths) + else: + # Use config.data.splits + data_split_dict = {} + train_size = int(splits[0] * len(data_list[0])) + val_size = int(splits[1] * len(data_list[0])) + test_size = len(data_list[0]) - train_size - val_size + lengths = [train_size, val_size, test_size] + data_split_dict['train'], data_split_dict['val'], data_split_dict[ + 'test'] = torch.utils.data.dataset.random_split( + data_list[0], lengths) + + return data_split_dict + + def load_torchaudio_data(name, splits=None, config=None): + + # dataset_func = getattr(import_module('torchaudio.datasets'), name) + raise NotImplementedError + + def load_huggingface_datasets_data(name, splits=None, config=None): + from datasets import load_dataset, load_from_disk + + if config.data.args: + raw_args = config.data.args[0] + else: + raw_args = {} + assert 'max_len' in raw_args, "Miss key 'max_len' in " \ + "`config.data.args`." + filtered_args = filter_dict(load_dataset, raw_args) + logger.info("Begin to load huggingface dataset") + if "hg_cache_dir" in raw_args: + hugging_face_path = raw_args["hg_cache_dir"] + else: + hugging_face_path = os.getcwd() + + if "load_disk_dir" in raw_args: + load_path = raw_args["load_disk_dir"] + try: + dataset = load_from_disk(load_path) + except Exception as e: + logging.error(f"When loading cached dataset form " + f"{load_path}, we faced the exception: \n " + f"{str(e)}") + else: + dataset = load_dataset(path=config.data.root, + name=name, + **filtered_args) + if config.model.type.endswith('transformers'): + os.environ["TOKENIZERS_PARALLELISM"] = "false" + from transformers import AutoTokenizer + logger.info("To load huggingface tokenizer") + tokenizer = AutoTokenizer.from_pretrained( + config.model.type.split('@')[0], + local_files_only=True, + cache_dir=os.path.join(hugging_face_path, "transformers")) + + for split in dataset: + x_all = [i['sentence'] for i in dataset[split]] + targets = [i['label'] for i in dataset[split]] + + if split == "train" and "used_train_ratio" in raw_args and \ + 1 > raw_args['used_train_ratio'] > 0: + selected_idx = [i for i in range(len(dataset[split]))] + shuffle(selected_idx) + selected_idx = selected_idx[:int( + len(selected_idx) * raw_args['used_train_ratio'])] + x_all = [ + element for i, element in enumerate(x_all) + if i in selected_idx + ] + targets = [ + element for i, element in enumerate(targets) + if i in selected_idx + ] + + x_all = tokenizer(x_all, + return_tensors='pt', + padding=True, + truncation=True, + max_length=raw_args['max_len']) + data = [{key: value[i] + for key, value in x_all.items()} + for i in range(len(next(iter(x_all.values()))))] + dataset[split] = (data, targets) + data_split_dict = { + 'train': [(x, y) + for x, y in zip(dataset['train'][0], dataset['train'][1]) + ], + 'val': [(x, y) for x, y in zip(dataset['validation'][0], + dataset['validation'][1])], + 'test': [ + (x, y) for x, y in zip(dataset['test'][0], dataset['test'][1]) + ] if (set(dataset['test'][1]) - set([-1])) else None, + } + original_train_size = len(data_split_dict["train"]) + + if "half_val_dummy_test" in raw_args and raw_args[ + "half_val_dummy_test"]: + # since the "test" set from GLUE dataset may be masked, we need to + # submit to get the ground-truth, for fast FL experiments, + # we split the validation set into two parts with the same size as + # new test/val data + original_val = [(x, y) for x, y in zip(dataset['validation'][0], + dataset['validation'][1])] + data_split_dict["val"], data_split_dict[ + "test"] = original_val[:len(original_val) // + 2], original_val[len(original_val) // + 2:] + if "val_as_dummy_test" in raw_args and raw_args["val_as_dummy_test"]: + # use the validation set as tmp test set, + # and partial training set as validation set + data_split_dict["test"] = data_split_dict["val"] + data_split_dict["val"] = [] + if "part_train_dummy_val" in raw_args and 1 > raw_args[ + "part_train_dummy_val"] > 0: + new_val_part = int(original_train_size * + raw_args["part_train_dummy_val"]) + data_split_dict["val"].extend( + data_split_dict["train"][:new_val_part]) + data_split_dict["train"] = data_split_dict["train"][new_val_part:] + if "part_train_dummy_test" in raw_args and 1 > raw_args[ + "part_train_dummy_test"] > 0: + new_test_part = int(original_train_size * + raw_args["part_train_dummy_test"]) + data_split_dict["test"] = data_split_dict["val"] + if data_split_dict["test"] is not None: + data_split_dict["test"].extend( + data_split_dict["train"][:new_test_part]) + else: + data_split_dict["test"] = ( + data_split_dict["train"][:new_test_part]) + data_split_dict["train"] = data_split_dict["train"][new_test_part:] + + return data_split_dict + + def load_openml_data(tid, splits=None, config=None): + import openml + from sklearn.model_selection import train_test_split + + task = openml.tasks.get_task(int(tid)) + did = task.dataset_id + dataset = load_dataset(did) + data, targets, _, _ = dataset.get_data( + dataset_format="array", target=dataset.default_target_attribute) + + train_data, test_data, train_targets, test_targets = train_test_split( + data, targets, train_size=splits[0], random_state=config.seed) + val_data, test_data, val_targets, test_targets = train_test_split( + test_data, + test_targets, + train_size=splits[1] / (1. - splits[0]), + random_state=config.seed) + data_split_dict = { + 'train': [(x, y) for x, y in zip(train_data, train_targets)], + 'val': [(x, y) for x, y in zip(val_data, val_targets)], + 'test': [(x, y) for x, y in zip(test_data, test_targets)] + } + return data_split_dict + + DATA_LOAD_FUNCS = { + 'torchvision': load_torchvision_data, + 'torchtext': load_torchtext_data, + 'torchaudio': load_torchaudio_data, + 'huggingface_datasets': load_huggingface_datasets_data, + 'openml': load_openml_data + } + + modified_config = config.clone() + + # Load dataset + splits = modified_config.data.splits + name, package = modified_config.data.type.split('@') + + # Comply with the original train/val/test + dataset = DATA_LOAD_FUNCS[package.lower()](name, splits, modified_config) + data_split_tuple = (dataset.get('train'), dataset.get('val'), + dataset.get('test')) + + return data_split_tuple, modified_config + + +def convert_data_mode(data, config): + if config.federate.mode.lower() == 'standalone': + return data + else: + # Invalid data_idx + if config.distribute.data_idx == -1: + return data + elif config.distribute.data_idx not in data.keys(): + data_idx = np.random.choice(list(data.keys())) + logger.warning( + f"The provided data_idx={config.distribute.data_idx} is " + f"invalid, so that we randomly sample a data_idx as {data_idx}" + ) + else: + data_idx = config.distribute.data_idx + return data[data_idx] + + +def get_func_args(func): + sign = inspect.signature(func).parameters.values() + sign = set([val.name for val in sign]) + return sign + + +def filter_dict(func, kwarg): + sign = get_func_args(func) + common_args = sign.intersection(kwarg.keys()) + filtered_dict = {key: kwarg[key] for key in common_args} + return filtered_dict + + +def merge_data(all_data, merged_max_data_id=None, specified_dataset_name=None): + """ + Merge data from client 1 to `merged_max_data_id` contained in given + `all_data`. + :param all_data: + :param merged_max_data_id: + :param specified_dataset_name: + :return: + """ + import torch.utils.data + from federatedscope.core.data.wrap_dataset import WrapDataset + + # Assert + if merged_max_data_id is None: + merged_max_data_id = len(all_data) - 1 + assert merged_max_data_id >= 1 + if specified_dataset_name is None: + dataset_names = list(all_data[1].keys()) # e.g., train, test, val + else: + if not isinstance(specified_dataset_name, list): + specified_dataset_name = [specified_dataset_name] + dataset_names = specified_dataset_name + assert len(dataset_names) >= 1, \ + "At least one sub-dataset is required in client 1" + + data_name = "test" if "test" in dataset_names else dataset_names[0] + id_contain_all_dataset_key = -1 + # check the existence of the data to be merged + for client_id in range(1, merged_max_data_id + 1): + contain_all_dataset_key = True + for dataset_name in dataset_names: + if dataset_name not in all_data[client_id]: + contain_all_dataset_key = False + logger.warning(f'Client {client_id} does not contain ' + f'dataset key {dataset_name}.') + if id_contain_all_dataset_key == -1 and contain_all_dataset_key: + id_contain_all_dataset_key = client_id + assert id_contain_all_dataset_key != -1, \ + "At least one client within [1, merged_max_data_id] should contain " \ + "all the key for expected dataset names." + + if issubclass(type(all_data[id_contain_all_dataset_key][data_name]), + torch.utils.data.DataLoader): + if isinstance(all_data[id_contain_all_dataset_key][data_name].dataset, + WrapDataset): + data_elem_names = list(all_data[id_contain_all_dataset_key] + [data_name].dataset.dataset.keys()) # + # e.g., x, y + merged_data = {name: defaultdict(list) for name in dataset_names} + for data_id in range(1, merged_max_data_id + 1): + for d_name in dataset_names: + if d_name not in all_data[data_id]: + continue + for elem_name in data_elem_names: + merged_data[d_name][elem_name].append( + all_data[data_id] + [d_name].dataset.dataset[elem_name]) + for d_name in dataset_names: + for elem_name in data_elem_names: + merged_data[d_name][elem_name] = np.concatenate( + merged_data[d_name][elem_name]) + for name in all_data[id_contain_all_dataset_key]: + all_data[id_contain_all_dataset_key][ + name].dataset.dataset = merged_data[name] + else: + merged_data = copy.deepcopy(all_data[id_contain_all_dataset_key]) + for data_id in range(1, merged_max_data_id + 1): + if data_id == id_contain_all_dataset_key: + continue + for d_name in dataset_names: + if d_name not in all_data[data_id]: + continue + merged_data[d_name].dataset.extend( + all_data[data_id][d_name].dataset) + else: + raise NotImplementedError( + "Un-supported type when merging data across different clients." + f"Your data type is " + f"{type(all_data[id_contain_all_dataset_key][data_name])}. " + f"Currently we only support the following forms: " + " 1): {data_id: {train: {x:ndarray, y:ndarray}} }" + " 2): {data_id: {train: DataLoader }") + return merged_data diff --git a/fgssl/core/data/wrap_dataset.py b/fgssl/core/data/wrap_dataset.py new file mode 100644 index 0000000..72ed832 --- /dev/null +++ b/fgssl/core/data/wrap_dataset.py @@ -0,0 +1,31 @@ +import torch +import numpy as np +from torch.utils.data import Dataset + + +class WrapDataset(Dataset): + """Wrap raw data into pytorch Dataset + + Arguments: + dataset (dict): raw data dictionary contains "x" and "y" + + """ + def __init__(self, dataset): + super(WrapDataset, self).__init__() + self.dataset = dataset + + def __getitem__(self, idx): + if isinstance(self.dataset["x"][idx], torch.Tensor): + return self.dataset["x"][idx], self.dataset["y"][idx] + elif isinstance(self.dataset["x"][idx], np.ndarray): + return torch.from_numpy( + self.dataset["x"][idx]).float(), torch.from_numpy( + self.dataset["y"][idx]).float() + elif isinstance(self.dataset["x"][idx], list): + return torch.FloatTensor(self.dataset["x"][idx]), \ + torch.FloatTensor(self.dataset["y"][idx]) + else: + raise TypeError + + def __len__(self): + return len(self.dataset["y"]) diff --git a/fgssl/core/fed_runner.py b/fgssl/core/fed_runner.py new file mode 100644 index 0000000..dac9e66 --- /dev/null +++ b/fgssl/core/fed_runner.py @@ -0,0 +1,407 @@ +import logging + +from collections import deque +import heapq + +import numpy as np + +from federatedscope.core.workers import Server, Client +from federatedscope.core.gpu_manager import GPUManager +from federatedscope.core.auxiliaries.model_builder import get_model +from federatedscope.core.auxiliaries.utils import get_resource_info +from federatedscope.core.data.utils import merge_data + +logger = logging.getLogger(__name__) + + +class FedRunner(object): + """ + This class is used to construct an FL course, which includes `_set_up` + and `run`. + + Arguments: + data: The data used in the FL courses, which are formatted as { + 'ID':data} for standalone mode. More details can be found in + federatedscope.core.auxiliaries.data_builder . + server_class: The server class is used for instantiating a ( + customized) server. + client_class: The client class is used for instantiating a ( + customized) client. + config: The configurations of the FL course. + client_configs: The clients' configurations. + """ + def __init__(self, + data, + server_class=Server, + client_class=Client, + config=None, + client_configs=None): + self.data = data + self.server_class = server_class + self.client_class = client_class + assert config is not None, \ + "When using FedRunner, you should specify the `config` para" + if not config.is_ready_for_run: + config.ready_for_run() + self.cfg = config + self.client_cfgs = client_configs + + self.mode = self.cfg.federate.mode.lower() + self.gpu_manager = GPUManager(gpu_available=self.cfg.use_gpu, + specified_device=self.cfg.device) + + self.unseen_clients_id = [] + if self.cfg.federate.unseen_clients_rate > 0: + self.unseen_clients_id = np.random.choice( + np.arange(1, self.cfg.federate.client_num + 1), + size=max( + 1, + int(self.cfg.federate.unseen_clients_rate * + self.cfg.federate.client_num)), + replace=False).tolist() + # get resource information + self.resource_info = get_resource_info( + config.federate.resource_info_file) + + if self.mode == 'standalone': + self.shared_comm_queue = deque() + self._setup_for_standalone() + # in standalone mode, by default, we print the trainer info only + # once for better logs readability + trainer_representative = self.client[1].trainer + if trainer_representative is not None: + trainer_representative.print_trainer_meta_info() + elif self.mode == 'distributed': + self._setup_for_distributed() + + def _setup_for_standalone(self): + """ + To set up server and client for standalone mode. + """ + if self.cfg.backend == 'torch': + import torch + torch.set_num_threads(1) + + assert self.cfg.federate.client_num != 0, \ + "In standalone mode, self.cfg.federate.client_num should be " \ + "non-zero. " \ + "This is usually cased by using synthetic data and users not " \ + "specify a non-zero value for client_num" + + if self.cfg.federate.method == "global": + self.cfg.defrost() + self.cfg.federate.client_num = 1 + self.cfg.federate.sample_client_num = 1 + self.cfg.freeze() + + # sample resource information + if self.resource_info is not None: + if len(self.resource_info) < self.cfg.federate.client_num + 1: + replace = True + logger.warning( + f"Because the provided the number of resource information " + f"{len(self.resource_info)} is less than the number of " + f"participants {self.cfg.federate.client_num+1}, one " + f"candidate might be selected multiple times.") + else: + replace = False + sampled_index = np.random.choice( + list(self.resource_info.keys()), + size=self.cfg.federate.client_num + 1, + replace=replace) + server_resource_info = self.resource_info[sampled_index[0]] + client_resource_info = [ + self.resource_info[x] for x in sampled_index[1:] + ] + else: + server_resource_info = None + client_resource_info = None + + self.server = self._setup_server( + resource_info=server_resource_info, + client_resource_info=client_resource_info) + + self.client = dict() + + # assume the client-wise data are consistent in their input&output + # shape + self._shared_client_model = get_model( + self.cfg.model, self.data[1], backend=self.cfg.backend + ) if self.cfg.federate.share_local_model else None + + for client_id in range(1, self.cfg.federate.client_num + 1): + self.client[client_id] = self._setup_client( + client_id=client_id, + client_model=self._shared_client_model, + resource_info=client_resource_info[client_id - 1] + if client_resource_info is not None else None) + + def _setup_for_distributed(self): + """ + To set up server or client for distributed mode. + """ + + # sample resource information + if self.resource_info is not None: + sampled_index = np.random.choice(list(self.resource_info.keys())) + sampled_resource = self.resource_info[sampled_index] + else: + sampled_resource = None + + self.server_address = { + 'host': self.cfg.distribute.server_host, + 'port': self.cfg.distribute.server_port + } + if self.cfg.distribute.role == 'server': + self.server = self._setup_server(resource_info=sampled_resource) + elif self.cfg.distribute.role == 'client': + # When we set up the client in the distributed mode, we assume + # the server has been set up and number with #0 + self.client_address = { + 'host': self.cfg.distribute.client_host, + 'port': self.cfg.distribute.client_port + } + self.client = self._setup_client(resource_info=sampled_resource) + + def run(self): + """ + To run an FL course, which is called after server/client has been + set up. + For the standalone mode, a shared message queue will be set up to + simulate ``receiving message``. + """ + if self.mode == 'standalone': + # trigger the FL course + for each_client in self.client: + self.client[each_client].join_in() + + if self.cfg.federate.online_aggr: + # any broadcast operation would be executed client-by-client + # to avoid the existence of #clients messages at the same time. + # currently, only consider centralized topology + self._run_simulation_online() + + else: + self._run_simulation() + + self.server._monitor.finish_fed_runner(fl_mode=self.mode) + + return self.server.best_results + + elif self.mode == 'distributed': + if self.cfg.distribute.role == 'server': + self.server.run() + return self.server.best_results + elif self.cfg.distribute.role == 'client': + self.client.join_in() + self.client.run() + + def _run_simulation_online(self): + def is_broadcast(msg): + return len(msg.receiver) >= 1 and msg.sender == 0 + + cached_bc_msgs = [] + cur_idx = 0 + while True: + if len(self.shared_comm_queue) > 0: + msg = self.shared_comm_queue.popleft() + if is_broadcast(msg): + cached_bc_msgs.append(msg) + # assume there is at least one client + msg = cached_bc_msgs[0] + self._handle_msg(msg, rcv=msg.receiver[cur_idx]) + cur_idx += 1 + if cur_idx >= len(msg.receiver): + del cached_bc_msgs[0] + cur_idx = 0 + else: + self._handle_msg(msg) + elif len(cached_bc_msgs) > 0: + msg = cached_bc_msgs[0] + self._handle_msg(msg, rcv=msg.receiver[cur_idx]) + cur_idx += 1 + if cur_idx >= len(msg.receiver): + del cached_bc_msgs[0] + cur_idx = 0 + else: + # finished + break + + def _run_simulation(self): + + server_msg_cache = list() + while True: + if len(self.shared_comm_queue) > 0: + msg = self.shared_comm_queue.popleft() + if msg.receiver == [self.server_id]: + # For the server, move the received message to a + # cache for reordering the messages according to + # the timestamps + heapq.heappush(server_msg_cache, msg) + else: + self._handle_msg(msg) + elif len(server_msg_cache) > 0: + msg = heapq.heappop(server_msg_cache) + if self.cfg.asyn.use and self.cfg.asyn.aggregator \ + == 'time_up': + # When the timestamp of the received message beyond + # the deadline for the currency round, trigger the + # time up event first and push the message back to + # the cache + if self.server.trigger_for_time_up(msg.timestamp): + heapq.heappush(server_msg_cache, msg) + else: + self._handle_msg(msg) + else: + self._handle_msg(msg) + else: + if self.cfg.asyn.use and self.cfg.asyn.aggregator \ + == 'time_up': + self.server.trigger_for_time_up() + if len(self.shared_comm_queue) == 0 and \ + len(server_msg_cache) == 0: + break + else: + # terminate when shared_comm_queue and + # server_msg_cache are all empty + break + + def _setup_server(self, resource_info=None, client_resource_info=None): + """ + Set up the server + """ + self.server_id = 0 + if self.mode == 'standalone': + if self.server_id in self.data: + server_data = self.data[self.server_id] + model = get_model(self.cfg.model, + server_data, + backend=self.cfg.backend) + else: + server_data = None + data_representative = self.data[1] + model = get_model( + self.cfg.model, + data_representative, + backend=self.cfg.backend + ) # get the model according to client's data if the server + # does not own data + kw = { + 'shared_comm_queue': self.shared_comm_queue, + 'resource_info': resource_info, + 'client_resource_info': client_resource_info + } + elif self.mode == 'distributed': + server_data = self.data + model = get_model(self.cfg.model, + server_data, + backend=self.cfg.backend) + kw = self.server_address + kw.update({'resource_info': resource_info}) + else: + raise ValueError('Mode {} is not provided'.format( + self.cfg.mode.type)) + + if self.server_class: + self._server_device = self.gpu_manager.auto_choice() + server = self.server_class( + ID=self.server_id, + config=self.cfg, + data=server_data, + model=model, + client_num=self.cfg.federate.client_num, + total_round_num=self.cfg.federate.total_round_num, + device=self._server_device, + unseen_clients_id=self.unseen_clients_id, + **kw) + + if self.cfg.nbafl.use: + from federatedscope.core.trainers.trainer_nbafl import \ + wrap_nbafl_server + wrap_nbafl_server(server) + + else: + raise ValueError + + logger.info('Server has been set up ... ') + + return server + + def _setup_client(self, + client_id=-1, + client_model=None, + resource_info=None): + """ + Set up the client + """ + self.server_id = 0 + if self.mode == 'standalone': + client_data = self.data[client_id] + kw = { + 'shared_comm_queue': self.shared_comm_queue, + 'resource_info': resource_info + } + elif self.mode == 'distributed': + client_data = self.data + kw = self.client_address + kw['server_host'] = self.server_address['host'] + kw['server_port'] = self.server_address['port'] + kw['resource_info'] = resource_info + else: + raise ValueError('Mode {} is not provided'.format( + self.cfg.mode.type)) + + if self.client_class: + client_specific_config = self.cfg.clone() + if self.client_cfgs: + client_specific_config.defrost() + client_specific_config.merge_from_other_cfg( + self.client_cfgs.get('client_{}'.format(client_id))) + client_specific_config.freeze() + client_device = self._server_device if \ + self.cfg.federate.share_local_model else \ + self.gpu_manager.auto_choice() + client = self.client_class( + ID=client_id, + server_id=self.server_id, + config=client_specific_config, + data=client_data, + model=client_model or get_model(client_specific_config.model, + client_data, + backend=self.cfg.backend), + device=client_device, + is_unseen_client=client_id in self.unseen_clients_id, + **kw) + else: + raise ValueError + + if client_id == -1: + logger.info('Client (address {}:{}) has been set up ... '.format( + self.client_address['host'], self.client_address['port'])) + else: + logger.info(f'Client {client_id} has been set up ... ') + + return client + + def _handle_msg(self, msg, rcv=-1): + """ + To simulate the message handling process (used only for the + standalone mode) + """ + if rcv != -1: + # simulate broadcast one-by-one + self.client[rcv].msg_handlers[msg.msg_type](msg) + return + + _, receiver = msg.sender, msg.receiver + download_bytes, upload_bytes = msg.count_bytes() + if not isinstance(receiver, list): + receiver = [receiver] + for each_receiver in receiver: + if each_receiver == 0: + self.server.msg_handlers[msg.msg_type](msg) + self.server._monitor.track_download_bytes(download_bytes) + else: + self.client[each_receiver].msg_handlers[msg.msg_type](msg) + self.client[each_receiver]._monitor.track_download_bytes( + download_bytes) diff --git a/fgssl/core/gRPC_server.py b/fgssl/core/gRPC_server.py new file mode 100644 index 0000000..9ce82d5 --- /dev/null +++ b/fgssl/core/gRPC_server.py @@ -0,0 +1,21 @@ +import queue +from collections import deque + +from federatedscope.core.proto import gRPC_comm_manager_pb2, \ + gRPC_comm_manager_pb2_grpc + + +class gRPCComServeFunc(gRPC_comm_manager_pb2_grpc.gRPCComServeFuncServicer): + def __init__(self): + self.msg_queue = deque() + + def sendMessage(self, request, context): + self.msg_queue.append(request) + + return gRPC_comm_manager_pb2.MessageResponse(msg='ACK') + + def receive(self): + while len(self.msg_queue) == 0: + continue + msg = self.msg_queue.popleft() + return msg diff --git a/fgssl/core/gpu_manager.py b/fgssl/core/gpu_manager.py new file mode 100644 index 0000000..452dff7 --- /dev/null +++ b/fgssl/core/gpu_manager.py @@ -0,0 +1,90 @@ +import os + + +def check_gpus(): + if not 'NVIDIA System Management' in os.popen('nvidia-smi -h').read(): + print("'nvidia-smi' tool not found.") + return False + return True + + +class GPUManager(): + """ + To automatic allocate the gpu, which returns the gpu with the largest + free memory rate, unless the specified_device has been set up + When gpus is unavailable, return 'cpu'; + The implementation of GPUManager is referred to + https://github.com/QuantumLiu/tf_gpu_manager + """ + def __init__(self, gpu_available=False, specified_device=-1): + self.gpu_avaiable = gpu_available and check_gpus() + self.specified_device = specified_device + if self.gpu_avaiable: + self.gpus = self._query_gpus() + for gpu in self.gpus: + gpu['allocated'] = False + else: + self.gpus = None + + def _sort_by_memory(self, gpus, by_size=False): + if by_size: + return sorted(gpus, key=lambda d: d['memory.free'], reverse=True) + else: + print('Sorted by free memory rate') + return sorted( + gpus, + key=lambda d: float(d['memory.free']) / d['memory.total'], + reverse=True) + + def _query_gpus(self): + args = ['index', 'gpu_name', 'memory.free', 'memory.total'] + cmd = 'nvidia-smi --query-gpu={} --format=csv,noheader'.format( + ','.join(args)) + results = os.popen(cmd).readlines() + return [self._parse(line, args) for line in results] + + def _parse(self, line, args): + numberic_args = ['memory.free', 'memory.total'] + to_numberic = lambda v: float(v.upper().strip().replace('MIB', ''). + replace('W', '')) + process = lambda k, v: (int(to_numberic(v)) + if k in numberic_args else v.strip()) + return { + k: process(k, v) + for k, v in zip(args, + line.strip().split(',')) + } + + def auto_choice(self): + """ + To allocate a device + """ + if self.gpus is None: + return 'cpu' + elif self.specified_device >= 0: + # allow users to specify the device + return 'cuda:{}'.format(self.specified_device) + else: + for old_infos, new_infos in zip(self.gpus, self._query_gpus()): + old_infos.update(new_infos) + unallocated_gpus = [ + gpu for gpu in self.gpus if not gpu['allocated'] + ] + if len(unallocated_gpus) == 0: + # reset when all gpus have been allocated + unallocated_gpus = self.gpus + for gpu in self.gpus: + gpu['allocated'] = False + + chosen_gpu = self._sort_by_memory(unallocated_gpus, True)[0] + chosen_gpu['allocated'] = True + index = chosen_gpu['index'] + return 'cuda:{:s}'.format(index) + + +# for testing +if __name__ == '__main__': + + gpu_manager = GPUManager(gpu_available=True, specified_device=0) + for i in range(20): + print(gpu_manager.auto_choice()) diff --git a/fgssl/core/lr.py b/fgssl/core/lr.py new file mode 100644 index 0000000..16e8462 --- /dev/null +++ b/fgssl/core/lr.py @@ -0,0 +1,10 @@ +import torch + + +class LogisticRegression(torch.nn.Module): + def __init__(self, in_channels, class_num, use_bias=True): + super(LogisticRegression, self).__init__() + self.fc = torch.nn.Linear(in_channels, class_num, bias=use_bias) + + def forward(self, x): + return self.fc(x) diff --git a/fgssl/core/message.py b/fgssl/core/message.py new file mode 100644 index 0000000..87e3a2a --- /dev/null +++ b/fgssl/core/message.py @@ -0,0 +1,255 @@ +import json +import numpy as np +from federatedscope.core.proto import gRPC_comm_manager_pb2 + + +class Message(object): + """ + The data exchanged during an FL course are abstracted as 'Message' in + FederatedScope. + A message object includes: + msg_type: The type of message, which is used to trigger the + corresponding handlers of server/client + sender: The sender's ID + receiver: The receiver's ID + state: The training round of the message, which is determined by + the sender and used to filter out the outdated messages. + strategy: redundant attribute + """ + def __init__(self, + msg_type=None, + sender=0, + receiver=0, + state=0, + content=None, + timestamp=0, + strategy=None): + self._msg_type = msg_type + self._sender = sender + self._receiver = receiver + self._state = state + self._content = content + self._timestamp = timestamp + self._strategy = strategy + + @property + def msg_type(self): + return self._msg_type + + @msg_type.setter + def msg_type(self, value): + self._msg_type = value + + @property + def sender(self): + return self._sender + + @sender.setter + def sender(self, value): + self._sender = value + + @property + def receiver(self): + return self._receiver + + @receiver.setter + def receiver(self, value): + self._receiver = value + + @property + def state(self): + return self._state + + @state.setter + def state(self, value): + self._state = value + + @property + def content(self): + return self._content + + @content.setter + def content(self, value): + self._content = value + + @property + def timestamp(self): + return self._timestamp + + @timestamp.setter + def timestamp(self, value): + assert isinstance(value, int) or isinstance(value, float), \ + "We only support an int or a float value for timestamp" + self._timestamp = value + + @property + def strategy(self): + return self._strategy + + @strategy.setter + def strategy(self, value): + self._strategy = value + + def __lt__(self, other): + if self.timestamp != other.timestamp: + return self.timestamp < other.timestamp + else: + return self.state < other.state + + def transform_to_list(self, x): + if isinstance(x, list) or isinstance(x, tuple): + return [self.transform_to_list(each_x) for each_x in x] + elif isinstance(x, dict): + for key in x.keys(): + x[key] = self.transform_to_list(x[key]) + return x + else: + if hasattr(x, 'tolist'): + return x.tolist() + else: + return x + + def msg_to_json(self, to_list=False): + if to_list: + self.content = self.transform_to_list(self.content) + + json_msg = { + 'msg_type': self.msg_type, + 'sender': self.sender, + 'receiver': self.receiver, + 'state': self.state, + 'content': self.content, + 'timestamp': self.timestamp, + 'strategy': self.strategy, + } + return json.dumps(json_msg) + + def json_to_msg(self, json_string): + json_msg = json.loads(json_string) + self.msg_type = json_msg['msg_type'] + self.sender = json_msg['sender'] + self.receiver = json_msg['receiver'] + self.state = json_msg['state'] + self.content = json_msg['content'] + self.timestamp = json_msg['timestamp'] + self.strategy = json_msg['strategy'] + + def create_by_type(self, value, nested=False): + if isinstance(value, dict): + if isinstance(list(value.keys())[0], str): + m_dict = gRPC_comm_manager_pb2.mDict_keyIsString() + key_type = 'string' + else: + m_dict = gRPC_comm_manager_pb2.mDict_keyIsInt() + key_type = 'int' + + for key in value.keys(): + m_dict.dict_value[key].MergeFrom( + self.create_by_type(value[key], nested=True)) + if nested: + msg_value = gRPC_comm_manager_pb2.MsgValue() + if key_type == 'string': + msg_value.dict_msg_stringkey.MergeFrom(m_dict) + else: + msg_value.dict_msg_intkey.MergeFrom(m_dict) + return msg_value + else: + return m_dict + elif isinstance(value, list) or isinstance(value, tuple): + m_list = gRPC_comm_manager_pb2.mList() + for each in value: + m_list.list_value.append(self.create_by_type(each, + nested=True)) + if nested: + msg_value = gRPC_comm_manager_pb2.MsgValue() + msg_value.list_msg.MergeFrom(m_list) + return msg_value + else: + return m_list + else: + m_single = gRPC_comm_manager_pb2.mSingle() + if type(value) in [int, np.int32]: + m_single.int_value = value + elif type(value) in [str]: + m_single.str_value = value + elif type(value) in [float, np.float32]: + m_single.float_value = value + else: + raise ValueError( + 'The data type {} has not been supported.'.format( + type(value))) + + if nested: + msg_value = gRPC_comm_manager_pb2.MsgValue() + msg_value.single_msg.MergeFrom(m_single) + return msg_value + else: + return m_single + + def build_msg_value(self, value): + msg_value = gRPC_comm_manager_pb2.MsgValue() + + if isinstance(value, list) or isinstance(value, tuple): + msg_value.list_msg.MergeFrom(self.create_by_type(value)) + elif isinstance(value, dict): + if isinstance(list(value.keys())[0], str): + msg_value.dict_msg_stringkey.MergeFrom( + self.create_by_type(value)) + else: + msg_value.dict_msg_intkey.MergeFrom(self.create_by_type(value)) + else: + msg_value.single_msg.MergeFrom(self.create_by_type(value)) + + return msg_value + + def transform(self, to_list=False): + if to_list: + self.content = self.transform_to_list(self.content) + + splited_msg = gRPC_comm_manager_pb2.MessageRequest() # map/dict + splited_msg.msg['sender'].MergeFrom(self.build_msg_value(self.sender)) + splited_msg.msg['receiver'].MergeFrom( + self.build_msg_value(self.receiver)) + splited_msg.msg['state'].MergeFrom(self.build_msg_value(self.state)) + splited_msg.msg['msg_type'].MergeFrom( + self.build_msg_value(self.msg_type)) + splited_msg.msg['content'].MergeFrom(self.build_msg_value( + self.content)) + splited_msg.msg['timestamp'].MergeFrom( + self.build_msg_value(self.timestamp)) + return splited_msg + + def _parse_msg(self, value): + if isinstance(value, gRPC_comm_manager_pb2.MsgValue) or isinstance( + value, gRPC_comm_manager_pb2.mSingle): + return self._parse_msg(getattr(value, value.WhichOneof("type"))) + elif isinstance(value, gRPC_comm_manager_pb2.mList): + return [self._parse_msg(each) for each in value.list_value] + elif isinstance(value, gRPC_comm_manager_pb2.mDict_keyIsString) or \ + isinstance(value, gRPC_comm_manager_pb2.mDict_keyIsInt): + return { + k: self._parse_msg(value.dict_value[k]) + for k in value.dict_value + } + else: + return value + + def parse(self, received_msg): + self.sender = self._parse_msg(received_msg['sender']) + self.receiver = self._parse_msg(received_msg['receiver']) + self.msg_type = self._parse_msg(received_msg['msg_type']) + self.state = self._parse_msg(received_msg['state']) + self.content = self._parse_msg(received_msg['content']) + self.timestamp = self._parse_msg(received_msg['timestamp']) + + def count_bytes(self): + """ + calculate the message bytes to be sent/received + :return: tuple of bytes of the message to be sent and received + """ + from pympler import asizeof + download_bytes = asizeof.asizeof(self.content) + upload_cnt = len(self.receiver) if isinstance(self.receiver, + list) else 1 + upload_bytes = download_bytes * upload_cnt + return download_bytes, upload_bytes diff --git a/fgssl/core/mlp.py b/fgssl/core/mlp.py new file mode 100644 index 0000000..a71b76e --- /dev/null +++ b/fgssl/core/mlp.py @@ -0,0 +1,40 @@ +import torch +import torch.nn.functional as F +from torch.nn import Linear, ModuleList +from torch.nn import BatchNorm1d, Identity + + +class MLP(torch.nn.Module): + """ + Multilayer Perceptron + """ + def __init__(self, + channel_list, + dropout=0., + batch_norm=True, + relu_first=False): + super().__init__() + assert len(channel_list) >= 2 + self.channel_list = channel_list + self.dropout = dropout + self.relu_first = relu_first + + self.linears = ModuleList() + self.norms = ModuleList() + for in_channel, out_channel in zip(channel_list[:-1], + channel_list[1:]): + self.linears.append(Linear(in_channel, out_channel)) + self.norms.append( + BatchNorm1d(out_channel) if batch_norm else Identity()) + + def forward(self, x): + x = self.linears[0](x) + for layer, norm in zip(self.linears[1:], self.norms[:-1]): + if self.relu_first: + x = F.relu(x) + x = norm(x) + if not self.relu_first: + x = F.relu(x) + x = F.dropout(x, p=self.dropout, training=self.training) + x = layer.forward(x) + return x diff --git a/fgssl/core/monitors/__init__.py b/fgssl/core/monitors/__init__.py new file mode 100644 index 0000000..3f945b5 --- /dev/null +++ b/fgssl/core/monitors/__init__.py @@ -0,0 +1,5 @@ +from federatedscope.core.monitors.early_stopper import EarlyStopper +from federatedscope.core.monitors.metric_calculator import MetricCalculator +from federatedscope.core.monitors.monitor import Monitor + +__all__ = ['EarlyStopper', 'MetricCalculator', 'Monitor'] diff --git a/fgssl/core/monitors/early_stopper.py b/fgssl/core/monitors/early_stopper.py new file mode 100644 index 0000000..ab35163 --- /dev/null +++ b/fgssl/core/monitors/early_stopper.py @@ -0,0 +1,103 @@ +import operator +import numpy as np + + +# TODO: make this as a sub-module of monitor class +class EarlyStopper(object): + """ + Track the history of metric (e.g., validation loss), + check whether should stop (training) process if the metric doesn't + improve after a given patience. + """ + def __init__(self, + patience=5, + delta=0, + improve_indicator_mode='best', + the_smaller_the_better=True): + """ + Args: + patience (int): How long to wait after last time the monitored + metric improved. + Note that the + actual_checking_round = patience * cfg.eval.freq + Default: 5 + delta (float): Minimum change in the monitored metric to + indicate an improvement. + Default: 0 + improve_indicator_mode (str): Early stop when no improve to + last `patience` round, in ['mean', 'best'] + """ + assert 0 <= patience == int( + patience + ), "Please use a non-negtive integer to indicate the patience" + assert delta >= 0, "Please use a positive value to indicate the change" + assert improve_indicator_mode in [ + 'mean', 'best' + ], "Please make sure `improve_indicator_mode` is 'mean' or 'best']" + + self.patience = patience + self.counter_no_improve = 0 + self.best_metric = None + self.early_stopped = False + self.the_smaller_the_better = the_smaller_the_better + self.delta = delta + self.improve_indicator_mode = improve_indicator_mode + # For expansion usages of comparisons + self.comparator = operator.lt + self.improvement_operator = operator.add + + def __track_and_check_dummy(self, new_result): + self.early_stopped = False + return self.early_stopped + + def __track_and_check_best(self, history_result): + new_result = history_result[-1] + if self.best_metric is None: + self.best_metric = new_result + elif self.the_smaller_the_better and self.comparator( + self.improvement_operator(self.best_metric, -self.delta), + new_result): + # add(best_metric, -delta) < new_result + self.counter_no_improve += 1 + elif not self.the_smaller_the_better and self.comparator( + new_result, + self.improvement_operator(self.best_metric, self.delta)): + # new_result < add(best_metric, delta) + self.counter_no_improve += 1 + else: + self.best_metric = new_result + self.counter_no_improve = 0 + + self.early_stopped = self.counter_no_improve >= self.patience + return self.early_stopped + + def __track_and_check_mean(self, history_result): + new_result = history_result[-1] + if len(history_result) > self.patience: + if self.the_smaller_the_better and self.comparator( + self.improvement_operator( + np.mean(history_result[-self.patience - 1:-1]), + -self.delta), new_result): + self.early_stopped = True + elif not self.the_smaller_the_better and self.comparator( + new_result, + self.improvement_operator( + np.mean(history_result[-self.patience - 1:-1]), + self.delta)): + self.early_stopped = True + else: + self.early_stopped = False + + return self.early_stopped + + def track_and_check(self, new_result): + + track_method = self.__track_and_check_dummy # do nothing + if self.patience == 0: + track_method = self.__track_and_check_dummy + elif self.improve_indicator_mode == 'best': + track_method = self.__track_and_check_best + elif self.improve_indicator_mode == 'mean': + track_method = self.__track_and_check_mean + + return track_method(new_result) diff --git a/fgssl/core/monitors/metric_calculator.py b/fgssl/core/monitors/metric_calculator.py new file mode 100644 index 0000000..6d32122 --- /dev/null +++ b/fgssl/core/monitors/metric_calculator.py @@ -0,0 +1,235 @@ +import logging +from typing import Optional, Union, List, Set + +import numpy as np +from scipy.special import softmax +from sklearn.metrics import roc_auc_score, average_precision_score, f1_score + +from federatedscope.core.auxiliaries.metric_builder import get_metric + +# Blind torch +try: + import torch +except ImportError: + torch = None + +logger = logging.getLogger(__name__) + + +# TODO: make this as a sub-module of monitor class +class MetricCalculator(object): + def __init__(self, eval_metric: Union[Set[str], List[str], str]): + + # Add personalized metrics + if isinstance(eval_metric, str): + eval_metric = {eval_metric} + elif isinstance(eval_metric, list): + eval_metric = set(eval_metric) + + # Default metric is {'loss', 'avg_loss', 'total'} + self.eval_metric = self.get_metric_funcs(eval_metric) + + def get_metric_funcs(self, eval_metric): + metric_buildin = { + metric: SUPPORT_METRICS[metric] + for metric in {'loss', 'avg_loss', 'total'} | eval_metric + if metric in SUPPORT_METRICS + } + metric_register = get_metric(eval_metric - set(SUPPORT_METRICS.keys())) + return {**metric_buildin, **metric_register} + + def eval(self, ctx): + results = {} + y_true, y_pred, y_prob = self._check_and_parse(ctx) + for metric, func in self.eval_metric.items(): + results["{}_{}".format(ctx.cur_split, + metric)] = func(ctx=ctx, + y_true=y_true, + y_pred=y_pred, + y_prob=y_prob, + metric=metric) + return results + + def _check_and_parse(self, ctx): + """Check the format of the prediction and labels + + Args: + ctx: + + Returns: + y_true: The ground truth labels + y_pred: The prediction categories for classification task + y_prob: The output of the model + + """ + if ctx.get('ys_true', None) is None: + raise KeyError('Missing key ys_true!') + if ctx.get('ys_prob', None) is None: + raise KeyError('Missing key ys_prob!') + + y_true = ctx.ys_true + y_prob = ctx.ys_prob + + if torch is not None and isinstance(y_true, torch.Tensor): + y_true = y_true.detach().cpu().numpy() + if torch is not None and isinstance(y_prob, torch.Tensor): + y_prob = y_prob.detach().cpu().numpy() + + if 'regression' in ctx.cfg.model.task.lower(): + y_pred = None + else: + # classification task + if y_true.ndim == 1: + y_true = np.expand_dims(y_true, axis=-1) + if y_prob.ndim == 2: + y_prob = np.expand_dims(y_prob, axis=-1) + + # if len(y_prob.shape) > len(y_true.shape): + y_pred = np.argmax(y_prob, axis=1) + + # check shape and type + if not isinstance(y_true, np.ndarray): + raise RuntimeError('Type not support!') + if not y_true.shape == y_pred.shape: + raise RuntimeError('Shape not match!') + if not y_true.ndim == 2: + raise RuntimeError( + 'y_true must be 2-dim array, {}-dim given'.format( + y_true.ndim)) + + return y_true, y_pred, y_prob + + +def eval_correct(y_true, y_pred, **kwargs): + correct_list = [] + + for i in range(y_true.shape[1]): + is_labeled = y_true[:, i] == y_true[:, i] + correct = y_true[is_labeled, i] == y_pred[is_labeled, i] + correct_list.append(np.sum(correct)) + return sum(correct_list) / len(correct_list) + + +def eval_acc(y_true, y_pred, **kwargs): + acc_list = [] + + for i in range(y_true.shape[1]): + is_labeled = y_true[:, i] == y_true[:, i] + correct = y_true[is_labeled, i] == y_pred[is_labeled, i] + acc_list.append(float(np.sum(correct)) / len(correct)) + return sum(acc_list) / len(acc_list) + + +def eval_ap(y_true, y_pred, **kwargs): + ap_list = [] + + for i in range(y_true.shape[1]): + # AUC is only defined when there is at least one positive data. + if np.sum(y_true[:, i] == 1) > 0 and np.sum(y_true[:, i] == 0) > 0: + # ignore nan values + is_labeled = y_true[:, i] == y_true[:, i] + ap = average_precision_score(y_true[is_labeled, i], + y_pred[is_labeled, i]) + + ap_list.append(ap) + + if len(ap_list) == 0: + logger.warning('No positively labeled data available. ') + return 0.0 + + return sum(ap_list) / len(ap_list) + + +def eval_f1_score(y_true, y_pred, **kwargs): + return f1_score(y_true, y_pred, average='macro') + + +def eval_hits(y_true, y_prob, metric, **kwargs): + n = int(metric.split('@')[1]) + hits_list = [] + for i in range(y_true.shape[1]): + idx = np.argsort(-y_prob[:, :, i], axis=1) + pred_rank = idx.argsort(axis=1) + # Obtain the label rank + arg = np.arange(0, pred_rank.shape[0]) + rank = pred_rank[arg, y_true[:, i]] + 1 + hits_num = (rank <= n).sum().item() + hits_list.append(float(hits_num) / len(rank)) + + return sum(hits_list) / len(hits_list) + + +def eval_roc_auc(y_true, y_prob, **kwargs): + rocauc_list = [] + + for i in range(y_true.shape[1]): + # AUC is only defined when there is at least one positive data. + if np.sum(y_true[:, i] == 1) > 0 and np.sum(y_true[:, i] == 0) > 0: + # ignore nan values + is_labeled = y_true[:, i] == y_true[:, i] + y_true_one_hot = np.eye(y_prob.shape[1])[y_true[is_labeled, i]] + rocauc_list.append( + roc_auc_score(y_true_one_hot, + softmax(y_prob[is_labeled, :, i], axis=-1))) + if len(rocauc_list) == 0: + logger.warning('No positively labeled data available.') + return 0.5 + + return sum(rocauc_list) / len(rocauc_list) + + +def eval_rmse(y_true, y_prob, **kwargs): + return np.sqrt(np.mean(np.power(y_true - y_prob, 2))) + + +def eval_mse(y_true, y_prob, **kwargs): + return np.mean(np.power(y_true - y_prob, 2)) + + +def eval_loss(ctx, **kwargs): + return ctx.loss_batch_total + + +def eval_avg_loss(ctx, **kwargs): + return ctx.loss_batch_total / ctx.num_samples + + +def eval_total(ctx, **kwargs): + return ctx.num_samples + + +def eval_regular(ctx, **kwargs): + return ctx.loss_regular_total + + +def eval_imp_ratio(ctx, y_true, y_prob, y_pred, **kwargs): + if not hasattr(ctx.cfg.eval, 'base') or ctx.cfg.eval.base <= 0: + logger.info( + "To use the metric `imp_rato`, please set `eval.base` as the " + "basic performance and it must be greater than zero.") + return 0. + + base = ctx.cfg.eval.base + task = ctx.cfg.model.task.lower() + if 'regression' in task: + perform = eval_mse(y_true, y_prob) + elif 'classification' in task: + perform = 1 - eval_acc(y_true, y_pred) + return (base - perform) / base * 100. + + +SUPPORT_METRICS = { + 'loss': eval_loss, + 'avg_loss': eval_avg_loss, + 'total': eval_total, + 'correct': eval_correct, + 'acc': eval_acc, + 'ap': eval_ap, + 'f1': eval_f1_score, + 'roc_auc': eval_roc_auc, + 'rmse': eval_rmse, + 'mse': eval_mse, + 'loss_regular': eval_regular, + 'imp_ratio': eval_imp_ratio, + **dict.fromkeys([f'hits@{n}' for n in range(1, 101)], eval_hits) +} diff --git a/fgssl/core/monitors/monitor.py b/fgssl/core/monitors/monitor.py new file mode 100644 index 0000000..0f48321 --- /dev/null +++ b/fgssl/core/monitors/monitor.py @@ -0,0 +1,655 @@ +import copy +import json +import logging +import os +import gzip +import shutil +import datetime +from collections import defaultdict + +import numpy as np + +from federatedscope.core.auxiliaries.logging import logline_2_wandb_dict + +try: + import torch +except ImportError: + torch = None + +logger = logging.getLogger(__name__) + +global_all_monitors = [ +] # used in standalone mode, to merge sys metric results for all workers + + +class Monitor(object): + """ + Provide the monitoring functionalities such as formatting the + evaluation results into diverse metrics. + Besides the prediction related performance, the monitor also can + track efficiency related metrics for a worker + """ + SUPPORTED_FORMS = ['weighted_avg', 'avg', 'fairness', 'raw'] + + def __init__(self, cfg, monitored_object=None): + self.log_res_best = {} + self.outdir = cfg.outdir + self.use_wandb = cfg.wandb.use + self.wandb_online_track = cfg.wandb.online_track if cfg.wandb.use \ + else False + # self.use_tensorboard = cfg.use_tensorboard + + self.monitored_object = monitored_object + + # ======= efficiency indicators of the worker to be monitored ======= + # leveraged the flops counter provided by [fvcore]( + # https://github.com/facebookresearch/fvcore) + self.total_model_size = 0 # model size used in the worker, in terms + # of number of parameters + self.flops_per_sample = 0 # average flops for forwarding each data + # sample + self.flop_count = 0 # used to calculated the running mean for + # flops_per_sample + self.total_flops = 0 # total computation flops to convergence until + # current fl round + self.total_upload_bytes = 0 # total upload space cost in bytes + # until current fl round + self.total_download_bytes = 0 # total download space cost in bytes + # until current fl round + self.fl_begin_wall_time = datetime.datetime.now() + self.fl_end_wall_time = 0 + # for the metrics whose names includes "convergence", 0 indicates + # the worker does not converge yet + # Note: + # 1) the convergence wall time is prone to fluctuations due to + # possible resource competition during FL courses + # 2) the global/local indicates whether the early stopping triggered + # with global-aggregation/local-training + self.global_convergence_round = 0 # total fl rounds to convergence + self.global_convergence_wall_time = 0 + self.local_convergence_round = 0 # total fl rounds to convergence + self.local_convergence_wall_time = 0 + + if self.wandb_online_track: + global_all_monitors.append(self) + if self.use_wandb: + try: + import wandb + except ImportError: + logger.error( + "cfg.wandb.use=True but not install the wandb package") + exit() + + def global_converged(self): + self.global_convergence_wall_time = datetime.datetime.now( + ) - self.fl_begin_wall_time + self.global_convergence_round = self.monitored_object.state + + def local_converged(self): + self.local_convergence_wall_time = datetime.datetime.now( + ) - self.fl_begin_wall_time + self.local_convergence_round = self.monitored_object.state + + def finish_fl(self): + self.fl_end_wall_time = datetime.datetime.now( + ) - self.fl_begin_wall_time + + system_metrics = self.get_sys_metrics() + sys_metric_f_name = os.path.join(self.outdir, "system_metrics.log") + with open(sys_metric_f_name, "a") as f: + f.write(json.dumps(system_metrics) + "\n") + + def get_sys_metrics(self, verbose=True): + system_metrics = { + "id": self.monitored_object.ID, + "fl_end_time_minutes": self.fl_end_wall_time.total_seconds() / + 60 if isinstance(self.fl_end_wall_time, datetime.timedelta) else 0, + "total_model_size": self.total_model_size, + "total_flops": self.total_flops, + "total_upload_bytes": self.total_upload_bytes, + "total_download_bytes": self.total_download_bytes, + "global_convergence_round": self.global_convergence_round, + "local_convergence_round": self.local_convergence_round, + "global_convergence_time_minutes": self. + global_convergence_wall_time.total_seconds() / 60 if isinstance( + self.global_convergence_wall_time, datetime.timedelta) else 0, + "local_convergence_time_minutes": self.local_convergence_wall_time. + total_seconds() / 60 if isinstance( + self.local_convergence_wall_time, datetime.timedelta) else 0, + } + if verbose: + logger.info( + f"In worker #{self.monitored_object.ID}, the system-related " + f"metrics are: {str(system_metrics)}") + return system_metrics + + def merge_system_metrics_simulation_mode(self, + file_io=True, + from_global_monitors=False): + """ + average the system metrics recorded in "system_metrics.json" by + all workers + :return: + """ + + all_sys_metrics = defaultdict(list) + avg_sys_metrics = defaultdict() + std_sys_metrics = defaultdict() + + if file_io: + sys_metric_f_name = os.path.join(self.outdir, "system_metrics.log") + if not os.path.exists(sys_metric_f_name): + logger.warning( + "You have not tracked the workers' system metrics in " + "$outdir$/system_metrics.log, " + "we will skip the merging. Plz check whether you do not " + "want to call monitor.finish_fl()") + return + with open(sys_metric_f_name, "r") as f: + for line in f: + res = json.loads(line) + if all_sys_metrics is None: + all_sys_metrics = res + all_sys_metrics["id"] = "all" + else: + for k, v in res.items(): + all_sys_metrics[k].append(v) + id_to_be_merged = all_sys_metrics["id"] + if len(id_to_be_merged) != len(set(id_to_be_merged)): + logger.warning( + f"The sys_metric_file ({sys_metric_f_name}) contains " + f"duplicated tracked sys-results with these ids: " + f"f{id_to_be_merged} " + f"We will skip the merging as the merge is invalid. " + f"Plz check whether you specify the 'outdir' " + f"as the same as the one of another older experiment.") + return + elif from_global_monitors: + for monitor in global_all_monitors: + res = monitor.get_sys_metrics(verbose=False) + if all_sys_metrics is None: + all_sys_metrics = res + all_sys_metrics["id"] = "all" + else: + for k, v in res.items(): + all_sys_metrics[k].append(v) + else: + raise ValueError("file_io or from_monitors should be True: " + f"but got file_io={file_io}, from_monitors" + f"={from_global_monitors}") + + for k, v in all_sys_metrics.items(): + if k == "id": + avg_sys_metrics[k] = "sys_avg" + std_sys_metrics[k] = "sys_std" + else: + v = np.array(v).astype("float") + mean_res = np.mean(v) + std_res = np.std(v) + if "flops" in k or "bytes" in k or "size" in k: + mean_res = self.convert_size(mean_res) + std_res = self.convert_size(std_res) + avg_sys_metrics[f"sys_avg/{k}"] = mean_res + std_sys_metrics[f"sys_std/{k}"] = std_res + + logger.info( + f"After merging the system metrics from all works, we got avg:" + f" {avg_sys_metrics}") + logger.info( + f"After merging the system metrics from all works, we got std:" + f" {std_sys_metrics}") + if file_io: + with open(sys_metric_f_name, "a") as f: + f.write(json.dumps(avg_sys_metrics) + "\n") + f.write(json.dumps(std_sys_metrics) + "\n") + + if self.use_wandb and self.wandb_online_track: + try: + import wandb + # wandb.log(avg_sys_metrics) + # wandb.log(std_sys_metrics) + for k, v in avg_sys_metrics.items(): + wandb.summary[k] = v + for k, v in std_sys_metrics.items(): + wandb.summary[k] = v + except ImportError: + logger.error( + "cfg.wandb.use=True but not install the wandb package") + exit() + + def save_formatted_results(self, + formatted_res, + save_file_name="eval_results.log"): + line = str(formatted_res) + "\n" + if save_file_name != "": + with open(os.path.join(self.outdir, save_file_name), + "a") as outfile: + outfile.write(line) + if self.use_wandb and self.wandb_online_track: + try: + import wandb + exp_stop_normal = False + exp_stop_normal, log_res = logline_2_wandb_dict( + exp_stop_normal, line, self.log_res_best, raw_out=False) + wandb.log(log_res) + except ImportError: + logger.error( + "cfg.wandb.use=True but not install the wandb package") + exit() + + def finish_fed_runner(self, fl_mode=None): + self.compress_raw_res_file() + if fl_mode == "standalone": + self.merge_system_metrics_simulation_mode() + + if self.use_wandb and not self.wandb_online_track: + try: + import wandb + except ImportError: + logger.error( + "cfg.wandb.use=True but not install the wandb package") + exit() + + from federatedscope.core.auxiliaries.logging import \ + logfile_2_wandb_dict + with open(os.path.join(self.outdir, "eval_results.log"), + "r") as exp_log_f: + # track the prediction related performance + all_log_res, exp_stop_normal, last_line, log_res_best = \ + logfile_2_wandb_dict(exp_log_f, raw_out=False) + for log_res in all_log_res: + wandb.log(log_res) + wandb.log(log_res_best) + + # track the system related performance + sys_metric_f_name = os.path.join(self.outdir, + "system_metrics.log") + with open(sys_metric_f_name, "r") as f: + for line in f: + res = json.loads(line) + if res["id"] in ["sys_avg", "sys_std"]: + # wandb.log(res) + for k, v in res.items(): + wandb.summary[k] = v + + def compress_raw_res_file(self): + old_f_name = os.path.join(self.outdir, "eval_results.raw") + if os.path.exists(old_f_name): + logger.info( + "We will compress the file eval_results.raw into a .gz file, " + "and delete the old one") + with open(old_f_name, 'rb') as f_in: + with gzip.open(old_f_name + ".gz", 'wb') as f_out: + shutil.copyfileobj(f_in, f_out) + os.remove(old_f_name) + + def format_eval_res(self, + results, + rnd, + role=-1, + forms=None, + return_raw=False): + """ + format the evaluation results from trainer.ctx.eval_results + + Args: + results (dict): a dict to store the evaluation results {metric: + value} + rnd (int|string): FL round + role (int|string): the output role + forms (list): format type + return_raw (bool): return either raw results, or other results + + Returns: + round_formatted_results (dict): a formatted results with + different forms and roles, + e.g., + { + 'Role': 'Server #', + 'Round': 200, + 'Results_weighted_avg': { + 'test_avg_loss': 0.58, 'test_acc': 0.67, 'test_correct': + 3356, 'test_loss': 2892, 'test_total': 5000 + }, + 'Results_avg': { + 'test_avg_loss': 0.57, 'test_acc': 0.67, 'test_correct': + 3356, 'test_loss': 2892, 'test_total': 5000 + }, + 'Results_fairness': { + 'test_total': 33.99, 'test_correct': 27.185, + 'test_avg_loss_std': 0.433551, + 'test_avg_loss_bottom_decile': 0.356503, + 'test_avg_loss_top_decile': 1.212492, + 'test_avg_loss_min': 0.198317, 'test_avg_loss_max': 3.603567, + 'test_avg_loss_bottom10%': 0.276681, 'test_avg_loss_top10%': + 1.686649, + 'test_avg_loss_cos1': 0.867932, 'test_avg_loss_entropy': 5.164172, + 'test_loss_std': 13.686828, 'test_loss_bottom_decile': 11.822035, + 'test_loss_top_decile': 39.727236, 'test_loss_min': 7.337724, + 'test_loss_max': 100.899873, 'test_loss_bottom10%': 9.618685, + 'test_loss_top10%': 54.96769, 'test_loss_cos1': 0.880356, + 'test_loss_entropy': 5.175803, 'test_acc_std': 0.123823, + 'test_acc_bottom_decile': 0.676471, 'test_acc_top_decile': + 0.916667, + 'test_acc_min': 0.071429, 'test_acc_max': 0.972973, + 'test_acc_bottom10%': 0.527482, 'test_acc_top10%': 0.94486, + 'test_acc_cos1': 0.988134, 'test_acc_entropy': 5.283755 + }, + } + """ + if forms is None: + forms = ['weighted_avg', 'avg', 'fairness', 'raw'] + round_formatted_results = {'Role': role, 'Round': rnd} + round_formatted_results_raw = {'Role': role, 'Round': rnd} + for form in forms: + new_results = copy.deepcopy(results) + if not role.lower().startswith('server') or form == 'raw': + round_formatted_results_raw['Results_raw'] = new_results + elif form not in Monitor.SUPPORTED_FORMS: + continue + else: + for key in results.keys(): + dataset_name = key.split("_")[0] + if f'{dataset_name}_total' not in results: + raise ValueError( + "Results to be formatted should be include the " + "dataset_num in the dict," + f"with key = {dataset_name}_total") + else: + dataset_num = np.array( + results[f'{dataset_name}_total']) + if key in [ + f'{dataset_name}_total', + f'{dataset_name}_correct' + ]: + new_results[key] = np.mean(new_results[key]) + + if key in [ + f'{dataset_name}_total', f'{dataset_name}_correct' + ]: + new_results[key] = np.mean(new_results[key]) + else: + all_res = np.array(copy.copy(results[key])) + if form == 'weighted_avg': + new_results[key] = np.sum( + np.array(new_results[key]) * + dataset_num) / np.sum(dataset_num) + if form == "avg": + new_results[key] = np.mean(new_results[key]) + if form == "fairness" and all_res.size > 1: + # by default, log the std and decile + new_results.pop( + key, None) # delete the redundant original one + all_res.sort() + new_results[f"{key}_std"] = np.std( + np.array(all_res)) + new_results[f"{key}_bottom_decile"] = all_res[ + all_res.size // 10] + new_results[f"{key}_top_decile"] = all_res[ + all_res.size * 9 // 10] + # log more fairness metrics + # min and max + new_results[f"{key}_min"] = all_res[0] + new_results[f"{key}_max"] = all_res[-1] + # bottom and top 10% + new_results[f"{key}_bottom10%"] = np.mean( + all_res[:all_res.size // 10]) + new_results[f"{key}_top10%"] = np.mean( + all_res[all_res.size * 9 // 10:]) + # cosine similarity between the performance + # distribution and 1 + new_results[f"{key}_cos1"] = np.mean(all_res) / ( + np.sqrt(np.mean(all_res**2))) + # entropy of performance distribution + all_res_preprocessed = all_res + 1e-9 + new_results[f"{key}_entropy"] = np.sum( + -all_res_preprocessed / + np.sum(all_res_preprocessed) * (np.log( + (all_res_preprocessed) / + np.sum(all_res_preprocessed)))) + round_formatted_results[f'Results_{form}'] = new_results + + with open(os.path.join(self.outdir, "eval_results.raw"), + "a") as outfile: + outfile.write(str(round_formatted_results_raw) + "\n") + + return round_formatted_results_raw if return_raw else \ + round_formatted_results + + def calc_blocal_dissim(self, last_model, local_updated_models): + ''' + Arguments: + last_model (dict): the state of last round. + local_updated_models (list): each element is ooxx. + Returns: + b_local_dissimilarity (dict): the measurements proposed in + "Tian Li, Anit Kumar Sahu, Manzil Zaheer, and et al. Federated + Optimization in Heterogeneous Networks". + ''' + # for k, v in last_model.items(): + # print(k, v) + # for i, elem in enumerate(local_updated_models): + # print(i, elem) + local_grads = [] + weights = [] + local_gnorms = [] + for tp in local_updated_models: + weights.append(tp[0]) + grads = dict() + gnorms = dict() + for k, v in tp[1].items(): + grad = v - last_model[k] + grads[k] = grad + gnorms[k] = torch.sum(grad**2) + local_grads.append(grads) + local_gnorms.append(gnorms) + weights = np.asarray(weights) + weights = weights / np.sum(weights) + avg_gnorms = dict() + global_grads = dict() + for i in range(len(local_updated_models)): + gnorms = local_gnorms[i] + for k, v in gnorms.items(): + if k not in avg_gnorms: + avg_gnorms[k] = .0 + avg_gnorms[k] += weights[i] * v + grads = local_grads[i] + for k, v in grads.items(): + if k not in global_grads: + global_grads[k] = torch.zeros_like(v) + global_grads[k] += weights[i] * v + b_local_dissimilarity = dict() + for k in avg_gnorms: + b_local_dissimilarity[k] = np.sqrt( + avg_gnorms[k].item() / torch.sum(global_grads[k]**2).item()) + return b_local_dissimilarity + + def convert_size(self, size_bytes): + import math + if size_bytes <= 0: + return str(size_bytes) + size_name = ("", "K", "M", "G", "T", "P", "E", "Z", "Y") + i = int(math.floor(math.log(size_bytes, 1024))) + p = math.pow(1024, i) + s = round(size_bytes / p, 2) + return f"{s}{size_name[i]}" + + def track_model_size(self, models): + """ + calculate the total model size given the models hold by the + worker/trainer + + :param models: torch.nn.Module or list of torch.nn.Module + :return: + """ + if self.total_model_size != 0: + logger.warning( + "the total_model_size is not zero. You may have been " + "calculated the total_model_size before") + + if not hasattr(models, '__iter__'): + models = [models] + for model in models: + assert isinstance(model, torch.nn.Module), \ + f"the `model` should be type torch.nn.Module when " \ + f"calculating its size, but got {type(model)}" + for name, para in model.named_parameters(): + self.total_model_size += para.numel() + + def track_avg_flops(self, flops, sample_num=1): + """ + update the average flops for forwarding each data sample, + for most models and tasks, + the averaging is not needed as the input shape is fixed + + :param flops: flops/ + :param sample_num: + :return: + """ + + self.flops_per_sample = (self.flops_per_sample * self.flop_count + + flops) / (self.flop_count + sample_num) + self.flop_count += 1 + + def track_upload_bytes(self, bytes): + self.total_upload_bytes += bytes + + def track_download_bytes(self, bytes): + self.total_download_bytes += bytes + + def update_best_result(self, + best_results, + new_results, + results_type, + round_wise_update_key="val_loss"): + """ + update best evaluation results. + by default, the update is based on validation loss with + `round_wise_update_key="val_loss" ` + """ + update_best_this_round = False + if not isinstance(new_results, dict): + raise ValueError( + f"update best results require `results` a dict, but got" + f" {type(new_results)}") + else: + if results_type not in best_results: + best_results[results_type] = dict() + best_result = best_results[results_type] + # update different keys separately: the best values can be in + # different rounds + if round_wise_update_key is None: + for key in new_results: + cur_result = new_results[key] + if 'loss' in key or 'std' in key: # the smaller, + # the better + if results_type in [ + "client_best_individual", + "unseen_client_best_individual" + ]: + cur_result = min(cur_result) + if key not in best_result or cur_result < best_result[ + key]: + best_result[key] = cur_result + update_best_this_round = True + + elif 'acc' in key: # the larger, the better + if results_type in [ + "client_best_individual", + "unseen_client_best_individual" + ]: + cur_result = max(cur_result) + if key not in best_result or cur_result > best_result[ + key]: + best_result[key] = cur_result + update_best_this_round = True + else: + # unconcerned metric + pass + # update different keys round-wise: if find better + # round_wise_update_key, update others at the same time + else: + if round_wise_update_key not in [ + "val_loss", "test_loss", "loss", "val_avg_loss", + "test_avg_loss", "avg_loss", "test_acc", "test_std", + "val_acc", "val_std", "val_imp_ratio", "train_loss", + "train_avg_loss" + ]: + raise NotImplementedError( + f"We currently support round_wise_update_key as one " + f"of ['val_loss', 'test_loss', 'loss', " + f"'val_avg_loss', 'test_avg_loss', 'avg_loss," + f"''val_acc', 'val_std', 'test_acc', 'test_std', " + f"'val_imp_ratio'] for round-wise best results " + f" update, but got {round_wise_update_key}.") + + found_round_wise_update_key = False + sorted_keys = [] + for key in new_results: + if round_wise_update_key in key: + sorted_keys.insert(0, key) + found_round_wise_update_key = True + else: + sorted_keys.append(key) + if not found_round_wise_update_key: + raise ValueError( + "Your specified eval.best_res_update_round_wise_key " + "is not in target results, " + "use another key or check the name. \n" + f"Got eval.best_res_update_round_wise_key" + f"={round_wise_update_key}, " + f"the keys of results are {list(new_results.keys())}") + + for key in sorted_keys: + cur_result = new_results[key] + if update_best_this_round or \ + ('loss' in round_wise_update_key and 'loss' in + key) or \ + ('std' in round_wise_update_key and 'std' in key): + # The smaller the better + if results_type in [ + "client_best_individual", + "unseen_client_best_individual" + ]: + cur_result = min(cur_result) + if update_best_this_round or \ + key not in best_result or cur_result < \ + best_result[key]: + best_result[key] = cur_result + update_best_this_round = True + elif update_best_this_round or \ + 'acc' in round_wise_update_key and 'acc' in key: + # The larger the better + if results_type in [ + "client_best_individual", + "unseen_client_best_individual" + ]: + cur_result = max(cur_result) + if update_best_this_round or \ + key not in best_result or cur_result > \ + best_result[key]: + best_result[key] = cur_result + update_best_this_round = True + else: + # unconcerned metric + pass + + if update_best_this_round: + line = f"Find new best result: {best_results}" + logging.info(line) + if self.use_wandb and self.wandb_online_track: + try: + import wandb + exp_stop_normal = False + exp_stop_normal, log_res = logline_2_wandb_dict( + exp_stop_normal, + line, + self.log_res_best, + raw_out=False) + # wandb.log(self.log_res_best) + for k, v in self.log_res_best.items(): + wandb.summary[k] = v + except ImportError: + logger.error( + "cfg.wandb.use=True but not install the wandb package") + exit() diff --git a/fgssl/core/optimizer.py b/fgssl/core/optimizer.py new file mode 100644 index 0000000..6e54836 --- /dev/null +++ b/fgssl/core/optimizer.py @@ -0,0 +1,59 @@ +import copy +from typing import Dict, List + + +def wrap_regularized_optimizer(base_optimizer, regular_weight): + base_optimizer_type = type(base_optimizer) + internal_base_optimizer = copy.copy( + base_optimizer) # shallow copy to link the underlying model para + + class ParaRegularOptimizer(base_optimizer_type): + """ + Regularization-based optimizer wrapper + """ + def __init__(self, base_optimizer, regular_weight): + # inherit all the attributes of base optimizer + self.__dict__.update(base_optimizer.__dict__) + + # attributes used in the wrapper + self.optimizer = base_optimizer # internal torch optimizer + self.param_groups = self.optimizer.param_groups # link the para + # of internal optimizer with the wrapper + self.regular_weight = regular_weight + self.compared_para_groups = None + + def set_compared_para_group(self, compared_para_dict: List[Dict]): + if not (isinstance(compared_para_dict, list) + and isinstance(compared_para_dict[0], dict) + and 'params' in compared_para_dict[0]): + raise ValueError( + "compared_para_dict should be a torch style para group, " + "i.e., list[dict], " + "in which the dict stores the para with key `params`") + self.compared_para_groups = copy.deepcopy(compared_para_dict) + + def reset_compared_para_group(self, target=None): + # by default, del stale compared_para to free memory + self.compared_para_groups = target + + def regularize_by_para_diff(self): + """ + before optim.step(), regularize the gradients based on para + diff + """ + for group, compared_group in zip(self.param_groups, + self.compared_para_groups): + for p, compared_weight in zip(group['params'], + compared_group['params']): + if p.grad is not None: + if compared_weight.device != p.device: + # For Tensor, the to() is not in-place operation + compared_weight = compared_weight.to(p.device) + p.grad.data = p.grad.data + self.regular_weight * ( + p.data - compared_weight.data) + + def step(self): + self.regularize_by_para_diff() # key action + self.optimizer.step() + + return ParaRegularOptimizer(internal_base_optimizer, regular_weight) diff --git a/fgssl/core/optimizers/__init__.py b/fgssl/core/optimizers/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/fgssl/core/proto/__init__.py b/fgssl/core/proto/__init__.py new file mode 100644 index 0000000..4deaa8e --- /dev/null +++ b/fgssl/core/proto/__init__.py @@ -0,0 +1,2 @@ +from federatedscope.core.proto.gRPC_comm_manager_pb2 import * +from federatedscope.core.proto.gRPC_comm_manager_pb2_grpc import * diff --git a/fgssl/core/proto/gRPC_comm_manager.proto b/fgssl/core/proto/gRPC_comm_manager.proto new file mode 100644 index 0000000..f7e5627 --- /dev/null +++ b/fgssl/core/proto/gRPC_comm_manager.proto @@ -0,0 +1,42 @@ +syntax = "proto3"; + +service gRPCComServeFunc { + rpc sendMessage (MessageRequest) returns (MessageResponse) {}; +} + +message MessageRequest{ + map msg = 1; +} + +message MsgValue{ + oneof type { + mSingle single_msg = 1; + mList list_msg = 2; + mDict_keyIsString dict_msg_stringkey = 3; + mDict_keyIsInt dict_msg_intkey = 4; + } +} + +message mSingle{ + oneof type { + float float_value = 1; + int32 int_value = 2; + string str_value = 3; + } +} + +message mList{ + repeated MsgValue list_value = 1; +} + +message mDict_keyIsString{ + map dict_value = 1; +} + +message mDict_keyIsInt{ + map dict_value = 1; +} + +message MessageResponse{ + string msg = 1; +} diff --git a/fgssl/core/proto/gRPC_comm_manager_pb2.py b/fgssl/core/proto/gRPC_comm_manager_pb2.py new file mode 100644 index 0000000..f35b1ed --- /dev/null +++ b/fgssl/core/proto/gRPC_comm_manager_pb2.py @@ -0,0 +1,760 @@ +# -*- coding: utf-8 -*- +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: gRPC_comm_manager.proto +"""Generated protocol buffer code.""" +from google.protobuf import descriptor as _descriptor +from google.protobuf import message as _message +from google.protobuf import reflection as _reflection +from google.protobuf import symbol_database as _symbol_database +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + +DESCRIPTOR = _descriptor.FileDescriptor( + name='gRPC_comm_manager.proto', + package='', + syntax='proto3', + serialized_options=None, + create_key=_descriptor._internal_create_key, + serialized_pb= + b'\n\x17gRPC_comm_manager.proto\"n\n\x0eMessageRequest\x12%\n\x03msg\x18\x01 \x03(\x0b\x32\x18.MessageRequest.MsgEntry\x1a\x35\n\x08MsgEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\x18\n\x05value\x18\x02 \x01(\x0b\x32\t.MsgValue:\x02\x38\x01\"\xac\x01\n\x08MsgValue\x12\x1e\n\nsingle_msg\x18\x01 \x01(\x0b\x32\x08.mSingleH\x00\x12\x1a\n\x08list_msg\x18\x02 \x01(\x0b\x32\x06.mListH\x00\x12\x30\n\x12\x64ict_msg_stringkey\x18\x03 \x01(\x0b\x32\x12.mDict_keyIsStringH\x00\x12*\n\x0f\x64ict_msg_intkey\x18\x04 \x01(\x0b\x32\x0f.mDict_keyIsIntH\x00\x42\x06\n\x04type\"R\n\x07mSingle\x12\x15\n\x0b\x66loat_value\x18\x01 \x01(\x02H\x00\x12\x13\n\tint_value\x18\x02 \x01(\x05H\x00\x12\x13\n\tstr_value\x18\x03 \x01(\tH\x00\x42\x06\n\x04type\"&\n\x05mList\x12\x1d\n\nlist_value\x18\x01 \x03(\x0b\x32\t.MsgValue\"\x87\x01\n\x11mDict_keyIsString\x12\x35\n\ndict_value\x18\x01 \x03(\x0b\x32!.mDict_keyIsString.DictValueEntry\x1a;\n\x0e\x44ictValueEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\x18\n\x05value\x18\x02 \x01(\x0b\x32\t.MsgValue:\x02\x38\x01\"\x81\x01\n\x0emDict_keyIsInt\x12\x32\n\ndict_value\x18\x01 \x03(\x0b\x32\x1e.mDict_keyIsInt.DictValueEntry\x1a;\n\x0e\x44ictValueEntry\x12\x0b\n\x03key\x18\x01 \x01(\x05\x12\x18\n\x05value\x18\x02 \x01(\x0b\x32\t.MsgValue:\x02\x38\x01\"\x1e\n\x0fMessageResponse\x12\x0b\n\x03msg\x18\x01 \x01(\t2F\n\x10gRPCComServeFunc\x12\x32\n\x0bsendMessage\x12\x0f.MessageRequest\x1a\x10.MessageResponse\"\x00\x62\x06proto3' +) + +_MESSAGEREQUEST_MSGENTRY = _descriptor.Descriptor( + name='MsgEntry', + full_name='MessageRequest.MsgEntry', + filename=None, + file=DESCRIPTOR, + containing_type=None, + create_key=_descriptor._internal_create_key, + fields=[ + _descriptor.FieldDescriptor( + name='key', + full_name='MessageRequest.MsgEntry.key', + index=0, + number=1, + type=9, + cpp_type=9, + label=1, + has_default_value=False, + default_value=b"".decode('utf-8'), + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + serialized_options=None, + file=DESCRIPTOR, + create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='value', + full_name='MessageRequest.MsgEntry.value', + index=1, + number=2, + type=11, + cpp_type=10, + label=1, + has_default_value=False, + default_value=None, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + serialized_options=None, + file=DESCRIPTOR, + create_key=_descriptor._internal_create_key), + ], + extensions=[], + nested_types=[], + enum_types=[], + serialized_options=b'8\001', + is_extendable=False, + syntax='proto3', + extension_ranges=[], + oneofs=[], + serialized_start=84, + serialized_end=137, +) + +_MESSAGEREQUEST = _descriptor.Descriptor( + name='MessageRequest', + full_name='MessageRequest', + filename=None, + file=DESCRIPTOR, + containing_type=None, + create_key=_descriptor._internal_create_key, + fields=[ + _descriptor.FieldDescriptor( + name='msg', + full_name='MessageRequest.msg', + index=0, + number=1, + type=11, + cpp_type=10, + label=3, + has_default_value=False, + default_value=[], + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + serialized_options=None, + file=DESCRIPTOR, + create_key=_descriptor._internal_create_key), + ], + extensions=[], + nested_types=[ + _MESSAGEREQUEST_MSGENTRY, + ], + enum_types=[], + serialized_options=None, + is_extendable=False, + syntax='proto3', + extension_ranges=[], + oneofs=[], + serialized_start=27, + serialized_end=137, +) + +_MSGVALUE = _descriptor.Descriptor( + name='MsgValue', + full_name='MsgValue', + filename=None, + file=DESCRIPTOR, + containing_type=None, + create_key=_descriptor._internal_create_key, + fields=[ + _descriptor.FieldDescriptor( + name='single_msg', + full_name='MsgValue.single_msg', + index=0, + number=1, + type=11, + cpp_type=10, + label=1, + has_default_value=False, + default_value=None, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + serialized_options=None, + file=DESCRIPTOR, + create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='list_msg', + full_name='MsgValue.list_msg', + index=1, + number=2, + type=11, + cpp_type=10, + label=1, + has_default_value=False, + default_value=None, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + serialized_options=None, + file=DESCRIPTOR, + create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='dict_msg_stringkey', + full_name='MsgValue.dict_msg_stringkey', + index=2, + number=3, + type=11, + cpp_type=10, + label=1, + has_default_value=False, + default_value=None, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + serialized_options=None, + file=DESCRIPTOR, + create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='dict_msg_intkey', + full_name='MsgValue.dict_msg_intkey', + index=3, + number=4, + type=11, + cpp_type=10, + label=1, + has_default_value=False, + default_value=None, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + serialized_options=None, + file=DESCRIPTOR, + create_key=_descriptor._internal_create_key), + ], + extensions=[], + nested_types=[], + enum_types=[], + serialized_options=None, + is_extendable=False, + syntax='proto3', + extension_ranges=[], + oneofs=[ + _descriptor.OneofDescriptor( + name='type', + full_name='MsgValue.type', + index=0, + containing_type=None, + create_key=_descriptor._internal_create_key, + fields=[]), + ], + serialized_start=140, + serialized_end=312, +) + +_MSINGLE = _descriptor.Descriptor( + name='mSingle', + full_name='mSingle', + filename=None, + file=DESCRIPTOR, + containing_type=None, + create_key=_descriptor._internal_create_key, + fields=[ + _descriptor.FieldDescriptor( + name='float_value', + full_name='mSingle.float_value', + index=0, + number=1, + type=2, + cpp_type=6, + label=1, + has_default_value=False, + default_value=float(0), + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + serialized_options=None, + file=DESCRIPTOR, + create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='int_value', + full_name='mSingle.int_value', + index=1, + number=2, + type=5, + cpp_type=1, + label=1, + has_default_value=False, + default_value=0, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + serialized_options=None, + file=DESCRIPTOR, + create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='str_value', + full_name='mSingle.str_value', + index=2, + number=3, + type=9, + cpp_type=9, + label=1, + has_default_value=False, + default_value=b"".decode('utf-8'), + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + serialized_options=None, + file=DESCRIPTOR, + create_key=_descriptor._internal_create_key), + ], + extensions=[], + nested_types=[], + enum_types=[], + serialized_options=None, + is_extendable=False, + syntax='proto3', + extension_ranges=[], + oneofs=[ + _descriptor.OneofDescriptor( + name='type', + full_name='mSingle.type', + index=0, + containing_type=None, + create_key=_descriptor._internal_create_key, + fields=[]), + ], + serialized_start=314, + serialized_end=396, +) + +_MLIST = _descriptor.Descriptor( + name='mList', + full_name='mList', + filename=None, + file=DESCRIPTOR, + containing_type=None, + create_key=_descriptor._internal_create_key, + fields=[ + _descriptor.FieldDescriptor( + name='list_value', + full_name='mList.list_value', + index=0, + number=1, + type=11, + cpp_type=10, + label=3, + has_default_value=False, + default_value=[], + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + serialized_options=None, + file=DESCRIPTOR, + create_key=_descriptor._internal_create_key), + ], + extensions=[], + nested_types=[], + enum_types=[], + serialized_options=None, + is_extendable=False, + syntax='proto3', + extension_ranges=[], + oneofs=[], + serialized_start=398, + serialized_end=436, +) + +_MDICT_KEYISSTRING_DICTVALUEENTRY = _descriptor.Descriptor( + name='DictValueEntry', + full_name='mDict_keyIsString.DictValueEntry', + filename=None, + file=DESCRIPTOR, + containing_type=None, + create_key=_descriptor._internal_create_key, + fields=[ + _descriptor.FieldDescriptor( + name='key', + full_name='mDict_keyIsString.DictValueEntry.key', + index=0, + number=1, + type=9, + cpp_type=9, + label=1, + has_default_value=False, + default_value=b"".decode('utf-8'), + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + serialized_options=None, + file=DESCRIPTOR, + create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='value', + full_name='mDict_keyIsString.DictValueEntry.value', + index=1, + number=2, + type=11, + cpp_type=10, + label=1, + has_default_value=False, + default_value=None, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + serialized_options=None, + file=DESCRIPTOR, + create_key=_descriptor._internal_create_key), + ], + extensions=[], + nested_types=[], + enum_types=[], + serialized_options=b'8\001', + is_extendable=False, + syntax='proto3', + extension_ranges=[], + oneofs=[], + serialized_start=515, + serialized_end=574, +) + +_MDICT_KEYISSTRING = _descriptor.Descriptor( + name='mDict_keyIsString', + full_name='mDict_keyIsString', + filename=None, + file=DESCRIPTOR, + containing_type=None, + create_key=_descriptor._internal_create_key, + fields=[ + _descriptor.FieldDescriptor( + name='dict_value', + full_name='mDict_keyIsString.dict_value', + index=0, + number=1, + type=11, + cpp_type=10, + label=3, + has_default_value=False, + default_value=[], + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + serialized_options=None, + file=DESCRIPTOR, + create_key=_descriptor._internal_create_key), + ], + extensions=[], + nested_types=[ + _MDICT_KEYISSTRING_DICTVALUEENTRY, + ], + enum_types=[], + serialized_options=None, + is_extendable=False, + syntax='proto3', + extension_ranges=[], + oneofs=[], + serialized_start=439, + serialized_end=574, +) + +_MDICT_KEYISINT_DICTVALUEENTRY = _descriptor.Descriptor( + name='DictValueEntry', + full_name='mDict_keyIsInt.DictValueEntry', + filename=None, + file=DESCRIPTOR, + containing_type=None, + create_key=_descriptor._internal_create_key, + fields=[ + _descriptor.FieldDescriptor( + name='key', + full_name='mDict_keyIsInt.DictValueEntry.key', + index=0, + number=1, + type=5, + cpp_type=1, + label=1, + has_default_value=False, + default_value=0, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + serialized_options=None, + file=DESCRIPTOR, + create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='value', + full_name='mDict_keyIsInt.DictValueEntry.value', + index=1, + number=2, + type=11, + cpp_type=10, + label=1, + has_default_value=False, + default_value=None, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + serialized_options=None, + file=DESCRIPTOR, + create_key=_descriptor._internal_create_key), + ], + extensions=[], + nested_types=[], + enum_types=[], + serialized_options=b'8\001', + is_extendable=False, + syntax='proto3', + extension_ranges=[], + oneofs=[], + serialized_start=647, + serialized_end=706, +) + +_MDICT_KEYISINT = _descriptor.Descriptor( + name='mDict_keyIsInt', + full_name='mDict_keyIsInt', + filename=None, + file=DESCRIPTOR, + containing_type=None, + create_key=_descriptor._internal_create_key, + fields=[ + _descriptor.FieldDescriptor( + name='dict_value', + full_name='mDict_keyIsInt.dict_value', + index=0, + number=1, + type=11, + cpp_type=10, + label=3, + has_default_value=False, + default_value=[], + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + serialized_options=None, + file=DESCRIPTOR, + create_key=_descriptor._internal_create_key), + ], + extensions=[], + nested_types=[ + _MDICT_KEYISINT_DICTVALUEENTRY, + ], + enum_types=[], + serialized_options=None, + is_extendable=False, + syntax='proto3', + extension_ranges=[], + oneofs=[], + serialized_start=577, + serialized_end=706, +) + +_MESSAGERESPONSE = _descriptor.Descriptor( + name='MessageResponse', + full_name='MessageResponse', + filename=None, + file=DESCRIPTOR, + containing_type=None, + create_key=_descriptor._internal_create_key, + fields=[ + _descriptor.FieldDescriptor( + name='msg', + full_name='MessageResponse.msg', + index=0, + number=1, + type=9, + cpp_type=9, + label=1, + has_default_value=False, + default_value=b"".decode('utf-8'), + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + serialized_options=None, + file=DESCRIPTOR, + create_key=_descriptor._internal_create_key), + ], + extensions=[], + nested_types=[], + enum_types=[], + serialized_options=None, + is_extendable=False, + syntax='proto3', + extension_ranges=[], + oneofs=[], + serialized_start=708, + serialized_end=738, +) + +_MESSAGEREQUEST_MSGENTRY.fields_by_name['value'].message_type = _MSGVALUE +_MESSAGEREQUEST_MSGENTRY.containing_type = _MESSAGEREQUEST +_MESSAGEREQUEST.fields_by_name['msg'].message_type = _MESSAGEREQUEST_MSGENTRY +_MSGVALUE.fields_by_name['single_msg'].message_type = _MSINGLE +_MSGVALUE.fields_by_name['list_msg'].message_type = _MLIST +_MSGVALUE.fields_by_name[ + 'dict_msg_stringkey'].message_type = _MDICT_KEYISSTRING +_MSGVALUE.fields_by_name['dict_msg_intkey'].message_type = _MDICT_KEYISINT +_MSGVALUE.oneofs_by_name['type'].fields.append( + _MSGVALUE.fields_by_name['single_msg']) +_MSGVALUE.fields_by_name[ + 'single_msg'].containing_oneof = _MSGVALUE.oneofs_by_name['type'] +_MSGVALUE.oneofs_by_name['type'].fields.append( + _MSGVALUE.fields_by_name['list_msg']) +_MSGVALUE.fields_by_name[ + 'list_msg'].containing_oneof = _MSGVALUE.oneofs_by_name['type'] +_MSGVALUE.oneofs_by_name['type'].fields.append( + _MSGVALUE.fields_by_name['dict_msg_stringkey']) +_MSGVALUE.fields_by_name[ + 'dict_msg_stringkey'].containing_oneof = _MSGVALUE.oneofs_by_name['type'] +_MSGVALUE.oneofs_by_name['type'].fields.append( + _MSGVALUE.fields_by_name['dict_msg_intkey']) +_MSGVALUE.fields_by_name[ + 'dict_msg_intkey'].containing_oneof = _MSGVALUE.oneofs_by_name['type'] +_MSINGLE.oneofs_by_name['type'].fields.append( + _MSINGLE.fields_by_name['float_value']) +_MSINGLE.fields_by_name[ + 'float_value'].containing_oneof = _MSINGLE.oneofs_by_name['type'] +_MSINGLE.oneofs_by_name['type'].fields.append( + _MSINGLE.fields_by_name['int_value']) +_MSINGLE.fields_by_name[ + 'int_value'].containing_oneof = _MSINGLE.oneofs_by_name['type'] +_MSINGLE.oneofs_by_name['type'].fields.append( + _MSINGLE.fields_by_name['str_value']) +_MSINGLE.fields_by_name[ + 'str_value'].containing_oneof = _MSINGLE.oneofs_by_name['type'] +_MLIST.fields_by_name['list_value'].message_type = _MSGVALUE +_MDICT_KEYISSTRING_DICTVALUEENTRY.fields_by_name[ + 'value'].message_type = _MSGVALUE +_MDICT_KEYISSTRING_DICTVALUEENTRY.containing_type = _MDICT_KEYISSTRING +_MDICT_KEYISSTRING.fields_by_name[ + 'dict_value'].message_type = _MDICT_KEYISSTRING_DICTVALUEENTRY +_MDICT_KEYISINT_DICTVALUEENTRY.fields_by_name['value'].message_type = _MSGVALUE +_MDICT_KEYISINT_DICTVALUEENTRY.containing_type = _MDICT_KEYISINT +_MDICT_KEYISINT.fields_by_name[ + 'dict_value'].message_type = _MDICT_KEYISINT_DICTVALUEENTRY +DESCRIPTOR.message_types_by_name['MessageRequest'] = _MESSAGEREQUEST +DESCRIPTOR.message_types_by_name['MsgValue'] = _MSGVALUE +DESCRIPTOR.message_types_by_name['mSingle'] = _MSINGLE +DESCRIPTOR.message_types_by_name['mList'] = _MLIST +DESCRIPTOR.message_types_by_name['mDict_keyIsString'] = _MDICT_KEYISSTRING +DESCRIPTOR.message_types_by_name['mDict_keyIsInt'] = _MDICT_KEYISINT +DESCRIPTOR.message_types_by_name['MessageResponse'] = _MESSAGERESPONSE +_sym_db.RegisterFileDescriptor(DESCRIPTOR) + +MessageRequest = _reflection.GeneratedProtocolMessageType( + 'MessageRequest', + (_message.Message, ), + { + 'MsgEntry': _reflection.GeneratedProtocolMessageType( + 'MsgEntry', + (_message.Message, ), + { + 'DESCRIPTOR': _MESSAGEREQUEST_MSGENTRY, + '__module__': 'gRPC_comm_manager_pb2' + # @@protoc_insertion_point(class_scope:MessageRequest.MsgEntry) + }), + 'DESCRIPTOR': _MESSAGEREQUEST, + '__module__': 'gRPC_comm_manager_pb2' + # @@protoc_insertion_point(class_scope:MessageRequest) + }) +_sym_db.RegisterMessage(MessageRequest) +_sym_db.RegisterMessage(MessageRequest.MsgEntry) + +MsgValue = _reflection.GeneratedProtocolMessageType( + 'MsgValue', + (_message.Message, ), + { + 'DESCRIPTOR': _MSGVALUE, + '__module__': 'gRPC_comm_manager_pb2' + # @@protoc_insertion_point(class_scope:MsgValue) + }) +_sym_db.RegisterMessage(MsgValue) + +mSingle = _reflection.GeneratedProtocolMessageType( + 'mSingle', + (_message.Message, ), + { + 'DESCRIPTOR': _MSINGLE, + '__module__': 'gRPC_comm_manager_pb2' + # @@protoc_insertion_point(class_scope:mSingle) + }) +_sym_db.RegisterMessage(mSingle) + +mList = _reflection.GeneratedProtocolMessageType( + 'mList', + (_message.Message, ), + { + 'DESCRIPTOR': _MLIST, + '__module__': 'gRPC_comm_manager_pb2' + # @@protoc_insertion_point(class_scope:mList) + }) +_sym_db.RegisterMessage(mList) + +mDict_keyIsString = _reflection.GeneratedProtocolMessageType( + 'mDict_keyIsString', + (_message.Message, ), + { + 'DictValueEntry': _reflection.GeneratedProtocolMessageType( + 'DictValueEntry', + (_message.Message, ), + { + 'DESCRIPTOR': _MDICT_KEYISSTRING_DICTVALUEENTRY, + '__module__': 'gRPC_comm_manager_pb2' + # @@protoc_insertion_point(class_scope:mDict_keyIsString.DictValueEntry) + }), + 'DESCRIPTOR': _MDICT_KEYISSTRING, + '__module__': 'gRPC_comm_manager_pb2' + # @@protoc_insertion_point(class_scope:mDict_keyIsString) + }) +_sym_db.RegisterMessage(mDict_keyIsString) +_sym_db.RegisterMessage(mDict_keyIsString.DictValueEntry) + +mDict_keyIsInt = _reflection.GeneratedProtocolMessageType( + 'mDict_keyIsInt', + (_message.Message, ), + { + 'DictValueEntry': _reflection.GeneratedProtocolMessageType( + 'DictValueEntry', + (_message.Message, ), + { + 'DESCRIPTOR': _MDICT_KEYISINT_DICTVALUEENTRY, + '__module__': 'gRPC_comm_manager_pb2' + # @@protoc_insertion_point(class_scope:mDict_keyIsInt.DictValueEntry) + }), + 'DESCRIPTOR': _MDICT_KEYISINT, + '__module__': 'gRPC_comm_manager_pb2' + # @@protoc_insertion_point(class_scope:mDict_keyIsInt) + }) +_sym_db.RegisterMessage(mDict_keyIsInt) +_sym_db.RegisterMessage(mDict_keyIsInt.DictValueEntry) + +MessageResponse = _reflection.GeneratedProtocolMessageType( + 'MessageResponse', + (_message.Message, ), + { + 'DESCRIPTOR': _MESSAGERESPONSE, + '__module__': 'gRPC_comm_manager_pb2' + # @@protoc_insertion_point(class_scope:MessageResponse) + }) +_sym_db.RegisterMessage(MessageResponse) + +_MESSAGEREQUEST_MSGENTRY._options = None +_MDICT_KEYISSTRING_DICTVALUEENTRY._options = None +_MDICT_KEYISINT_DICTVALUEENTRY._options = None + +_GRPCCOMSERVEFUNC = _descriptor.ServiceDescriptor( + name='gRPCComServeFunc', + full_name='gRPCComServeFunc', + file=DESCRIPTOR, + index=0, + serialized_options=None, + create_key=_descriptor._internal_create_key, + serialized_start=740, + serialized_end=810, + methods=[ + _descriptor.MethodDescriptor( + name='sendMessage', + full_name='gRPCComServeFunc.sendMessage', + index=0, + containing_service=None, + input_type=_MESSAGEREQUEST, + output_type=_MESSAGERESPONSE, + serialized_options=None, + create_key=_descriptor._internal_create_key, + ), + ]) +_sym_db.RegisterServiceDescriptor(_GRPCCOMSERVEFUNC) + +DESCRIPTOR.services_by_name['gRPCComServeFunc'] = _GRPCCOMSERVEFUNC + +# @@protoc_insertion_point(module_scope) diff --git a/fgssl/core/proto/gRPC_comm_manager_pb2_grpc.py b/fgssl/core/proto/gRPC_comm_manager_pb2_grpc.py new file mode 100644 index 0000000..9405549 --- /dev/null +++ b/fgssl/core/proto/gRPC_comm_manager_pb2_grpc.py @@ -0,0 +1,69 @@ +# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! +"""Client and server classes corresponding to protobuf-defined services.""" +import grpc + +from federatedscope.core.proto import gRPC_comm_manager_pb2 \ + as gRPC__comm__manager__pb2 + + +class gRPCComServeFuncStub(object): + """Missing associated documentation comment in .proto file.""" + def __init__(self, channel): + """Constructor. + + Args: + channel: A grpc.Channel. + """ + self.sendMessage = channel.unary_unary( + '/gRPCComServeFunc/sendMessage', + request_serializer=gRPC__comm__manager__pb2.MessageRequest. + SerializeToString, + response_deserializer=gRPC__comm__manager__pb2.MessageResponse. + FromString, + ) + + +class gRPCComServeFuncServicer(object): + """Missing associated documentation comment in .proto file.""" + def sendMessage(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + +def add_gRPCComServeFuncServicer_to_server(servicer, server): + rpc_method_handlers = { + 'sendMessage': grpc.unary_unary_rpc_method_handler( + servicer.sendMessage, + request_deserializer=gRPC__comm__manager__pb2.MessageRequest. + FromString, + response_serializer=gRPC__comm__manager__pb2.MessageResponse. + SerializeToString, + ), + } + generic_handler = grpc.method_handlers_generic_handler( + 'gRPCComServeFunc', rpc_method_handlers) + server.add_generic_rpc_handlers((generic_handler, )) + + +# This class is part of an EXPERIMENTAL API. +class gRPCComServeFunc(object): + """Missing associated documentation comment in .proto file.""" + @staticmethod + def sendMessage(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary( + request, target, '/gRPCComServeFunc/sendMessage', + gRPC__comm__manager__pb2.MessageRequest.SerializeToString, + gRPC__comm__manager__pb2.MessageResponse.FromString, options, + channel_credentials, insecure, call_credentials, compression, + wait_for_ready, timeout, metadata) diff --git a/fgssl/core/regularizer/__init__.py b/fgssl/core/regularizer/__init__.py new file mode 100644 index 0000000..5821bc7 --- /dev/null +++ b/fgssl/core/regularizer/__init__.py @@ -0,0 +1 @@ +from federatedscope.core.regularizer.proximal_regularizer import * diff --git a/fgssl/core/regularizer/proximal_regularizer.py b/fgssl/core/regularizer/proximal_regularizer.py new file mode 100644 index 0000000..aab1eb0 --- /dev/null +++ b/fgssl/core/regularizer/proximal_regularizer.py @@ -0,0 +1,39 @@ +from federatedscope.register import register_regularizer +try: + from torch.nn import Module + import torch +except ImportError: + Module = object + torch = None + +REGULARIZER_NAME = "proximal_regularizer" + + +class ProximalRegularizer(Module): + """Returns the norm of the specific weight update. + + Arguments: + p (int): The order of norm. + tensor_before: The original matrix or vector + tensor_after: The updated matrix or vector + + Returns: + Tensor: the norm of the given udpate. + """ + def __init__(self): + super(ProximalRegularizer, self).__init__() + + def forward(self, ctx, p=2): + norm = 0. + for w_init, w in zip(ctx.weight_init, ctx.model.parameters()): + norm += torch.pow(torch.norm(w - w_init, p), p) + return norm * 1. / float(p) + + +def call_proximal_regularizer(type): + if type == REGULARIZER_NAME: + regularizer = ProximalRegularizer + return regularizer + + +register_regularizer(REGULARIZER_NAME, call_proximal_regularizer) diff --git a/fgssl/core/sampler.py b/fgssl/core/sampler.py new file mode 100644 index 0000000..6438cf8 --- /dev/null +++ b/fgssl/core/sampler.py @@ -0,0 +1,131 @@ +from abc import ABC, abstractmethod + +import numpy as np + + +class Sampler(ABC): + """ + The strategies of sampling clients for each training round + + Arguments: + client_state: a dict to manager the state of clients (idle or busy) + """ + def __init__(self, client_num): + self.client_state = np.asarray([1] * (client_num + 1)) + # Set the state of server (index=0) to 'working' + self.client_state[0] = 0 + + @abstractmethod + def sample(self, size): + raise NotImplementedError + + def change_state(self, indices, state): + """ + To modify the state of clients (idle or working) + """ + if isinstance(indices, list) or isinstance(indices, np.ndarray): + all_idx = indices + else: + all_idx = [indices] + for idx in all_idx: + if state in ['idle', 'seen']: + self.client_state[idx] = 1 + elif state in ['working', 'unseen']: + self.client_state[idx] = 0 + else: + raise ValueError( + f"The state of client should be one of " + f"['idle', 'working', 'unseen], but got {state}") + + +class UniformSampler(Sampler): + """ + To uniformly sample the clients from all the idle clients + """ + def __init__(self, client_num): + super(UniformSampler, self).__init__(client_num) + + def sample(self, size): + """ + To sample clients + """ + idle_clients = np.nonzero(self.client_state)[0] + sampled_clients = np.random.choice(idle_clients, + size=size, + replace=False).tolist() + self.change_state(sampled_clients, 'working') + return sampled_clients + + +class GroupSampler(Sampler): + """ + To grouply sample the clients based on their responsiveness (or other + client information of the clients) + """ + def __init__(self, client_num, client_info, bins=10): + super(GroupSampler, self).__init__(client_num) + self.bins = bins + self.update_client_info(client_info) + self.candidate_iterator = self.partition() + + def update_client_info(self, client_info): + """ + To update the client information + """ + self.client_info = np.asarray( + [1.0] + [x for x in client_info + ]) # client_info[0] is preversed for the server + assert len(self.client_info) == len( + self.client_state + ), "The first dimension of client_info is mismatched with client_num" + + def partition(self): + """ + To partition the clients into groups according to the client + information + + Arguments: + :returns: a iteration of candidates + """ + sorted_index = np.argsort(self.client_info) + num_per_bins = np.int(len(sorted_index) / self.bins) + + # grouped clients + self.grouped_clients = np.split( + sorted_index, np.cumsum([num_per_bins] * (self.bins - 1))) + + return self.permutation() + + def permutation(self): + candidates = list() + permutation = np.random.permutation(np.arange(self.bins)) + for i in permutation: + np.random.shuffle(self.grouped_clients[i]) + candidates.extend(self.grouped_clients[i]) + + return iter(candidates) + + def sample(self, size, shuffle=False): + """ + To sample clients + """ + if shuffle: + self.candidate_iterator = self.permutation() + + sampled_clients = list() + for i in range(size): + # To find an idle client + while True: + try: + item = next(self.candidate_iterator) + except StopIteration: + self.candidate_iterator = self.permutation() + item = next(self.candidate_iterator) + + if self.client_state[item] == 1: + break + + sampled_clients.append(item) + self.change_state(item, 'working') + + return sampled_clients diff --git a/fgssl/core/secret_sharing/__init__.py b/fgssl/core/secret_sharing/__init__.py new file mode 100644 index 0000000..3261af9 --- /dev/null +++ b/fgssl/core/secret_sharing/__init__.py @@ -0,0 +1,2 @@ +from federatedscope.core.secret_sharing.secret_sharing import \ + AdditiveSecretSharing diff --git a/fgssl/core/secret_sharing/secret_sharing.py b/fgssl/core/secret_sharing/secret_sharing.py new file mode 100644 index 0000000..31fb99b --- /dev/null +++ b/fgssl/core/secret_sharing/secret_sharing.py @@ -0,0 +1,98 @@ +from abc import ABC, abstractmethod +import numpy as np +try: + import torch +except ImportError: + torch = None + + +class SecretSharing(ABC): + def __init__(self): + pass + + @abstractmethod + def secret_split(self, secret): + pass + + @abstractmethod + def secret_reconstruct(self, secret_seq): + pass + + +class AdditiveSecretSharing(SecretSharing): + """ + AdditiveSecretSharing class, which can split a number into frames and + recover it by summing up + """ + def __init__(self, shared_party_num, size=60): + super(SecretSharing, self).__init__() + assert shared_party_num > 1, "AdditiveSecretSharing require " \ + "shared_party_num > 1" + self.shared_party_num = shared_party_num + self.maximum = 2**size + self.mod_number = 2 * self.maximum + 1 + self.epsilon = 1e8 + self.mod_funs = np.vectorize(lambda x: x % self.mod_number) + self.float2fixedpoint = np.vectorize(self._float2fixedpoint) + self.fixedpoint2float = np.vectorize(self._fixedpoint2float) + + def secret_split(self, secret): + """ + To split the secret into frames according to the shared_party_num + """ + if isinstance(secret, dict): + secret_list = [dict() for _ in range(self.shared_party_num)] + for key in secret: + for idx, each in enumerate(self.secret_split(secret[key])): + secret_list[idx][key] = each + return secret_list + + if isinstance(secret, list) or isinstance(secret, np.ndarray): + secret = np.asarray(secret) + shape = [self.shared_party_num - 1] + list(secret.shape) + elif isinstance(secret, torch.Tensor): + secret = secret.numpy() + shape = [self.shared_party_num - 1] + list(secret.shape) + else: + shape = [self.shared_party_num - 1] + + secret = self.float2fixedpoint(secret) + secret_seq = np.random.randint(low=0, high=self.mod_number, size=shape) + # last_seq = self.mod_funs(secret - self.mod_funs(np.sum(secret_seq, + # axis=0))) + last_seq = self.mod_funs(secret - + self.mod_funs(np.sum(secret_seq, axis=0))) + + secret_seq = np.append(secret_seq, + np.expand_dims(last_seq, axis=0), + axis=0) + return secret_seq + + def secret_reconstruct(self, secret_seq): + """ + To recover the secret + """ + assert len(secret_seq) == self.shared_party_num + merge_model = secret_seq[0].copy() + if isinstance(merge_model, dict): + for key in merge_model: + for idx in range(len(secret_seq)): + if idx == 0: + merge_model[key] = secret_seq[idx][key] + else: + merge_model[key] += secret_seq[idx][key] + merge_model[key] = self.fixedpoint2float(merge_model[key]) + + return merge_model + + def _float2fixedpoint(self, x): + x = round(x * self.epsilon, 0) + assert abs(x) < self.maximum + return x % self.mod_number + + def _fixedpoint2float(self, x): + x = x % self.mod_number + if x > self.maximum: + return -1 * (self.mod_number - x) / self.epsilon + else: + return x / self.epsilon diff --git a/fgssl/core/splitters/__init__.py b/fgssl/core/splitters/__init__.py new file mode 100644 index 0000000..d7af0c1 --- /dev/null +++ b/fgssl/core/splitters/__init__.py @@ -0,0 +1,3 @@ +from federatedscope.core.splitters.base_splitter import BaseSplitter + +__all__ = ['BaseSplitter'] diff --git a/fgssl/core/splitters/base_splitter.py b/fgssl/core/splitters/base_splitter.py new file mode 100644 index 0000000..1e80bb9 --- /dev/null +++ b/fgssl/core/splitters/base_splitter.py @@ -0,0 +1,28 @@ +import abc +import inspect + + +class BaseSplitter(abc.ABC): + def __init__(self, client_num): + """ + This is an abstract base class for all splitter. + + Args: + client_num: Divide the dataset into `client_num` pieces. + """ + self.client_num = client_num + + @abc.abstractmethod + def __call__(self, dataset, *args, **kwargs): + raise NotImplementedError + + def __repr__(self): + """ + + Returns: Meta information for `Splitter`. + + """ + sign = inspect.signature(self.__init__).parameters.values() + meta_info = tuple([(val.name, getattr(self, val.name)) + for val in sign]) + return f'{self.__class__.__name__}{meta_info}' diff --git a/fgssl/core/splitters/generic/__init__.py b/fgssl/core/splitters/generic/__init__.py new file mode 100644 index 0000000..8bf4c27 --- /dev/null +++ b/fgssl/core/splitters/generic/__init__.py @@ -0,0 +1,4 @@ +from federatedscope.core.splitters.generic.lda_splitter import LDASplitter +from federatedscope.core.splitters.generic.iid_splitter import IIDSplitter + +__all__ = ['LDASplitter', 'IIDSplitter'] diff --git a/fgssl/core/splitters/generic/iid_splitter.py b/fgssl/core/splitters/generic/iid_splitter.py new file mode 100644 index 0000000..a8032bd --- /dev/null +++ b/fgssl/core/splitters/generic/iid_splitter.py @@ -0,0 +1,17 @@ +import numpy as np +from federatedscope.core.splitters import BaseSplitter + + +class IIDSplitter(BaseSplitter): + def __init__(self, client_num): + super(IIDSplitter, self).__init__(client_num) + + def __call__(self, dataset, prior=None): + dataset = [ds for ds in dataset] + np.random.shuffle(dataset) + length = len(dataset) + prop = [1.0 / self.client_num for _ in range(self.client_num)] + prop = (np.cumsum(prop) * length).astype(int)[:-1] + data_list = np.split(dataset, prop) + data_list = [x.tolist() for x in data_list] + return data_list diff --git a/fgssl/core/splitters/generic/lda_splitter.py b/fgssl/core/splitters/generic/lda_splitter.py new file mode 100644 index 0000000..b32d5e0 --- /dev/null +++ b/fgssl/core/splitters/generic/lda_splitter.py @@ -0,0 +1,20 @@ +import numpy as np +from federatedscope.core.splitters import BaseSplitter +from federatedscope.core.splitters.utils import \ + dirichlet_distribution_noniid_slice + + +class LDASplitter(BaseSplitter): + def __init__(self, client_num, alpha=0.5): + self.alpha = alpha + super(LDASplitter, self).__init__(client_num) + + def __call__(self, dataset, prior=None, **kwargs): + dataset = [ds for ds in dataset] + label = np.array([y for x, y in dataset]) + idx_slice = dirichlet_distribution_noniid_slice(label, + self.client_num, + self.alpha, + prior=prior) + data_list = [[dataset[idx] for idx in idxs] for idxs in idx_slice] + return data_list diff --git a/fgssl/core/splitters/graph/__init__.py b/fgssl/core/splitters/graph/__init__.py new file mode 100644 index 0000000..0ba8c6e --- /dev/null +++ b/fgssl/core/splitters/graph/__init__.py @@ -0,0 +1,18 @@ +from federatedscope.core.splitters.graph.louvain_splitter import \ + LouvainSplitter +from federatedscope.core.splitters.graph.random_splitter import RandomSplitter +from federatedscope.core.splitters.graph.reltype_splitter import \ + RelTypeSplitter +from federatedscope.core.splitters.graph.scaffold_splitter import \ + ScaffoldSplitter +from federatedscope.core.splitters.graph.randchunk_splitter import \ + RandChunkSplitter + +from federatedscope.core.splitters.graph.analyzer import Analyzer +from federatedscope.core.splitters.graph.scaffold_lda_splitter import \ + ScaffoldLdaSplitter + +__all__ = [ + 'LouvainSplitter', 'RandomSplitter', 'RelTypeSplitter', 'ScaffoldSplitter', + 'RandChunkSplitter', 'Analyzer', 'ScaffoldLdaSplitter' +] diff --git a/fgssl/core/splitters/graph/analyzer.py b/fgssl/core/splitters/graph/analyzer.py new file mode 100644 index 0000000..7565797 --- /dev/null +++ b/fgssl/core/splitters/graph/analyzer.py @@ -0,0 +1,182 @@ +import torch + +from typing import List +from torch_geometric.data import Data +from torch_geometric.utils import to_networkx, to_dense_adj, dense_to_sparse + + +class Analyzer(object): + r"""Analyzer for raw graph and split subgraphs. + + Arguments: + raw_data (PyG.data): raw graph. + split_data (list): the list for subgraphs split by splitter. + + """ + def __init__(self, raw_data: Data, split_data: List[Data]): + + self.raw_data = raw_data + self.split_data = split_data + + self.raw_graph = to_networkx(raw_data, to_undirected=True) + self.sub_graphs = [ + to_networkx(g, to_undirected=True) for g in split_data + ] + + def num_missing_edge(self): + r""" + + Returns: + the number of missing edge and the rate of missing edge. + + """ + missing_edge = len(self.raw_graph.edges) - self.fl_adj().shape[1] // 2 + rate_missing_edge = missing_edge / len(self.raw_graph.edges) + + return missing_edge, rate_missing_edge + + def fl_adj(self): + r""" + + Returns: + the adj for missing edge ADJ. + + """ + raw_adj = to_dense_adj(self.raw_data.edge_index)[0] + adj = torch.zeros_like(raw_adj) + if 'index_orig' in self.split_data[0]: + for sub_g in self.split_data: + for row, col in sub_g.edge_index.T: + adj[sub_g.index_orig[row.item()]][sub_g.index_orig[ + col.item()]] = 1 + + else: + raise KeyError('index_orig not in Split Data.') + + return dense_to_sparse(adj)[0] + + def fl_data(self): + r""" + + Returns: + the split edge index. + + """ + fl_data = Data() + for key, item in self.raw_data: + if key == 'edge_index': + fl_data[key] = self.fl_adj() + else: + fl_data[key] = item + + return fl_data + + def missing_data(self): + r""" + + Returns: + the graph data built by missing edge index. + + """ + ms_data = Data() + raw_edge_set = {tuple(x) for x in self.raw_data.edge_index.T.numpy()} + split_edge_set = { + tuple(x) + for x in self.fl_data().edge_index.T.numpy() + } + ms_set = raw_edge_set - split_edge_set + for key, item in self.raw_data: + if key == 'edge_index': + ms_data[key] = torch.tensor([list(x) for x in ms_set], + dtype=torch.int64).T + else: + ms_data[key] = item + + return ms_data + + def portion_ms_node(self): + r""" + + Returns: + the proportion of nodes who miss egde. + + """ + cnt_list = [] + ms_set = {x.item() for x in set(self.missing_data().edge_index[0])} + for sub_data in self.split_data: + cnt = 0 + for idx in sub_data.index_orig: + if idx.item() in ms_set: + cnt += 1 + cnt_list.append(cnt / sub_data.num_nodes) + return cnt_list + + def average_clustering(self): + r""" + + Returns: + the average clustering coefficient for the raw G and split G + + """ + import networkx.algorithms.cluster as cluster + + return cluster.average_clustering( + self.raw_graph), cluster.average_clustering( + to_networkx(self.fl_data())) + + def homophily_value(self, edge_index, y): + r""" + + Returns: + calculate homophily_value + + """ + from torch_sparse import SparseTensor + + if isinstance(edge_index, SparseTensor): + row, col, _ = edge_index.coo() + else: + row, col = edge_index + + return int((y[row] == y[col]).sum()) / row.size(0) + + def homophily(self): + r""" + + Returns: + the homophily for the raw G and split G + + """ + + return self.homophily_value(self.raw_data.edge_index, + self.raw_data.y), self.homophily_value( + self.fl_data().edge_index, + self.fl_data().y) + + def hamming_distance_graph(self, data): + r""" + + Returns: + calculate the hamming distance of graph data + + """ + edge_index, x = data.edge_index, data.x + cnt = 0 + for row, col in edge_index.T: + row, col = row.item(), col.item() + cnt += torch.sum(x[row] != x[col]).item() + + return cnt / edge_index.shape[1] + + def hamming(self): + r""" + + Returns: + the average hamming distance of feature for the raw G, split G + and missing edge G + + """ + return self.hamming_distance_graph( + self.raw_data), self.hamming_distance_graph( + self.fl_data()), self.hamming_distance_graph( + self.missing_data()) diff --git a/fgssl/core/splitters/graph/louvain_splitter.py b/fgssl/core/splitters/graph/louvain_splitter.py new file mode 100644 index 0000000..908ae9a --- /dev/null +++ b/fgssl/core/splitters/graph/louvain_splitter.py @@ -0,0 +1,74 @@ +import torch + +from torch_geometric.transforms import BaseTransform +from torch_geometric.utils import to_networkx, from_networkx + +import networkx as nx +import community as community_louvain + +from federatedscope.core.splitters import BaseSplitter + + +class LouvainSplitter(BaseTransform, BaseSplitter): + r""" + Split Data into small data via louvain algorithm. + + Args: + client_num (int): Split data into client_num of pieces. + delta (int): The gap between the number of nodes on the each client. + + """ + def __init__(self, client_num, delta=20): + self.delta = delta + BaseSplitter.__init__(self, client_num) + + def __call__(self, data, **kwargs): + data.index_orig = torch.arange(data.num_nodes) + G = to_networkx( + data, + node_attrs=['x', 'y', 'train_mask', 'val_mask', 'test_mask'], + to_undirected=True) + nx.set_node_attributes(G, + dict([(nid, nid) + for nid in range(nx.number_of_nodes(G))]), + name="index_orig") + partition = community_louvain.best_partition(G) + + cluster2node = {} + for node in partition: + cluster = partition[node] + if cluster not in cluster2node: + cluster2node[cluster] = [node] + else: + cluster2node[cluster].append(node) + + max_len = len(G) // self.client_num - self.delta + max_len_client = len(G) // self.client_num + + tmp_cluster2node = {} + for cluster in cluster2node: + while len(cluster2node[cluster]) > max_len: + tmp_cluster = cluster2node[cluster][:max_len] + tmp_cluster2node[len(cluster2node) + len(tmp_cluster2node) + + 1] = tmp_cluster + cluster2node[cluster] = cluster2node[cluster][max_len:] + cluster2node.update(tmp_cluster2node) + + orderedc2n = (zip(cluster2node.keys(), cluster2node.values())) + orderedc2n = sorted(orderedc2n, key=lambda x: len(x[1]), reverse=True) + + client_node_idx = {idx: [] for idx in range(self.client_num)} + idx = 0 + for (cluster, node_list) in orderedc2n: + while len(node_list) + len( + client_node_idx[idx]) > max_len_client + self.delta: + idx = (idx + 1) % self.client_num + client_node_idx[idx] += node_list + idx = (idx + 1) % self.client_num + + graphs = [] + for owner in client_node_idx: + nodes = client_node_idx[owner] + graphs.append(from_networkx(nx.subgraph(G, nodes))) + + return graphs diff --git a/fgssl/core/splitters/graph/randchunk_splitter.py b/fgssl/core/splitters/graph/randchunk_splitter.py new file mode 100644 index 0000000..07e2e93 --- /dev/null +++ b/fgssl/core/splitters/graph/randchunk_splitter.py @@ -0,0 +1,36 @@ +import numpy as np + +from torch_geometric.transforms import BaseTransform +from federatedscope.core.splitters import BaseSplitter + + +class RandChunkSplitter(BaseTransform, BaseSplitter): + def __init__(self, client_num): + BaseSplitter.__init__(self, client_num) + + def __call__(self, dataset, **kwargs): + r"""Split dataset via random chunk. + + Arguments: + dataset (List or PyG.dataset): The datasets. + + Returns: + data_list (List(List(PyG.data))): Splited dataset via random + chunk split. + """ + data_list = [] + dataset = [ds for ds in dataset] + num_graph = len(dataset) + + # Split dataset + num_graph = len(dataset) + min_size = min(50, int(num_graph / self.client_num)) + + for i in range(self.client_num): + data_list.append(dataset[i * min_size:(i + 1) * min_size]) + for graph in dataset[self.client_num * min_size:]: + client_idx = np.random.randint(low=0, high=self.client_num, + size=1)[0] + data_list[client_idx].append(graph) + + return data_list diff --git a/fgssl/core/splitters/graph/random_splitter.py b/fgssl/core/splitters/graph/random_splitter.py new file mode 100644 index 0000000..a3c12be --- /dev/null +++ b/fgssl/core/splitters/graph/random_splitter.py @@ -0,0 +1,105 @@ +import torch + +from torch_geometric.transforms import BaseTransform +from torch_geometric.utils import to_networkx, from_networkx + +import numpy as np +import networkx as nx + +from federatedscope.core.splitters import BaseSplitter + +EPSILON = 1e-5 + + +class RandomSplitter(BaseTransform, BaseSplitter): + r""" + Split Data into small data via random sampling. + + Args: + client_num (int): Split data into client_num of pieces. + sampling_rate (str): Samples of the unique nodes for each client, + eg. '0.2,0.2,0.2'. + overlapping_rate(float): Additional samples of overlapping data, + eg. '0.4' + drop_edge(float): Drop edges (drop_edge / client_num) for each + client whthin overlapping part. + + """ + def __init__(self, + client_num, + sampling_rate=None, + overlapping_rate=0, + drop_edge=0): + BaseSplitter.__init__(self, client_num) + self.ovlap = overlapping_rate + if sampling_rate is not None: + self.sampling_rate = np.array( + [float(val) for val in sampling_rate.split(',')]) + else: + # Default: Average + self.sampling_rate = (np.ones(client_num) - + self.ovlap) / client_num + + if len(self.sampling_rate) != client_num: + raise ValueError( + f'The client_num ({client_num}) should be equal to the ' + f'lenghth of sampling_rate and overlapping_rate.') + + if abs((sum(self.sampling_rate) + self.ovlap) - 1) > EPSILON: + raise ValueError( + f'The sum of sampling_rate:{self.sampling_rate} and ' + f'overlapping_rate({self.ovlap}) should be 1.') + + self.drop_edge = drop_edge + + def __call__(self, data, **kwargs): + data.index_orig = torch.arange(data.num_nodes) + G = to_networkx( + data, + node_attrs=['x', 'y', 'train_mask', 'val_mask', 'test_mask'], + to_undirected=True) + nx.set_node_attributes(G, + dict([(nid, nid) + for nid in range(nx.number_of_nodes(G))]), + name="index_orig") + + client_node_idx = {idx: [] for idx in range(self.client_num)} + + indices = np.random.permutation(data.num_nodes) + sum_rate = 0 + for idx, rate in enumerate(self.sampling_rate): + client_node_idx[idx] = indices[round(sum_rate * + data.num_nodes):round( + (sum_rate + rate) * + data.num_nodes)] + sum_rate += rate + + if self.ovlap: + ovlap_nodes = indices[round(sum_rate * data.num_nodes):] + for idx in client_node_idx: + client_node_idx[idx] = np.concatenate( + (client_node_idx[idx], ovlap_nodes)) + + # Drop_edge index for each client + if self.drop_edge: + ovlap_graph = nx.Graph(nx.subgraph(G, ovlap_nodes)) + ovlap_edge_ind = np.random.permutation( + ovlap_graph.number_of_edges()) + drop_all = ovlap_edge_ind[:round(ovlap_graph.number_of_edges() * + self.drop_edge)] + drop_client = [ + drop_all[s:s + round(len(drop_all) / self.client_num)] + for s in range(0, len(drop_all), + round(len(drop_all) / self.client_num)) + ] + + graphs = [] + for owner in client_node_idx: + nodes = client_node_idx[owner] + sub_g = nx.Graph(nx.subgraph(G, nodes)) + if self.drop_edge: + sub_g.remove_edges_from( + np.array(ovlap_graph.edges)[drop_client[owner]]) + graphs.append(from_networkx(sub_g)) + + return graphs diff --git a/fgssl/core/splitters/graph/reltype_splitter.py b/fgssl/core/splitters/graph/reltype_splitter.py new file mode 100644 index 0000000..2452add --- /dev/null +++ b/fgssl/core/splitters/graph/reltype_splitter.py @@ -0,0 +1,65 @@ +import torch + +from torch_geometric.data import Data +from torch_geometric.utils import to_undirected +from torch_geometric.transforms import BaseTransform + +from federatedscope.core.splitters.utils import \ + dirichlet_distribution_noniid_slice +from federatedscope.core.splitters import BaseSplitter + + +class RelTypeSplitter(BaseTransform, BaseSplitter): + r""" + Split Data into small data via dirichlet distribution to + generate non-i.i.d data split. + + Arguments: + client_num (int): Split data into client_num of pieces. + alpha (float): parameter controlling the identicalness among clients. + + """ + def __init__(self, client_num, alpha=0.5, realloc_mask=False): + BaseSplitter.__init__(self, client_num) + self.alpha = alpha + self.realloc_mask = realloc_mask + + def __call__(self, data, **kwargs): + data_list = [] + label = data.edge_type.numpy() + idx_slice = dirichlet_distribution_noniid_slice( + label, self.client_num, self.alpha) + # Reallocation train/val/test mask + train_ratio = data.train_edge_mask.sum().item() / data.num_edges + test_ratio = data.test_edge_mask.sum().item() / data.num_edges + for idx_j in idx_slice: + edge_index = data.edge_index.T[idx_j].T + edge_type = data.edge_type[idx_j] + train_edge_mask = data.train_edge_mask[idx_j] + valid_edge_mask = data.valid_edge_mask[idx_j] + test_edge_mask = data.test_edge_mask[idx_j] + if self.realloc_mask: + num_edges = edge_index.size(-1) + indices = torch.randperm(num_edges) + train_edge_mask = torch.zeros(num_edges, dtype=torch.bool) + train_edge_mask[indices[:round(train_ratio * + num_edges)]] = True + valid_edge_mask = torch.zeros(num_edges, dtype=torch.bool) + valid_edge_mask[ + indices[round(train_ratio * + num_edges):-round(test_ratio * + num_edges)]] = True + test_edge_mask = torch.zeros(num_edges, dtype=torch.bool) + test_edge_mask[indices[-round(test_ratio * num_edges):]] = True + sub_g = Data(x=data.x, + edge_index=edge_index, + index_orig=data.index_orig, + edge_type=edge_type, + train_edge_mask=train_edge_mask, + valid_edge_mask=valid_edge_mask, + test_edge_mask=test_edge_mask, + input_edge_index=to_undirected( + edge_index.T[train_edge_mask].T)) + data_list.append(sub_g) + + return data_list diff --git a/fgssl/core/splitters/graph/scaffold_lda_splitter.py b/fgssl/core/splitters/graph/scaffold_lda_splitter.py new file mode 100644 index 0000000..87b119a --- /dev/null +++ b/fgssl/core/splitters/graph/scaffold_lda_splitter.py @@ -0,0 +1,180 @@ +import logging +import numpy as np +import torch + +from rdkit import Chem +from rdkit import RDLogger +from federatedscope.core.splitters.utils import \ + dirichlet_distribution_noniid_slice +from federatedscope.core.splitters.graph.scaffold_splitter import \ + generate_scaffold +from federatedscope.core.splitters import BaseSplitter + +logger = logging.getLogger(__name__) + +RDLogger.DisableLog('rdApp.*') + + +class GenFeatures: + r"""Implementation of 'CanonicalAtomFeaturizer' and + 'CanonicalBondFeaturizer' in DGL. + Source: https://lifesci.dgl.ai/_modules/dgllife/utils/featurizers.html + + Arguments: + data: PyG.data in PyG.dataset. + + Returns: + data: PyG.data, data passing featurizer. + + """ + def __init__(self): + self.symbols = [ + 'C', 'N', 'O', 'S', 'F', 'Si', 'P', 'Cl', 'Br', 'Mg', 'Na', 'Ca', + 'Fe', 'As', 'Al', 'I', 'B', 'V', 'K', 'Tl', 'Yb', 'Sb', 'Sn', 'Ag', + 'Pd', 'Co', 'Se', 'Ti', 'Zn', 'H', 'Li', 'Ge', 'Cu', 'Au', 'Ni', + 'Cd', 'In', 'Mn', 'Zr', 'Cr', 'Pt', 'Hg', 'Pb', 'other' + ] + + self.hybridizations = [ + Chem.rdchem.HybridizationType.SP, + Chem.rdchem.HybridizationType.SP2, + Chem.rdchem.HybridizationType.SP3, + Chem.rdchem.HybridizationType.SP3D, + Chem.rdchem.HybridizationType.SP3D2, + 'other', + ] + + self.stereos = [ + Chem.rdchem.BondStereo.STEREONONE, + Chem.rdchem.BondStereo.STEREOANY, + Chem.rdchem.BondStereo.STEREOZ, + Chem.rdchem.BondStereo.STEREOE, + Chem.rdchem.BondStereo.STEREOCIS, + Chem.rdchem.BondStereo.STEREOTRANS, + ] + + def __call__(self, data, **kwargs): + mol = Chem.MolFromSmiles(data.smiles) + + xs = [] + for atom in mol.GetAtoms(): + symbol = [0.] * len(self.symbols) + if atom.GetSymbol() in self.symbols: + symbol[self.symbols.index(atom.GetSymbol())] = 1. + else: + symbol[self.symbols.index('other')] = 1. + degree = [0.] * 10 + degree[atom.GetDegree()] = 1. + implicit = [0.] * 6 + implicit[atom.GetImplicitValence()] = 1. + formal_charge = atom.GetFormalCharge() + radical_electrons = atom.GetNumRadicalElectrons() + hybridization = [0.] * len(self.hybridizations) + if atom.GetHybridization() in self.hybridizations: + hybridization[self.hybridizations.index( + atom.GetHybridization())] = 1. + else: + hybridization[self.hybridizations.index('other')] = 1. + aromaticity = 1. if atom.GetIsAromatic() else 0. + hydrogens = [0.] * 5 + hydrogens[atom.GetTotalNumHs()] = 1. + + x = torch.tensor(symbol + degree + implicit + [formal_charge] + + [radical_electrons] + hybridization + + [aromaticity] + hydrogens) + xs.append(x) + + data.x = torch.stack(xs, dim=0) + + edge_attrs = [] + for bond in mol.GetBonds(): + bond_type = bond.GetBondType() + single = 1. if bond_type == Chem.rdchem.BondType.SINGLE else 0. + double = 1. if bond_type == Chem.rdchem.BondType.DOUBLE else 0. + triple = 1. if bond_type == Chem.rdchem.BondType.TRIPLE else 0. + aromatic = 1. if bond_type == Chem.rdchem.BondType.AROMATIC else 0. + conjugation = 1. if bond.GetIsConjugated() else 0. + ring = 1. if bond.IsInRing() else 0. + stereo = [0.] * 6 + stereo[self.stereos.index(bond.GetStereo())] = 1. + + edge_attr = torch.tensor( + [single, double, triple, aromatic, conjugation, ring] + stereo) + + edge_attrs += [edge_attr, edge_attr] + + if len(edge_attrs) == 0: + data.edge_index = torch.zeros((2, 0), dtype=torch.long) + data.edge_attr = torch.zeros((0, 10), dtype=torch.float) + else: + num_atoms = mol.GetNumAtoms() + feats = torch.stack(edge_attrs, dim=0) + feats = torch.cat([feats, torch.zeros(feats.shape[0], 1)], dim=1) + self_loop_feats = torch.zeros(num_atoms, feats.shape[1]) + self_loop_feats[:, -1] = 1 + feats = torch.cat([feats, self_loop_feats], dim=0) + data.edge_attr = feats + + return data + + +def gen_scaffold_lda_split(dataset, client_num=5, alpha=0.1): + r""" + return dict{ID:[idxs]} + """ + logger.info('Scaffold split might take minutes, please wait...') + scaffolds = {} + for idx, data in enumerate(dataset): + smiles = data.smiles + _ = Chem.MolFromSmiles(smiles) + scaffold = generate_scaffold(smiles) + if scaffold not in scaffolds: + scaffolds[scaffold] = [idx] + else: + scaffolds[scaffold].append(idx) + # Sort from largest to smallest scaffold sets + scaffolds = {key: sorted(value) for key, value in scaffolds.items()} + scaffold_list = [ + list(scaffold_set) + for (scaffold, + scaffold_set) in sorted(scaffolds.items(), + key=lambda x: (len(x[1]), x[1][0]), + reverse=True) + ] + label = np.zeros(len(dataset)) + for i in range(len(scaffold_list)): + label[scaffold_list[i]] = i + 1 + label = torch.LongTensor(label) + # Split data to list + idx_slice = dirichlet_distribution_noniid_slice(label, client_num, alpha) + return idx_slice + + +class ScaffoldLdaSplitter(BaseSplitter): + r"""First adopt scaffold splitting and then assign the samples to + clients according to Latent Dirichlet Allocation. + + Arguments: + dataset (List or PyG.dataset): The molecular datasets. + alpha (float): Partition hyperparameter in LDA, smaller alpha + generates more extreme heterogeneous scenario. + + Returns: + data_list (List(List(PyG.data))): Splited dataset via scaffold split. + + """ + def __init__(self, client_num, alpha): + super(ScaffoldLdaSplitter, self).__init__(client_num) + self.alpha = alpha + + def __call__(self, dataset): + featurizer = GenFeatures() + data = [] + for ds in dataset: + ds = featurizer(ds) + data.append(ds) + dataset = data + idx_slice = gen_scaffold_lda_split(dataset, self.client_num, + self.alpha) + data_list = [[dataset[idx] for idx in idxs] for idxs in idx_slice] + return data_list diff --git a/fgssl/core/splitters/graph/scaffold_splitter.py b/fgssl/core/splitters/graph/scaffold_splitter.py new file mode 100644 index 0000000..db41779 --- /dev/null +++ b/fgssl/core/splitters/graph/scaffold_splitter.py @@ -0,0 +1,69 @@ +import logging +import numpy as np + +from rdkit import Chem +from rdkit import RDLogger +from rdkit.Chem.Scaffolds import MurckoScaffold + +from federatedscope.core.splitters import BaseSplitter + +logger = logging.getLogger(__name__) + +RDLogger.DisableLog('rdApp.*') + + +def generate_scaffold(smiles, include_chirality=False): + """return scaffold string of target molecule""" + mol = Chem.MolFromSmiles(smiles) + scaffold = MurckoScaffold\ + .MurckoScaffoldSmiles(mol=mol, includeChirality=include_chirality) + return scaffold + + +def gen_scaffold_split(dataset, client_num=5): + r""" + return dict{ID:[idxs]} + """ + logger.info('Scaffold split might take minutes, please wait...') + scaffolds = {} + for idx, data in enumerate(dataset): + smiles = data.smiles + _ = Chem.MolFromSmiles(smiles) + scaffold = generate_scaffold(smiles) + if scaffold not in scaffolds: + scaffolds[scaffold] = [idx] + else: + scaffolds[scaffold].append(idx) + # Sort from largest to smallest scaffold sets + scaffolds = {key: sorted(value) for key, value in scaffolds.items()} + scaffold_list = [ + list(scaffold_set) + for (scaffold, + scaffold_set) in sorted(scaffolds.items(), + key=lambda x: (len(x[1]), x[1][0]), + reverse=True) + ] + scaffold_idxs = sum(scaffold_list, []) + # Split data to list + splits = np.array_split(scaffold_idxs, client_num) + return [splits[ID] for ID in range(client_num)] + + +class ScaffoldSplitter(BaseSplitter): + def __init__(self, client_num): + super(ScaffoldSplitter, self).__init__(client_num) + + def __call__(self, dataset, **kwargs): + r"""Split dataset with smiles string into scaffold split + + Arguments: + dataset (List or PyG.dataset): The molecular datasets. + + Returns: + data_list (List(List(PyG.data))): Splited dataset via scaffold + split. + """ + dataset = [ds for ds in dataset] + idx_slice = gen_scaffold_split(dataset) + data_list = [[dataset[idx] for idx in idxs] for idxs in idx_slice] + return data_list diff --git a/fgssl/core/splitters/utils.py b/fgssl/core/splitters/utils.py new file mode 100644 index 0000000..9be87d2 --- /dev/null +++ b/fgssl/core/splitters/utils.py @@ -0,0 +1,87 @@ +import numpy as np + + +def _split_according_to_prior(label, client_num, prior): + assert client_num == len(prior) + classes = len(np.unique(label)) + assert classes == len(np.unique(np.concatenate(prior, 0))) + + # counting + frequency = np.zeros(shape=(client_num, classes)) + for idx, client_prior in enumerate(prior): + for each in client_prior: + frequency[idx][each] += 1 + sum_frequency = np.sum(frequency, axis=0) + + idx_slice = [[] for _ in range(client_num)] + for k in range(classes): + idx_k = np.where(label == k)[0] + np.random.shuffle(idx_k) + nums_k = np.ceil(frequency[:, k] / sum_frequency[k] * + len(idx_k)).astype(int) + while len(idx_k) < np.sum(nums_k): + random_client = np.random.choice(range(client_num)) + if nums_k[random_client] > 0: + nums_k[random_client] -= 1 + assert len(idx_k) == np.sum(nums_k) + idx_slice = [ + idx_j + idx.tolist() for idx_j, idx in zip( + idx_slice, np.split(idx_k, + np.cumsum(nums_k)[:-1])) + ] + + for i in range(len(idx_slice)): + np.random.shuffle(idx_slice[i]) + return idx_slice + + +def dirichlet_distribution_noniid_slice(label, + client_num, + alpha, + min_size=1, + prior=None): + r"""Get sample index list for each client from the Dirichlet distribution. + https://github.com/FedML-AI/FedML/blob/master/fedml_core/non_iid + partition/noniid_partition.py + + Arguments: + label (np.array): Label list to be split. + client_num (int): Split label into client_num parts. + alpha (float): alpha of LDA. + min_size (int): min number of sample in each client + Returns: + idx_slice (List): List of splited label index slice. + """ + if len(label.shape) != 1: + raise ValueError('Only support single-label tasks!') + + if prior is not None: + return _split_according_to_prior(label, client_num, prior) + + num = len(label) + classes = len(np.unique(label)) + assert num > client_num * min_size, f'The number of sample should be ' \ + f'greater than' \ + f' {client_num * min_size}.' + size = 0 + while size < min_size: + idx_slice = [[] for _ in range(client_num)] + for k in range(classes): + # for label k + idx_k = np.where(label == k)[0] + np.random.shuffle(idx_k) + prop = np.random.dirichlet(np.repeat(alpha, client_num)) + # prop = np.array([ + # p * (len(idx_j) < num / client_num) + # for p, idx_j in zip(prop, idx_slice) + # ]) + # prop = prop / sum(prop) + prop = (np.cumsum(prop) * len(idx_k)).astype(int)[:-1] + idx_slice = [ + idx_j + idx.tolist() + for idx_j, idx in zip(idx_slice, np.split(idx_k, prop)) + ] + size = min([len(idx_j) for idx_j in idx_slice]) + for i in range(client_num): + np.random.shuffle(idx_slice[i]) + return idx_slice diff --git a/fgssl/core/strategy.py b/fgssl/core/strategy.py new file mode 100644 index 0000000..5c4a80b --- /dev/null +++ b/fgssl/core/strategy.py @@ -0,0 +1,23 @@ +import sys + + +class Strategy(object): + def __init__(self, stg_type=None, threshold=0): + self._stg_type = stg_type + self._threshold = threshold + + @property + def stg_type(self): + return self._stg_type + + @stg_type.setter + def stg_type(self, value): + self._stg_type = value + + @property + def threshold(self): + return self._threshold + + @threshold.setter + def threshold(self, value): + self._threshold = value diff --git a/fgssl/core/trainers/__init__.py b/fgssl/core/trainers/__init__.py new file mode 100644 index 0000000..de9221e --- /dev/null +++ b/fgssl/core/trainers/__init__.py @@ -0,0 +1,19 @@ +from federatedscope.core.trainers.base_trainer import BaseTrainer +from federatedscope.core.trainers.trainer import Trainer +from federatedscope.core.trainers.torch_trainer import GeneralTorchTrainer +from federatedscope.core.trainers.trainer_multi_model import \ + GeneralMultiModelTrainer +from federatedscope.core.trainers.trainer_pFedMe import wrap_pFedMeTrainer +from federatedscope.core.trainers.trainer_Ditto import wrap_DittoTrainer +from federatedscope.core.trainers.trainer_FedEM import FedEMTrainer +from federatedscope.core.trainers.context import Context +from federatedscope.core.trainers.trainer_fedprox import wrap_fedprox_trainer +from federatedscope.core.trainers.trainer_nbafl import wrap_nbafl_trainer, \ + wrap_nbafl_server + +__all__ = [ + 'Trainer', 'Context', 'GeneralTorchTrainer', 'GeneralMultiModelTrainer', + 'wrap_pFedMeTrainer', 'wrap_DittoTrainer', 'FedEMTrainer', + 'wrap_fedprox_trainer', 'wrap_nbafl_trainer', 'wrap_nbafl_server', + 'BaseTrainer' +] diff --git a/fgssl/core/trainers/base_trainer.py b/fgssl/core/trainers/base_trainer.py new file mode 100644 index 0000000..50bf572 --- /dev/null +++ b/fgssl/core/trainers/base_trainer.py @@ -0,0 +1,29 @@ +import abc + + +class BaseTrainer(abc.ABC): + def __init__(self, model, data, device, **kwargs): + self.model = model + self.data = data + self.device = device + self.kwargs = kwargs + + @abc.abstractmethod + def train(self): + raise NotImplementedError + + @abc.abstractmethod + def evaluate(self, target_data_split_name='test'): + raise NotImplementedError + + @abc.abstractmethod + def update(self, model_parameters, strict=False): + raise NotImplementedError + + @abc.abstractmethod + def get_model_para(self): + raise NotImplementedError + + @abc.abstractmethod + def print_trainer_meta_info(self): + raise NotImplementedError diff --git a/fgssl/core/trainers/context.py b/fgssl/core/trainers/context.py new file mode 100644 index 0000000..e612339 --- /dev/null +++ b/fgssl/core/trainers/context.py @@ -0,0 +1,269 @@ +import logging +import collections + +from federatedscope.core.auxiliaries.criterion_builder import get_criterion +from federatedscope.core.auxiliaries.model_builder import \ + get_trainable_para_names +from federatedscope.core.auxiliaries.regularizer_builder import get_regularizer +from federatedscope.core.auxiliaries.enums import MODE +from federatedscope.core.auxiliaries.utils import calculate_batch_epoch_num +from federatedscope.core.data import ClientData + +logger = logging.getLogger(__name__) + + +class LifecycleDict(dict): + """A customized dict that provides lifecycle management + Arguments: + init_dict: initialized dict + """ + __delattr__ = dict.__delitem__ + + def __getattr__(self, item): + try: + return self[item] + except KeyError: + raise AttributeError("Attribute {} is not found".format(item)) + + def __init__(self, init_dict=None): + if init_dict is not None: + super(LifecycleDict, self).__init__(init_dict) + self.lifecycles = collections.defaultdict(set) + + def __setattr__(self, key, value): + if isinstance(value, CtxVar): + self.lifecycles[value.lifecycle].add(key) + super(LifecycleDict, self).__setitem__(key, value.obj) + else: + super(LifecycleDict, self).__setitem__(key, value) + + def clear(self, lifecycle): + keys = list(self.lifecycles[lifecycle]) + for key in keys: + if key in self: + del self[key] + self.lifecycles[lifecycle].remove(key) + + +class Context(LifecycleDict): + """Record and pass variables among different hook functions + Arguments: + model: training model + cfg: config + data (dict): a dict contains train/val/test dataset or dataloader + device: running device + init_dict (dict): a dict used to initialize the instance of Context + init_attr (bool): if set up the static variables + Note: + - The variables within an instance of class `Context` + can be set/get as an attribute. + ``` + ctx.${NAME_VARIABLE} = ${VALUE_VARIABLE} + ``` + where `${NAME_VARIABLE}` and `${VALUE_VARIABLE}` + is the name and value of the variable. + + - To achieve automatically lifecycle management, you can + wrap the variable with `CtxVar` and a lifecycle parameter + as follows + ``` + ctx.${NAME_VARIABLE} = CtxVar(${VALUE_VARIABLE}, ${LFECYCLE}) + ``` + The parameter `${LFECYCLE}` can be chosen from `LIFECYCLE.BATCH`, + `LIFECYCLE.EPOCH` and `LIFECYCLE.ROUTINE`. + Then the variable `ctx.${NAME_VARIABLE}` will be deleted at + the end of the corresponding stage + - `LIFECYCLE.BATCH`: the variables will + be deleted after running a batch + - `LIFECYCLE.EPOCH`: the variables will be + deleted after running a epoch + - `LIFECYCLE.ROUTINE`: the variables will be + deleted after running a routine + More details please refer to our + [tutorial](https://federatedscope.io/docs/trainer/). + + - Context also maintains some special variables across + different routines, like + - cfg + - model + - data + - device + - ${split}_data: the dataset object of data split + named `${split}` + - ${split}_loader: the data loader object of data + split named `${split}` + - num_${split}_data: the number of examples within + the dataset named `${split}` + """ + def __init__(self, + model, + cfg, + data=None, + device=None, + init_dict=None, + init_attr=True): + super(Context, self).__init__(init_dict) + + self.cfg = cfg + self.model = model + self.data = data + self.device = device + + self.cur_mode = None + self.mode_stack = list() + + self.cur_split = None + self.split_stack = list() + + self.lifecycles = collections.defaultdict(set) + + if init_attr: + # setup static variables for training/evaluation + self.setup_vars() + + def setup_vars(self): + if self.cfg.backend == 'torch': + self.trainable_para_names = get_trainable_para_names(self.model) + self.criterion = get_criterion(self.cfg.criterion.type, + self.device) + self.regularizer = get_regularizer(self.cfg.regularizer.type) + self.grad_clip = self.cfg.grad.grad_clip + if isinstance(self.data, ClientData): + self.data.setup(self.cfg) + elif self.cfg.backend == 'tensorflow': + self.trainable_para_names = self.model.trainable_variables() + self.criterion = None + self.regularizer = None + self.optimizer = None + self.grad_clip = None + + # Process training data + if self.get('train_data', None) is not None or self.get( + 'train_loader', None) is not None: + # Calculate the number of update steps during training given the + # local_update_steps + self.num_train_batch, self.num_train_batch_last_epoch, \ + self.num_train_epoch, self.num_total_train_batch = \ + calculate_batch_epoch_num( + self.cfg.train.local_update_steps, + self.cfg.train.batch_or_epoch, self.num_train_data, + self.cfg.dataloader.batch_size, + self.cfg.dataloader.drop_last) + + # Process evaluation data + for mode in ["val", "test"]: + setattr(self, "num_{}_epoch".format(mode), 1) + if self.get("{}_data".format(mode)) is not None or self.get( + "{}_loader".format(mode)) is not None: + setattr( + self, "num_{}_batch".format(mode), + getattr(self, "num_{}_data".format(mode)) // + self.cfg.dataloader.batch_size + + int(not self.cfg.dataloader.drop_last and bool( + getattr(self, "num_{}_data".format(mode)) % + self.cfg.dataloader.batch_size))) + + def track_mode(self, mode): + self.mode_stack.append(mode) + self.cur_mode = self.mode_stack[-1] + self.change_mode(self.cur_mode) + + def reset_mode(self): + self.mode_stack.pop() + self.cur_mode = self.mode_stack[-1] if len( + self.mode_stack) != 0 else None + if len(self.mode_stack) != 0: + self.change_mode(self.cur_mode) + + def change_mode(self, mode): + # change state + if self.cfg.backend == 'torch': + getattr( + self.model, 'train' + if mode == MODE.TRAIN or mode == MODE.FINETUNE else 'eval')() + else: + pass + + def track_split(self, dataset): + # stack-style to enable mixture usage such as evaluation on train + # dataset + self.split_stack.append(dataset) + self.cur_split = self.split_stack[-1] + + def reset_split(self): + self.split_stack.pop() + self.cur_split = self.split_stack[-1] if \ + len(self.split_stack) != 0 else None + + def check_split(self, target_split_name, skip=False): + if self.get(f"{target_split_name}_data") is None and self.get( + f"{target_split_name}_loader") is None: + if skip: + logger.warning( + f"No {target_split_name}_data or" + f" {target_split_name}_loader in the trainer, " + f"will skip evaluation" + f"If this is not the case you want, please check " + f"whether there is typo for the name") + return False + else: + raise ValueError(f"No {target_split_name}_data or" + f" {target_split_name}_loader in the trainer") + else: + return True + + +class CtxVar(object): + """Basic variable class + Arguments: + lifecycle: specific lifecycle of the attribute + """ + + LIEFTCYCLES = ["batch", "epoch", "routine", None] + + def __init__(self, obj, lifecycle=None): + assert lifecycle in CtxVar.LIEFTCYCLES + self.obj = obj + self.lifecycle = lifecycle + + +def lifecycle(lifecycle): + """Manage the lifecycle of the variables within context, + and blind these operations from user. + Args: + lifecycle: the type of lifecycle, choose from "batch/epoch/routine" + """ + if lifecycle == "routine": + + def decorate(func): + def wrapper(self, mode, hooks_set, dataset_name=None): + self.ctx.track_mode(mode) + self.ctx.track_split(dataset_name or mode) + + res = func(self, mode, hooks_set, dataset_name) + + # Clear the variables at the end of lifecycles + self.ctx.clear(lifecycle) + + # rollback the model and data_split + self.ctx.reset_mode() + self.ctx.reset_split() + + # Move the model into CPU to avoid memory leak + self.discharge_model() + + return res + + return wrapper + else: + + def decorate(func): + def wrapper(self, *args, **kwargs): + res = func(self, *args, **kwargs) + # Clear the variables at the end of lifecycles + self.ctx.clear(lifecycle) + return res + + return wrapper + + return decorate diff --git a/fgssl/core/trainers/tf_trainer.py b/fgssl/core/trainers/tf_trainer.py new file mode 100644 index 0000000..5493111 --- /dev/null +++ b/fgssl/core/trainers/tf_trainer.py @@ -0,0 +1,152 @@ +import tensorflow as tf + +import numpy as np +from federatedscope.core.trainers import Trainer +from federatedscope.core.auxiliaries.enums import MODE +from federatedscope.core.auxiliaries.utils import batch_iter +from federatedscope.core.trainers.context import CtxVar +from federatedscope.core.auxiliaries.enums import LIFECYCLE + + +class GeneralTFTrainer(Trainer): + def train(self, target_data_split_name="train", hooks_set=None): + hooks_set = self.hooks_in_train if hooks_set is None else hooks_set + + self.ctx.check_split(target_data_split_name) + + num_samples = self._run_routine(MODE.TRAIN, hooks_set, + target_data_split_name) + + # TODO: The return values should be more flexible? Now: sample_num, + # model_para, results={k:v} + + return num_samples, self.ctx.model.state_dict(), self.ctx.eval_metrics + + def parse_data(self, data): + """Populate "{}_data", "{}_loader" and "num_{}_data" for different + modes + + """ + init_dict = dict() + if isinstance(data, dict): + for mode in ["train", "val", "test"]: + init_dict["{}_data".format(mode)] = None + init_dict["{}_loader".format(mode)] = None + init_dict["num_{}_data".format(mode)] = 0 + if data.get(mode, None) is not None: + init_dict["{}_data".format(mode)] = data.get(mode) + init_dict["num_{}_data".format(mode)] = len(data.get(mode)) + else: + raise TypeError("Type of data should be dict.") + return init_dict + + def register_default_hooks_train(self): + self.register_hook_in_train(self._hook_on_fit_start_init, + "on_fit_start") + self.register_hook_in_train(self._hook_on_epoch_start, + "on_epoch_start") + self.register_hook_in_train(self._hook_on_batch_start_init, + "on_batch_start") + self.register_hook_in_train(self._hook_on_batch_forward, + "on_batch_forward") + self.register_hook_in_train(self._hook_on_batch_forward_regularizer, + "on_batch_forward") + self.register_hook_in_train(self._hook_on_batch_backward, + "on_batch_backward") + self.register_hook_in_train(self._hook_on_batch_end, "on_batch_end") + self.register_hook_in_train(self._hook_on_fit_end, "on_fit_end") + + def register_default_hooks_eval(self): + # test/val + self.register_hook_in_eval(self._hook_on_fit_start_init, + "on_fit_start") + self.register_hook_in_eval(self._hook_on_epoch_start, "on_epoch_start") + self.register_hook_in_eval(self._hook_on_batch_start_init, + "on_batch_start") + self.register_hook_in_eval(self._hook_on_batch_forward, + "on_batch_forward") + self.register_hook_in_eval(self._hook_on_batch_end, "on_batch_end") + self.register_hook_in_eval(self._hook_on_fit_end, "on_fit_end") + + def _hook_on_fit_start_init(self, ctx): + # prepare model + ctx.model.to(ctx.device) + + # prepare statistics + ctx.loss_batch_total = CtxVar(0., LIFECYCLE.ROUTINE) + ctx.loss_regular_total = CtxVar(0., LIFECYCLE.ROUTINE) + ctx.num_samples = CtxVar(0, LIFECYCLE.ROUTINE) + ctx.ys_true = CtxVar([], LIFECYCLE.ROUTINE) + ctx.ys_prob = CtxVar([], LIFECYCLE.ROUTINE) + + def _hook_on_epoch_start(self, ctx): + # prepare dataloader + setattr(ctx, "{}_loader".format(ctx.cur_split), + batch_iter(ctx.get("{}_data".format(ctx.cur_split)))) + + def _hook_on_batch_start_init(self, ctx): + # prepare data batch + try: + ctx.data_batch = next(ctx.get("{}_loader".format(ctx.cur_split))) + except StopIteration: + raise StopIteration + + def _hook_on_batch_forward(self, ctx): + + ctx.optimizer = ctx.model.optimizer + + ctx.batch_size = len(ctx.data_batch) + + with ctx.model.graph.as_default(): + with ctx.model.sess.as_default(): + feed_dict = { + ctx.model.input_x: ctx.data_batch['x'], + ctx.model.input_y: ctx.data_batch['y'] + } + _, batch_loss, y_true, y_prob = ctx.model.sess.run( + [ + ctx.model.train_op, ctx.model.losses, + ctx.model.input_y, ctx.model.out + ], + feed_dict=feed_dict) + ctx.loss_batch = batch_loss + ctx.y_true = CtxVar(y_true, LIFECYCLE.BATCH) + ctx.y_prob = CtxVar(y_prob, LIFECYCLE.BATCH) + + def _hook_on_batch_forward_regularizer(self, ctx): + pass + + def _hook_on_batch_backward(self, ctx): + pass + + def _hook_on_batch_end(self, ctx): + # TODO: the same with the torch_trainer + # update statistics + ctx.num_samples += ctx.batch_size + ctx.loss_batch_total += ctx.loss_batch + ctx.loss_regular_total += float(ctx.get("loss_regular", 0.)) + + # cache label for evaluate + ctx.ys_true.append(ctx.y_true.detach().cpu().numpy()) + ctx.ys_prob.append(ctx.y_prob.detach().cpu().numpy()) + + def _hook_on_fit_end(self, ctx): + """Evaluate metrics. + + """ + ctx.ys_true = CtxVar(np.concatenate(ctx.ys_true), LIFECYCLE.ROUTINE) + ctx.ys_prob = CtxVar(np.concatenate(ctx.ys_prob), LIFECYCLE.ROUTINE) + results = self.metric_calculator.eval(ctx) + setattr(ctx, 'eval_metrics', results) + + def update(self, model_parameters, strict=False): + self.ctx.model.load_state_dict(model_parameters, strict=strict) + + def save_model(self, path, cur_round=-1): + pass + + def load_model(self, path): + pass + + def discharge_model(self): + pass diff --git a/fgssl/core/trainers/torch_trainer.py b/fgssl/core/trainers/torch_trainer.py new file mode 100644 index 0000000..a5c2a09 --- /dev/null +++ b/fgssl/core/trainers/torch_trainer.py @@ -0,0 +1,316 @@ +import os +import logging + +import numpy as np +try: + import torch + from torch.utils.data import DataLoader, Dataset +except ImportError: + torch = None + DataLoader = None + Dataset = None + +from federatedscope.core.auxiliaries.enums import MODE +from federatedscope.core.auxiliaries.enums import LIFECYCLE +from federatedscope.core.auxiliaries.optimizer_builder import get_optimizer +from federatedscope.core.auxiliaries.scheduler_builder import get_scheduler +from federatedscope.core.trainers.trainer import Trainer +from federatedscope.core.trainers.context import CtxVar +from federatedscope.core.data.wrap_dataset import WrapDataset +from federatedscope.core.auxiliaries.dataloader_builder import get_dataloader +from federatedscope.core.auxiliaries.ReIterator import ReIterator +from federatedscope.core.auxiliaries.utils import param2tensor, \ + merge_param_dict +from federatedscope.core.monitors.monitor import Monitor + +logger = logging.getLogger(__name__) + + +class GeneralTorchTrainer(Trainer): + def get_model_para(self): + return self._param_filter( + self.ctx.model.state_dict() if self.cfg.federate. + share_local_model else self.ctx.model.cpu().state_dict()) + + def parse_data(self, data): + """Populate "${split}_data", "${split}_loader" and "num_${ + split}_data" for different data splits + + """ + init_dict = dict() + if isinstance(data, dict): + for split in data.keys(): + if split not in ['train', 'val', 'test']: + continue + init_dict["{}_data".format(split)] = None + init_dict["{}_loader".format(split)] = None + init_dict["num_{}_data".format(split)] = 0 + if data.get(split, None) is not None: + if isinstance(data.get(split), Dataset): + init_dict["{}_data".format(split)] = data.get(split) + init_dict["num_{}_data".format(split)] = len( + data.get(split)) + elif isinstance(data.get(split), DataLoader): + init_dict["{}_loader".format(split)] = data.get(split) + init_dict["num_{}_data".format(split)] = len( + data.get(split).dataset) + elif isinstance(data.get(split), dict): + init_dict["{}_data".format(split)] = data.get(split) + init_dict["num_{}_data".format(split)] = len( + data.get(split)['y']) + else: + raise TypeError("Type {} is not supported.".format( + type(data.get(split)))) + else: + raise TypeError("Type of data should be dict.") + return init_dict + + def update(self, model_parameters, strict=False): + """ + Called by the FL client to update the model parameters + Arguments: + model_parameters (dict): PyTorch Module object's state_dict. + """ + for key in model_parameters: + model_parameters[key] = param2tensor(model_parameters[key]) + # Due to lazy load, we merge two state dict + merged_param = merge_param_dict(self.ctx.model.state_dict().copy(), + self._param_filter(model_parameters)) + self.ctx.model.load_state_dict(merged_param, strict=strict) + + def evaluate(self, target_data_split_name="test"): + with torch.no_grad(): + super(GeneralTorchTrainer, self).evaluate(target_data_split_name) + + return self.ctx.eval_metrics + + def register_default_hooks_train(self): + self.register_hook_in_train(self._hook_on_fit_start_init, + "on_fit_start") + self.register_hook_in_train( + self._hook_on_fit_start_calculate_model_size, "on_fit_start") + self.register_hook_in_train(self._hook_on_epoch_start, + "on_epoch_start") + self.register_hook_in_train(self._hook_on_batch_start_init, + "on_batch_start") + self.register_hook_in_train(self._hook_on_batch_forward, + "on_batch_forward") + self.register_hook_in_train(self._hook_on_batch_forward_regularizer, + "on_batch_forward") + self.register_hook_in_train(self._hook_on_batch_forward_flop_count, + "on_batch_forward") + self.register_hook_in_train(self._hook_on_batch_backward, + "on_batch_backward") + self.register_hook_in_train(self._hook_on_batch_end, "on_batch_end") + self.register_hook_in_train(self._hook_on_fit_end, "on_fit_end") + + def register_default_hooks_ft(self): + self.register_hook_in_ft(self._hook_on_fit_start_init, "on_fit_start") + self.register_hook_in_ft(self._hook_on_fit_start_calculate_model_size, + "on_fit_start") + self.register_hook_in_ft(self._hook_on_epoch_start, "on_epoch_start") + self.register_hook_in_ft(self._hook_on_batch_start_init, + "on_batch_start") + self.register_hook_in_ft(self._hook_on_batch_forward, + "on_batch_forward") + self.register_hook_in_ft(self._hook_on_batch_forward_regularizer, + "on_batch_forward") + self.register_hook_in_ft(self._hook_on_batch_forward_flop_count, + "on_batch_forward") + self.register_hook_in_ft(self._hook_on_batch_backward, + "on_batch_backward") + self.register_hook_in_ft(self._hook_on_batch_end, "on_batch_end") + self.register_hook_in_ft(self._hook_on_fit_end, "on_fit_end") + + def register_default_hooks_eval(self): + # test/val + self.register_hook_in_eval(self._hook_on_fit_start_init, + "on_fit_start") + self.register_hook_in_eval(self._hook_on_epoch_start, "on_epoch_start") + self.register_hook_in_eval(self._hook_on_batch_start_init, + "on_batch_start") + self.register_hook_in_eval(self._hook_on_batch_forward, + "on_batch_forward") + self.register_hook_in_eval(self._hook_on_batch_end, "on_batch_end") + self.register_hook_in_eval(self._hook_on_fit_end, "on_fit_end") + + def _hook_on_fit_start_init(self, ctx): + # prepare model and optimizer + ctx.model.to(ctx.device) + + if ctx.cur_mode in [MODE.TRAIN, MODE.FINETUNE]: + # Initialize optimizer here to avoid the reuse of optimizers + # across different routines + ctx.optimizer = get_optimizer(ctx.model, + **ctx.cfg[ctx.cur_mode].optimizer) + ctx.scheduler = get_scheduler(ctx.optimizer, + **ctx.cfg[ctx.cur_mode].scheduler) + + # TODO: the number of batch and epoch is decided by the current mode + # and data split, so the number of batch and epoch should be + # initialized at the beginning of the routine + + # prepare statistics + ctx.loss_batch_total = CtxVar(0., LIFECYCLE.ROUTINE) + ctx.loss_regular_total = CtxVar(0., LIFECYCLE.ROUTINE) + ctx.num_samples = CtxVar(0, LIFECYCLE.ROUTINE) + ctx.ys_true = CtxVar([], LIFECYCLE.ROUTINE) + ctx.ys_prob = CtxVar([], LIFECYCLE.ROUTINE) + + def _hook_on_fit_start_calculate_model_size(self, ctx): + if not isinstance(self.ctx.monitor, Monitor): + logger.warning( + f"The trainer {type(self)} does contain a valid monitor, " + f"this may be caused by initializing trainer subclasses " + f"without passing a valid monitor instance." + f"Plz check whether this is you want.") + return + if self.ctx.monitor.total_model_size == 0: + self.ctx.monitor.track_model_size(ctx.models) + + def _hook_on_epoch_start(self, ctx): + # prepare dataloader + if ctx.get("{}_loader".format(ctx.cur_split)) is None: + loader = get_dataloader( + WrapDataset(ctx.get("{}_data".format(ctx.cur_split))), + self.cfg, ctx.cur_split) + setattr(ctx, "{}_loader".format(ctx.cur_split), ReIterator(loader)) + elif not isinstance(ctx.get("{}_loader".format(ctx.cur_split)), + ReIterator): + setattr(ctx, "{}_loader".format(ctx.cur_split), + ReIterator(ctx.get("{}_loader".format(ctx.cur_split)))) + else: + ctx.get("{}_loader".format(ctx.cur_split)).reset() + + def _hook_on_batch_start_init(self, ctx): + # prepare data batch + try: + ctx.data_batch = CtxVar( + next(ctx.get("{}_loader".format(ctx.cur_split))), + LIFECYCLE.BATCH) + except StopIteration: + raise StopIteration + + def _hook_on_batch_forward(self, ctx): + x, label = [_.to(ctx.device) for _ in ctx.data_batch] + pred = ctx.model(x) + if len(label.size()) == 0: + label = label.unsqueeze(0) + + ctx.y_true = CtxVar(label, LIFECYCLE.BATCH) + ctx.y_prob = CtxVar(pred, LIFECYCLE.BATCH) + ctx.loss_batch = CtxVar(ctx.criterion(pred, label), LIFECYCLE.BATCH) + ctx.batch_size = CtxVar(len(label), LIFECYCLE.BATCH) + + def _hook_on_batch_forward_flop_count(self, ctx): + """ + the monitoring hook to calculate the flops during the fl course + + Note: for customized cases that the forward process is not only + based on ctx.model, please override this function (inheritance + case) or replace this hook (plug-in case) + + :param ctx: + :return: + """ + if not isinstance(self.ctx.monitor, Monitor): + logger.warning( + f"The trainer {type(self)} does contain a valid monitor, " + f"this may be caused by initializing trainer subclasses " + f"without passing a valid monitor instance." + f"Plz check whether this is you want.") + return + + if self.cfg.eval.count_flops and self.ctx.monitor.flops_per_sample \ + == 0: + # calculate the flops_per_sample + try: + x, y = [_.to(ctx.device) for _ in ctx.data_batch] + from fvcore.nn import FlopCountAnalysis + flops_one_batch = FlopCountAnalysis(ctx.model, x).total() + if self.model_nums > 1 and ctx.mirrored_models: + flops_one_batch *= self.model_nums + logger.warning( + "the flops_per_batch is multiplied " + "by internal model nums as self.mirrored_models=True." + "if this is not the case you want, " + "please customize the count hook") + self.ctx.monitor.track_avg_flops(flops_one_batch, + ctx.batch_size) + except: + logger.warning( + "current flop count implementation is for general " + "trainer case: " + "1) ctx.data_batch = [x, y]; and" + "2) the ctx.model takes only x as input." + "Please check the forward format or implement your own " + "flop_count function") + self.ctx.monitor.flops_per_sample = -1 # warning at the + # first failure + + # by default, we assume the data has the same input shape, + # thus simply multiply the flops to avoid redundant forward + self.ctx.monitor.total_flops +=\ + self.ctx.monitor.flops_per_sample * ctx.batch_size + + def _hook_on_batch_forward_regularizer(self, ctx): + ctx.loss_regular = CtxVar( + self.cfg.regularizer.mu * ctx.regularizer(ctx), LIFECYCLE.BATCH) + ctx.loss_task = CtxVar(ctx.loss_batch + ctx.loss_regular, + LIFECYCLE.BATCH) + + def _hook_on_batch_backward(self, ctx): + ctx.optimizer.zero_grad() + ctx.loss_task.backward() + if ctx.grad_clip > 0: + torch.nn.utils.clip_grad_norm_(ctx.model.parameters(), + ctx.grad_clip) + + ctx.optimizer.step() + if ctx.scheduler is not None: + ctx.scheduler.step() + + def _hook_on_batch_end(self, ctx): + # update statistics + ctx.num_samples += ctx.batch_size + ctx.loss_batch_total += ctx.loss_batch.item() * ctx.batch_size + ctx.loss_regular_total += float(ctx.get("loss_regular", 0.)) + # cache label for evaluate + ctx.ys_true.append(ctx.y_true.detach().cpu().numpy()) + ctx.ys_prob.append(ctx.y_prob.detach().cpu().numpy()) + + def _hook_on_fit_end(self, ctx): + """Evaluate metrics. + + """ + ctx.ys_true = CtxVar(np.concatenate(ctx.ys_true), LIFECYCLE.ROUTINE) + ctx.ys_prob = CtxVar(np.concatenate(ctx.ys_prob), LIFECYCLE.ROUTINE) + results = self.metric_calculator.eval(ctx) + setattr(ctx, 'eval_metrics', results) + + def save_model(self, path, cur_round=-1): + assert self.ctx.model is not None + + ckpt = {'cur_round': cur_round, 'model': self.ctx.model.state_dict()} + torch.save(ckpt, path) + + def load_model(self, path): + assert self.ctx.model is not None + + if os.path.exists(path): + ckpt = torch.load(path, map_location=self.ctx.device) + self.ctx.model.load_state_dict(ckpt['model']) + return ckpt['cur_round'] + else: + raise ValueError("The file {} does NOT exist".format(path)) + + def discharge_model(self): + """Discharge the model from GPU device + + """ + # Avoid memory leak + if not self.cfg.federate.share_local_model: + if torch is None: + pass + else: + self.ctx.model.to(torch.device("cpu")) diff --git a/fgssl/core/trainers/trainer.py b/fgssl/core/trainers/trainer.py new file mode 100644 index 0000000..409b663 --- /dev/null +++ b/fgssl/core/trainers/trainer.py @@ -0,0 +1,389 @@ +import collections +import copy +import logging + +from federatedscope.core.trainers.base_trainer import BaseTrainer +from federatedscope.core.auxiliaries.enums import MODE +from federatedscope.core.auxiliaries.enums import LIFECYCLE +from federatedscope.core.auxiliaries.decorators import use_diff +from federatedscope.core.auxiliaries.utils import format_log_hooks +from federatedscope.core.auxiliaries.utils import filter_by_specified_keywords +from federatedscope.core.trainers.context import Context +from federatedscope.core.trainers.context import CtxVar +from federatedscope.core.trainers.context import lifecycle +from federatedscope.core.monitors.metric_calculator import MetricCalculator + +try: + import torch + from torch.utils.data import DataLoader, Dataset +except ImportError: + torch = None + DataLoader = None + Dataset = None + +logger = logging.getLogger(__name__) + + +class Trainer(BaseTrainer): + """ + Register, organize and run the train/test/val procedures + """ + + HOOK_TRIGGER = [ + "on_fit_start", "on_epoch_start", "on_batch_start", "on_batch_forward", + "on_batch_backward", "on_batch_end", "on_epoch_end", "on_fit_end" + ] + + def __init__(self, + model, + data, + device, + config, + only_for_eval=False, + monitor=None): + self.cfg = config + self.metric_calculator = MetricCalculator(config.eval.metrics) + + self.ctx = Context(model, + self.cfg, + data, + device, + init_dict=self.parse_data(data)) + + if monitor is None: + logger.warning( + f"Will not use monitor in trainer with class {type(self)}") + self.ctx.monitor = monitor + # the "model_nums", and "models" are used for multi-model case and + # model size calculation + self.model_nums = 1 + self.ctx.models = [model] + # "mirrored_models": whether the internal multi-models adopt the + # same architects and almost the same behaviors, + # which is used to simply the flops, model size calculation + self.ctx.mirrored_models = False + + # Atomic operation during training/evaluation + self.hooks_in_train = collections.defaultdict(list) + + # By default, use the same trigger keys + self.hooks_in_eval = copy.deepcopy(self.hooks_in_train) + self.hooks_in_ft = copy.deepcopy(self.hooks_in_train) + + # register necessary hooks into self.hooks_in_train and + # self.hooks_in_eval + if not only_for_eval: + self.register_default_hooks_train() + if self.cfg.finetune.before_eval: + self.register_default_hooks_ft() + self.register_default_hooks_eval() + + if self.cfg.federate.mode == 'distributed': + self.print_trainer_meta_info() + else: + # in standalone mode, by default, we print the trainer info only + # once for better logs readability + pass + + def parse_data(self, data): + pass + + def register_default_hooks_train(self): + pass + + def register_default_hooks_eval(self): + pass + + def register_default_hooks_ft(self): + pass + + def reset_hook_in_train(self, target_trigger, target_hook_name=None): + hooks_dict = self.hooks_in_train + del_one_hook_idx = self._reset_hook_in_trigger(hooks_dict, + target_hook_name, + target_trigger) + return del_one_hook_idx + + def reset_hook_in_eval(self, target_trigger, target_hook_name=None): + hooks_dict = self.hooks_in_eval + del_one_hook_idx = self._reset_hook_in_trigger(hooks_dict, + target_hook_name, + target_trigger) + return del_one_hook_idx + + def replace_hook_in_train(self, new_hook, target_trigger, + target_hook_name): + del_one_hook_idx = self.reset_hook_in_train( + target_trigger=target_trigger, target_hook_name=target_hook_name) + self.register_hook_in_train(new_hook=new_hook, + trigger=target_trigger, + insert_pos=del_one_hook_idx) + + def replace_hook_in_eval(self, new_hook, target_trigger, target_hook_name): + del_one_hook_idx = self.reset_hook_in_eval( + target_trigger=target_trigger, target_hook_name=target_hook_name) + self.register_hook_in_eval(new_hook=new_hook, + trigger=target_trigger, + insert_pos=del_one_hook_idx) + + def _reset_hook_in_trigger(self, hooks_dict, target_hook_name, + target_trigger): + # clean/delete existing hooks for a specific trigger, + # if target_hook_name given, will clean only the specific one; + # otherwise, will clean all hooks for the trigger. + assert target_trigger in self.HOOK_TRIGGER, \ + f"Got {target_trigger} as hook trigger, you should specify a " \ + f"string within {self.HOOK_TRIGGER}." + del_one_hook_idx = None + if target_hook_name is None: + hooks_dict[target_trigger] = [] + del_one_hook_idx = -1 # -1 indicates del the whole list + else: + for hook_idx in range(len(hooks_dict[target_trigger])): + if target_hook_name == hooks_dict[target_trigger][ + hook_idx].__name__: + del_one = hooks_dict[target_trigger].pop(hook_idx) + logger.info(f"Remove the hook `{del_one.__name__}` from " + f"hooks_set at trigger `{target_trigger}`") + del_one_hook_idx = hook_idx + break + if del_one_hook_idx is None: + logger.warning( + f"In hook del procedure, can't find the target hook " + f"named {target_hook_name}") + return del_one_hook_idx + + def register_hook_in_train(self, + new_hook, + trigger, + insert_pos=None, + base_hook=None, + insert_mode="before"): + hooks_dict = self.hooks_in_train + self._register_hook(base_hook, hooks_dict, insert_mode, insert_pos, + new_hook, trigger) + + def register_hook_in_ft(self, + new_hook, + trigger, + insert_pos=None, + base_hook=None, + insert_mode="before"): + hooks_dict = self.hooks_in_ft + self._register_hook(base_hook, hooks_dict, insert_mode, insert_pos, + new_hook, trigger) + + def register_hook_in_eval(self, + new_hook, + trigger, + insert_pos=None, + base_hook=None, + insert_mode="before"): + hooks_dict = self.hooks_in_eval + self._register_hook(base_hook, hooks_dict, insert_mode, insert_pos, + new_hook, trigger) + + def _register_hook(self, base_hook, hooks_dict, insert_mode, insert_pos, + new_hook, trigger): + assert trigger in self.HOOK_TRIGGER, \ + f"Got {trigger} as hook trigger, you should specify a string " \ + f"within {self.HOOK_TRIGGER}." + # parse the insertion position + target_hook_set = hooks_dict[trigger] + if insert_pos is not None: + assert (insert_pos == -1) or (insert_pos == len(target_hook_set) + == 0) or \ + (0 <= insert_pos <= (len(target_hook_set))), \ + f"Got {insert_pos} as insert pos, you should specify a " \ + f"integer (1) =-1 " \ + f"or (2) =0 for null target_hook_set;" \ + f"or (3) within [0, {(len(target_hook_set))}]." + elif base_hook is not None: + base_hook_pos = target_hook_set.index(base_hook) + insert_pos = base_hook_pos - 1 if insert_mode == "before" else \ + base_hook_pos + 1 + # bounding the insert_pos in rational range + insert_pos = 0 if insert_pos < 0 else insert_pos + insert_pos = -1 if insert_pos > len( + target_hook_set) else insert_pos + else: + insert_pos = -1 # By default, the new hook is called finally + # register the new hook + if insert_pos == -1: + hooks_dict[trigger].append(new_hook) + else: + hooks_dict[trigger].insert(insert_pos, new_hook) + + @use_diff + def train(self, target_data_split_name="train", hooks_set=None): + hooks_set = hooks_set or self.hooks_in_train + + self.ctx.check_split(target_data_split_name) + + num_samples = self._run_routine(MODE.TRAIN, hooks_set, + target_data_split_name) + + return num_samples, self.get_model_para(), self.ctx.eval_metrics + + def evaluate(self, target_data_split_name="test", hooks_set=None): + hooks_set = hooks_set or self.hooks_in_eval + + if self.ctx.check_split(target_data_split_name, skip=True): + self._run_routine(MODE.TEST, hooks_set, target_data_split_name) + else: + self.ctx.eval_metrics = dict() + + return self.ctx.eval_metrics + + def finetune(self, target_data_split_name="train", hooks_set=None): + hooks_set = hooks_set or self.hooks_in_ft + + self.ctx.check_split(target_data_split_name) + + self._run_routine(MODE.FINETUNE, hooks_set, target_data_split_name) + + @lifecycle(LIFECYCLE.ROUTINE) + def _run_routine(self, mode, hooks_set, dataset_name=None): + """Run the hooks_set and maintain the mode + Arguments: + mode: running mode of client, chosen from train/val/test + Note: + Considering evaluation could be in ```hooks_set["on_epoch_end"]```, + there could be two data loaders in self.ctx, we must tell the + running hooks which data_loader to call and which + num_samples to count + """ + for hook in hooks_set["on_fit_start"]: + hook(self.ctx) + + self._run_epoch(hooks_set) + + for hook in hooks_set["on_fit_end"]: + hook(self.ctx) + + return self.ctx.num_samples + + @lifecycle(LIFECYCLE.EPOCH) + def _run_epoch(self, hooks_set): + for epoch_i in range(self.ctx.get(f"num_{self.ctx.cur_split}_epoch")): + self.ctx.cur_epoch_i = CtxVar(epoch_i, "epoch") + + for hook in hooks_set["on_epoch_start"]: + hook(self.ctx) + + self._run_batch(hooks_set) + + for hook in hooks_set["on_epoch_end"]: + hook(self.ctx) + + @lifecycle(LIFECYCLE.BATCH) + def _run_batch(self, hooks_set): + for batch_i in range(self.ctx.get(f"num_{self.ctx.cur_split}_batch")): + self.ctx.cur_batch_i = CtxVar(batch_i, LIFECYCLE.BATCH) + + for hook in hooks_set["on_batch_start"]: + hook(self.ctx) + + for hook in hooks_set["on_batch_forward"]: + hook(self.ctx) + + for hook in hooks_set["on_batch_backward"]: + hook(self.ctx) + + for hook in hooks_set["on_batch_end"]: + hook(self.ctx) + + # Break in the final epoch + if self.ctx.cur_mode in [ + MODE.TRAIN, MODE.FINETUNE + ] and self.ctx.cur_epoch_i == self.ctx.num_train_epoch - 1: + if batch_i >= self.ctx.num_train_batch_last_epoch - 1: + break + + def update(self, model_parameters, strict=False): + ''' + Called by the FL client to update the model parameters + Arguments: + model_parameters (dict): {model_name: model_val} + strict (bool): ensure the k-v paris are strictly same + ''' + pass + + def get_model_para(self): + ''' + + :return: model_parameters (dict): {model_name: model_val} + ''' + pass + + def print_trainer_meta_info(self): + ''' + print some meta info for code-users, e.g., model type; the para + names will be filtered out, etc., + ''' + logger.info(f"Model meta-info: {type(self.ctx.model)}.") + logger.debug(f"Model meta-info: {self.ctx.model}.") + # logger.info(f"Data meta-info: {self.ctx['data']}.") + + ori_para_names = set(self.ctx.model.state_dict().keys()) + preserved_paras = self._param_filter(self.ctx.model.state_dict()) + preserved_para_names = set(preserved_paras.keys()) + filtered_para_names = ori_para_names - preserved_para_names + logger.info(f"Num of original para names: {len(ori_para_names)}.") + logger.info(f"Num of original trainable para names:" + f" {len(self.ctx['trainable_para_names'])}.") + logger.info( + f"Num of preserved para names in local update:" + f" {len(preserved_para_names)}. \n" + f"Preserved para names in local update: {preserved_para_names}.") + logger.info( + f"Num of filtered para names in local update:" + f" {len(filtered_para_names)}. \n" + f"Filtered para names in local update: {filtered_para_names}.") + + logger.info(f"After register default hooks,\n" + f"\tthe hooks_in_train is:\n\t" + f"{format_log_hooks(self.hooks_in_train)};\n" + f"\tthe hooks_in_eval is:\n\ + t{format_log_hooks(self.hooks_in_eval)}") + + def _param_filter(self, state_dict, filter_keywords=None): + ''' + model parameter filter when transmit between local and gloabl, + which is useful in personalization. + e.g., setting cfg.personalization.local_param= ['bn', 'norms'] + indicates the implementation of + "FedBN: Federated Learning on Non-IID Features via Local Batch + Normalization, ICML2021", which can be found in + https://openreview.net/forum?id=6YEQUn0QICG + + Arguments: + state_dict (dict): PyTorch Module object's state_dict. + Returns: + state_dict (dict): remove the keys that match any of the given + keywords. + ''' + if self.cfg.federate.method in ["local", "global"]: + return {} + + if filter_keywords is None: + filter_keywords = self.cfg.personalization.local_param + + trainable_filter = lambda p: True if \ + self.cfg.personalization.share_non_trainable_para else \ + lambda p: p in self.ctx.trainable_para_names + keyword_filter = filter_by_specified_keywords + return dict( + filter( + lambda elem: trainable_filter(elem[1]) and keyword_filter( + elem[0], filter_keywords), state_dict.items())) + + def save_model(self, path, cur_round=-1): + raise NotImplementedError( + "The function `save_model` should be implemented according to " + "the ML backend (Pytorch, Tensorflow ...).") + + def load_model(self, path): + raise NotImplementedError( + "The function `load_model` should be implemented according to " + "the ML backend (Pytorch, Tensorflow ...).") diff --git a/fgssl/core/trainers/trainer_Ditto.py b/fgssl/core/trainers/trainer_Ditto.py new file mode 100644 index 0000000..e4e4a02 --- /dev/null +++ b/fgssl/core/trainers/trainer_Ditto.py @@ -0,0 +1,219 @@ +import copy +import logging + +import torch + +from federatedscope.core.auxiliaries.optimizer_builder import get_optimizer +from federatedscope.core.trainers.torch_trainer import GeneralTorchTrainer +from federatedscope.core.optimizer import wrap_regularized_optimizer +from federatedscope.core.auxiliaries.utils import calculate_batch_epoch_num +from typing import Type + +logger = logging.getLogger(__name__) + +DEBUG_DITTO = False + + +def wrap_DittoTrainer( + base_trainer: Type[GeneralTorchTrainer]) -> Type[GeneralTorchTrainer]: + """ + Build a `DittoTrainer` with a plug-in manner, by registering new + functions into specific `BaseTrainer` + + The Ditto implementation, "Ditto: Fair and Robust Federated Learning + Through Personalization. (ICML2021)" + based on the Algorithm 2 in their paper and official codes: + https://github.com/litian96/ditto + """ + + # ---------------- attribute-level plug-in ----------------------- + init_Ditto_ctx(base_trainer) + + # ---------------- action-level plug-in ----------------------- + base_trainer.register_hook_in_train(new_hook=_hook_on_fit_start_clean, + trigger='on_fit_start', + insert_pos=-1) + base_trainer.register_hook_in_train( + new_hook=hook_on_fit_start_set_regularized_para, + trigger="on_fit_start", + insert_pos=0) + base_trainer.register_hook_in_train( + new_hook=hook_on_batch_start_switch_model, + trigger="on_batch_start", + insert_pos=0) + base_trainer.register_hook_in_train(new_hook=hook_on_batch_forward_cnt_num, + trigger="on_batch_forward", + insert_pos=-1) + base_trainer.register_hook_in_train(new_hook=_hook_on_batch_end_flop_count, + trigger="on_batch_end", + insert_pos=-1) + base_trainer.register_hook_in_train(new_hook=_hook_on_fit_end_calibrate, + trigger='on_fit_end', + insert_pos=-1) + # evaluation is based on the local personalized model + base_trainer.register_hook_in_eval( + new_hook=hook_on_fit_start_switch_local_model, + trigger="on_fit_start", + insert_pos=0) + base_trainer.register_hook_in_eval( + new_hook=hook_on_fit_end_switch_global_model, + trigger="on_fit_end", + insert_pos=-1) + + base_trainer.register_hook_in_train(new_hook=hook_on_fit_end_free_cuda, + trigger="on_fit_end", + insert_pos=-1) + base_trainer.register_hook_in_eval(new_hook=hook_on_fit_end_free_cuda, + trigger="on_fit_end", + insert_pos=-1) + + return base_trainer + + +def init_Ditto_ctx(base_trainer): + """ + init necessary attributes used in Ditto, + `global_model` acts as the shared global model in FedAvg; + `local_model` acts as personalized model will be optimized with + regularization based on weights of `global_model` + + """ + ctx = base_trainer.ctx + cfg = base_trainer.cfg + + ctx.global_model = copy.deepcopy(ctx.model) + ctx.local_model = copy.deepcopy(ctx.model) # the personalized model + ctx.models = [ctx.local_model, ctx.global_model] + + ctx.model = ctx.global_model + ctx.use_local_model_current = False + + ctx.num_samples_local_model_train = 0 + + # track the batch_num, epoch_num, for local & global model respectively + cfg_p_local_update_steps = cfg.personalization.local_update_steps + ctx.num_train_batch_for_local_model, \ + ctx.num_train_batch_last_epoch_for_local_model, \ + ctx.num_train_epoch_for_local_model, \ + ctx.num_total_train_batch = \ + calculate_batch_epoch_num(cfg_p_local_update_steps, + cfg.train.batch_or_epoch, + ctx.num_train_data, + cfg.dataloader.batch_size, + cfg.dataloader.drop_last) + + # In the first + # 1. `num_train_batch` and `num_train_batch_last_epoch` + # (batch_or_epoch == 'batch' case) or + # 2. `num_train_epoch`, + # (batch_or_epoch == 'epoch' case) + # we will manipulate local models, and manipulate global model in the + # remaining steps + if cfg.train.batch_or_epoch == 'batch': + ctx.num_train_batch += ctx.num_train_batch_for_local_model + ctx.num_train_batch_last_epoch += \ + ctx.num_train_batch_last_epoch_for_local_model + else: + ctx.num_train_epoch += ctx.num_train_epoch_for_local_model + + +def hook_on_fit_start_set_regularized_para(ctx): + # set the compared model data for local personalized model + ctx.global_model.to(ctx.device) + ctx.local_model.to(ctx.device) + ctx.global_model.train() + ctx.local_model.train() + compared_global_model_para = [{ + "params": list(ctx.global_model.parameters()) + }] + + ctx.optimizer_for_global_model = get_optimizer(ctx.global_model, + **ctx.cfg.train.optimizer) + ctx.optimizer_for_local_model = get_optimizer(ctx.local_model, + **ctx.cfg.train.optimizer) + + ctx.optimizer_for_local_model = wrap_regularized_optimizer( + ctx.optimizer_for_local_model, ctx.cfg.personalization.regular_weight) + + ctx.optimizer_for_local_model.set_compared_para_group( + compared_global_model_para) + + +def _hook_on_fit_start_clean(ctx): + # remove the unnecessary optimizer + del ctx.optimizer + ctx.num_samples_local_model_train = 0 + + +def _hook_on_fit_end_calibrate(ctx): + # make the num_samples_train only related to the global model. + # (num_samples_train will be used in aggregation process) + ctx.num_samples -= ctx.num_samples_local_model_train + ctx.eval_metrics['train_total'] = ctx.num_samples + ctx.eval_metrics['train_total_local_model'] = \ + ctx.num_samples_local_model_train + + +def _hook_on_batch_end_flop_count(ctx): + # besides the normal forward flops, the regularization adds the cost of + # number of model parameters + ctx.monitor.total_flops += ctx.monitor.total_model_size / 2 + + +def hook_on_batch_forward_cnt_num(ctx): + if ctx.use_local_model_current: + ctx.num_samples_local_model_train += ctx.batch_size + + +def hook_on_batch_start_switch_model(ctx): + if ctx.cfg.train.batch_or_epoch == 'batch': + if ctx.cur_epoch_i == (ctx.num_train_epoch - 1): + ctx.use_local_model_current = \ + ctx.cur_batch_i < \ + ctx.num_train_batch_last_epoch_for_local_model + else: + ctx.use_local_model_current = \ + ctx.cur_batch_i < ctx.num_train_batch_for_local_model + else: + ctx.use_local_model_current = \ + ctx.cur_epoch_i < ctx.num_train_epoch_for_local_model + + if DEBUG_DITTO: + logger.info("====================================================") + logger.info(f"cur_epoch_i: {ctx.cur_epoch_i}") + logger.info(f"num_train_epoch: {ctx.num_train_epoch}") + logger.info(f"cur_batch_i: {ctx.cur_batch_i}") + logger.info(f"num_train_batch: {ctx.num_train_batch}") + logger.info(f"num_train_batch_for_local_model: " + f"{ctx.num_train_batch_for_local_model}") + logger.info(f"num_train_epoch_for_local_model: " + f"{ctx.num_train_epoch_for_local_model}") + logger.info(f"use_local_model: {ctx.use_local_model_current}") + + if ctx.use_local_model_current: + ctx.model = ctx.local_model + ctx.optimizer = ctx.optimizer_for_local_model + else: + ctx.model = ctx.global_model + ctx.optimizer = ctx.optimizer_for_global_model + + +# Note that Ditto only updates the para of global_model received from other +# FL participants, and in the remaining steps, ctx.model has been = +# ctx.global_model, thus we do not need register the following hook +# def hook_on_fit_end_link_global_model(ctx): +# ctx.model = ctx.global_model + + +def hook_on_fit_start_switch_local_model(ctx): + ctx.model = ctx.local_model + ctx.model.eval() + + +def hook_on_fit_end_switch_global_model(ctx): + ctx.model = ctx.global_model + + +def hook_on_fit_end_free_cuda(ctx): + ctx.global_model.to(torch.device("cpu")) + ctx.local_model.to(torch.device("cpu")) diff --git a/fgssl/core/trainers/trainer_FedEM.py b/fgssl/core/trainers/trainer_FedEM.py new file mode 100644 index 0000000..e6da6df --- /dev/null +++ b/fgssl/core/trainers/trainer_FedEM.py @@ -0,0 +1,169 @@ +from typing import Type + +import numpy as np +import torch +from torch.nn.functional import softmax as f_softmax + +from federatedscope.core.auxiliaries.enums import LIFECYCLE +from federatedscope.core.trainers.context import CtxVar +from federatedscope.core.trainers.torch_trainer import GeneralTorchTrainer +from federatedscope.core.trainers.trainer_multi_model import \ + GeneralMultiModelTrainer + + +class FedEMTrainer(GeneralMultiModelTrainer): + """ + The FedEM implementation, "Federated Multi-Task Learning under a + Mixture of Distributions (NeurIPS 2021)" + based on the Algorithm 1 in their paper and official codes: + https://github.com/omarfoq/FedEM + """ + def __init__(self, + model_nums, + models_interact_mode="sequential", + model=None, + data=None, + device=None, + config=None, + base_trainer: Type[GeneralTorchTrainer] = None): + super(FedEMTrainer, + self).__init__(model_nums, models_interact_mode, model, data, + device, config, base_trainer) + device = self.ctx.device + + # --------------- attribute-level modifications ---------------------- + # used to mixture the internal models + self.weights_internal_models = (torch.ones(self.model_nums) / + self.model_nums).to(device) + self.weights_data_sample = ( + torch.ones(self.model_nums, self.ctx.num_train_batch) / + self.model_nums).to(device) + + self.ctx.all_losses_model_batch = torch.zeros( + self.model_nums, self.ctx.num_train_batch).to(device) + self.ctx.cur_batch_idx = -1 + # `ctx[f"{cur_data}_y_prob_ensemble"] = 0` in + # func `_hook_on_fit_end_ensemble_eval` + # -> self.ctx.test_y_prob_ensemble = 0 + # -> self.ctx.train_y_prob_ensemble = 0 + # -> self.ctx.val_y_prob_ensemble = 0 + + # ---------------- action-level modifications ----------------------- + # see register_multiple_model_hooks(), + # which is called in the __init__ of `GeneralMultiModelTrainer` + + def register_multiple_model_hooks(self): + """ + customized multiple_model_hooks, which is called + in the __init__ of `GeneralMultiModelTrainer` + """ + # First register hooks for model 0 + # ---------------- train hooks ----------------------- + self.register_hook_in_train( + new_hook=self.hook_on_fit_start_mixture_weights_update, + trigger="on_fit_start", + insert_pos=0) # insert at the front + self.register_hook_in_train( + new_hook=self._hook_on_fit_start_flop_count, + trigger="on_fit_start", + insert_pos=1 # follow the mixture operation + ) + self.register_hook_in_train(new_hook=self._hook_on_fit_end_flop_count, + trigger="on_fit_end", + insert_pos=-1) + self.register_hook_in_train( + new_hook=self.hook_on_batch_forward_weighted_loss, + trigger="on_batch_forward", + insert_pos=-1) + self.register_hook_in_train( + new_hook=self.hook_on_batch_start_track_batch_idx, + trigger="on_batch_start", + insert_pos=0) # insert at the front + # ---------------- eval hooks ----------------------- + self.register_hook_in_eval( + new_hook=self.hook_on_batch_end_gather_loss, + trigger="on_batch_end", + insert_pos=0 + ) # insert at the front, (we need gather the loss before clean it) + self.register_hook_in_eval( + new_hook=self.hook_on_batch_start_track_batch_idx, + trigger="on_batch_start", + insert_pos=0) # insert at the front + # replace the original evaluation into the ensemble one + self.replace_hook_in_eval(new_hook=self._hook_on_fit_end_ensemble_eval, + target_trigger="on_fit_end", + target_hook_name="_hook_on_fit_end") + + # Then for other models, set the same hooks as model 0 + # since we differentiate different models in the hook + # implementations via ctx.cur_model_idx + self.hooks_in_train_multiple_models.extend([ + self.hooks_in_train_multiple_models[0] + for _ in range(1, self.model_nums) + ]) + self.hooks_in_eval_multiple_models.extend([ + self.hooks_in_eval_multiple_models[0] + for _ in range(1, self.model_nums) + ]) + + def hook_on_batch_start_track_batch_idx(self, ctx): + # for both train & eval + ctx.cur_batch_idx = (self.ctx.cur_batch_idx + + 1) % self.ctx.num_train_batch + + def hook_on_batch_forward_weighted_loss(self, ctx): + # for only train + ctx.loss_batch *= self.weights_internal_models[ctx.cur_model_idx] + + def hook_on_batch_end_gather_loss(self, ctx): + # for only eval + # before clean the loss_batch; we record it + # for further weights_data_sample update + ctx.all_losses_model_batch[ctx.cur_model_idx][ + ctx.cur_batch_idx] = ctx.loss_batch.item() + + def hook_on_fit_start_mixture_weights_update(self, ctx): + # for only train + if ctx.cur_model_idx != 0: + # do the mixture_weights_update once + pass + else: + # gathers losses for all sample in iterator + # for each internal model, calling *evaluate()* + for model_idx in range(self.model_nums): + self._switch_model_ctx(model_idx) + self.evaluate(target_data_split_name="train") + + self.weights_data_sample = f_softmax( + (torch.log(self.weights_internal_models) - + ctx.all_losses_model_batch.T), + dim=1).T + self.weights_internal_models = self.weights_data_sample.mean(dim=1) + + # restore the model_ctx + self._switch_model_ctx(0) + + def _hook_on_fit_start_flop_count(self, ctx): + self.ctx.monitor.total_flops += self.ctx.monitor.flops_per_sample * \ + self.model_nums * ctx.num_train_data + + def _hook_on_fit_end_flop_count(self, ctx): + self.ctx.monitor.total_flops += self.ctx.monitor.flops_per_sample * \ + self.model_nums * ctx.num_train_data + + def _hook_on_fit_end_ensemble_eval(self, ctx): + """ + Ensemble evaluation + """ + if ctx.get("ys_prob_ensemble", None) is None: + ctx.ys_prob_ensemble = CtxVar(0, LIFECYCLE.ROUTINE) + ctx.ys_prob_ensemble += np.concatenate( + ctx.ys_prob) * self.weights_internal_models[ + ctx.cur_model_idx].item() + + # do metrics calculation after the last internal model evaluation done + if ctx.cur_model_idx == self.model_nums - 1: + ctx.ys_true = CtxVar(np.concatenate(ctx.ys_true), + LIFECYCLE.ROUTINE) + ctx.ys_prob = ctx.ys_prob_ensemble + ctx.eval_metrics = self.metric_calculator.eval(ctx) diff --git a/fgssl/core/trainers/trainer_fedprox.py b/fgssl/core/trainers/trainer_fedprox.py new file mode 100644 index 0000000..89e02da --- /dev/null +++ b/fgssl/core/trainers/trainer_fedprox.py @@ -0,0 +1,73 @@ +from federatedscope.core.trainers.torch_trainer import GeneralTorchTrainer +from typing import Type +from copy import deepcopy + + +def wrap_fedprox_trainer( + base_trainer: Type[GeneralTorchTrainer]) -> Type[GeneralTorchTrainer]: + """Implementation of fedprox refer to `Federated Optimization in + Heterogeneous Networks` [Tian Li, et al., 2020] + (https://proceedings.mlsys.org/paper/2020/ \ + file/38af86134b65d0f10fe33d30dd76442e-Paper.pdf) + + """ + + # ---------------- attribute-level plug-in ----------------------- + init_fedprox_ctx(base_trainer) + + # ---------------- action-level plug-in ----------------------- + base_trainer.register_hook_in_train(new_hook=record_initialization, + trigger='on_fit_start', + insert_pos=-1) + + base_trainer.register_hook_in_eval(new_hook=record_initialization, + trigger='on_fit_start', + insert_pos=-1) + + base_trainer.register_hook_in_train(new_hook=del_initialization, + trigger='on_fit_end', + insert_pos=-1) + + base_trainer.register_hook_in_eval(new_hook=del_initialization, + trigger='on_fit_end', + insert_pos=-1) + + return base_trainer + + +def init_fedprox_ctx(base_trainer): + """Set proximal regularizer and the factor of regularizer + + """ + ctx = base_trainer.ctx + cfg = base_trainer.cfg + + cfg.defrost() + cfg.regularizer.type = 'proximal_regularizer' + cfg.regularizer.mu = cfg.fedprox.mu + cfg.freeze() + + from federatedscope.core.auxiliaries.regularizer_builder import \ + get_regularizer + ctx.regularizer = get_regularizer(cfg.regularizer.type) + + +# ---------------------------------------------------------------------- # +# Additional functions for FedProx algorithm +# ---------------------------------------------------------------------- # + + +# Trainer +def record_initialization(ctx): + """Record the initialized weights within local updates + + """ + ctx.weight_init = deepcopy( + [_.data.detach() for _ in ctx.model.parameters()]) + + +def del_initialization(ctx): + """Clear the variable to avoid memory leakage + + """ + ctx.weight_init = None diff --git a/fgssl/core/trainers/trainer_multi_model.py b/fgssl/core/trainers/trainer_multi_model.py new file mode 100644 index 0000000..b2fe60c --- /dev/null +++ b/fgssl/core/trainers/trainer_multi_model.py @@ -0,0 +1,313 @@ +import copy +from types import FunctionType +from typing import Type + +from federatedscope.core.auxiliaries.optimizer_builder import get_optimizer +from federatedscope.core.trainers.torch_trainer import GeneralTorchTrainer + +import numpy as np + + +class GeneralMultiModelTrainer(GeneralTorchTrainer): + def __init__(self, + model_nums, + models_interact_mode="sequential", + model=None, + data=None, + device=None, + config=None, + base_trainer: Type[GeneralTorchTrainer] = None): + """ + `GeneralMultiModelTrainer` supports train/eval via multiple + internal models + + Arguments: + model_nums (int): how many internal models and optimizers + will be held by the trainer + models_interact_mode (str): how the models interact, can be + "sequential" or "parallel". + model: training model + data: a dict contains train/val/test data + device: device to run + config: for trainer-related configuration + base_trainer: if given, the GeneralMultiModelTrainer init + will based on base_trainer copy + + The sequential mode indicates the interaction at + run_routine level + [one model runs its whole routine, then do sth. for + interaction, then next model runs its whole routine] + ... -> run_routine_model_i + -> _switch_model_ctx + -> (on_fit_end, _interact_to_other_models) + -> run_routine_model_i+1 + -> ... + + The parallel mode indicates the interaction + at point-in-time level + [At a specific point-in-time, one model call hooks ( + including interaction), then next model call hooks] + ... -> (on_xxx_point, hook_xxx_model_i) + -> (on_xxx_point, _interact_to_other_models) + -> (on_xxx_point, _switch_model_ctx) + -> (on_xxx_point, hook_xxx_model_i+1) + -> ... + + """ + # support two initialization methods for the `GeneralMultiModelTrainer` + # 1) from another trainer; or 2) standard init manner given (model, + # data, device, config) + if base_trainer is None: + assert model is not None and \ + data is not None and \ + device is not None and \ + config is not None, "when not copy construction, (model, " \ + "data, device, config) should not be " \ + "None" + super(GeneralMultiModelTrainer, + self).__init__(model, data, device, config) + else: + assert isinstance(base_trainer, GeneralMultiModelTrainer) or \ + issubclass(type(base_trainer), GeneralMultiModelTrainer) \ + or isinstance(base_trainer, GeneralTorchTrainer) or \ + issubclass(type(base_trainer), GeneralTorchTrainer) or \ + "can only copy instances of `GeneralMultiModelTrainer` " \ + "and its subclasses, or " \ + "`GeneralTorchTrainer` and its subclasses" + self.__dict__ = copy.deepcopy(base_trainer.__dict__) + + assert models_interact_mode in ["sequential", "parallel"], \ + f"Invalid models_interact_mode, should be `sequential` or " \ + f"`parallel`, but got {models_interact_mode}" + self.models_interact_mode = models_interact_mode + + if int(model_nums) != model_nums or model_nums < 1: + raise ValueError( + f"model_nums should be integer and >= 1, got {model_nums}.") + self.model_nums = model_nums + + self.ctx.cur_model_idx = 0 # used to mark cur model + + # different internal models can have different hook_set + self.hooks_in_train_multiple_models = [self.hooks_in_train] + self.hooks_in_eval_multiple_models = [self.hooks_in_eval] + self.init_multiple_models() + self.init_multiple_model_hooks() + assert len(self.ctx.models) == model_nums == \ + len(self.hooks_in_train_multiple_models) == len( + self.hooks_in_eval_multiple_models),\ + "After init, len(hooks_in_train_multiple_models), " \ + "len(hooks_in_eval_multiple_models), " \ + "len(ctx.models) and model_nums should be the same" + + def init_multiple_models(self): + """ + init multiple models and optimizers: the default implementation + is copy init manner; + ========================= Extension ============================= + users can override this function according to their own + requirements + """ + + additional_models = [ + copy.deepcopy(self.ctx.model) for _ in range(self.model_nums - 1) + ] + self.ctx.models = [self.ctx.model] + additional_models + + self.ctx.optimizers = [ + get_optimizer(self.ctx.models[i], **self.cfg.train.optimizer) + for i in range(0, self.model_nums) + ] + + def register_multiple_model_hooks(self): + """ + By default, all internal models adopt the same hook_set. + ========================= Extension ============================= + Users can override this function to register customized hooks + for different internal models. + + Note: + for sequential mode, users can append interact_hook on + begin/end triggers such as + " -> (on_fit_end, _interact_to_other_models) -> " + + for parallel mode, users can append interact_hook on any + trigger they want such as + " -> (on_xxx_point, _interact_to_other_models) -> " + + self.ctx, we must tell the running hooks which data_loader to + call and which num_samples to count + """ + + self.hooks_in_train_multiple_models.extend([ + self.hooks_in_train_multiple_models[0] + for _ in range(1, self.model_nums) + ]) + self.hooks_in_eval_multiple_models.extend([ + self.hooks_in_eval_multiple_models[0] + for _ in range(1, self.model_nums) + ]) + + def init_multiple_model_hooks(self): + self.register_multiple_model_hooks() + if self.models_interact_mode == "sequential": + # hooks_in_xxx is a list of dict, hooks_in_xxx[i] stores + # specific set for i-th internal model; + # for each dict, the key indicates point-in-time and the value + # indicates specific hook + self.hooks_in_train = self.hooks_in_train_multiple_models + self.hooks_in_eval = self.hooks_in_eval_multiple_models + elif self.models_interact_mode == "parallel": + # hooks_in_xxx is a dict whose key indicates point-in-time and + # value indicates specific hook + for trigger in list(self.hooks_in_train.keys()): + self.hooks_in_train[trigger] = [] + self.hooks_in_eval[trigger] = [] + for model_idx in range(len(self.ctx.models)): + self.hooks_in_train[trigger].extend( + self.hooks_in_train_multiple_models[model_idx] + [trigger]) + self.hooks_in_train[trigger].extend( + [self._switch_model_ctx]) + self.hooks_in_eval[trigger].extend( + self.hooks_in_eval_multiple_models[model_idx][trigger]) + self.hooks_in_eval[trigger].extend( + [self._switch_model_ctx]) + else: + raise RuntimeError( + f"Invalid models_interact_mode, should be `sequential` or " + f"`parallel`," + f" but got {self.models_interact_mode}") + + def register_hook_in_train(self, + new_hook, + trigger, + model_idx=0, + insert_pos=None, + base_hook=None, + insert_mode="before"): + hooks_dict = self.hooks_in_train_multiple_models[model_idx] + self._register_hook(base_hook, hooks_dict, insert_mode, insert_pos, + new_hook, trigger) + + def register_hook_in_eval(self, + new_hook, + trigger, + model_idx=0, + insert_pos=None, + base_hook=None, + insert_mode="before"): + hooks_dict = self.hooks_in_eval_multiple_models[model_idx] + self._register_hook(base_hook, hooks_dict, insert_mode, insert_pos, + new_hook, trigger) + + def _switch_model_ctx(self, next_model_idx=None): + if next_model_idx is None: + next_model_idx = (self.ctx.cur_model_idx + 1) % len( + self.ctx.models) + self.ctx.cur_model_idx = next_model_idx + self.ctx.model = self.ctx.models[next_model_idx] + self.ctx.optimizer = self.ctx.optimizers[next_model_idx] + + def _run_routine(self, mode, hooks_set, dataset_name=None): + """Run the hooks_set and maintain the mode for multiple internal models + + Arguments: + mode: running mode of client, chosen from train/val/test + + Note: + Considering evaluation could be in ```hooks_set[ + "on_epoch_end"]```, there could be two data loaders in + self.ctx, we must tell the running hooks which data_loader to call + and which num_samples to count + + """ + num_samples_model = list() + if self.models_interact_mode == "sequential": + assert isinstance(hooks_set, list) and isinstance(hooks_set[0], + dict), \ + "When models_interact_mode=sequential, " \ + "hooks_set should be a list of dict" \ + "hooks_set[i] stores specific set for i-th internal model." \ + "For each dict, the key indicates point-in-time and the " \ + "value indicates specific hook" + for model_idx in range(len(self.ctx.models)): + # switch different hooks & ctx for different internal models + hooks_set_model_i = hooks_set[model_idx] + self._switch_model_ctx(model_idx) + # [Interaction at run_routine level] + # one model runs its whole routine, then do sth. for + # interaction, then next model runs its whole routine + # ... -> run_routine_model_i + # -> _switch_model_ctx + # -> (on_fit_end, _interact_to_other_models) + # -> run_routine_model_i+1 + # -> ... + num_samples = super()._run_routine(mode, hooks_set_model_i, + dataset_name) + num_samples_model.append(num_samples) + elif self.models_interact_mode == "parallel": + assert isinstance(hooks_set, dict), \ + "When models_interact_mode=parallel, hooks_set should be a " \ + "dict whose key indicates point-in-time and value indicates " \ + "specific hook" + # [Interaction at point-in-time level] + # at a specific point-in-time, one model call hooks (including + # interaction), then next model call hooks + # ... -> (on_xxx_point, hook_xxx_model_i) + # -> (on_xxx_point, _interact_to_other_models) + # -> (on_xxx_point, _switch_model_ctx) + # -> (on_xxx_point, hook_xxx_model_i+1) + # -> ... + num_samples = super()._run_routine(mode, hooks_set, dataset_name) + num_samples_model.append(num_samples) + else: + raise RuntimeError( + f"Invalid models_interact_mode, should be `sequential` or " + f"`parallel`," + f" but got {self.models_interact_mode}") + # For now, we return the average number of samples for different models + return np.mean(num_samples_model) + + def get_model_para(self): + """ + return multiple model parameters + :return: + """ + trained_model_para = [] + for model_idx in range(self.model_nums): + trained_model_para.append( + self._param_filter( + self.ctx.models[model_idx].cpu().state_dict())) + + return trained_model_para[ + 0] if self.model_nums == 1 else trained_model_para + + def update(self, model_parameters, strict=False): + # update multiple model paras + """ + Arguments: + model_parameters (list[dict]): Multiple pyTorch Module object's + state_dict. + """ + if self.model_nums == 1: + super().update(model_parameters, strict=strict) + else: + assert isinstance(model_parameters, list) and isinstance( + model_parameters[0], dict), \ + "model_parameters should a list of multiple state_dict" + assert len(model_parameters) == self.model_nums, \ + f"model_parameters should has the same length to " \ + f"self.model_nums, " \ + f"but got {len(model_parameters)} and {self.model_nums} " \ + f"respectively" + for model_idx in range(self.model_nums): + self.ctx.models[model_idx].load_state_dict(self._param_filter( + model_parameters[model_idx]), + strict=strict) + + def train(self, target_data_split_name="train"): + # return multiple model paras + sample_size, _, results = super().train(target_data_split_name) + + return sample_size, self.get_model_para(), results diff --git a/fgssl/core/trainers/trainer_nbafl.py b/fgssl/core/trainers/trainer_nbafl.py new file mode 100644 index 0000000..53959b2 --- /dev/null +++ b/fgssl/core/trainers/trainer_nbafl.py @@ -0,0 +1,141 @@ +from federatedscope.core.auxiliaries.utils import get_random +from federatedscope.core.trainers.torch_trainer import GeneralTorchTrainer +from typing import Type +from copy import deepcopy + +import numpy as np +import torch + + +def wrap_nbafl_trainer( + base_trainer: Type[GeneralTorchTrainer]) -> Type[GeneralTorchTrainer]: + """Implementation of NbAFL refer to `Federated Learning with + Differential Privacy: Algorithms and Performance Analysis` [et al., 2020] + (https://ieeexplore.ieee.org/abstract/document/9069945/) + + Arguments: + mu: the factor of the regularizer + epsilon: the distinguishable bound + w_clip: the threshold to clip weights + + """ + + # ---------------- attribute-level plug-in ----------------------- + init_nbafl_ctx(base_trainer) + + # ---------------- action-level plug-in ----------------------- + base_trainer.register_hook_in_train(new_hook=record_initialization, + trigger='on_fit_start', + insert_pos=-1) + + base_trainer.register_hook_in_eval(new_hook=record_initialization, + trigger='on_fit_start', + insert_pos=-1) + + base_trainer.register_hook_in_train(new_hook=del_initialization, + trigger='on_fit_end', + insert_pos=-1) + + base_trainer.register_hook_in_eval(new_hook=del_initialization, + trigger='on_fit_end', + insert_pos=-1) + + base_trainer.register_hook_in_train(new_hook=inject_noise_in_upload, + trigger='on_fit_end', + insert_pos=-1) + return base_trainer + + +def init_nbafl_ctx(base_trainer): + """Set proximal regularizer, and the scale of gaussian noise + + """ + ctx = base_trainer.ctx + cfg = base_trainer.cfg + + # set proximal regularizer + cfg.defrost() + cfg.regularizer.type = 'proximal_regularizer' + cfg.regularizer.mu = cfg.nbafl.mu + cfg.freeze() + from federatedscope.core.auxiliaries.regularizer_builder import \ + get_regularizer + ctx.regularizer = get_regularizer(cfg.regularizer.type) + + # set noise scale during upload + if cfg.trainer.type == 'nodefullbatch_trainer': + num_train_data = sum(ctx.train_loader.dataset[0]['train_mask']) + else: + num_train_data = ctx.num_train_data + ctx.nbafl_scale_u = cfg.nbafl.w_clip * cfg.federate.total_round_num * \ + cfg.nbafl.constant / num_train_data / \ + cfg.nbafl.epsilon + + +# ---------------------------------------------------------------------- # +# Additional functions for NbAFL algorithm +# ---------------------------------------------------------------------- # + + +# Trainer +def record_initialization(ctx): + """Record the initialized weights within local updates + + """ + ctx.weight_init = deepcopy( + [_.data.detach() for _ in ctx.model.parameters()]) + + +def del_initialization(ctx): + """Clear the variable to avoid memory leakage + + """ + ctx.weight_init = None + + +def inject_noise_in_upload(ctx): + """Inject noise into weights before the client upload them to server + + """ + for p in ctx.model.parameters(): + noise = get_random("Normal", p.shape, { + "loc": 0, + "scale": ctx.nbafl_scale_u + }, p.device) + p.data += noise + + +# Server +def inject_noise_in_broadcast(cfg, sample_client_num, model): + """Inject noise into weights before the server broadcasts them + + """ + + # Clip weight + for p in model.parameters(): + p.data = p.data / torch.max( + torch.ones(size=p.shape, device=p.data.device), + torch.abs(p.data) / cfg.nbafl.w_clip) + if len(sample_client_num) > 0: + # Inject noise + L = cfg.federate.sample_client_num if cfg.federate.sample_client_num\ + > 0 else cfg.federate.client_num + if cfg.federate.total_round_num > np.sqrt(cfg.federate.client_num) * L: + scale_d = 2 * cfg.nbafl.w_clip * cfg.nbafl.constant * np.sqrt( + np.power(cfg.federate.total_round_num, 2) - + np.power(L, 2) * cfg.federate.client_num) / ( + min(sample_client_num) * cfg.federate.client_num * + cfg.nbafl.epsilon) + for p in model.parameters(): + p.data += get_random("Normal", p.shape, { + "loc": 0, + "scale": scale_d + }, p.device) + + +# def wrap_nbafl_server(server: Type[Server]) -> Type[Server]: +def wrap_nbafl_server(server): + """Register noise injector for the server + + """ + server.register_noise_injector(inject_noise_in_broadcast) diff --git a/fgssl/core/trainers/trainer_pFedMe.py b/fgssl/core/trainers/trainer_pFedMe.py new file mode 100644 index 0000000..601429d --- /dev/null +++ b/fgssl/core/trainers/trainer_pFedMe.py @@ -0,0 +1,148 @@ +import copy + +from federatedscope.core.trainers.torch_trainer import GeneralTorchTrainer +from federatedscope.core.optimizer import wrap_regularized_optimizer +from typing import Type + + +def wrap_pFedMeTrainer( + base_trainer: Type[GeneralTorchTrainer]) -> Type[GeneralTorchTrainer]: + """ + Build a `pFedMeTrainer` with a plug-in manner, by registering new + functions into specific `BaseTrainer` + + The pFedMe implementation, "Personalized Federated Learning with Moreau + Envelopes (NeurIPS 2020)" + is based on the Algorithm 1 in their paper and official codes: + https://github.com/CharlieDinh/pFedMe + """ + + # ---------------- attribute-level plug-in ----------------------- + init_pFedMe_ctx(base_trainer) + + # ---------------- action-level plug-in ----------------------- + base_trainer.register_hook_in_train( + new_hook=hook_on_fit_start_set_local_para_tmp, + trigger="on_fit_start", + insert_pos=-1) + base_trainer.register_hook_in_train( + new_hook=hook_on_epoch_end_update_local, + trigger="on_epoch_end", + insert_pos=-1) + base_trainer.register_hook_in_train(new_hook=hook_on_fit_end_update_local, + trigger="on_fit_end", + insert_pos=-1) + + base_trainer.register_hook_in_train(new_hook=_hook_on_batch_end_flop_count, + trigger="on_batch_end", + insert_pos=-1) + base_trainer.register_hook_in_train(new_hook=_hook_on_epoch_end_flop_count, + trigger="on_epoch_end", + insert_pos=-1) + + # for "on_batch_start" trigger: replace the original hooks into new ones + # of pFedMe + # 1) cache the original hooks for "on_batch_start" + base_trainer.ctx.original_hook_on_batch_start_train = \ + base_trainer.hooks_in_train["on_batch_start"] + # 2) replace the original hooks for "on_batch_start" + base_trainer.replace_hook_in_train( + new_hook=hook_on_batch_start_init_pfedme, + target_trigger="on_batch_start", + target_hook_name=None) + + return base_trainer + + +def init_pFedMe_ctx(base_trainer): + """ + init necessary attributes used in pFedMe, + some new attributes will be with prefix `pFedMe` optimizer to avoid + namespace pollution + """ + ctx = base_trainer.ctx + cfg = base_trainer.cfg + + # pFedMe finds approximate model with K steps using the same data batch + # the complexity of each pFedMe client is K times the one of FedAvg + ctx.pFedMe_K = cfg.personalization.K + ctx.num_train_epoch *= ctx.pFedMe_K + ctx.pFedMe_approx_fit_counter = 0 + + # the local_model_tmp is used to be the referenced parameter when + # finding the approximate \theta in paper + # will be copied from model every run_routine + ctx.pFedMe_local_model_tmp = None + + +def hook_on_fit_start_set_local_para_tmp(ctx): + # the optimizer used in pFedMe is based on Moreau Envelopes regularization + # besides, there are two distinct lr for the approximate model and base + # model + ctx.optimizer = wrap_regularized_optimizer( + ctx.optimizer, ctx.cfg.personalization.regular_weight) + for g in ctx.optimizer.param_groups: + g['lr'] = ctx.cfg.personalization.lr + ctx.pFedMe_outer_lr = ctx.cfg.train.optimizer.lr + + ctx.pFedMe_local_model_tmp = copy.deepcopy(ctx.model) + # set the compared model data, then the optimizer will find approximate + # model using trainer.cfg.personalization.lr + compared_global_model_para = [{ + "params": list(ctx.pFedMe_local_model_tmp.parameters()) + }] + ctx.optimizer.set_compared_para_group(compared_global_model_para) + + +def hook_on_batch_start_init_pfedme(ctx): + # refresh data every K step + if ctx.pFedMe_approx_fit_counter == 0: + if ctx.cur_mode == "train": + for hook in ctx.original_hook_on_batch_start_train: + hook(ctx) + else: + for hook in ctx.original_hook_on_batch_start_eval: + hook(ctx) + ctx.data_batch_cache = copy.deepcopy(ctx.data_batch) + else: + # reuse the data_cache since the original hook `_hook_on_batch_end` + # will clean `data_batch` + ctx.data_batch = copy.deepcopy(ctx.data_batch_cache) + ctx.pFedMe_approx_fit_counter = (ctx.pFedMe_approx_fit_counter + + 1) % ctx.pFedMe_K + + +def _hook_on_batch_end_flop_count(ctx): + # besides the normal forward flops, pFedMe introduces + # 1) the regularization adds the cost of number of model parameters + ctx.monitor.total_flops += ctx.monitor.total_model_size / 2 + + +def _hook_on_epoch_end_flop_count(ctx): + # due to the local weight updating + ctx.monitor.total_flops += ctx.monitor.total_model_size / 2 + + +def hook_on_epoch_end_update_local(ctx): + # update local weight after finding approximate theta + for client_param, local_para_tmp in zip( + ctx.model.parameters(), ctx.pFedMe_local_model_tmp.parameters()): + local_para_tmp.data = local_para_tmp.data - \ + ctx.optimizer.regular_weight * \ + ctx.pFedMe_outer_lr * (local_para_tmp.data - + client_param.data) + + # set the compared model data, then the optimizer will find approximate + # model using trainer.cfg.personalization.lr + compared_global_model_para = [{ + "params": list(ctx.pFedMe_local_model_tmp.parameters()) + }] + ctx.optimizer.set_compared_para_group(compared_global_model_para) + + +def hook_on_fit_end_update_local(ctx): + for param, local_para_tmp in zip(ctx.model.parameters(), + ctx.pFedMe_local_model_tmp.parameters()): + param.data = local_para_tmp.data + + del ctx.pFedMe_local_model_tmp diff --git a/fgssl/core/workers/__init__.py b/fgssl/core/workers/__init__.py new file mode 100644 index 0000000..777c989 --- /dev/null +++ b/fgssl/core/workers/__init__.py @@ -0,0 +1,10 @@ +from __future__ import absolute_import +from __future__ import print_function +from __future__ import division +from __future__ import with_statement + +from federatedscope.core.workers.base_worker import Worker +from federatedscope.core.workers.server import Server +from federatedscope.core.workers.client import Client + +__all__ = ['Worker', 'Server', 'Client'] diff --git a/fgssl/core/workers/base_client.py b/fgssl/core/workers/base_client.py new file mode 100644 index 0000000..8ec2d3b --- /dev/null +++ b/fgssl/core/workers/base_client.py @@ -0,0 +1,121 @@ +import abc +from federatedscope.core.workers.base_worker import Worker + + +class BaseClient(Worker): + def __init__(self, ID, state, config, model, strategy): + super(BaseClient, self).__init__(ID, state, config, model, strategy) + self.msg_handlers = dict() + + def register_handlers(self, msg_type, callback_func): + """ + To bind a message type with a handling function. + + Arguments: + msg_type (str): The defined message type + callback_func: The handling functions to handle the received + message + """ + self.msg_handlers[msg_type] = callback_func + + def _register_default_handlers(self): + self.register_handlers('assign_client_id', + self.callback_funcs_for_assign_id) + self.register_handlers('ask_for_join_in_info', + self.callback_funcs_for_join_in_info) + self.register_handlers('address', self.callback_funcs_for_address) + self.register_handlers('model_para', + self.callback_funcs_for_model_para) + self.register_handlers('ss_model_para', + self.callback_funcs_for_model_para) + self.register_handlers('evaluate', self.callback_funcs_for_evaluate) + self.register_handlers('finish', self.callback_funcs_for_finish) + self.register_handlers('converged', self.callback_funcs_for_converged) + + @abc.abstractmethod + def run(self): + """ + To listen to the message and handle them accordingly (used for + distributed mode) + """ + raise NotImplementedError + + @abc.abstractmethod + def callback_funcs_for_model_para(self, message): + """ + The handling function for receiving model parameters, + which triggers the local training process. + This handling function is widely used in various FL courses. + + Arguments: + message: The received message, which includes sender, receiver, + state, and content. + More detail can be found in federatedscope.core.message + """ + raise NotImplementedError + + @abc.abstractmethod + def callback_funcs_for_assign_id(self, message): + """ + The handling function for receiving the client_ID assigned by the + server (during the joining process), + which is used in the distributed mode. + + Arguments: + message: The received message + """ + raise NotImplementedError + + @abc.abstractmethod + def callback_funcs_for_join_in_info(self, message): + """ + The handling function for receiving the request of join in information + (such as batch_size, num_of_samples) during the joining process. + + Arguments: + message: The received message + """ + raise NotImplementedError + + @abc.abstractmethod + def callback_funcs_for_address(self, message): + """ + The handling function for receiving other clients' IP addresses, + which is used for constructing a complex topology + + Arguments: + message: The received message + """ + raise NotImplementedError + + @abc.abstractmethod + def callback_funcs_for_evaluate(self, message): + """ + The handling function for receiving the request of evaluating + + Arguments: + message: The received message + """ + raise NotImplementedError + + @abc.abstractmethod + def callback_funcs_for_finish(self, message): + """ + The handling function for receiving the signal of finishing the FL + course. + + Arguments: + message: The received message + """ + raise NotImplementedError + + @abc.abstractmethod + def callback_funcs_for_converged(self, message): + """ + The handling function for receiving the signal that the FL course + converged + + Arguments: + message: The received message + """ + raise NotImplementedError diff --git a/fgssl/core/workers/base_server.py b/fgssl/core/workers/base_server.py new file mode 100644 index 0000000..10788bf --- /dev/null +++ b/fgssl/core/workers/base_server.py @@ -0,0 +1,74 @@ +import abc +from federatedscope.core.workers import Worker + + +class BaseServer(Worker): + def __init__(self, ID, state, config, model, strategy): + super(BaseServer, self).__init__(ID, state, config, model, strategy) + self.msg_handlers = dict() + + def register_handlers(self, msg_type, callback_func): + """ + To bind a message type with a handling function. + + Arguments: + msg_type (str): The defined message type + callback_func: The handling functions to handle the received + message + """ + self.msg_handlers[msg_type] = callback_func + + def _register_default_handlers(self): + self.register_handlers('join_in', self.callback_funcs_for_join_in) + self.register_handlers('join_in_info', self.callback_funcs_for_join_in) + self.register_handlers('model_para', self.callback_funcs_model_para) + self.register_handlers('metrics', self.callback_funcs_for_metrics) + + @abc.abstractmethod + def run(self): + """ + To start the FL course, listen and handle messages (for distributed + mode). + """ + raise NotImplementedError + + @abc.abstractmethod + def callback_funcs_model_para(self, message): + """ + The handling function for receiving model parameters, which triggers + check_and_move_on (perform aggregation when enough feedback has + been received). + This handling function is widely used in various FL courses. + + Arguments: + message: The received message, which includes sender, receiver, + state, and content. More detail can be found in + federatedscope.core.message + """ + raise NotImplementedError + + @abc.abstractmethod + def callback_funcs_for_join_in(self, message): + """ + The handling function for receiving the join in information. The + server might request for some information (such as num_of_samples) + if necessary, assign IDs for the servers. + If all the clients have joined in, the training process will be + triggered. + + Arguments: + message: The received message + """ + raise NotImplementedError + + @abc.abstractmethod + def callback_funcs_for_metrics(self, message): + """ + The handling function for receiving the evaluation results, + which triggers check_and_move_on + (perform aggregation when enough feedback has been received). + + Arguments: + message: The received message + """ + raise NotImplementedError diff --git a/fgssl/core/workers/base_worker.py b/fgssl/core/workers/base_worker.py new file mode 100644 index 0000000..f7f064d --- /dev/null +++ b/fgssl/core/workers/base_worker.py @@ -0,0 +1,55 @@ +from federatedscope.core.monitors.monitor import Monitor + + +class Worker(object): + """ + The base worker class. + """ + def __init__(self, ID=-1, state=0, config=None, model=None, strategy=None): + self._ID = ID + self._state = state + self._model = model + self._cfg = config + self._strategy = strategy + self._mode = self._cfg.federate.mode.lower() + self._monitor = Monitor(config, monitored_object=self) + + @property + def ID(self): + return self._ID + + @ID.setter + def ID(self, value): + self._ID = value + + @property + def state(self): + return self._state + + @state.setter + def state(self, value): + self._state = value + + @property + def model(self): + return self._model + + @model.setter + def model(self, value): + self._model = value + + @property + def strategy(self): + return self._strategy + + @strategy.setter + def strategy(self, value): + self._strategy = value + + @property + def mode(self): + return self._mode + + @mode.setter + def mode(self, value): + self._mode = value diff --git a/fgssl/core/workers/client.py b/fgssl/core/workers/client.py new file mode 100644 index 0000000..39895d8 --- /dev/null +++ b/fgssl/core/workers/client.py @@ -0,0 +1,517 @@ +import copy +import logging +import sys +import pickle + +from federatedscope.core.message import Message +from federatedscope.core.communication import StandaloneCommManager, \ + gRPCCommManager +from federatedscope.core.monitors.early_stopper import EarlyStopper +from federatedscope.core.workers import Worker +from federatedscope.core.auxiliaries.trainer_builder import get_trainer +from federatedscope.core.secret_sharing import AdditiveSecretSharing +from federatedscope.core.auxiliaries.utils import merge_dict, \ + calculate_time_cost +from federatedscope.core.workers.base_client import BaseClient + +logger = logging.getLogger(__name__) + + +class Client(BaseClient): + """ + The Client class, which describes the behaviors of client in an FL course. + The behaviors are described by the handling functions (named as + callback_funcs_for_xxx) + + Arguments: + ID: The unique ID of the client, which is assigned by the server + when joining the FL course + server_id: (Default) 0 + state: The training round + config: The configuration + data: The data owned by the client + model: The model maintained locally + device: The device to run local training and evaluation + strategy: redundant attribute + """ + def __init__(self, + ID=-1, + server_id=None, + state=-1, + config=None, + data=None, + model=None, + device='cpu', + strategy=None, + is_unseen_client=False, + *args, + **kwargs): + + super(Client, self).__init__(ID, state, config, model, strategy) + + # the unseen_client indicates that whether this client contributes to + # FL process by training on its local data and uploading the local + # model update, which is useful for check the participation + # generalization gap in + # [ICLR'22, What Do We Mean by Generalization in Federated Learning?] + self.is_unseen_client = is_unseen_client + + # Attack only support the stand alone model; + # Check if is a attacker; a client is a attacker if the + # config.attack.attack_method is provided + self.is_attacker = config.attack.attacker_id == ID and \ + config.attack.attack_method != '' and \ + config.federate.mode == 'standalone' + + # Build Trainer + # trainer might need configurations other than those of trainer node + self.trainer = get_trainer(model=model, + data=data, + device=device, + config=self._cfg, + is_attacker=self.is_attacker, + monitor=self._monitor) + + # For client-side evaluation + self.best_results = dict() + self.history_results = dict() + # in local or global training mode, we do use the early stopper. + # Otherwise, we set patience=0 to deactivate the local early-stopper + patience = self._cfg.early_stop.patience if \ + self._cfg.federate.method in [ + "local", "global" + ] else 0 + self.early_stopper = EarlyStopper( + patience, self._cfg.early_stop.delta, + self._cfg.early_stop.improve_indicator_mode, + self._cfg.early_stop.the_smaller_the_better) + + # Secret Sharing Manager and message buffer + self.ss_manager = AdditiveSecretSharing( + shared_party_num=int(self._cfg.federate.sample_client_num + )) if self._cfg.federate.use_ss else None + self.msg_buffer = {'train': dict(), 'eval': dict()} + + # Register message handlers + self._register_default_handlers() + + # Communication and communication ability + if 'resource_info' in kwargs and kwargs['resource_info'] is not None: + self.comp_speed = float( + kwargs['resource_info']['computation']) / 1000. # (s/sample) + self.comm_bandwidth = float( + kwargs['resource_info']['communication']) # (kbit/s) + else: + self.comp_speed = None + self.comm_bandwidth = None + + if self._cfg.backend == 'torch': + self.model_size = sys.getsizeof(pickle.dumps( + self.model)) / 1024.0 * 8. # kbits + else: + # TODO: calculate model size for TF Model + self.model_size = 1.0 + logger.warning(f'The calculation of model size in backend:' + f'{self._cfg.backend} is not provided.') + + # Initialize communication manager + self.server_id = server_id + if self.mode == 'standalone': + comm_queue = kwargs['shared_comm_queue'] + self.comm_manager = StandaloneCommManager(comm_queue=comm_queue, + monitor=self._monitor) + self.local_address = None + elif self.mode == 'distributed': + host = kwargs['host'] + port = kwargs['port'] + server_host = kwargs['server_host'] + server_port = kwargs['server_port'] + self.comm_manager = gRPCCommManager( + host=host, port=port, client_num=self._cfg.federate.client_num) + logger.info('Client: Listen to {}:{}...'.format(host, port)) + self.comm_manager.add_neighbors(neighbor_id=server_id, + address={ + 'host': server_host, + 'port': server_port + }) + self.local_address = { + 'host': self.comm_manager.host, + 'port': self.comm_manager.port + } + + def _gen_timestamp(self, init_timestamp, instance_number): + if init_timestamp is None: + return None + + comp_cost, comm_cost = calculate_time_cost( + instance_number=instance_number, + comm_size=self.model_size, + comp_speed=self.comp_speed, + comm_bandwidth=self.comm_bandwidth) + return init_timestamp + comp_cost + comm_cost + + def _calculate_model_delta(self, init_model, updated_model): + if not isinstance(init_model, list): + init_model = [init_model] + updated_model = [updated_model] + + model_deltas = list() + for model_index in range(len(init_model)): + model_delta = copy.deepcopy(init_model[model_index]) + for key in init_model[model_index].keys(): + model_delta[key] = updated_model[model_index][ + key] - init_model[model_index][key] + model_deltas.append(model_delta) + + if len(model_deltas) > 1: + return model_deltas + else: + return model_deltas[0] + + def join_in(self): + """ + To send 'join_in' message to the server for joining in the FL course. + """ + self.comm_manager.send( + Message(msg_type='join_in', + sender=self.ID, + receiver=[self.server_id], + timestamp=0, + content=self.local_address)) + + def run(self): + """ + To listen to the message and handle them accordingly (used for + distributed mode) + """ + while True: + msg = self.comm_manager.receive() + if self.state <= msg.state: + self.msg_handlers[msg.msg_type](msg) + + if msg.msg_type == 'finish': + break + + def callback_funcs_for_model_para(self, message: Message): + """ + The handling function for receiving model parameters, + which triggers the local training process. + This handling function is widely used in various FL courses. + + Arguments: + message: The received message, which includes sender, receiver, + state, and content. + More detail can be found inww federatedscope.core.message + """ + if 'ss' in message.msg_type: + # A fragment of the shared secret + state, content, timestamp = message.state, message.content, \ + message.timestamp + self.msg_buffer['train'][state].append(content) + + if len(self.msg_buffer['train'] + [state]) == self._cfg.federate.client_num: + # Check whether the received fragments are enough + model_list = self.msg_buffer['train'][state] + sample_size, first_aggregate_model_para = model_list[0] + single_model_case = True + if isinstance(first_aggregate_model_para, list): + assert isinstance(first_aggregate_model_para[0], dict), \ + "aggregate_model_para should a list of multiple " \ + "state_dict for multiple models" + single_model_case = False + else: + assert isinstance(first_aggregate_model_para, dict), \ + "aggregate_model_para should " \ + "a state_dict for single model case" + first_aggregate_model_para = [first_aggregate_model_para] + model_list = [[model] for model in model_list] + + for sub_model_idx, aggregate_single_model_para in enumerate( + first_aggregate_model_para): + for key in aggregate_single_model_para: + for i in range(1, len(model_list)): + aggregate_single_model_para[key] += model_list[i][ + sub_model_idx][key] + + self.comm_manager.send( + Message(msg_type='model_para', + sender=self.ID, + receiver=[self.server_id], + state=self.state, + timestamp=timestamp, + content=(sample_size, first_aggregate_model_para[0] + if single_model_case else + first_aggregate_model_para))) + + else: + round = message.state + sender = message.sender + timestamp = message.timestamp + content = message.content + # When clients share the local model, we must set strict=True to + # ensure all the model params (which might be updated by other + # clients in the previous local training process) are overwritten + # and synchronized with the received model + self.trainer.update(content, + strict=self._cfg.federate.share_local_model) + self.state = round + skip_train_isolated_or_global_mode = \ + self.early_stopper.early_stopped and \ + self._cfg.federate.method in ["local", "global"] + if self.is_unseen_client or skip_train_isolated_or_global_mode: + # for these cases (1) unseen client (2) isolated_global_mode, + # we do not local train and upload local model + sample_size, model_para_all, results = \ + 0, self.trainer.get_model_para(), {} + if skip_train_isolated_or_global_mode: + logger.info( + f"[Local/Global mode] Client #{self.ID} has been " + f"early stopped, we will skip the local training") + self._monitor.local_converged() + else: + if self.early_stopper.early_stopped and \ + self._monitor.local_convergence_round == 0: + logger.info( + f"[Normal FL Mode] Client #{self.ID} has been locally " + f"early stopped. " + f"The next FL update may result in negative effect") + self._monitor.local_converged() + sample_size, model_para_all, results = self.trainer.train() + if self._cfg.federate.share_local_model and not \ + self._cfg.federate.online_aggr: + model_para_all = copy.deepcopy(model_para_all) + train_log_res = self._monitor.format_eval_res( + results, + rnd=self.state, + role='Client #{}'.format(self.ID), + return_raw=True) + logger.info(train_log_res) + if self._cfg.wandb.use and self._cfg.wandb.client_train_info: + self._monitor.save_formatted_results(train_log_res, + save_file_name="") + + # Return the feedbacks to the server after local update + if self._cfg.federate.use_ss: + assert not self.is_unseen_client, \ + "Un-support using secret sharing for unseen clients." \ + "i.e., you set cfg.federate.use_ss=True and " \ + "cfg.federate.unseen_clients_rate in (0, 1)" + single_model_case = True + if isinstance(model_para_all, list): + assert isinstance(model_para_all[0], dict), \ + "model_para should a list of " \ + "multiple state_dict for multiple models" + single_model_case = False + else: + assert isinstance(model_para_all, dict), \ + "model_para should a state_dict for single model case" + model_para_all = [model_para_all] + model_para_list_all = [] + for model_para in model_para_all: + for key in model_para: + model_para[key] = model_para[key] * sample_size + model_para_list = self.ss_manager.secret_split(model_para) + model_para_list_all.append(model_para_list) + # print(model_para) + # print(self.ss_manager.secret_reconstruct( + # model_para_list)) + frame_idx = 0 + for neighbor in self.comm_manager.neighbors: + if neighbor != self.server_id: + content_frame = model_para_list_all[0][frame_idx] if \ + single_model_case else \ + [model_para_list[frame_idx] for model_para_list + in model_para_list_all] + self.comm_manager.send( + Message(msg_type='ss_model_para', + sender=self.ID, + receiver=[neighbor], + state=self.state, + timestamp=self._gen_timestamp( + init_timestamp=timestamp, + instance_number=sample_size), + content=content_frame)) + frame_idx += 1 + content_frame = model_para_list_all[0][frame_idx] if \ + single_model_case else \ + [model_para_list[frame_idx] for model_para_list in + model_para_list_all] + self.msg_buffer['train'][self.state] = [(sample_size, + content_frame)] + else: + if self._cfg.asyn.use: + # Return the model delta when using asynchronous training + # protocol, because the staled updated might be discounted + # and cause that the sum of the aggregated weights might + # not be equal to 1 + shared_model_para = self._calculate_model_delta( + init_model=content, updated_model=model_para_all) + else: + shared_model_para = model_para_all + + self.comm_manager.send( + Message(msg_type='model_para', + sender=self.ID, + receiver=[sender], + state=self.state, + timestamp=self._gen_timestamp( + init_timestamp=timestamp, + instance_number=sample_size), + content=(sample_size, shared_model_para))) + + def callback_funcs_for_assign_id(self, message: Message): + """ + The handling function for receiving the client_ID assigned by the + server (during the joining process), + which is used in the distributed mode. + + Arguments: + message: The received message + """ + content = message.content + self.ID = int(content) + logger.info('Client (address {}:{}) is assigned with #{:d}.'.format( + self.comm_manager.host, self.comm_manager.port, self.ID)) + + def callback_funcs_for_join_in_info(self, message: Message): + """ + The handling function for receiving the request of join in information + (such as batch_size, num_of_samples) during the joining process. + + Arguments: + message: The received message + """ + requirements = message.content + timestamp = message.timestamp + join_in_info = dict() + for requirement in requirements: + if requirement.lower() == 'num_sample': + if self._cfg.train.batch_or_epoch == 'batch': + num_sample = self._cfg.train.local_update_steps * \ + self._cfg.dataloader.batch_size + else: + num_sample = self._cfg.train.local_update_steps * \ + len(self.data['train']) + join_in_info['num_sample'] = num_sample + if self._cfg.trainer.type == 'nodefullbatch_trainer': + join_in_info['num_sample'] = self.data['data'].x.shape[0] + elif requirement.lower() == 'client_resource': + assert self.comm_bandwidth is not None and self.comp_speed \ + is not None, "The requirement join_in_info " \ + "'client_resource' does not exist." + join_in_info['client_resource'] = self.model_size / \ + self.comm_bandwidth + self.comp_speed + else: + raise ValueError( + 'Fail to get the join in information with type {}'.format( + requirement)) + self.comm_manager.send( + Message(msg_type='join_in_info', + sender=self.ID, + receiver=[self.server_id], + state=self.state, + timestamp=timestamp, + content=join_in_info)) + + def callback_funcs_for_address(self, message: Message): + """ + The handling function for receiving other clients' IP addresses, + which is used for constructing a complex topology + + Arguments: + message: The received message + """ + content = message.content + for neighbor_id, address in content.items(): + if int(neighbor_id) != self.ID: + self.comm_manager.add_neighbors(neighbor_id, address) + + def callback_funcs_for_evaluate(self, message: Message): + """ + The handling function for receiving the request of evaluating + + Arguments: + message: The received message + """ + sender, timestamp = message.sender, message.timestamp + self.state = message.state + if message.content is not None: + self.trainer.update(message.content, + strict=self._cfg.federate.share_local_model) + if self.early_stopper.early_stopped and self._cfg.federate.method in [ + "local", "global" + ]: + metrics = list(self.best_results.values())[0] + else: + metrics = {} + if self._cfg.finetune.before_eval: + self.trainer.finetune() + for split in self._cfg.eval.split: + # TODO: The time cost of evaluation is not considered here + eval_metrics = self.trainer.evaluate( + target_data_split_name=split) + + if self._cfg.federate.mode == 'distributed': + logger.info( + self._monitor.format_eval_res(eval_metrics, + rnd=self.state, + role='Client #{}'.format( + self.ID), + return_raw=True)) + + metrics.update(**eval_metrics) + + formatted_eval_res = self._monitor.format_eval_res( + metrics, + rnd=self.state, + role='Client #{}'.format(self.ID), + forms='raw', + return_raw=True) + self._monitor.update_best_result( + self.best_results, + formatted_eval_res['Results_raw'], + results_type=f"client #{self.ID}", + round_wise_update_key=self._cfg.eval. + best_res_update_round_wise_key) + self.history_results = merge_dict( + self.history_results, formatted_eval_res['Results_raw']) + self.early_stopper.track_and_check(self.history_results[ + self._cfg.eval.best_res_update_round_wise_key]) + + self.comm_manager.send( + Message(msg_type='metrics', + sender=self.ID, + receiver=[sender], + state=self.state, + timestamp=timestamp, + content=metrics)) + + def callback_funcs_for_finish(self, message: Message): + """ + The handling function for receiving the signal of finishing the FL + course. + + Arguments: + message: The received message + """ + logger.info( + f"================= client {self.ID} received finish message " + f"=================") + + if message.content is not None: + self.trainer.update(message.content, + strict=self._cfg.federate.share_local_model) + + self._monitor.finish_fl() + + def callback_funcs_for_converged(self, message: Message): + """ + The handling function for receiving the signal that the FL course + converged + + Arguments: + message: The received message + """ + + self._monitor.global_converged() diff --git a/fgssl/core/workers/server.py b/fgssl/core/workers/server.py new file mode 100644 index 0000000..972579c --- /dev/null +++ b/fgssl/core/workers/server.py @@ -0,0 +1,1012 @@ +import logging +import copy +import os +import sys + +import numpy as np +import pickle + +import seaborn as sns +from matplotlib import pyplot as plt +from sklearn import manifold + +from federatedscope.core.monitors.early_stopper import EarlyStopper +from federatedscope.core.message import Message +from federatedscope.core.communication import StandaloneCommManager, \ + gRPCCommManager +from federatedscope.core.auxiliaries.aggregator_builder import get_aggregator +from federatedscope.core.auxiliaries.sampler_builder import get_sampler +from federatedscope.core.auxiliaries.utils import merge_dict, Timeout, \ + merge_param_dict +from federatedscope.core.auxiliaries.trainer_builder import get_trainer +from federatedscope.core.secret_sharing import AdditiveSecretSharing +from federatedscope.core.workers.base_server import BaseServer + +logger = logging.getLogger(__name__) + + +class Server(BaseServer): + """ + The Server class, which describes the behaviors of server in an FL course. + The behaviors are described by the handled functions (named as + callback_funcs_for_xxx). + + Arguments: + ID: The unique ID of the server, which is set to 0 by default + state: The training round + config: the configuration + data: The data owned by the server (for global evaluation) + model: The model used for aggregation + client_num: The (expected) client num to start the FL course + total_round_num: The total number of the training round + device: The device to run local training and evaluation + strategy: redundant attribute + """ + def __init__(self, + ID=-1, + state=0, + config=None, + data=None, + model=None, + client_num=5, + total_round_num=10, + device='cpu', + strategy=None, + unseen_clients_id=None, + **kwargs): + + super(Server, self).__init__(ID, state, config, model, strategy) + + self.data = data + self.device = device + self.best_results = dict() + self.history_results = dict() + self.early_stopper = EarlyStopper( + self._cfg.early_stop.patience, self._cfg.early_stop.delta, + self._cfg.early_stop.improve_indicator_mode, + self._cfg.early_stop.the_smaller_the_better) + + if self._cfg.federate.share_local_model: + # put the model to the specified device + model.to(device) + # Build aggregator + self.aggregator = get_aggregator(self._cfg.federate.method, + model=model, + device=device, + online=self._cfg.federate.online_aggr, + config=self._cfg) + if self._cfg.federate.restore_from != '': + if not os.path.exists(self._cfg.federate.restore_from): + logger.warning(f'Invalid `restore_from`:' + f' {self._cfg.federate.restore_from}.') + else: + _ = self.aggregator.load_model(self._cfg.federate.restore_from) + logger.info("Restored the model from {}-th round's ckpt") + + if int(config.model.model_num_per_trainer) != \ + config.model.model_num_per_trainer or \ + config.model.model_num_per_trainer < 1: + raise ValueError( + f"model_num_per_trainer should be integer and >= 1, " + f"got {config.model.model_num_per_trainer}.") + self.model_num = config.model.model_num_per_trainer + self.models = [self.model] + self.aggregators = [self.aggregator] + if self.model_num > 1: + self.models.extend( + [copy.deepcopy(self.model) for _ in range(self.model_num - 1)]) + self.aggregators.extend([ + copy.deepcopy(self.aggregator) + for _ in range(self.model_num - 1) + ]) + + # function for recovering shared secret + self.recover_fun = AdditiveSecretSharing( + shared_party_num=int(self._cfg.federate.sample_client_num) + ).fixedpoint2float if self._cfg.federate.use_ss else None + + if self._cfg.federate.make_global_eval: + # set up a trainer for conducting evaluation in server + assert self.model is not None + assert self.data is not None + self.trainer = get_trainer( + model=self.model, + data=self.data, + device=self.device, + config=self._cfg, + only_for_eval=True, + monitor=self._monitor + ) # the trainer is only used for global evaluation + self.trainers = [self.trainer] + if self.model_num > 1: + # By default, the evaluation is conducted by calling + # trainer[i].eval over all internal models + self.trainers.extend([ + copy.deepcopy(self.trainer) + for _ in range(self.model_num - 1) + ]) + + # Initialize the number of joined-in clients + self._client_num = client_num + self._total_round_num = total_round_num + self.sample_client_num = int(self._cfg.federate.sample_client_num) + self.join_in_client_num = 0 + self.join_in_info = dict() + # the unseen clients indicate the ones that do not contribute to FL + # process by training on their local data and uploading their local + # model update. The splitting is useful to check participation + # generalization gap in + # [ICLR'22, What Do We Mean by Generalization in Federated Learning?] + self.unseen_clients_id = [] if unseen_clients_id is None \ + else unseen_clients_id + + # Server state + self.is_finish = False + + # Sampler + if self._cfg.federate.sampler in ['uniform']: + self.sampler = get_sampler( + sample_strategy=self._cfg.federate.sampler, + client_num=self.client_num, + client_info=None) + else: + # Some type of sampler would be instantiated in trigger_for_start, + # since they need more information + self.sampler = None + + # Current Timestamp + self.cur_timestamp = 0 + self.deadline_for_cur_round = 1 + + # Staleness toleration + self.staleness_toleration = self._cfg.asyn.staleness_toleration if \ + self._cfg.asyn.use else 0 + self.dropout_num = 0 + + # Device information + self.resource_info = kwargs['resource_info'] \ + if 'resource_info' in kwargs else None + self.client_resource_info = kwargs['client_resource_info'] \ + if 'client_resource_info' in kwargs else None + + # Register message handlers + self._register_default_handlers() + + # Initialize communication manager and message buffer + self.msg_buffer = {'train': dict(), 'eval': dict()} + self.staled_msg_buffer = list() + if self.mode == 'standalone': + comm_queue = kwargs['shared_comm_queue'] + self.comm_manager = StandaloneCommManager(comm_queue=comm_queue, + monitor=self._monitor) + elif self.mode == 'distributed': + host = kwargs['host'] + port = kwargs['port'] + self.comm_manager = gRPCCommManager(host=host, + port=port, + client_num=client_num) + logger.info('Server: Listen to {}:{}...'.format(host, port)) + + # inject noise before broadcast + self._noise_injector = None + + @property + def client_num(self): + return self._client_num + + @client_num.setter + def client_num(self, value): + self._client_num = value + + @property + def total_round_num(self): + return self._total_round_num + + @total_round_num.setter + def total_round_num(self, value): + self._total_round_num = value + + def register_noise_injector(self, func): + self._noise_injector = func + + def run(self): + """ + To start the FL course, listen and handle messages (for distributed + mode). + """ + + # Begin: Broadcast model parameters and start to FL train + while self.join_in_client_num < self.client_num: + msg = self.comm_manager.receive() + self.msg_handlers[msg.msg_type](msg) + + # Running: listen for message (updates from clients), + # aggregate and broadcast feedbacks (aggregated model parameters) + min_received_num = self._cfg.asyn.min_received_num \ + if self._cfg.asyn.use else self._cfg.federate.sample_client_num + num_failure = 0 + time_budget = self._cfg.asyn.time_budget if self._cfg.asyn.use else -1 + with Timeout(time_budget) as time_counter: + while self.state <= self.total_round_num: + try: + msg = self.comm_manager.receive() + move_on_flag = self.msg_handlers[msg.msg_type](msg) + if move_on_flag: + time_counter.reset() + except TimeoutError: + logger.info('Time out at the training round #{}'.format( + self.state)) + move_on_flag_eval = self.check_and_move_on( + min_received_num=min_received_num, + check_eval_result=True) + move_on_flag = self.check_and_move_on( + min_received_num=min_received_num) + if not move_on_flag and not move_on_flag_eval: + num_failure += 1 + # Terminate the training if the number of failure + # exceeds the maximum number (default value: 10) + if time_counter.exceed_max_failure(num_failure): + logger.info(f'----------- Training fails at round ' + f'#{self.state}-------------') + break + + # Time out, broadcast the model para and re-start + # the training round + logger.info( + f'----------- Re-starting the training round (' + f'Round #{self.state}) for {num_failure} time ' + f'-------------') + # TODO: Clean the msg_buffer + if self.state in self.msg_buffer['train']: + self.msg_buffer['train'][self.state].clear() + + self.broadcast_model_para( + msg_type='model_para', + sample_client_num=self.sample_client_num) + else: + num_failure = 0 + time_counter.reset() + + self.terminate(msg_type='finish') + + def check_and_move_on(self, + check_eval_result=False, + min_received_num=None): + """ + To check the message_buffer. When enough messages are receiving, + some events (such as perform aggregation, evaluation, and move to + the next training round) would be triggered. + + Arguments: + check_eval_result (bool): If True, check the message buffer for + evaluation; and check the message buffer for training otherwise. + """ + if min_received_num is None: + if self._cfg.asyn.use: + min_received_num = self._cfg.asyn.min_received_num + else: + min_received_num = self._cfg.federate.sample_client_num + assert min_received_num <= self.sample_client_num + + if check_eval_result and self._cfg.federate.mode.lower( + ) == "standalone": + # in evaluation stage and standalone simulation mode, we assume + # strong synchronization that receives responses from all clients + min_received_num = len(self.comm_manager.get_neighbors().keys()) + + move_on_flag = True # To record whether moving to a new training + # round or finishing the evaluation + if self.check_buffer(self.state, min_received_num, check_eval_result): + if not check_eval_result: + # Receiving enough feedback in the training process + aggregated_num = self._perform_federated_aggregation() + + self.state += 1 + if self.state % self._cfg.eval.freq == 0 and self.state != \ + self.total_round_num: + # Evaluate + logger.info(f'Server: Starting evaluation at the end ' + f'of round {self.state - 1}.') + self.eval() + + if self.state < self.total_round_num: + # Move to next round of training + logger.info( + f'----------- Starting a new training round (Round ' + f'#{self.state}) -------------') + # Clean the msg_buffer + self.msg_buffer['train'][self.state - 1].clear() + self.msg_buffer['train'][self.state] = dict() + self.staled_msg_buffer.clear() + # Start a new training round + self._start_new_training_round(aggregated_num) + else: + # Final Evaluate + logger.info('Server: Training is finished! Starting ' + 'evaluation.') + self.eval() + + else: + # Receiving enough feedback in the evaluation process + self._merge_and_format_eval_results() + + else: + move_on_flag = False + + return move_on_flag + + def check_and_save(self): + """ + To save the results and save model after each evaluation. + """ + + # early stopping + if "Results_weighted_avg" in self.history_results and \ + self._cfg.eval.best_res_update_round_wise_key in \ + self.history_results['Results_weighted_avg']: + should_stop = self.early_stopper.track_and_check( + self.history_results['Results_weighted_avg'][ + self._cfg.eval.best_res_update_round_wise_key]) + elif "Results_avg" in self.history_results and \ + self._cfg.eval.best_res_update_round_wise_key in \ + self.history_results['Results_avg']: + should_stop = self.early_stopper.track_and_check( + self.history_results['Results_avg'][ + self._cfg.eval.best_res_update_round_wise_key]) + else: + should_stop = False + + if should_stop: + self._monitor.global_converged() + self.comm_manager.send( + Message( + msg_type="converged", + sender=self.ID, + receiver=list(self.comm_manager.neighbors.keys()), + timestamp=self.cur_timestamp, + state=self.state, + )) + self.state = self.total_round_num + 1 + + if should_stop or self.state == self.total_round_num: + logger.info('Server: Final evaluation is finished! Starting ' + 'merging results.') + # last round or early stopped + self.save_best_results() + if not self._cfg.federate.make_global_eval: + self.save_client_eval_results() + self.terminate(msg_type='finish') + + # Clean the clients evaluation msg buffer + if not self._cfg.federate.make_global_eval: + round = max(self.msg_buffer['eval'].keys()) + self.msg_buffer['eval'][round].clear() + + if self.state == self.total_round_num: + # break out the loop for distributed mode + self.state += 1 + + def _perform_federated_aggregation(self): + """ + Perform federated aggregation and update the global model + """ + train_msg_buffer = self.msg_buffer['train'][self.state] + for model_idx in range(self.model_num): + model = self.models[model_idx] + aggregator = self.aggregators[model_idx] + msg_list = list() + staleness = list() + + for client_id in train_msg_buffer.keys(): + if self.model_num == 1: + msg_list.append(train_msg_buffer[client_id]) + else: + train_data_size, model_para_multiple = \ + train_msg_buffer[client_id] + msg_list.append( + (train_data_size, model_para_multiple[model_idx])) + + # The staleness of the messages in train_msg_buffer + # should be 0 + staleness.append((client_id, 0)) + + for staled_message in self.staled_msg_buffer: + state, client_id, content = staled_message + if self.model_num == 1: + msg_list.append(content) + else: + train_data_size, model_para_multiple = content + msg_list.append( + (train_data_size, model_para_multiple[model_idx])) + + staleness.append((client_id, self.state - state)) + + # Trigger the monitor here (for training) + if 'dissim' in self._cfg.eval.monitoring: + # TODO: fix this + B_val = self._monitor.calc_blocal_dissim( + model.load_state_dict(strict=False), msg_list) + formatted_eval_res = self._monitor.format_eval_res( + B_val, rnd=self.state, role='Server #') + logger.info(formatted_eval_res) + + # Aggregate + aggregated_num = len(msg_list) + agg_info = { + 'client_feedback': msg_list, + 'recover_fun': self.recover_fun, + 'staleness': staleness, + } + # logger.info(f'The staleness is {staleness}') + result = aggregator.aggregate(agg_info) + # Due to lazy load, we merge two state dict + merged_param = merge_param_dict(model.state_dict().copy(), result) + model.load_state_dict(merged_param, strict=False) + if self.state == 198: + visual_tsne(model, self.data["val"].dataset[0], self.__class__.__name__) + return aggregated_num + + def _start_new_training_round(self, aggregated_num=0): + """ + The behaviors for starting a new training round + """ + if self._cfg.asyn.use: # for asynchronous training + if self._cfg.asyn.aggregator == "time_up": + # Update the deadline according to the time budget + self.deadline_for_cur_round = \ + self.cur_timestamp + self._cfg.asyn.time_budget + + if self._cfg.asyn.broadcast_manner == \ + 'after_aggregating': + if self._cfg.asyn.overselection: + sample_client_num = self.sample_client_num + else: + sample_client_num = aggregated_num + \ + self.dropout_num + + self.broadcast_model_para(msg_type='model_para', + sample_client_num=sample_client_num) + self.dropout_num = 0 + else: # for synchronous training + self.broadcast_model_para(msg_type='model_para', + sample_client_num=self.sample_client_num) + + def _merge_and_format_eval_results(self): + """ + The behaviors of server when receiving enough evaluating results + """ + # Get all the message & aggregate + formatted_eval_res = \ + self.merge_eval_results_from_all_clients() + self.history_results = merge_dict(self.history_results, + formatted_eval_res) + if self.mode == 'standalone' and \ + self._monitor.wandb_online_track and \ + self._monitor.use_wandb: + self._monitor.merge_system_metrics_simulation_mode( + file_io=False, from_global_monitors=True) + self.check_and_save() + + def save_best_results(self): + """ + To Save the best evaluation results. + """ + + if self._cfg.federate.save_to != '': + self.aggregator.save_model(self._cfg.federate.save_to, self.state) + formatted_best_res = self._monitor.format_eval_res( + results=self.best_results, + rnd="Final", + role='Server #', + forms=["raw"], + return_raw=True) + logger.info(formatted_best_res) + self._monitor.save_formatted_results(formatted_best_res) + + def save_client_eval_results(self): + """ + save the evaluation results of each client when the fl course + early stopped or terminated + + :return: + """ + round = max(self.msg_buffer['eval'].keys()) + eval_msg_buffer = self.msg_buffer['eval'][round] + + with open(os.path.join(self._cfg.outdir, "eval_results.log"), + "a") as outfile: + for client_id, client_eval_results in eval_msg_buffer.items(): + formatted_res = self._monitor.format_eval_res( + client_eval_results, + rnd=self.state, + role='Client #{}'.format(client_id), + return_raw=True) + logger.info(formatted_res) + outfile.write(str(formatted_res) + "\n") + + def merge_eval_results_from_all_clients(self): + """ + Merge evaluation results from all clients, update best, + log the merged results and save them into eval_results.log + + :returns: the formatted merged results + """ + + round = max(self.msg_buffer['eval'].keys()) + eval_msg_buffer = self.msg_buffer['eval'][round] + eval_res_participated_clients = [] + eval_res_unseen_clients = [] + for client_id in eval_msg_buffer: + if eval_msg_buffer[client_id] is None: + continue + if client_id in self.unseen_clients_id: + eval_res_unseen_clients.append(eval_msg_buffer[client_id]) + else: + eval_res_participated_clients.append( + eval_msg_buffer[client_id]) + + formatted_logs_all_set = dict() + for merge_type, eval_res_set in [("participated", + eval_res_participated_clients), + ("unseen", eval_res_unseen_clients)]: + if eval_res_set != []: + metrics_all_clients = dict() + for client_eval_results in eval_res_set: + for key in client_eval_results.keys(): + if key not in metrics_all_clients: + metrics_all_clients[key] = list() + metrics_all_clients[key].append( + float(client_eval_results[key])) + formatted_logs = self._monitor.format_eval_res( + metrics_all_clients, + rnd=round, + role='Server #', + forms=self._cfg.eval.report) + if merge_type == "unseen": + for key, val in copy.deepcopy(formatted_logs).items(): + if isinstance(val, dict): + # to avoid the overrides of results using the + # same name, we use new keys with postfix `unseen`: + # 'Results_weighted_avg' -> + # 'Results_weighted_avg_unseen' + formatted_logs[key + "_unseen"] = val + del formatted_logs[key] + logger.info(formatted_logs) + formatted_logs_all_set.update(formatted_logs) + self._monitor.update_best_result( + self.best_results, + metrics_all_clients, + results_type="unseen_client_best_individual" + if merge_type == "unseen" else "client_best_individual", + round_wise_update_key=self._cfg.eval. + best_res_update_round_wise_key) + self._monitor.save_formatted_results(formatted_logs) + for form in self._cfg.eval.report: + if form != "raw": + metric_name = form + "_unseen" if merge_type == \ + "unseen" else form + self._monitor.update_best_result( + self.best_results, + formatted_logs[f"Results_{metric_name}"], + results_type=f"unseen_client_summarized_{form}" + if merge_type == "unseen" else + f"client_summarized_{form}", + round_wise_update_key=self._cfg.eval. + best_res_update_round_wise_key) + + return formatted_logs_all_set + + def broadcast_model_para(self, + msg_type='model_para', + sample_client_num=-1, + filter_unseen_clients=True): + """ + To broadcast the message to all clients or sampled clients + + Arguments: + msg_type: 'model_para' or other user defined msg_type + sample_client_num: the number of sampled clients in the broadcast + behavior. And sample_client_num = -1 denotes to broadcast to + all the clients. + filter_unseen_clients: whether filter out the unseen clients that + do not contribute to FL process by training on their local + data and uploading their local model update. The splitting is + useful to check participation generalization gap in [ICLR'22, + What Do We Mean by Generalization in Federated Learning?] + You may want to set it to be False when in evaluation stage + """ + if filter_unseen_clients: + # to filter out the unseen clients when sampling + self.sampler.change_state(self.unseen_clients_id, 'unseen') + + if sample_client_num > 0: + receiver = self.sampler.sample(size=sample_client_num) + else: + # broadcast to all clients + receiver = list(self.comm_manager.neighbors.keys()) + if msg_type == 'model_para': + self.sampler.change_state(receiver, 'working') + + if self._noise_injector is not None and msg_type == 'model_para': + # Inject noise only when broadcast parameters + for model_idx_i in range(len(self.models)): + num_sample_clients = [ + v["num_sample"] for v in self.join_in_info.values() + ] + self._noise_injector(self._cfg, num_sample_clients, + self.models[model_idx_i]) + + skip_broadcast = self._cfg.federate.method in ["local", "global"] + if self.model_num > 1: + model_para = [{} if skip_broadcast else model.state_dict() + for model in self.models] + else: + model_para = {} if skip_broadcast else self.model.state_dict() + + self.comm_manager.send( + Message(msg_type=msg_type, + sender=self.ID, + receiver=receiver, + state=min(self.state, self.total_round_num), + timestamp=self.cur_timestamp, + content=model_para)) + if self._cfg.federate.online_aggr: + for idx in range(self.model_num): + self.aggregators[idx].reset() + + if filter_unseen_clients: + # restore the state of the unseen clients within sampler + self.sampler.change_state(self.unseen_clients_id, 'seen') + + def broadcast_client_address(self): + """ + To broadcast the communication addresses of clients (used for + additive secret sharing) + """ + + self.comm_manager.send( + Message(msg_type='address', + sender=self.ID, + receiver=list(self.comm_manager.neighbors.keys()), + state=self.state, + timestamp=self.cur_timestamp, + content=self.comm_manager.get_neighbors())) + + def check_buffer(self, + cur_round, + min_received_num, + check_eval_result=False): + """ + To check the message buffer + + Arguments: + cur_round (int): The current round number + min_received_num (int): The minimal number of the receiving messages + check_eval_result (bool): To check training results for evaluation + results + :returns: Whether enough messages have been received or not + :rtype: bool + """ + + if check_eval_result: + if 'eval' not in self.msg_buffer.keys() or len( + self.msg_buffer['eval'].keys()) == 0: + return False + + buffer = self.msg_buffer['eval'] + cur_round = max(buffer.keys()) + cur_buffer = buffer[cur_round] + return len(cur_buffer) >= min_received_num + else: + if cur_round not in self.msg_buffer['train']: + cur_buffer = dict() + else: + cur_buffer = self.msg_buffer['train'][cur_round] + if self._cfg.asyn.use and self._cfg.asyn.aggregator == 'time_up': + if self.cur_timestamp >= self.deadline_for_cur_round and len( + cur_buffer) + len(self.staled_msg_buffer) == 0: + # When the time budget is run out but the server has not + # received any feedback + logger.warning( + f'The server has not received any feedback when the ' + f'time budget has run out, therefore the server would ' + f'wait for more {self._cfg.asyn.time_budget} seconds. ' + f'Maybe you should carefully reset ' + f'`cfg.asyn.time_budget` to a reasonable value.') + self.deadline_for_cur_round += self._cfg.asyn.time_budget + if self._cfg.asyn.broadcast_manner == \ + 'after_aggregating' and self.dropout_num != 0: + self.broadcast_model_para( + msg_type='model_para', + sample_client_num=self.dropout_num) + self.dropout_num = 0 + return self.cur_timestamp >= self.deadline_for_cur_round + else: + return len(cur_buffer)+len(self.staled_msg_buffer) >= \ + min_received_num + + def check_client_join_in(self): + """ + To check whether all the clients have joined in the FL course. + """ + + if len(self._cfg.federate.join_in_info) != 0: + return len(self.join_in_info) == self.client_num + else: + return self.join_in_client_num == self.client_num + + def trigger_for_start(self): + """ + To start the FL course when the expected number of clients have joined + """ + + if self.check_client_join_in(): + if self._cfg.federate.use_ss: + self.broadcast_client_address() + + # get sampler + if 'client_resource' in self._cfg.federate.join_in_info: + client_resource = [ + self.join_in_info[client_index]['client_resource'] + for client_index in np.arange(1, self.client_num + 1) + ] + else: + if self._cfg.backend == 'torch': + model_size = sys.getsizeof(pickle.dumps( + self.model)) / 1024.0 * 8. + else: + # TODO: calculate model size for TF Model + model_size = 1.0 + logger.warning(f'The calculation of model size in backend:' + f'{self._cfg.backend} is not provided.') + + client_resource = [ + model_size / float(x['communication']) + + float(x['computation']) / 1000. + for x in self.client_resource_info + ] if self.client_resource_info is not None else None + + if self.sampler is None: + self.sampler = get_sampler( + sample_strategy=self._cfg.federate.sampler, + client_num=self.client_num, + client_info=client_resource) + + # change the deadline if the asyn.aggregator is `time up` + if self._cfg.asyn.use and self._cfg.asyn.aggregator == 'time_up': + self.deadline_for_cur_round = self.cur_timestamp + \ + self._cfg.asyn.time_budget + + logger.info( + '----------- Starting training (Round #{:d}) -------------'. + format(self.state)) + self.broadcast_model_para(msg_type='model_para', + sample_client_num=self.sample_client_num) + + def trigger_for_time_up(self, check_timestamp=None): + """ + The handler for time up: modify the currency timestamp + and check the trigger condition + """ + if self.is_finish: + return False + + if check_timestamp is not None and \ + check_timestamp < self.deadline_for_cur_round: + return False + + self.cur_timestamp = self.deadline_for_cur_round + self.check_and_move_on() + return True + + def terminate(self, msg_type='finish'): + """ + To terminate the FL course + """ + self.is_finish = True + if self.model_num > 1: + model_para = [model.state_dict() for model in self.models] + else: + model_para = self.model.state_dict() + + self._monitor.finish_fl() + + self.comm_manager.send( + Message(msg_type=msg_type, + sender=self.ID, + receiver=list(self.comm_manager.neighbors.keys()), + state=self.state, + timestamp=self.cur_timestamp, + content=model_para)) + + def eval(self): + """ + To conduct evaluation. When cfg.federate.make_global_eval=True, + a global evaluation is conducted by the server. + """ + + if self._cfg.federate.make_global_eval: + # By default, the evaluation is conducted one-by-one for all + # internal models; + # for other cases such as ensemble, override the eval function + for i in range(self.model_num): + trainer = self.trainers[i] + # Preform evaluation in server + metrics = {} + for split in self._cfg.eval.split: + eval_metrics = trainer.evaluate( + target_data_split_name=split) + metrics.update(**eval_metrics) + formatted_eval_res = self._monitor.format_eval_res( + metrics, + rnd=self.state, + role='Server #', + forms=self._cfg.eval.report, + return_raw=self._cfg.federate.make_global_eval) + self._monitor.update_best_result( + self.best_results, + formatted_eval_res['Results_raw'], + results_type="server_global_eval", + round_wise_update_key=self._cfg.eval. + best_res_update_round_wise_key) + self.history_results = merge_dict(self.history_results, + formatted_eval_res) + self._monitor.save_formatted_results(formatted_eval_res) + logger.info(formatted_eval_res) + self.check_and_save() + else: + # Preform evaluation in clients + self.broadcast_model_para(msg_type='evaluate', + filter_unseen_clients=False) + + def callback_funcs_model_para(self, message: Message): + """ + The handling function for receiving model parameters, which triggers + check_and_move_on (perform aggregation when enough feedback has + been received). + This handling function is widely used in various FL courses. + + Arguments: + message: The received message, which includes sender, receiver, + state, and content. More detail can be found in + federatedscope.core.message + """ + if self.is_finish: + return 'finish' + + round = message.state + sender = message.sender + timestamp = message.timestamp + content = message.content + self.sampler.change_state(sender, 'idle') + + # update the currency timestamp according to the received message + assert timestamp >= self.cur_timestamp # for test + self.cur_timestamp = timestamp + + if round == self.state: + if round not in self.msg_buffer['train']: + self.msg_buffer['train'][round] = dict() + # Save the messages in this round + self.msg_buffer['train'][round][sender] = content + elif round >= self.state - self.staleness_toleration: + # Save the staled messages + self.staled_msg_buffer.append((round, sender, content)) + else: + # Drop the out-of-date messages + logger.info(f'Drop a out-of-date message from round #{round}') + self.dropout_num += 1 + + if self._cfg.federate.online_aggr: + self.aggregator.inc(content) + + move_on_flag = self.check_and_move_on() + if self._cfg.asyn.use and self._cfg.asyn.broadcast_manner == \ + 'after_receiving': + self.broadcast_model_para(msg_type='model_para', + sample_client_num=1) + + return move_on_flag + + def callback_funcs_for_join_in(self, message: Message): + """ + The handling function for receiving the join in information. The + server might request for some information (such as num_of_samples) + if necessary, assign IDs for the servers. + If all the clients have joined in, the training process will be + triggered. + + Arguments: + message: The received message + """ + + if 'info' in message.msg_type: + sender, info = message.sender, message.content + for key in self._cfg.federate.join_in_info: + assert key in info + self.join_in_info[sender] = info + logger.info('Server: Client #{:d} has joined in !'.format(sender)) + else: + self.join_in_client_num += 1 + sender, address = message.sender, message.content + if int(sender) == -1: # assign number to client + sender = self.join_in_client_num + self.comm_manager.add_neighbors(neighbor_id=sender, + address=address) + self.comm_manager.send( + Message(msg_type='assign_client_id', + sender=self.ID, + receiver=[sender], + state=self.state, + timestamp=self.cur_timestamp, + content=str(sender))) + else: + self.comm_manager.add_neighbors(neighbor_id=sender, + address=address) + + if len(self._cfg.federate.join_in_info) != 0: + self.comm_manager.send( + Message(msg_type='ask_for_join_in_info', + sender=self.ID, + receiver=[sender], + state=self.state, + timestamp=self.cur_timestamp, + content=self._cfg.federate.join_in_info.copy())) + + self.trigger_for_start() + + def callback_funcs_for_metrics(self, message: Message): + """ + The handling function for receiving the evaluation results, + which triggers check_and_move_on + (perform aggregation when enough feedback has been received). + + Arguments: + message: The received message + """ + + round = message.state + sender = message.sender + content = message.content + + if round not in self.msg_buffer['eval'].keys(): + self.msg_buffer['eval'][round] = dict() + + self.msg_buffer['eval'][round][sender] = content + + return self.check_and_move_on(check_eval_result=True) + +import seaborn as sns +import pandas as pd +def visual_tsne(model, data,name): + labels = data.y + model.eval() + z = model(data) + num_class = labels.max().item() + 1 + z = z.detach().cpu().numpy() + tsne = manifold.TSNE(n_components=2, perplexity=35, init='pca') + plt.figure(figsize=(8, 8)) + x_tsne_data = list() + f = tsne.fit_transform(z) + for clazz in range(num_class): + fp = f[labels == clazz] + clazz = np.full(fp.shape[0], clazz) + clazz = np.expand_dims(clazz, axis=1) + fe = np.concatenate([fp, clazz], axis=1) + x_tsne_data.append(fe) + + x_tsne_data = np.concatenate(x_tsne_data, axis=0) + df_tsne = pd.DataFrame(x_tsne_data, columns=["dim1", "dim2", "class"]) + + sns.scatterplot(data=df_tsne, palette="bright", hue='class', x='dim1', y='dim2') + plt.legend([],[], frameon=False) + plt.xticks([]) + plt.yticks([]) + plt.xlabel("") + plt.ylabel("") + import os + if not os.path.exists('data/output/tsne/'): + os.mkdir("data/output/tsne/") + plt.savefig("data/output/tsne/result_"+ name + ".png", format='png', dpi=800, + pad_inches=0.1, bbox_inches='tight') + plt.show() \ No newline at end of file diff --git a/fgssl/cross_backends/README.md b/fgssl/cross_backends/README.md new file mode 100644 index 0000000..451243d --- /dev/null +++ b/fgssl/cross_backends/README.md @@ -0,0 +1,30 @@ +## Cross-Backend Federated Learning + +We provide an example for constructing cross-backend (Tensorflow and PyTorch) federated learning, which trains an LR model on the synthetic toy data. + +The server runs with Tensorflow, and clients run with PyTorch (the suggested version of Tensorflow is 1.12.0): +```shell script +# Generate toy data +python ../../scripts/distributed_scripts/gen_data.py +# Server +python ../main.py --cfg distributed_tf_server.yaml + +# Clients +python ../main.py --cfg ../../scripts/distributed_scripts/distributed_configs/distributed_client_1.yaml +python ../main.py --cfg ../../scripts/distributed_scripts/distributed_configs/distributed_client_2.yaml +python ../main.py --cfg ../../scripts/distributed_scripts/distributed_configs/distributed_client_3.yaml +``` + +One of the client runs with Tensorflow, and the server and other clients run with PyTorch: +```shell script +# Generate toy data +python ../../scripts/distributed_scripts/gen_data.py +# Server +python ../main.py --cfg ../../scripts/distributed_scripts/distributed_configs/distributed_server.yaml + +# Clients with Pytorch +python ../main.py --cfg ../../scripts/distributed_scripts/distributed_configs/distributed_client_1.yaml +python ../main.py --cfg ../../scripts/distributed_scripts/distributed_configs/distributed_client_2.yaml +# Clients with Tensorflow +python ../main.py --cfg distributed_tf_client_3.yaml +``` diff --git a/fgssl/cross_backends/__init__.py b/fgssl/cross_backends/__init__.py new file mode 100644 index 0000000..4688291 --- /dev/null +++ b/fgssl/cross_backends/__init__.py @@ -0,0 +1,4 @@ +from federatedscope.cross_backends.tf_lr import LogisticRegression +from federatedscope.cross_backends.tf_aggregator import FedAvgAggregator + +__all__ = ['LogisticRegression', 'FedAvgAggregator'] diff --git a/fgssl/cross_backends/distributed_tf_client_3.yaml b/fgssl/cross_backends/distributed_tf_client_3.yaml new file mode 100644 index 0000000..61792c2 --- /dev/null +++ b/fgssl/cross_backends/distributed_tf_client_3.yaml @@ -0,0 +1,24 @@ +use_gpu: False +backend: 'tensorflow' +federate: + client_num: 3 + mode: 'distributed' + total_round_num: 20 + make_global_eval: False + online_aggr: False +distribute: + use: True + server_host: '127.0.0.1' + server_port: 50051 + client_host: '127.0.0.1' + client_port: 50054 + role: 'client' + data_file: 'toy_data/client_3_data' +trainer: + type: 'general' +eval: + freq: 10 +data: + type: 'toy' +model: + type: 'lr' \ No newline at end of file diff --git a/fgssl/cross_backends/distributed_tf_server.yaml b/fgssl/cross_backends/distributed_tf_server.yaml new file mode 100644 index 0000000..cd1b23c --- /dev/null +++ b/fgssl/cross_backends/distributed_tf_server.yaml @@ -0,0 +1,22 @@ +use_gpu: False +backend: 'tensorflow' +federate: + client_num: 3 + mode: 'distributed' + total_round_num: 20 + make_global_eval: True + online_aggr: False +distribute: + use: True + server_host: '127.0.0.1' + server_port: 50051 + role: 'server' + data_file: 'toy_data/server_data' +trainer: + type: 'general' +eval: + freq: 10 +data: + type: 'toy' +model: + type: 'lr' \ No newline at end of file diff --git a/fgssl/cross_backends/tf_aggregator.py b/fgssl/cross_backends/tf_aggregator.py new file mode 100644 index 0000000..6c59fb3 --- /dev/null +++ b/fgssl/cross_backends/tf_aggregator.py @@ -0,0 +1,44 @@ +from __future__ import absolute_import +from __future__ import print_function +from __future__ import division + +from copy import deepcopy +import numpy as np + + +class FedAvgAggregator(object): + def __init__(self, model=None, device='cpu'): + self.model = model + self.device = device + + def aggregate(self, agg_info): + models = agg_info["client_feedback"] + avg_model = self._para_weighted_avg(models) + + return avg_model + + def _para_weighted_avg(self, models): + + training_set_size = 0 + for i in range(len(models)): + sample_size, _ = models[i] + training_set_size += sample_size + + sample_size, avg_model = models[0] + for key in avg_model: + for i in range(len(models)): + local_sample_size, local_model = models[i] + weight = local_sample_size / training_set_size + if i == 0: + avg_model[key] = np.asarray(local_model[key]) * weight + else: + avg_model[key] += np.asarray(local_model[key]) * weight + + return avg_model + + def update(self, model_parameters): + ''' + Arguments: + model_parameters (dict): PyTorch Module object's state_dict. + ''' + self.model.load_state_dict(model_parameters) diff --git a/fgssl/cross_backends/tf_lr.py b/fgssl/cross_backends/tf_lr.py new file mode 100644 index 0000000..2777562 --- /dev/null +++ b/fgssl/cross_backends/tf_lr.py @@ -0,0 +1,81 @@ +import tensorflow as tf +import numpy as np + + +class LogisticRegression(object): + def __init__(self, in_channels, class_num, use_bias=True): + + self.input_x = tf.placeholder(tf.float32, [None, in_channels], + name='input_x') + self.input_y = tf.placeholder(tf.float32, [None, 1], name='input_y') + + self.out = self.fc_layer(input_x=self.input_x, + in_channels=in_channels, + class_num=class_num, + use_bias=use_bias) + + with tf.name_scope('loss'): + self.losses = tf.losses.mean_squared_error(predictions=self.out, + labels=self.input_y) + + with tf.name_scope('train_op'): + self.optimizer = tf.train.GradientDescentOptimizer( + learning_rate=0.001) + self.train_op = self.optimizer.minimize(self.losses) + + self.sess = tf.Session() + self.graph = tf.get_default_graph() + + with self.graph.as_default(): + with self.sess.as_default(): + tf.global_variables_initializer().run() + + def fc_layer(self, input_x, in_channels, class_num, use_bias=True): + with tf.name_scope('fc'): + fc_w = tf.Variable(tf.truncated_normal([in_channels, class_num], + stddev=0.1), + name='weight') + if use_bias: + fc_b = tf.Variable(tf.constant(0.0, shape=[ + class_num, + ]), + name='bias') + fc_out = tf.nn.bias_add(tf.matmul(input_x, fc_w), fc_b) + else: + fc_out = tf.matmul(input_x, fc_w) + + return fc_out + + def to(self, device): + pass + + def trainable_variables(self): + return tf.trainable_variables() + + def state_dict(self): + with self.graph.as_default(): + with self.sess.as_default(): + model_param = list() + param_name = list() + for var in tf.global_variables(): + param = self.graph.get_tensor_by_name(var.name).eval() + if 'weight' in var.name: + param = np.transpose(param, (1, 0)) + model_param.append(param) + param_name.append(var.name.split(':')[0].replace("/", '.')) + + model_dict = {k: v for k, v in zip(param_name, model_param)} + + return model_dict + + def load_state_dict(self, model_para, strict=False): + with self.graph.as_default(): + with self.sess.as_default(): + for name in model_para.keys(): + new_param = model_para[name] + + param = self.graph.get_tensor_by_name( + name.replace('.', '/') + (':0')) + if 'weight' in name: + new_param = np.transpose(new_param, (1, 0)) + tf.assign(param, new_param).eval() diff --git a/fgssl/cv/__init__.py b/fgssl/cv/__init__.py new file mode 100644 index 0000000..f8e91f2 --- /dev/null +++ b/fgssl/cv/__init__.py @@ -0,0 +1,3 @@ +from __future__ import absolute_import +from __future__ import print_function +from __future__ import division diff --git a/fgssl/cv/baseline/fedavg_convnet2_on_celeba.yaml b/fgssl/cv/baseline/fedavg_convnet2_on_celeba.yaml new file mode 100644 index 0000000..500d94f --- /dev/null +++ b/fgssl/cv/baseline/fedavg_convnet2_on_celeba.yaml @@ -0,0 +1,33 @@ +use_gpu: True +device: 0 +early_stop: + patience: 10 +federate: + mode: standalone + total_round_num: 100 + sample_client_num: 10 +data: + root: data/ + type: celeba + splits: [0.6,0.2,0.2] + subsample: 0.1 + transform: [['ToTensor'], ['Normalize', {'mean': [0.1307], 'std': [0.3081]}]] +dataloader: + batch_size: 5 +model: + type: convnet2 + hidden: 2048 + out_channels: 2 + dropout: 0.0 +train: + local_update_steps: 10 + optimizer: + lr: 0.001 + weight_decay: 0.0 +criterion: + type: CrossEntropyLoss +trainer: + type: cvtrainer +eval: + freq: 10 + metrics: ['acc', 'correct'] \ No newline at end of file diff --git a/fgssl/cv/baseline/fedavg_convnet2_on_femnist.yaml b/fgssl/cv/baseline/fedavg_convnet2_on_femnist.yaml new file mode 100644 index 0000000..9a29857 --- /dev/null +++ b/fgssl/cv/baseline/fedavg_convnet2_on_femnist.yaml @@ -0,0 +1,37 @@ +use_gpu: True +device: 0 +early_stop: + patience: 5 +seed: 12345 +federate: + mode: standalone + total_round_num: 300 + sample_client_rate: 0.2 +data: + root: data/ + type: femnist + splits: [0.6,0.2,0.2] + subsample: 0.05 + transform: [['ToTensor'], ['Normalize', {'mean': [0.1307], 'std': [0.3081]}]] +dataloader: + batch_size: 10 +model: + type: convnet2 + hidden: 2048 + out_channels: 62 + dropout: 0.0 +train: + local_update_steps: 1 + batch_or_epoch: epoch + optimizer: + lr: 0.01 + weight_decay: 0.0 +grad: + grad_clip: 5.0 +criterion: + type: CrossEntropyLoss +trainer: + type: cvtrainer +eval: + freq: 10 + metrics: ['acc', 'correct'] diff --git a/fgssl/cv/baseline/fedbn_convnet2_on_femnist.yaml b/fgssl/cv/baseline/fedbn_convnet2_on_femnist.yaml new file mode 100644 index 0000000..d925a1b --- /dev/null +++ b/fgssl/cv/baseline/fedbn_convnet2_on_femnist.yaml @@ -0,0 +1,39 @@ +use_gpu: True +device: 0 +early_stop: + patience: 5 +seed: 12345 +federate: + mode: standalone + total_round_num: 300 + sample_client_rate: 0.2 +data: + root: data/ + type: femnist + splits: [0.6,0.2,0.2] + subsample: 0.05 + transform: [['ToTensor'], ['Normalize', {'mean': [0.1307], 'std': [0.3081]}]] +dataloader: + batch_size: 10 +model: + type: convnet2 + hidden: 2048 + out_channels: 62 + dropout: 0.0 +personalization: + local_param: [ 'bn', 'norms' ] # FedBN +train: + local_update_steps: 1 + batch_or_epoch: epoch + optimizer: + lr: 0.01 + weight_decay: 0.0 +grad: + grad_clip: 5.0 +criterion: + type: CrossEntropyLoss +trainer: + type: cvtrainer +eval: + freq: 10 + metrics: ['acc', 'correct'] diff --git a/fgssl/cv/dataloader/__init__.py b/fgssl/cv/dataloader/__init__.py new file mode 100644 index 0000000..e5e3429 --- /dev/null +++ b/fgssl/cv/dataloader/__init__.py @@ -0,0 +1,3 @@ +from federatedscope.cv.dataloader.dataloader import load_cv_dataset + +__all__ = ['load_cv_dataset'] diff --git a/fgssl/cv/dataloader/dataloader.py b/fgssl/cv/dataloader/dataloader.py new file mode 100644 index 0000000..9ea27e4 --- /dev/null +++ b/fgssl/cv/dataloader/dataloader.py @@ -0,0 +1,41 @@ +from federatedscope.cv.dataset.leaf_cv import LEAF_CV +from federatedscope.core.auxiliaries.transform_builder import get_transform + + +def load_cv_dataset(config=None): + r""" + return { + 'client_id': { + 'train': DataLoader(), + 'test': DataLoader(), + 'val': DataLoader() + } + } + """ + splits = config.data.splits + + path = config.data.root + name = config.data.type.lower() + transforms_funcs = get_transform(config, 'torchvision') + + if name in ['femnist', 'celeba']: + dataset = LEAF_CV(root=path, + name=name, + s_frac=config.data.subsample, + tr_frac=splits[0], + val_frac=splits[1], + seed=1234, + **transforms_funcs) + else: + raise ValueError(f'No dataset named: {name}!') + + client_num = min(len(dataset), config.federate.client_num + ) if config.federate.client_num > 0 else len(dataset) + config.merge_from_list(['federate.client_num', client_num]) + + # Convert list to dict + data_dict = dict() + for client_idx in range(1, client_num + 1): + data_dict[client_idx] = dataset[client_idx - 1] + + return data_dict, config diff --git a/fgssl/cv/dataset/__init__.py b/fgssl/cv/dataset/__init__.py new file mode 100644 index 0000000..c0b3138 --- /dev/null +++ b/fgssl/cv/dataset/__init__.py @@ -0,0 +1,8 @@ +from os.path import dirname, basename, isfile, join +import glob + +modules = glob.glob(join(dirname(__file__), "*.py")) +__all__ = [ + basename(f)[:-3] for f in modules + if isfile(f) and not f.endswith('__init__.py') +] diff --git a/fgssl/cv/dataset/leaf.py b/fgssl/cv/dataset/leaf.py new file mode 100644 index 0000000..eb253e2 --- /dev/null +++ b/fgssl/cv/dataset/leaf.py @@ -0,0 +1,128 @@ +import zipfile +import os +import torch + +import numpy as np +import os.path as osp + +from torch.utils.data import Dataset + +LEAF_NAMES = [ + 'femnist', 'celeba', 'synthetic', 'shakespeare', 'twitter', 'subreddit' +] + + +def is_exists(path, names): + exists_list = [osp.exists(osp.join(path, name)) for name in names] + return False not in exists_list + + +class LEAF(Dataset): + """Base class for LEAF dataset from "LEAF: A Benchmark for Federated Settings" + + Arguments: + root (str): root path. + name (str): name of dataset, in `LEAF_NAMES`. + transform: transform for x. + target_transform: transform for y. + + """ + def __init__(self, root, name, transform, target_transform): + self.root = root + self.name = name + self.data_dict = {} + if name not in LEAF_NAMES: + raise ValueError(f'No leaf dataset named {self.name}') + self.transform = transform + self.target_transform = target_transform + self.process_file() + + @property + def raw_file_names(self): + names = ['all_data.zip'] + return names + + @property + def extracted_file_names(self): + names = ['all_data'] + return names + + @property + def raw_dir(self): + return osp.join(self.root, self.name, 'raw') + + @property + def processed_dir(self): + return osp.join(self.root, self.name, 'processed') + + def __repr__(self): + return f'{self.__class__.__name__}({self.__len__()})' + + def __len__(self): + return len(self.data_dict) + + def __getitem__(self, index): + raise NotImplementedError + + def __iter__(self): + for index in range(len(self.data_dict)): + yield self.__getitem__(index) + + def download(self): + raise NotImplementedError + + def extract(self): + for name in self.raw_file_names: + with zipfile.ZipFile(osp.join(self.raw_dir, name), 'r') as f: + f.extractall(self.raw_dir) + + def process_file(self): + os.makedirs(self.processed_dir, exist_ok=True) + if len(os.listdir(self.processed_dir)) == 0: + if not is_exists(self.raw_dir, self.extracted_file_names): + if not is_exists(self.raw_dir, self.raw_file_names): + self.download() + self.extract() + self.process() + + def process(self): + raise NotImplementedError + + +class LocalDataset(Dataset): + """ + Convert data list to torch Dataset to save memory usage. + """ + def __init__(self, + Xs, + targets, + pre_process=None, + transform=None, + target_transform=None): + assert len(Xs) == len( + targets), "The number of data and labels are not equal." + self.Xs = np.array(Xs) + self.targets = np.array(targets) + self.pre_process = pre_process + self.transform = transform + self.target_transform = target_transform + + def __len__(self): + return len(self.Xs) + + def __getitem__(self, idx): + data, target = self.Xs[idx], self.targets[idx] + if self.pre_process: + data = self.pre_process(data) + + if self.transform: + data = self.transform(data) + + if self.target_transform: + target = self.target_transform(target) + + return data, target + + def extend(self, dataset): + self.Xs = np.vstack((self.Xs, dataset.Xs)) + self.targets = np.hstack((self.targets, dataset.targets)) diff --git a/fgssl/cv/dataset/leaf_cv.py b/fgssl/cv/dataset/leaf_cv.py new file mode 100644 index 0000000..f1be136 --- /dev/null +++ b/fgssl/cv/dataset/leaf_cv.py @@ -0,0 +1,179 @@ +import os +import random +import json +import torch +import math + +import numpy as np +import os.path as osp + +from PIL import Image +from tqdm import tqdm + +from sklearn.model_selection import train_test_split + +from federatedscope.core.auxiliaries.utils import save_local_data, download_url +from federatedscope.cv.dataset.leaf import LEAF + +IMAGE_SIZE = {'femnist': (28, 28), 'celeba': (84, 84, 3)} +MODE = {'femnist': 'L', 'celeba': 'RGB'} + + +class LEAF_CV(LEAF): + """ + LEAF CV dataset from "LEAF: A Benchmark for Federated Settings" + + leaf.cmu.edu + + Arguments: + root (str): root path. + name (str): name of dataset, ‘femnist’ or ‘celeba’. + s_frac (float): fraction of the dataset to be used; default=0.3. + tr_frac (float): train set proportion for each task; default=0.8. + val_frac (float): valid set proportion for each task; default=0.0. + train_tasks_frac (float): fraction of test tasks; default=1.0. + transform: transform for x. + target_transform: transform for y. + + """ + def __init__(self, + root, + name, + s_frac=0.3, + tr_frac=0.8, + val_frac=0.0, + train_tasks_frac=1.0, + seed=123, + transform=None, + target_transform=None): + self.s_frac = s_frac + self.tr_frac = tr_frac + self.val_frac = val_frac + self.seed = seed + self.train_tasks_frac = train_tasks_frac + super(LEAF_CV, self).__init__(root, name, transform, target_transform) + files = os.listdir(self.processed_dir) + files = [f for f in files if f.startswith('task_')] + if len(files): + # Sort by idx + files.sort(key=lambda k: int(k[5:])) + + for file in files: + train_data, train_targets = torch.load( + osp.join(self.processed_dir, file, 'train.pt')) + test_data, test_targets = torch.load( + osp.join(self.processed_dir, file, 'test.pt')) + self.data_dict[int(file[5:])] = { + 'train': (train_data, train_targets), + 'test': (test_data, test_targets) + } + if osp.exists(osp.join(self.processed_dir, file, 'val.pt')): + val_data, val_targets = torch.load( + osp.join(self.processed_dir, file, 'val.pt')) + self.data_dict[int(file[5:])]['val'] = (val_data, + val_targets) + else: + raise RuntimeError( + 'Please delete ‘processed’ folder and try again!') + + @property + def raw_file_names(self): + names = [f'{self.name}_all_data.zip'] + return names + + def download(self): + # Download to `self.raw_dir`. + url = 'https://federatedscope.oss-cn-beijing.aliyuncs.com' + os.makedirs(self.raw_dir, exist_ok=True) + for name in self.raw_file_names: + download_url(f'{url}/{name}', self.raw_dir) + + def __getitem__(self, index): + """ + Arguments: + index (int): Index + + :returns: + dict: {'train':[(image, target)], + 'test':[(image, target)], + 'val':[(image, target)]} + where target is the target class. + """ + img_dict = {} + data = self.data_dict[index] + for key in data: + img_dict[key] = [] + imgs, targets = data[key] + for idx in range(targets.shape[0]): + img = np.resize(imgs[idx].numpy().astype(np.uint8), + IMAGE_SIZE[self.name]) + img = Image.fromarray(img, mode=MODE[self.name]) + target = targets[idx] + if self.transform is not None: + img = self.transform(img) + + if self.target_transform is not None: + target = self.target_transform(target) + + img_dict[key].append((img, targets[idx])) + + return img_dict + + def process(self): + raw_path = osp.join(self.raw_dir, "all_data") + files = os.listdir(raw_path) + files = [f for f in files if f.endswith('.json')] + + n_tasks = math.ceil(len(files) * self.s_frac) + random.shuffle(files) + files = files[:n_tasks] + + print("Preprocess data (Please leave enough space)...") + + idx = 0 + for num, file in enumerate(tqdm(files)): + + with open(osp.join(raw_path, file), 'r') as f: + raw_data = json.load(f) + + # Numpy to Tensor + for writer, v in raw_data['user_data'].items(): + data, targets = v['x'], v['y'] + + if len(v['x']) > 2: + data = torch.tensor(np.stack(data)) + targets = torch.LongTensor(np.stack(targets)) + else: + data = torch.tensor(data) + targets = torch.LongTensor(targets) + + train_data, test_data, train_targets, test_targets =\ + train_test_split( + data, + targets, + train_size=self.tr_frac, + random_state=self.seed + ) + + if self.val_frac > 0: + val_data, test_data, val_targets, test_targets = \ + train_test_split( + test_data, + test_targets, + train_size=self.val_frac / (1.-self.tr_frac), + random_state=self.seed + ) + + else: + val_data, val_targets = None, None + save_path = osp.join(self.processed_dir, f"task_{idx}") + os.makedirs(save_path, exist_ok=True) + + save_local_data(dir_path=save_path, + train_data=train_data, + train_targets=train_targets, + test_data=test_data, + test_targets=test_targets, + val_data=val_data, + val_targets=val_targets) + idx += 1 diff --git a/fgssl/cv/dataset/preprocess/celeba_preprocess.py b/fgssl/cv/dataset/preprocess/celeba_preprocess.py new file mode 100644 index 0000000..042b370 --- /dev/null +++ b/fgssl/cv/dataset/preprocess/celeba_preprocess.py @@ -0,0 +1,66 @@ +# ---------------------------------------------------------------------- # +# A preprocess script for JSON file all_data.json to json with images +# To get raw all_data.json, see: +# https://github.com/TalwalkarLab/leaf/tree/master/data/celeba +# ---------------------------------------------------------------------- # + +import json +import math +import numpy as np +import os +import sys +import copy +from PIL import Image + +from tqdm import tqdm + +MAX_USERS = 100 +size = (84, 84) + + +def name2json(name): + file_path = os.path.join('raw', 'img_align_celeba', name) + img = Image.open(file_path) + gray = img.convert('RGB') + gray.thumbnail(size, Image.ANTIALIAS) + gray = gray.resize(size) + arr = np.asarray(gray).copy().astype(np.uint8) + vec = arr.flatten() + vec = vec.tolist() + return vec + + +if __name__ == '__main__': + file = 'all_data/all_data.json' + + with open(file, 'r') as f: + raw_data = json.load(f) + + data = copy.deepcopy(raw_data) + for idx, user in enumerate(tqdm(raw_data['user_data'])): + img_names = raw_data['user_data'][user]['x'] + data['user_data'][user]['x'] = [] + for name in img_names: + js = name2json(name) + data['user_data'][user]['x'].append(js) + + # Save to several json files + + cnt = 0 + file_id = 0 + all_data = {'users': [], 'num_samples': [], 'user_data': {}} + + for idx, user in enumerate(tqdm(data['user_data'])): + all_data['users'].append(data['users'][idx]) + all_data['num_samples'].append(data['num_samples'][idx]) + all_data['user_data'][user] = data['user_data'][user] + cnt += 1 + + if cnt == MAX_USERS or idx == len(data['user_data']) - 1: + file_name = f'all_data_{file_id}.json' + file_path = os.path.join('new_all_data', file_name) + with open(file_path, 'w') as outfile: + json.dump(all_data, outfile) + file_id += 1 + cnt = 0 + all_data = {'users': [], 'num_samples': [], 'user_data': {}} diff --git a/fgssl/cv/model/__init__.py b/fgssl/cv/model/__init__.py new file mode 100644 index 0000000..0939f47 --- /dev/null +++ b/fgssl/cv/model/__init__.py @@ -0,0 +1,8 @@ +from __future__ import absolute_import +from __future__ import print_function +from __future__ import division + +from federatedscope.cv.model.cnn import ConvNet2, ConvNet5, VGG11 +from federatedscope.cv.model.model_builder import get_cnn + +__all__ = ['ConvNet2', 'ConvNet5', 'VGG11', 'get_cnn'] diff --git a/fgssl/cv/model/cnn.py b/fgssl/cv/model/cnn.py new file mode 100644 index 0000000..dcab4fe --- /dev/null +++ b/fgssl/cv/model/cnn.py @@ -0,0 +1,191 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from torch.nn import Module +from torch.nn import Sequential +from torch.nn import Conv2d, BatchNorm2d +from torch.nn import Flatten +from torch.nn import Linear +from torch.nn import MaxPool2d +from torch.nn import ReLU + + +class ConvNet2(Module): + def __init__(self, + in_channels, + h=32, + w=32, + hidden=2048, + class_num=10, + use_bn=True, + dropout=.0): + super(ConvNet2, self).__init__() + + self.conv1 = Conv2d(in_channels, 32, 5, padding=2) + self.conv2 = Conv2d(32, 64, 5, padding=2) + self.use_bn = use_bn + if use_bn: + self.bn1 = BatchNorm2d(32) + self.bn2 = BatchNorm2d(64) + + self.fc1 = Linear((h // 2 // 2) * (w // 2 // 2) * 64, hidden) + self.fc2 = Linear(hidden, class_num) + + self.relu = ReLU(inplace=True) + self.maxpool = MaxPool2d(2) + self.dropout = dropout + + def forward(self, x): + x = self.bn1(self.conv1(x)) if self.use_bn else self.conv1(x) + x = self.maxpool(self.relu(x)) + x = self.bn2(self.conv2(x)) if self.use_bn else self.conv2(x) + x = self.maxpool(self.relu(x)) + x = Flatten()(x) + x = F.dropout(x, p=self.dropout, training=self.training) + x = self.relu(self.fc1(x)) + x = F.dropout(x, p=self.dropout, training=self.training) + x = self.fc2(x) + + return x + + +class ConvNet5(Module): + def __init__(self, + in_channels, + h=32, + w=32, + hidden=2048, + class_num=10, + dropout=.0): + super(ConvNet5, self).__init__() + + self.conv1 = Conv2d(in_channels, 32, 5, padding=2) + self.bn1 = BatchNorm2d(32) + + self.conv2 = Conv2d(32, 64, 5, padding=2) + self.bn2 = BatchNorm2d(64) + + self.conv3 = Conv2d(64, 64, 5, padding=2) + self.bn3 = BatchNorm2d(64) + + self.conv4 = Conv2d(64, 128, 5, padding=2) + self.bn4 = BatchNorm2d(128) + + self.conv5 = Conv2d(128, 128, 5, padding=2) + self.bn5 = BatchNorm2d(128) + + self.relu = ReLU(inplace=True) + self.maxpool = MaxPool2d(2) + + self.fc1 = Linear( + (h // 2 // 2 // 2 // 2 // 2) * (w // 2 // 2 // 2 // 2 // 2) * 128, + hidden) + self.fc2 = Linear(hidden, class_num) + + self.dropout = dropout + + def forward(self, x): + x = self.relu(self.bn1(self.conv1(x))) + x = self.maxpool(x) + + x = self.relu(self.bn2(self.conv2(x))) + x = self.maxpool(x) + + x = self.relu(self.bn3(self.conv3(x))) + x = self.maxpool(x) + + x = self.relu(self.bn4(self.conv4(x))) + x = self.maxpool(x) + + x = self.relu(self.bn5(self.conv5(x))) + x = self.maxpool(x) + + x = Flatten()(x) + x = F.dropout(x, p=self.dropout, training=self.training) + x = self.relu(self.fc1(x)) + x = F.dropout(x, p=self.dropout, training=self.training) + x = self.fc2(x) + + return x + + +class VGG11(Module): + def __init__(self, + in_channels, + h=32, + w=32, + hidden=128, + class_num=10, + dropout=.0): + super(VGG11, self).__init__() + + self.conv1 = Conv2d(in_channels, 64, 3, padding=1) + self.bn1 = BatchNorm2d(64) + + self.conv2 = Conv2d(64, 128, 3, padding=1) + self.bn2 = BatchNorm2d(128) + + self.conv3 = Conv2d(128, 256, 3, padding=1) + self.bn3 = BatchNorm2d(256) + + self.conv4 = Conv2d(256, 256, 3, padding=1) + self.bn4 = BatchNorm2d(256) + + self.conv5 = Conv2d(256, 512, 3, padding=1) + self.bn5 = BatchNorm2d(512) + + self.conv6 = Conv2d(512, 512, 3, padding=1) + self.bn6 = BatchNorm2d(512) + + self.conv7 = Conv2d(512, 512, 3, padding=1) + self.bn7 = BatchNorm2d(512) + + self.conv8 = Conv2d(512, 512, 3, padding=1) + self.bn8 = BatchNorm2d(512) + + self.relu = ReLU(inplace=True) + self.maxpool = MaxPool2d(2) + + self.fc1 = Linear( + (h // 2 // 2 // 2 // 2 // 2) * (w // 2 // 2 // 2 // 2 // 2) * 512, + hidden) + self.fc2 = Linear(hidden, hidden) + self.fc3 = Linear(hidden, class_num) + + self.dropout = dropout + + def forward(self, x): + x = self.relu(self.bn1(self.conv1(x))) + x = self.maxpool(x) + + x = self.relu(self.bn2(self.conv2(x))) + x = self.maxpool(x) + + x = self.relu(self.bn3(self.conv3(x))) + x = self.maxpool(x) + + x = self.relu(self.bn4(self.conv4(x))) + x = self.maxpool(x) + + x = self.relu(self.bn5(self.conv5(x))) + x = self.maxpool(x) + + x = self.relu(self.bn6(self.conv6(x))) + x = self.maxpool(x) + + x = self.relu(self.bn7(self.conv7(x))) + x = self.maxpool(x) + + x = self.relu(self.bn8(self.conv8(x))) + x = self.maxpool(x) + + x = Flatten()(x) + x = F.dropout(x, p=self.dropout, training=self.training) + x = self.relu(self.fc1(x)) + x = F.dropout(x, p=self.dropout, training=self.training) + x = self.relu(self.fc2(x)) + x = F.dropout(x, p=self.dropout, training=self.training) + x = self.fc3(x) + + return x diff --git a/fgssl/cv/model/model_builder.py b/fgssl/cv/model/model_builder.py new file mode 100644 index 0000000..aaba483 --- /dev/null +++ b/fgssl/cv/model/model_builder.py @@ -0,0 +1,35 @@ +from __future__ import absolute_import +from __future__ import print_function +from __future__ import division + +from federatedscope.cv.model.cnn import ConvNet2, ConvNet5, VGG11 + + +def get_cnn(model_config, input_shape): + # check the task + # input_shape: (batch_size, in_channels, h, w) or (in_channels, h, w) + if model_config.type == 'convnet2': + model = ConvNet2(in_channels=input_shape[-3], + h=input_shape[-2], + w=input_shape[-1], + hidden=model_config.hidden, + class_num=model_config.out_channels, + dropout=model_config.dropout) + elif model_config.type == 'convnet5': + model = ConvNet5(in_channels=input_shape[-3], + h=input_shape[-2], + w=input_shape[-1], + hidden=model_config.hidden, + class_num=model_config.out_channels, + dropout=model_config.dropout) + elif model_config.type == 'vgg11': + model = VGG11(in_channels=input_shape[-3], + h=input_shape[-2], + w=input_shape[-1], + hidden=model_config.hidden, + class_num=model_config.out_channels, + dropout=model_config.dropout) + else: + raise ValueError(f'No model named {model_config.type}!') + + return model diff --git a/fgssl/cv/trainer/__init__.py b/fgssl/cv/trainer/__init__.py new file mode 100644 index 0000000..880c20a --- /dev/null +++ b/fgssl/cv/trainer/__init__.py @@ -0,0 +1,31 @@ +""" +Copyright (c) 2021 Matthias Fey, Jiaxuan You + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +""" + +from os.path import dirname, basename, isfile, join +import glob + +modules = glob.glob(join(dirname(__file__), "*.py")) +__all__ = [ + basename(f)[:-3] for f in modules + if isfile(f) and not f.endswith('__init__.py') +] diff --git a/fgssl/cv/trainer/trainer.py b/fgssl/cv/trainer/trainer.py new file mode 100644 index 0000000..b00ebd6 --- /dev/null +++ b/fgssl/cv/trainer/trainer.py @@ -0,0 +1,15 @@ +from federatedscope.register import register_trainer +from federatedscope.core.trainers import GeneralTorchTrainer + + +class CVTrainer(GeneralTorchTrainer): + pass + + +def call_cv_trainer(trainer_type): + if trainer_type == 'cvtrainer': + trainer_builder = CVTrainer + return trainer_builder + + +register_trainer('cvtrainer', call_cv_trainer) diff --git a/fgssl/gfl/README.md b/fgssl/gfl/README.md new file mode 100644 index 0000000..5ea6b71 --- /dev/null +++ b/fgssl/gfl/README.md @@ -0,0 +1,291 @@ +# FederatedScope-GNN: Towards a Unified, Comprehensive and Efficient Package for Federated Graph Learning + +FederatedScope-GNN (FS-G) is a unified, comprehensive and efficient package for federated graph learning. We provide a hands-on tutorial here, while for more detailed tutorial, please refer to [FGL Tutorial](https://federatedscope.io/docs/graph/). + +## Quick Start + +Let’s start with a two-layer GCN on FedCora to familiarize you with FS-G. + +### Step 1. Installation + +The installation of FS-G follows FederatedScope, please refer to [Installation](https://github.com/alibaba/FederatedScope#step-1-installation). + +After installing the minimal version of FederatedScope, you should install extra dependencies ([PyG](https://github.com/pyg-team/pytorch_geometric), rdkit, and nltk) for the application version of FGL, run: + +```bash +conda install -y pyg==2.0.4 -c pyg +conda install -y rdkit=2021.09.4=py39hccf6a74_0 -c conda-forge +conda install -y nltk +``` + +Now, you have successfully installed the FGL version of FederatedScope. + +### Step 2. Run with exmaple config + +Now, we train a two-layer GCN on FedCora with FedAvg. + +```bash +python federatedscope/main.py --cfg federatedscope/gfl/baseline/pubmed_gat.yaml +``` + +For more details about customized configurations, see **Advanced**. + +## Reproduce the results in our paper + +We provide scripts (grid search to find optimal results) to reproduce the results of our experiments. + +* Node-level tasks, please refer to `federatedscope/gfl/baseline/repro_exp/node_level/`: + + ```bash + # Example of FedAvg + cd federatedscope/gfl/baseline/repro_exp/node_level/ + bash run_node_level.sh 0 cora louvain + + # Example of FedAvg + bash run_node_level.sh 0 cora random + + # Example of FedOpt + bash run_node_level_opt.sh 0 cora louvain gcn 0.25 4 + + # Example of FedProx + bash run_node_level_prox.sh 0 cora louvain gcn 0.25 4 + ``` + +* Link-level tasks, please refer to `federatedscope/gfl/baseline/repro_exp/link_level/`: + + ```bash + cd federatedscope/gfl/baseline/repro_exp/link_level/ + + # Example of FedAvg + bash run_link_level_KG.sh 0 wn18 rel_type + + # Example of FedOpt + bash run_link_level_opt.sh 0 wn18 rel_type gcn 0.25 16 + + # Example of FedProx + bash run_link_level_prox.sh 7 wn18 rel_type gcn 0.25 16 + ``` + +* Graph-level tasks, please refer to `federatedscope/gfl/baseline/repro_exp/graph_level/`: + + ```bash + cd federatedscope/gfl/baseline/repro_exp/graph_level/ + + # Example of FedAvg + bash run_graph_level.sh 0 proteins + + # Example of FedOpt + bash run_graph_level_opt.sh 0 proteins gcn 0.25 4 + + # Example of FedProx + bash run_graph_level_prox.sh 0 proteins gcn 0.25 4 + ``` + +## Advanced + +### Start with built-in functions + +You can easily run through a customized `yaml` file: + +```yaml +# Whether to use GPU +use_gpu: True + +# Deciding which GPU to use +device: 0 + +# Federate learning related options +federate: + # `standalone` or `distributed` + mode: standalone + # Evaluate in Server or Client test set + make_global_eval: True + # Number of dataset being split + client_num: 5 + # Number of communication round + total_round_num: 400 + +# Dataset related options +data: + # Root directory where the data stored + root: data/ + # Dataset name + type: cora + # Use Louvain algorithm to split `Cora` + splitter: 'louvain' + # Use fullbatch training, batch_size should be `1` + batch_size: 1 + +# Model related options +model: + # Model type + type: gcn + # Hidden dim + hidden: 64 + # Dropout rate + dropout: 0.5 + # Number of Class of `Cora` + out_channels: 7 + +# Criterion related options +criterion: + # Criterion type + type: CrossEntropyLoss + +# Trainer related options +trainer: + # Trainer type + type: nodefullbatch_trainer + +# Train related options +train: + # Number of local update steps + local_update_steps: 4 + # Optimizer related options + optimizer: + # Learning rate + lr: 0.25 + # Weight decay + weight_decay: 0.0005 + # Optimizer type + type: SGD + +# Evaluation related options +eval: + # Frequency of evaluation + freq: 1 + # Evaluation metrics, accuracy and number of correct items + metrics: ['acc', 'correct'] +``` + +### Start with customized functions + +FS-G also provides `register` function to set up the FL. Here we provide an example about how to run your own model and data to FS-G. + +* Load your data (write in `federatedscope/contrib/data/`): + + ```python + import copy + import numpy as np + + from torch_geometric.datasets import Planetoid + from federatedscope.core.splitters.graph import LouvainSplitter + from federatedscope.register import register_data + + + def my_cora(config=None): + path = config.data.root + + num_split = [232, 542, np.iinfo(np.int64).max] + dataset = Planetoid(path, + 'cora', + split='random', + num_train_per_class=num_split[0], + num_val=num_split[1], + num_test=num_split[2]) + global_data = copy.deepcopy(dataset)[0] + dataset = LouvainSplitter(config.federate.client_num)(dataset[0]) + + data_local_dict = dict() + for client_idx in range(len(dataset)): + data_local_dict[client_idx + 1] = dataset[client_idx] + + data_local_dict[0] = global_data + return data_local_dict, config + + + def call_my_data(config): + if config.data.type == "mycora": + data, modified_config = my_cora(config) + return data, modified_config + + + register_data("mycora", call_my_data) + + ``` + +* Build your model (write in `federatedscope/contrib/model/`): + + ```python + import torch + import torch.nn.functional as F + + from torch.nn import ModuleList + from torch_geometric.data import Data + from torch_geometric.nn import GCNConv + from federatedscope.register import register_model + + + class MyGCN(torch.nn.Module): + def __init__(self, + in_channels, + out_channels, + hidden=64, + max_depth=2, + dropout=.0): + super(MyGCN, self).__init__() + self.convs = ModuleList() + for i in range(max_depth): + if i == 0: + self.convs.append(GCNConv(in_channels, hidden)) + elif (i + 1) == max_depth: + self.convs.append(GCNConv(hidden, out_channels)) + else: + self.convs.append(GCNConv(hidden, hidden)) + self.dropout = dropout + + def forward(self, data): + if isinstance(data, Data): + x, edge_index = data.x, data.edge_index + elif isinstance(data, tuple): + x, edge_index = data + else: + raise TypeError('Unsupported data type!') + + for i, conv in enumerate(self.convs): + x = conv(x, edge_index) + if (i + 1) == len(self.convs): + break + x = F.relu(F.dropout(x, p=self.dropout, training=self.training)) + return x + + + def gcnbuilder(model_config, input_shape): + x_shape, num_label, num_edge_features = input_shape + model = MyGCN(x_shape[-1], + model_config.out_channels, + hidden=model_config.hidden, + max_depth=model_config.layer, + dropout=model_config.dropout) + return model + + + def call_my_net(model_config, local_data): + # Please name your gnn model with prefix 'gnn_' + if model_config.type == "gnn_mygcn": + model = gcnbuilder(model_config, local_data) + return model + + + register_model("gnn_mygcn", call_my_net) + + ``` + +- Run with following command to start: + + ```bash + python federatedscope/main.py --cfg federatedscope/gfl/baseline/pubmed_gat.yaml data.type mycora model.type gnn_mygcn + ``` + +## Publications + +If you find FS-G useful for research or development, please cite the following [paper](https://arxiv.org/abs/2204.05562): + +```latex +@inproceedings{federatedscopegnn, + title = {FederatedScope-GNN: Towards a Unified, Comprehensive and Efficient Package for Federated Graph Learning}, + author = {Zhen Wang and Weirui Kuang and Yuexiang Xie and Liuyi Yao and Yaliang Li and Bolin Ding and Jingren Zhou}, + booktitle = {Proc.\ of the ACM SIGKDD International Conference on Knowledge Discovery and Data Mining (KDD'22)}, + year = {2022} +} +``` diff --git a/fgssl/gfl/__init__.py b/fgssl/gfl/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/fgssl/gfl/baseline/__init__.py b/fgssl/gfl/baseline/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/fgssl/gfl/baseline/download.yaml b/fgssl/gfl/baseline/download.yaml new file mode 100644 index 0000000..9e33be6 --- /dev/null +++ b/fgssl/gfl/baseline/download.yaml @@ -0,0 +1,250 @@ +asyn: + use: false +attack: + alpha_TV: 0.001 + alpha_prop_loss: 0 + attack_method: '' + attacker_id: -1 + classifier_PIA: randomforest + edge_num: 100 + edge_path: edge_data/ + freq: 10 + info_diff_type: l2 + inject_round: 0 + insert_round: 100000 + label_type: dirty + max_ite: 400 + mean: + - 0.1307 + mia_is_simulate_in: false + mia_simulate_in_round: 20 + pgd_eps: 2 + pgd_lr: 0.1 + pgd_poisoning: false + poison_ratio: 0.5 + reconstruct_lr: 0.01 + reconstruct_optim: Adam + scale_para: 1.0 + scale_poisoning: false + self_epoch: 6 + self_lr: 0.05 + self_opt: false + setting: fix + std: + - 0.3081 + target_label_ind: -1 + trigger_path: trigger/ + trigger_type: edge +backend: torch +cfg_file: '' +criterion: + type: CrossEntropyLoss +data: + args: [] + batch_size: 64 + cSBM_phi: + - 0.5 + - 0.5 + - 0.5 + consistent_label_distribution: false + drop_last: false + loader: '' + num_steps: 30 + num_workers: 0 + pre_transform: [] + quadratic: + dim: 1 + max_curv: 12.5 + min_curv: 0.02 + root: data/ + save_data: false + server_holds_all: false + shuffle: true + sizes: + - 10 + - 5 + splits: + - 0.8 + - 0.1 + - 0.1 + splitter: louvain + splitter_args: [] + subsample: 1.0 + target_transform: [] + transform: [] + type: PubMed + walk_length: 2 +dataloader: + batch_size: 1 + drop_last: false + num_steps: 30 + num_workers: 0 + pin_memory: true + shuffle: true + sizes: + - 10 + - 5 + theta: -1 + type: pyg + walk_length: 2 +device: 0 +distribute: + use: false +early_stop: + delta: 0.0 + improve_indicator_mode: best + patience: 5 + the_smaller_the_better: true +eval: + best_res_update_round_wise_key: val_loss + count_flops: true + freq: 1 + metrics: + - acc + - correct + monitoring: [] + report: + - weighted_avg + - avg + - fairness + - raw + split: + - test + - val +expname: FedAvg_gcn_on_cora_lr0.25_lstep4 +expname_tag: '' +federate: + client_num: 5 + data_weighted_aggr: false + ignore_weight: false + join_in_info: [] + make_global_eval: true + merge_test_data: false + method: FedAvg + mode: standalone + online_aggr: false + resource_info_file: '' + restore_from: '' + sample_client_num: 5 + sample_client_rate: -1.0 + sampler: uniform + save_to: '' + share_local_model: false + total_round_num: 200 + unseen_clients_rate: 0.0 + use_diff: false + use_ss: false +fedopt: + use: false +fedprox: + use: false +fedsageplus: + a: 1.0 + b: 1.0 + c: 1.0 + fedgen_epoch: 200 + gen_hidden: 128 + hide_portion: 0.5 + loc_epoch: 1 + num_pred: 5 +finetune: + batch_or_epoch: epoch + before_eval: false + freeze_param: '' + local_update_steps: 1 + optimizer: + lr: 0.1 + type: SGD + scheduler: + type: '' +flitplus: + factor_ema: 0.8 + lambdavat: 0.5 + tmpFed: 0.5 + weightReg: 1.0 +gcflplus: + EPS_1: 0.05 + EPS_2: 0.1 + seq_length: 5 + standardize: false +grad: + grad_clip: -1.0 +hpo: + fedex: + cutoff: 0.0 + diff: false + eta0: -1.0 + flatten_ss: true + gamma: 0.0 + sched: auto + ss: '' + use: false + init_cand_num: 16 + larger_better: false + metric: client_summarized_weighted_avg.val_loss + num_workers: 0 + pbt: + max_stage: 5 + perf_threshold: 0.1 + scheduler: rs + sha: + budgets: [] + elim_rate: 3 + iter: 0 + ss: '' + table: + eps: 0.1 + idx: 0 + num: 27 + working_folder: hpo +model: + dropout: 0.5 + embed_size: 8 + graph_pooling: mean + hidden: + in_channels: 0 + input_shape: [] + layer: 2 + model_num_per_trainer: 1 + num_item: 0 + num_user: 0 + out_channels: 3 + task: node + type: gcn + use_bias: true +nbafl: + use: false +outdir: exp/FedAvg_gcn_on_cora_lr0.25_lstep4 +personalization: + K: 5 + beta: 1.0 + local_param: [] + local_update_steps: 4 + lr: 0.25 + regular_weight: 0.1 + share_non_trainable_para: false +print_decimal_digits: 6 +regularizer: + mu: 0.0 + type: '' +seed: 0 +sgdmf: + use: false +train: + batch_or_epoch: batch + local_update_steps: 4 + optimizer: + lr: 0.25 + type: SGD + weight_decay: 0.0005 + scheduler: + type: '' +trainer: + type: nodefullbatch_trainer +use_gpu: true +verbose: 1 +vertical: + use: false +wandb: + use: false + diff --git a/fgssl/gfl/baseline/example.yaml b/fgssl/gfl/baseline/example.yaml new file mode 100644 index 0000000..c6a72de --- /dev/null +++ b/fgssl/gfl/baseline/example.yaml @@ -0,0 +1,83 @@ +# Whether to use GPU +use_gpu: True + +# Deciding which GPU to use +device: 2 + +# Federate learning related options +federate: + # `standalone` or `distributed` + mode: standalone + # Evaluate in Server or Client test set + make_global_eval: True + # Number of dataset being split + client_num: 5 + # Number of communication round + total_round_num: 200 + method: fgcl +# Dataset related options +data: + # Root directory where the data stored + root: data/ + splits: [0.6, 0.2, 0.2] + # Dataset name + type: citeseer + # Use Louvain algorithm to split `Cora` + splitter: 'louvain' +dataloader: + # Type of sampler + type: pyg + # Use fullbatch training, batch_size should be `1` + batch_size: 1 +# Model related options +model: + # Model type + type: gnn_mygat + # Hidden dim + hidden: 128 + # Dropout rate + dropout: 0.5 + # Number of Class of `Cora` + out_channels: 6 + layer: 2 +# Criterion related options +criterion: + # Criterion type + type: CrossEntropyLoss + +# Trainer related options +trainer: + # Trainer type + type: fgcl1 +seed: + 12345 +# Train related options +train: + # Number of local update steps + local_update_steps: 4 + # Optimizer related options + optimizer: + # Learning rate + lr: 5 + # Weight decay + weight_decay: 0.0005 + # Optimizer type + type: SGD +# scheduler: +## type: myscheduler +grad: + grad_clip: 0.01 +#hpo: +# scheduler: sha +# num_workers: 3 +# init_cand_num: 3 +# ss: toy_hpo_ss.yaml +# sha: +# budgets: [1, 1] +# Evaluation related options +eval: + # Frequency of evaluation + freq: 1 + best_res_update_round_wise_key: 'val_loss' + # Evaluation metrics, accuracy and number of correct items + metrics: ['acc', 'correct'] \ No newline at end of file diff --git a/fgssl/gfl/baseline/example_aug.yaml b/fgssl/gfl/baseline/example_aug.yaml new file mode 100644 index 0000000..772c845 --- /dev/null +++ b/fgssl/gfl/baseline/example_aug.yaml @@ -0,0 +1,83 @@ +# Whether to use GPU +use_gpu: True + +# Deciding which GPU to use +device: 2 + +# Federate learning related options +federate: + # `standalone` or `distributed` + mode: standalone + # Evaluate in Server or Client test set + make_global_eval: True + # Number of dataset being split + client_num: 5 + # Number of communication round + total_round_num: 150 + method: fgcl +# Dataset related options +data: + # Root directory where the data stored + root: data/ + splits: [0.6, 0.2, 0.2] + # Dataset name + type: citeseer + # Use Louvain algorithm to split `Cora` + splitter: 'random' +dataloader: + # Type of sampler + type: pyg + # Use fullbatch training, batch_size should be `1` + batch_size: 1 + +# Model related options +model: + # Model type + type: gnn_gcn_aug + # Hidden dim + hidden: 128 + # Dropout rate + dropout: 0.5 + # Number of Class of `Cora` + out_channels: 6 + layer: 2 +# Criterion related options +criterion: + # Criterion type + type: CrossEntropyLoss + +# Trainer related options +trainer: + # Trainer type + type: fgcl2 + +# Train related options +train: + # Number of local update steps + local_update_steps: 4 + # Optimizer related options + optimizer: + # Learning rate + lr: 1 + # Weight decay + weight_decay: 0.0005 + # Optimizer type + type: SGD +# scheduler: +## type: myscheduler +#grad: +# grad_clip: 0.01 +#hpo: +# scheduler: sha +# num_workers: 3 +# init_cand_num: 3 +# ss: toy_hpo_ss.yaml +# sha: +# budgets: [1, 1] +# Evaluation related options +eval: + # Frequency of evaluation + freq: 1 + best_res_update_round_wise_key: 'val_acc' + # Evaluation metrics, accuracy and number of correct items + metrics: ['acc', 'correct'] \ No newline at end of file diff --git a/fgssl/gfl/baseline/example_gcn.yaml b/fgssl/gfl/baseline/example_gcn.yaml new file mode 100644 index 0000000..1bbd1e9 --- /dev/null +++ b/fgssl/gfl/baseline/example_gcn.yaml @@ -0,0 +1,83 @@ +# Whether to use GPU +use_gpu: True + +# Deciding which GPU to use +device: 2 + +# Federate learning related options +federate: + # `standalone` or `distributed` + mode: standalone + # Evaluate in Server or Client test set + make_global_eval: True + # Number of dataset being split + client_num: 5 + # Number of communication round + total_round_num: 250 + method: fgcl +# Dataset related options +data: + # Root directory where the data stored + root: data/ + # Dataset name + type: pubmed + # Use Louvain algorithm to split `Cora` + splitter: 'random' +dataloader: + # Type of sampler + type: pyg + # Use fullbatch training, batch_size should be `1` + batch_size: 1 + +# Model related options +model: + # Model type + type: gnn_mygcn + # Hidden dim + hidden: 128 + # Dropout rate + dropout: 0.5 + # Number of Class of `Cora` + out_channels: 3 + layer: 2 +# Criterion related options +criterion: + # Criterion type + type: CrossEntropyLoss + +# Trainer related options +trainer: + # Trainer type + type: fgcl1 +seed: + 12345 +# Train related options +train: + # Number of local update steps + local_update_steps: 4 + # Optimizer related options + optimizer: + # Learning rate + lr: 5 + # Weight decay + weight_decay: 0.0005 + # Optimizer type + type: SGD +# scheduler: +## type: myscheduler +grad: + grad_clip: 0.001 +#hpo: +# scheduler: sha +# num_workers: 3 +# init_cand_num: 3 +# ss: toy_hpo_ss.yaml +# sha: +# budgets: [1, 1] +# Evaluation related options +eval: + # Frequency of evaluation + freq: 1 + best_res_update_round_wise_key: 'val_acc' + # Evaluation metrics, accuracy and number of correct items + metrics: ['acc', 'correct'] \ No newline at end of file diff --git a/fgssl/gfl/baseline/example_pubmed.yaml b/fgssl/gfl/baseline/example_pubmed.yaml new file mode 100644 index 0000000..8d535c4 --- /dev/null +++ b/fgssl/gfl/baseline/example_pubmed.yaml @@ -0,0 +1,84 @@ +# Whether to use GPU +use_gpu: True + +# Deciding which GPU to use +device: 3 + +# Federate learning related options +federate: + # `standalone` or `distributed` + mode: standalone + # Evaluate in Server or Client test set + make_global_eval: True + # Number of dataset being split + client_num: 7 + # Number of communication round + total_round_num: 250 + method: fgcl +# Dataset related options +data: + # Root directory where the data stored + root: data/ + splits: [0.6, 0.2, 0.2] + # Dataset name + type: pubmed + # Use Louvain algorithm to split `Cora` + splitter: 'random' +dataloader: + # Type of sampler + type: pyg + # Use fullbatch training, batch_size should be `1` + batch_size: 1 + +# Model related options +model: + # Model type + type: gnn_mygat + # Hidden dim + hidden: 128 + # Dropout rate + dropout: 0.5 + # Number of Class of `Cora` + out_channels: 3 + layer: 2 +# Criterion related options +criterion: + # Criterion type + type: CrossEntropyLoss + +# Trainer related options +trainer: + # Trainer type + type: fgcl1 +seed: + 12345 +# Train related options +train: + # Number of local update steps + local_update_steps: 4 + # Optimizer related options + optimizer: + # Learning rate + lr: 0.3 + # Weight decay + weight_decay: 0.0005 + # Optimizer type + type: SGD +# scheduler: +### type: myscheduler +#grad: +# grad_clip: 0.01 +#hpo: +# scheduler: sha +# num_workers: 3 +# init_cand_num: 3 +# ss: toy_hpo_ss.yaml +# sha: +# budgets: [1, 1] +# Evaluation related options +eval: + # Frequency of evaluation + freq: 1 + best_res_update_round_wise_key: 'val_acc' + # Evaluation metrics, accuracy and number of correct items + metrics: ['acc', 'correct'] \ No newline at end of file diff --git a/fgssl/gfl/baseline/example_visual.yaml b/fgssl/gfl/baseline/example_visual.yaml new file mode 100644 index 0000000..dfd0706 --- /dev/null +++ b/fgssl/gfl/baseline/example_visual.yaml @@ -0,0 +1,83 @@ +# Whether to use GPU +use_gpu: True + +# Deciding which GPU to use +device: 2 + +# Federate learning related options +federate: + # `standalone` or `distributed` + mode: standalone + # Evaluate in Server or Client test set + make_global_eval: True + # Number of dataset being split + client_num: 15 + # Number of communication round + total_round_num: 20 + method: fgcl +# Dataset related options +data: + # Root directory where the data stored + root: data/ + # Dataset name + type: citeseer + # Use Louvain algorithm to split `Cora` + splitter: 'random' +dataloader: + # Type of sampler + type: pyg + # Use fullbatch training, batch_size should be `1` + batch_size: 1 + +# Model related options +model: + # Model type + type: gnn_mygat + # Hidden dim + hidden: 128 + # Dropout rate + dropout: 0.5 + # Number of Class of `Cora` + out_channels: 6 + layer: 2 +# Criterion related options +criterion: + # Criterion type + type: CrossEntropyLoss + +# Trainer related options +trainer: + # Trainer type + type: fgcl1 +seed: + 12345 +# Train related options +train: + # Number of local update steps + local_update_steps: 150 + # Optimizer related options + optimizer: + # Learning rate + lr: 1 + # Weight decay + weight_decay: 0.0005 + # Optimizer type + type: SGD +# scheduler: +# type: myscheduler +grad: + grad_clip: 0.05 +#hpo: +# scheduler: she +# num_workers: 3 +# init_cand_num: 3 +# ss: toy_hpo_ss.yaml +# sha: +# budgets: [1, 1] +# Evaluation related options +eval: + # Frequency of evaluation + freq: 1 + best_res_update_round_wise_key: 'val_acc' + # Evaluation metrics, accuracy and number of correct items + metrics: ['acc', 'correct'] \ No newline at end of file diff --git a/fgssl/gfl/baseline/fed_gcn.yaml b/fgssl/gfl/baseline/fed_gcn.yaml new file mode 100644 index 0000000..75b024f --- /dev/null +++ b/fgssl/gfl/baseline/fed_gcn.yaml @@ -0,0 +1,70 @@ +# Whether to use GPU +use_gpu: True + +# Deciding which GPU to use +device: 3 + +# Federate learning related options +federate: + # `standalone` or `distributed` + mode: standalone + # Evaluate in Server or Client test set + make_global_eval: True + # Number of dataset being split + client_num: 5 + # Number of communication round + total_round_num: 400 +# Dataset related options +data: + # Root directory where the data stored + root: data/ + # Dataset name + type: cora + # Use Louvain algorithm to split `Cora` + splitter: 'louvain' +dataloader: + # Type of sampler + type: pyg + # Use fullbatch training, batch_size should be `1` + batch_size: 1 + +# Model related options +model: + # Model type + type: gcn + # Hidden dim + hidden: 128 + # Dropout rate + dropout: 0.5 + # Number of Class of `Cora` + out_channels: 7 + +# Criterion related options +criterion: + # Criterion type + type: CrossEntropyLoss + +# Trainer related options +trainer: + # Trainer type + type: nodefullbatch_trainer + +# Train related options +train: + # Number of local update steps + local_update_steps: 4 + # Optimizer related options + optimizer: + # Learning rate + lr: 0.25 + # Weight decay + weight_decay: 0.0005 + # Optimizer type + type: SGD + +# Evaluation related options +eval: + # Frequency of evaluation + freq: 1 + # Evaluation metrics, accuracy and number of correct items + metrics: ['acc', 'correct'] \ No newline at end of file diff --git a/fgssl/gfl/baseline/fedavg_gcn_fullbatch_on_dblpnew.yaml b/fgssl/gfl/baseline/fedavg_gcn_fullbatch_on_dblpnew.yaml new file mode 100644 index 0000000..2878b70 --- /dev/null +++ b/fgssl/gfl/baseline/fedavg_gcn_fullbatch_on_dblpnew.yaml @@ -0,0 +1,31 @@ +use_gpu: True +device: 7 +early_stop: + patience: 100 + improve_indicator_mode: mean +federate: + mode: standalone + make_global_eval: True + total_round_num: 400 +data: + root: data/ + type: dblp_conf + splits: [0.5, 0.2, 0.3] +dataloader: + type: pyg + batch_size: 1 +model: + type: gcn + hidden: 1024 + out_channels: 4 + task: node +train: + optimizer: + lr: 0.05 + weight_decay: 0.0005 +criterion: + type: CrossEntropyLoss +trainer: + type: nodefullbatch_trainer +eval: + metrics: ['acc', 'correct'] diff --git a/fgssl/gfl/baseline/fedavg_gcn_fullbatch_on_kg.yaml b/fgssl/gfl/baseline/fedavg_gcn_fullbatch_on_kg.yaml new file mode 100644 index 0000000..79b4f0a --- /dev/null +++ b/fgssl/gfl/baseline/fedavg_gcn_fullbatch_on_kg.yaml @@ -0,0 +1,34 @@ +use_gpu: True +device: 1 +early_stop: + patience: 20 + improve_indicator_mode: mean +federate: + mode: standalone + make_global_eval: True + total_round_num: 400 + client_num: 5 +data: + root: data/ + type: wn18 + splitter: rel_type + pre_transform: ['Constant', {'value':1.0, 'cat':False}] +dataloader: + type: pyg +model: + type: gat + hidden: 64 + out_channels: 18 + task: link +train: + local_update_steps: 16 + optimizer: + lr: 0.25 + weight_decay: 0.0005 +criterion: + type: CrossEntropyLoss +trainer: + type: linkfullbatch_trainer +eval: + freq: 5 + metrics: ['hits@1', 'hits@5', 'hits@10'] diff --git a/fgssl/gfl/baseline/fedavg_gcn_minibatch_on_hiv.yaml b/fgssl/gfl/baseline/fedavg_gcn_minibatch_on_hiv.yaml new file mode 100644 index 0000000..b0b6c49 --- /dev/null +++ b/fgssl/gfl/baseline/fedavg_gcn_minibatch_on_hiv.yaml @@ -0,0 +1,33 @@ +use_gpu: True +device: 0 +early_stop: + patience: 20 + improve_indicator_mode: mean +federate: + mode: 'standalone' + make_global_eval: True + total_round_num: 400 + client_num: 5 +data: + root: data/ + type: hiv + splitter: scaffold +dataloader: + type: pyg +model: + type: gcn + hidden: 64 + out_channels: 2 + task: graph +train: + local_update_steps: 16 + optimizer: + lr: 0.25 + weight_decay: 0.0005 +criterion: + type: CrossEntropyLoss +trainer: + type: graphminibatch_trainer +eval: + freq: 5 + metrics: ['acc', 'correct', 'roc_auc'] diff --git a/fgssl/gfl/baseline/fedavg_gin_minibatch_on_cikmcup.yaml b/fgssl/gfl/baseline/fedavg_gin_minibatch_on_cikmcup.yaml new file mode 100644 index 0000000..c7865d4 --- /dev/null +++ b/fgssl/gfl/baseline/fedavg_gin_minibatch_on_cikmcup.yaml @@ -0,0 +1,36 @@ +use_gpu: True +device: 0 +early_stop: + patience: 20 + improve_indicator_mode: mean + the_smaller_the_better: False +federate: + mode: 'standalone' + make_global_eval: False + total_round_num: 100 + share_local_model: False +data: + root: data/ + type: cikmcup +dataloader: + type: pyg +model: + type: gin + hidden: 64 +personalization: + local_param: ['encoder_atom', 'encoder', 'clf'] +train: + batch_or_epoch: epoch + local_update_steps: 1 + optimizer: + weight_decay: 0.0005 + type: SGD +trainer: + type: graphminibatch_trainer +eval: + freq: 5 + metrics: ['imp_ratio'] + report: ['avg'] + best_res_update_round_wise_key: val_imp_ratio + count_flops: False + base: 0. \ No newline at end of file diff --git a/fgssl/gfl/baseline/fedavg_gin_minibatch_on_cikmcup_per_client.yaml b/fgssl/gfl/baseline/fedavg_gin_minibatch_on_cikmcup_per_client.yaml new file mode 100644 index 0000000..e1c8566 --- /dev/null +++ b/fgssl/gfl/baseline/fedavg_gin_minibatch_on_cikmcup_per_client.yaml @@ -0,0 +1,147 @@ +client_1: + model: + out_channels: 2 + task: graphClassification + criterion: + type: CrossEntropyLoss + train: + optimizer: + lr: 0.1 + eval: + base: 0.263789 +client_2: + model: + out_channels: 2 + task: graphClassification + criterion: + type: CrossEntropyLoss + train: + optimizer: + lr: 0.01 + eval: + base: 0.289617 +client_3: + model: + out_channels: 2 + task: graphClassification + criterion: + type: CrossEntropyLoss + train: + optimizer: + lr: 0.001 + eval: + base: 0.355404 +client_4: + model: + out_channels: 2 + task: graphClassification + criterion: + type: CrossEntropyLoss + train: + optimizer: + lr: 0.01 + eval: + base: 0.176471 +client_5: + model: + out_channels: 2 + task: graphClassification + criterion: + type: CrossEntropyLoss + train: + optimizer: + lr: 0.0001 + eval: + base: 0.396825 +client_6: + model: + out_channels: 2 + task: graphClassification + criterion: + type: CrossEntropyLoss + train: + optimizer: + lr: 0.0005 + eval: + base: 0.261580 +client_7: + model: + out_channels: 2 + task: graphClassification + criterion: + type: CrossEntropyLoss + train: + optimizer: + lr: 0.01 + eval: + base: 0.302378 +client_8: + model: + out_channels: 2 + task: graphClassification + criterion: + type: CrossEntropyLoss + train: + optimizer: + lr: 0.05 + eval: + base: 0.211538 +client_9: + model: + out_channels: 1 + task: graphRegression + criterion: + type: MSELoss + train: + optimizer: + lr: 0.1 + eval: + base: 0.059199 +client_10: + model: + out_channels: 10 + task: graphRegression + criterion: + type: MSELoss + train: + optimizer: + lr: 0.05 + grad: + grad_clip: 1.0 + eval: + base: 0.007083 +client_11: + model: + out_channels: 1 + task: graphRegression + criterion: + type: MSELoss + train: + optimizer: + lr: 0.05 + eval: + base: 0.734011 +client_12: + model: + out_channels: 1 + task: graphRegression + criterion: + type: MSELoss + train: + optimizer: + lr: 0.01 + eval: + base: 1.361326 +client_13: + model: + out_channels: 12 + task: graphRegression + criterion: + type: MSELoss + train: + optimizer: + lr: 0.05 + grad: + grad_clip: 1.0 + eval: + base: 0.004389 \ No newline at end of file diff --git a/fgssl/gfl/baseline/fedavg_gnn_minibatch_on_multi_task.yaml b/fgssl/gfl/baseline/fedavg_gnn_minibatch_on_multi_task.yaml new file mode 100644 index 0000000..a97aade --- /dev/null +++ b/fgssl/gfl/baseline/fedavg_gnn_minibatch_on_multi_task.yaml @@ -0,0 +1,37 @@ +use_gpu: True +device: 0 +early_stop: + patience: 20 + improve_indicator_mode: mean +federate: + mode: 'standalone' + make_global_eval: False + total_round_num: 400 + share_local_model: False +data: + root: data/ + type: graph_multi_domain_mol + pre_transform: ['Constant', {'value':1.0, 'cat':False}] +dataloader: + type: pyg +model: + type: gin + hidden: 64 + out_channels: 0 + task: graph +personalization: + local_param: ['encoder_atom', 'encoder', 'clf'] # to handle size-different pre & post layers + # local_param: [ 'encoder_atom', 'encoder', 'clf', 'norms' ] # pre, post + FedBN +train: + local_update_steps: 1 + optimizer: + lr: 0.5 + weight_decay: 0.0005 + type: SGD +criterion: + type: CrossEntropyLoss +trainer: + type: graphminibatch_trainer +eval: + freq: 5 + metrics: ['acc', 'correct'] diff --git a/fgssl/gfl/baseline/fedavg_gnn_minibatch_on_multi_task_total_samples_aggr.yaml b/fgssl/gfl/baseline/fedavg_gnn_minibatch_on_multi_task_total_samples_aggr.yaml new file mode 100644 index 0000000..e8933c1 --- /dev/null +++ b/fgssl/gfl/baseline/fedavg_gnn_minibatch_on_multi_task_total_samples_aggr.yaml @@ -0,0 +1,38 @@ +use_gpu: True +device: 0 +early_stop: + patience: 20 + improve_indicator_mode: mean +federate: + mode: 'standalone' + make_global_eval: False + total_round_num: 400 + share_local_model: False + data_weighted_aggr: True +data: + root: data/ + type: graph_multi_domain_mix + pre_transform: ['Constant', {'value':1.0, 'cat':False}] +dataloader: + type: pyg +model: + type: gin + hidden: 64 + out_channels: 0 + task: graph +personalization: + local_param: ['encoder_atom', 'encoder', 'clf'] # to handle size-different pre & post layers + # local_param: [ 'encoder_atom', 'encoder', 'clf', 'norms' ] # pre, post + FedBN +train: + local_update_steps: 1 + optimizer: + lr: 0.5 + weight_decay: 0.0005 + type: SGD +criterion: + type: CrossEntropyLoss +trainer: + type: graphminibatch_trainer +eval: + freq: 5 + metrics: ['acc', 'correct'] diff --git a/fgssl/gfl/baseline/fedavg_gnn_node_fullbatch_citation.yaml b/fgssl/gfl/baseline/fedavg_gnn_node_fullbatch_citation.yaml new file mode 100644 index 0000000..965269e --- /dev/null +++ b/fgssl/gfl/baseline/fedavg_gnn_node_fullbatch_citation.yaml @@ -0,0 +1,35 @@ +use_gpu: True +device: 0 +early_stop: + patience: 100 + improve_indicator_mode: mean +federate: + mode: standalone + make_global_eval: True + client_num: 5 + total_round_num: 400 +data: + root: data/ + type: cora + splitter: 'louvain' +dataloader: + type: pyg + batch_size: 1 +model: + type: gcn + hidden: 64 + dropout: 0.5 + out_channels: 7 + task: node +train: + local_update_steps: 4 + optimizer: + lr: 0.25 + weight_decay: 0.0005 + type: SGD +criterion: + type: CrossEntropyLoss +trainer: + type: nodefullbatch_trainer +eval: + metrics: ['acc', 'correct'] diff --git a/fgssl/gfl/baseline/fedavg_on_cSBM.yaml b/fgssl/gfl/baseline/fedavg_on_cSBM.yaml new file mode 100644 index 0000000..d0de6f9 --- /dev/null +++ b/fgssl/gfl/baseline/fedavg_on_cSBM.yaml @@ -0,0 +1,36 @@ +use_gpu: True +device: 2 +early_stop: + patience: 200 + improve_indicator_mode: mean +# monitoring: ['dissim'] +federate: + mode: standalone + total_round_num: 400 +data: + root: data/ + type: 'csbm' + #type: 'csbm_data_feb_07_2022-00:19' + cSBM_phi: [0.1, 0.5, 0.9] +dataloader: + type: pyg + batch_size: 1 +model: + type: gpr + hidden: 256 + out_channels: 2 + task: node +#personalization: + #local_param: ['prop1'] +train: + local_update_steps: 2 + optimizer: + lr: 0.5 + weight_decay: 0.0005 + type: SGD +criterion: + type: CrossEntropyLoss +trainer: + type: nodefullbatch_trainer +eval: + metrics: ['acc', 'correct'] diff --git a/fgssl/gfl/baseline/fedavg_sage_minibatch_on_dblpnew.yaml b/fgssl/gfl/baseline/fedavg_sage_minibatch_on_dblpnew.yaml new file mode 100644 index 0000000..29838bb --- /dev/null +++ b/fgssl/gfl/baseline/fedavg_sage_minibatch_on_dblpnew.yaml @@ -0,0 +1,32 @@ +use_gpu: True +device: 0 +early_stop: + patience: 100 + improve_indicator_mode: mean +federate: + mode: standalone + make_global_eval: True + total_round_num: 400 +data: + root: data/ + type: dblp_conf +dataloader: + type: graphsaint-rw + batch_size: 256 +model: + type: sage + hidden: 1024 + out_channels: 4 + task: node +train: + local_update_steps: 16 + optimizer: + lr: 0.05 + weight_decay: 0.0005 + type: SGD +criterion: + type: CrossEntropyLoss +trainer: + type: nodeminibatch_trainer +eval: + metrics: ['acc', 'correct'] diff --git a/fgssl/gfl/baseline/fedavg_wpsn_on_cSBM.yaml b/fgssl/gfl/baseline/fedavg_wpsn_on_cSBM.yaml new file mode 100644 index 0000000..ef77a59 --- /dev/null +++ b/fgssl/gfl/baseline/fedavg_wpsn_on_cSBM.yaml @@ -0,0 +1,37 @@ +use_gpu: True +device: 2 +early_stop: + patience: 200 + improve_indicator_mode: mean +# monitoring: ['dissim'] +federate: + mode: standalone + total_round_num: 400 +data: + root: data/ + type: 'csbm' + #type: 'csbm_data_feb_05_2022-19:23' + cSBM_phi: [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8] +dataloader: + type: pyg +model: + type: gpr + hidden: 256 + out_channels: 2 + task: node +personalization: + local_param: ['prop1'] +train: + local_update_steps: 2 + optimizer: + lr: 0.5 + weight_decay: 0.0005 + type: SGD +criterion: + type: CrossEntropyLoss +trainer: + type: nodeminibatch_trainer +finetune: + local_update_steps: 2 +eval: + metrics: ['acc', 'correct'] diff --git a/fgssl/gfl/baseline/fedbn_gnn_minibatch_on_multi_task.yaml b/fgssl/gfl/baseline/fedbn_gnn_minibatch_on_multi_task.yaml new file mode 100644 index 0000000..9487f37 --- /dev/null +++ b/fgssl/gfl/baseline/fedbn_gnn_minibatch_on_multi_task.yaml @@ -0,0 +1,37 @@ +use_gpu: True +device: 0 +early_stop: + patience: 20 + improve_indicator_mode: mean +federate: + mode: 'standalone' + make_global_eval: False + total_round_num: 400 + share_local_model: False +data: + root: data/ + type: graph_multi_domain_mix + pre_transform: ['Constant', {'value':1.0, 'cat':False}] +dataloader: + type: pyg +model: + type: gin + hidden: 64 + out_channels: 0 + task: graph +personalization: + # local_param: ['encoder_atom', 'encoder', 'clf'] # to handle size-different pre & post layers + local_param: [ 'encoder_atom', 'encoder', 'clf', 'norms' ] # pre, post + FedBN +train: + local_update_steps: 16 + optimizer: + lr: 0.5 + weight_decay: 0.0005 + type: SGD +criterion: + type: CrossEntropyLoss +trainer: + type: graphminibatch_trainer +eval: + freq: 5 + metrics: ['acc', 'correct'] diff --git a/fgssl/gfl/baseline/fgcl_afg.yaml b/fgssl/gfl/baseline/fgcl_afg.yaml new file mode 100644 index 0000000..ccd591a --- /dev/null +++ b/fgssl/gfl/baseline/fgcl_afg.yaml @@ -0,0 +1,77 @@ +# Whether to use GPU +use_gpu: True + +# Deciding which GPU to use +device: 3 + +# Federate learning related options +federate: + # `standalone` or `distributed` + mode: standalone + # Evaluate in Server or Client test set + make_global_eval: True + # Number of dataset being split + client_num: 5 + # Number of communication round + total_round_num: 300 + method: fgcl +# Dataset related options +data: + # Root directory where the data stored + fgcl: True + root: data/ + # Dataset name + type: cora + # Use Louvain algorithm to split `Cora` + splitter: 'louvain' +dataloader: + # Type of sampler + type: pyg + # Use fullbatch training, batch_size should be `1` + batch_size: 1 + +# Model related options +model: + # Model type + type: gnn_fgcl + # Hidden dim + hidden: 1024 + # Dropout rate + dropout: 0.5 + # Number of Class of `Cora` + out_channels: 7 + + +# Criterion related options +criterion: + # Criterion type + type: CrossEntropyLoss + +# Trainer related options +trainer: + # Trainer type + type: fgcl2 + + +# Train related options +train: + # Number of local update steps + local_update_steps: 4 + # Optimizer related op tions + optimizer: + # Learning rate + lr: 0.25 + # Weight decay + weight_decay: 0.0005 + # Optimizer type + type: SGD +# scheduler: +# type: myscheduler + + +# Evaluation related options +eval: + # Frequency of evaluation + freq: 1 + # Evaluation metrics, accuracy and number of correct items + metrics: ['acc', 'correct'] \ No newline at end of file diff --git a/fgssl/gfl/baseline/isolated_gin_minibatch_on_cikmcup.yaml b/fgssl/gfl/baseline/isolated_gin_minibatch_on_cikmcup.yaml new file mode 100644 index 0000000..3153c31 --- /dev/null +++ b/fgssl/gfl/baseline/isolated_gin_minibatch_on_cikmcup.yaml @@ -0,0 +1,36 @@ +use_gpu: True +device: 0 +early_stop: + patience: 20 + improve_indicator_mode: mean + the_smaller_the_better: False +federate: + mode: standalone + method: local + make_global_eval: False + total_round_num: 10 + share_local_model: False +data: + batch_size: 64 + root: data/ + type: cikmcup +dataloader: + type: pyg +model: + type: gin + hidden: 64 +personalization: + local_param: ['encoder_atom', 'encoder', 'clf'] +train: + batch_or_epoch: epoch + local_update_steps: 21 + optimizer: + weight_decay: 0.0005 + type: SGD +trainer: + type: graphminibatch_trainer +eval: + freq: 5 + report: ['avg'] + best_res_update_round_wise_key: val_loss + count_flops: False diff --git a/fgssl/gfl/baseline/isolated_gin_minibatch_on_cikmcup_per_client.yaml b/fgssl/gfl/baseline/isolated_gin_minibatch_on_cikmcup_per_client.yaml new file mode 100644 index 0000000..01b6414 --- /dev/null +++ b/fgssl/gfl/baseline/isolated_gin_minibatch_on_cikmcup_per_client.yaml @@ -0,0 +1,147 @@ +client_1: + model: + out_channels: 2 + task: graphClassification + criterion: + type: CrossEntropyLoss + train: + optimizer: + lr: 0.1 + eval: + metrics: ['acc'] +client_2: + model: + out_channels: 2 + task: graphClassification + criterion: + type: CrossEntropyLoss + train: + optimizer: + lr: 0.01 + eval: + metrics: ['acc'] +client_3: + model: + out_channels: 2 + task: graphClassification + criterion: + type: CrossEntropyLoss + train: + optimizer: + lr: 0.001 + eval: + metrics: ['acc'] +client_4: + model: + out_channels: 2 + task: graphClassification + criterion: + type: CrossEntropyLoss + train: + optimizer: + lr: 0.01 + eval: + metrics: ['acc'] +client_5: + model: + out_channels: 2 + task: graphClassification + criterion: + type: CrossEntropyLoss + train: + optimizer: + lr: 0.0001 + eval: + metrics: ['acc'] +client_6: + model: + out_channels: 2 + task: graphClassification + criterion: + type: CrossEntropyLoss + train: + optimizer: + lr: 0.0005 + eval: + metrics: ['acc'] +client_7: + model: + out_channels: 2 + task: graphClassification + criterion: + type: CrossEntropyLoss + train: + optimizer: + lr: 0.01 + eval: + metrics: ['acc'] +client_8: + model: + out_channels: 2 + task: graphClassification + criterion: + type: CrossEntropyLoss + train: + optimizer: + lr: 0.05 + eval: + metrics: ['acc'] +client_9: + model: + out_channels: 1 + task: graphRegression + criterion: + type: MSELoss + train: + optimizer: + lr: 0.1 + eval: + metrics: ['mse'] +client_10: + model: + out_channels: 10 + task: graphRegression + criterion: + type: MSELoss + train: + optimizer: + lr: 0.05 + grad: + grad_clip: 1.0 + eval: + metrics: ['mse'] +client_11: + model: + out_channels: 1 + task: graphRegression + criterion: + type: MSELoss + train: + optimizer: + lr: 0.05 + eval: + metrics: ['mse'] +client_12: + model: + out_channels: 1 + task: graphRegression + criterion: + type: MSELoss + train: + optimizer: + lr: 0.01 + eval: + metrics: ['mse'] +client_13: + model: + out_channels: 12 + task: graphRegression + criterion: + type: MSELoss + train: + optimizer: + lr: 0.05 + grad: + grad_clip: 1.0 + eval: + metrics: ['mse'] \ No newline at end of file diff --git a/fgssl/gfl/baseline/local_gnn_node_fullbatch_citation.yaml b/fgssl/gfl/baseline/local_gnn_node_fullbatch_citation.yaml new file mode 100644 index 0000000..c8a298f --- /dev/null +++ b/fgssl/gfl/baseline/local_gnn_node_fullbatch_citation.yaml @@ -0,0 +1,32 @@ +use_gpu: True +device: 0 +early_stop: + patience: 100 + improve_indicator_mode: mean +federate: + make_global_eval: True + client_num: 5 + total_round_num: 400 + method: 'local' +data: + root: data/ + type: cora + splitter: 'louvain' +dataloader: + type: pyg + batch_size: 1 +model: + type: gcn + hidden: 64 + dropout: 0.5 + out_channels: 7 + task: node +train: + optimizer: + lr: 0.05 + weight_decay: 0.0005 + type: SGD +criterion: + type: CrossEntropyLoss +trainer: + type: graphfullbatch diff --git a/fgssl/gfl/baseline/model_change.yaml b/fgssl/gfl/baseline/model_change.yaml new file mode 100644 index 0000000..0dcb52f --- /dev/null +++ b/fgssl/gfl/baseline/model_change.yaml @@ -0,0 +1,75 @@ +# Whether to use GPU +use_gpu: True + +# Deciding which GPU to use +device: 3 + +# Federate learning related options +federate: + # `standalone` or `distributed` + mode: standalone + # Evaluate in Server or Client test set + make_global_eval: True + # Number of dataset being split + client_num: 5 + # Number of communication round + total_round_num: 400 + method: fgcl +# Dataset related options +data: + # Root directory where the data stored + root: data/ + # Dataset name + type: cora + # Use Louvain algorithm to split `Cora` + splitter: 'louvain' +dataloader: + # Type of sampler + type: pyg + # Use fullbatch training, batch_size should be `1` + batch_size: 1 + +# Model related options +model: + # Model type + type: gcn + # Hidden dim + hidden: 64 + # Dropout rate + dropout: 0.5 + # Number of Class of `Cora` + out_channels: 7 + + +# Criterion related options +criterion: + # Criterion type + type: CrossEntropyLoss + +# Trainer related options +trainer: + # Trainer type + type: fgcl1 + +# Train related options +train: + # Number of local update steps + local_update_steps: 4 + # Optimizer related options + optimizer: + # Learning rate + lr: 0.25 + # Weight decay + weight_decay: 0.0005 + # Optimizer type + type: SGD +# scheduler: +# type: myscheduler + + +# Evaluation related options +eval: + # Frequency of evaluation + freq: 1 + # Evaluation metrics, accuracy and number of correct items + metrics: ['acc', 'correct'] \ No newline at end of file diff --git a/fgssl/gfl/baseline/repro_exp/graph_level/args_graph_fedalgo.sh b/fgssl/gfl/baseline/repro_exp/graph_level/args_graph_fedalgo.sh new file mode 100644 index 0000000..d64cad5 --- /dev/null +++ b/fgssl/gfl/baseline/repro_exp/graph_level/args_graph_fedalgo.sh @@ -0,0 +1,37 @@ +# ---------------------------------------------------------------------- # +# FedOpt +# ---------------------------------------------------------------------- # + +# proteins +bash run_graph_level_opt.sh 0 proteins gcn 0.25 4 & + +bash run_graph_level_opt.sh 1 proteins gin 0.25 1 & + +bash run_graph_level_opt.sh 2 proteins gat 0.25 4 & + +# imdb-binary +bash run_graph_level_opt.sh 3 imdb-binary gcn 0.25 16 & + +bash run_graph_level_opt.sh 4 imdb-binary gin 0.01 16 & + +bash run_graph_level_opt.sh 5 imdb-binary gat 0.25 16 & + +# ---------------------------------------------------------------------- # +# FedProx +# ---------------------------------------------------------------------- # + +# proteins +bash run_graph_level_prox.sh 6 proteins gcn 0.25 4 & + +bash run_graph_level_prox.sh 7 proteins gin 0.25 1 & + +bash run_graph_level_prox.sh 1 proteins gat 0.25 4 & + +# imdb-binary +bash run_graph_level_prox.sh 2 imdb-binary gcn 0.25 16 & + +bash run_graph_level_prox.sh 3 imdb-binary gin 0.01 16 & + +bash run_graph_level_prox.sh 4 imdb-binary gat 0.25 16 & + + diff --git a/fgssl/gfl/baseline/repro_exp/graph_level/args_multi_graph_fedalgo.sh b/fgssl/gfl/baseline/repro_exp/graph_level/args_multi_graph_fedalgo.sh new file mode 100644 index 0000000..6809320 --- /dev/null +++ b/fgssl/gfl/baseline/repro_exp/graph_level/args_multi_graph_fedalgo.sh @@ -0,0 +1,23 @@ +# ---------------------------------------------------------------------- # +# FedOpt +# ---------------------------------------------------------------------- # + +# mol +bash run_multi_opt.sh 5 mol gcn 0.25 16 & + +bash run_multi_opt.sh 7 mol gin 0.25 4 & + +bash run_multi_opt.sh 5 mol gat 0.25 16 & + +# ---------------------------------------------------------------------- # +# FedProx +# ---------------------------------------------------------------------- # + +# mol +bash run_multi_prox.sh 7 mol gcn 0.25 16 & + +bash run_multi_prox.sh 5 mol gin 0.01 4 & + +bash run_multi_prox.sh 7 mol gat 0.25 16 & + + diff --git a/fgssl/gfl/baseline/repro_exp/graph_level/run_graph_level.sh b/fgssl/gfl/baseline/repro_exp/graph_level/run_graph_level.sh new file mode 100644 index 0000000..29fe3b8 --- /dev/null +++ b/fgssl/gfl/baseline/repro_exp/graph_level/run_graph_level.sh @@ -0,0 +1,49 @@ +set -e + +cudaid=$1 +dataset=$2 + +cd ../../../../.. + +if [ ! -d "out" ];then + mkdir out +fi + +if [[ $dataset = 'hiv' ]]; then + out_channels=2 + hidden=64 + splitter='scaffold' +elif [[ $dataset = 'proteins' ]]; then + out_channels=2 + hidden=64 + splitter='rand_chunk' +elif [[ $dataset = 'imdb-binary' ]]; then + out_channels=2 + hidden=64 + splitter='lda' +else + out_channels=4 + hidden=1024 +fi + +echo "HPO starts..." + +gnns=('gcn' 'gin' 'gat') +lrs=(0.01 0.05 0.25) +local_updates=(1 4 16) + +for (( g=0; g<${#gnns[@]}; g++ )) +do + for (( i=0; i<${#lrs[@]}; i++ )) + do + for (( j=0; j<${#local_updates[@]}; j++ )) + do + for k in {1..5} + do + python federatedscope/main.py --cfg federatedscope/gfl/baseline/fedavg_gcn_minibatch_on_hiv.yaml device ${cudaid} data.type ${dataset} data.splitter ${splitter} train.optimizer.lr ${lrs[$i]} train.local_update_steps ${local_updates[$j]} model.type ${gnns[$g]} model.out_channels ${out_channels} model.hidden ${hidden} seed $k >>out/${gnns[$g]}_${lrs[$i]}_${local_updates[$j]}_on_${dataset}_${splitter}.log 2>&1 + done + done + done +done + +echo "HPO ends." diff --git a/fgssl/gfl/baseline/repro_exp/graph_level/run_graph_level_multi_task.sh b/fgssl/gfl/baseline/repro_exp/graph_level/run_graph_level_multi_task.sh new file mode 100644 index 0000000..028d5e9 --- /dev/null +++ b/fgssl/gfl/baseline/repro_exp/graph_level/run_graph_level_multi_task.sh @@ -0,0 +1,48 @@ +set -e + +cudaid=$1 +dname=$2 + +cd ../../../../.. + +if [[ $dname = 'mol' ]]; then + dataset='graph_multi_domain_mol' +elif [[ $dname = 'mix' ]]; then + dataset='graph_multi_domain_mix' +elif [[ $dname = 'biochem' ]]; then + dataset='graph_multi_domain_biochem' +elif [[ $dname = 'v1' ]]; then + dataset='graph_multi_domain_kddcupv1' +else + dataset='graph_multi_domain_small' +fi + +if [ ! -d "out" ];then + mkdir out +fi + +out_channels=0 +hidden=64 +splitter='ooxx' + +echo "HPO starts..." + +gnns=('gcn' 'gin' 'gat') +lrs=(0.01 0.05 0.25) +local_updates=(1 4 16) + +for (( g=0; g<${#gnns[@]}; g++ )) +do + for (( i=0; i<${#lrs[@]}; i++ )) + do + for (( j=0; j<${#local_updates[@]}; j++ )) + do + for k in {1..5} + do + python federatedscope/main.py --cfg federatedscope/gfl/baseline/fedavg_gnn_minibatch_on_multi_task.yaml device ${cudaid} data.type ${dataset} data.splitter ${splitter} train.optimizer.lr ${lrs[$i]} train.local_update_steps ${local_updates[$j]} model.type ${gnns[$g]} model.out_channels ${out_channels} model.hidden ${hidden} seed $k >>out/${gnns[$g]}_${lrs[$i]}_${local_updates[$j]}_on_${dataset}_${splitter}.log 2>&1 + done + done + done +done + +echo "HPO ends." diff --git a/fgssl/gfl/baseline/repro_exp/graph_level/run_graph_level_multi_task_bn.sh b/fgssl/gfl/baseline/repro_exp/graph_level/run_graph_level_multi_task_bn.sh new file mode 100644 index 0000000..50ede4f --- /dev/null +++ b/fgssl/gfl/baseline/repro_exp/graph_level/run_graph_level_multi_task_bn.sh @@ -0,0 +1,48 @@ +set -e + +cudaid=$1 +dname=$2 + +cd ../../../../.. + +if [[ $dname = 'mol' ]]; then + dataset='graph_multi_domain_mol' +elif [[ $dname = 'mix' ]]; then + dataset='graph_multi_domain_mix' +elif [[ $dname = 'biochem' ]]; then + dataset='graph_multi_domain_biochem' +elif [[ $dname = 'v1' ]]; then + dataset='graph_multi_domain_kddcupv1' +else + dataset='graph_multi_domain_small' +fi + +if [ ! -d "out_bn" ];then + mkdir out_bn +fi + +out_channels=0 +hidden=64 +splitter='ooxx' + +echo "HPO starts..." + +gnns=('gin') +lrs=(0.01 0.05 0.25) +local_updates=(1 4 16) + +for (( g=0; g<${#gnns[@]}; g++ )) +do + for (( i=0; i<${#lrs[@]}; i++ )) + do + for (( j=0; j<${#local_updates[@]}; j++ )) + do + for k in {1..5} + do + python federatedscope/main.py --cfg federatedscope/gfl/baseline/fedbn_gnn_minibatch_on_multi_task.yaml device ${cudaid} data.type ${dataset} data.splitter ${splitter} train.optimizer.lr ${lrs[$i]} train.local_update_steps ${local_updates[$j]} model.type ${gnns[$g]} model.out_channels ${out_channels} model.hidden ${hidden} seed $k >>out_bn/${gnns[$g]}_${lrs[$i]}_${local_updates[$j]}_on_${dataset}_${splitter}.log 2>&1 + done + done + done +done + +echo "HPO ends." diff --git a/fgssl/gfl/baseline/repro_exp/graph_level/run_graph_level_multi_task_bn_finetune.sh b/fgssl/gfl/baseline/repro_exp/graph_level/run_graph_level_multi_task_bn_finetune.sh new file mode 100644 index 0000000..f4b22ed --- /dev/null +++ b/fgssl/gfl/baseline/repro_exp/graph_level/run_graph_level_multi_task_bn_finetune.sh @@ -0,0 +1,48 @@ +set -e + +cudaid=$1 +dname=$2 + +cd ../../../../.. + +if [[ $dname = 'mol' ]]; then + dataset='graph_multi_domain_mol' +elif [[ $dname = 'mix' ]]; then + dataset='graph_multi_domain_mix' +elif [[ $dname = 'biochem' ]]; then + dataset='graph_multi_domain_biochem' +elif [[ $dname = 'v1' ]]; then + dataset='graph_multi_domain_kddcupv1' +else + dataset='graph_multi_domain_small' +fi + +if [ ! -d "out_bn_finetune" ];then + mkdir out_bn_finetune +fi + +out_channels=0 +hidden=64 +splitter='ooxx' + +echo "HPO starts..." + +gnns=('gin') +lrs=(0.25) +local_updates=(16) + +for (( g=0; g<${#gnns[@]}; g++ )) +do + for (( i=0; i<${#lrs[@]}; i++ )) + do + for (( j=0; j<${#local_updates[@]}; j++ )) + do + for k in {1..3} + do + python federatedscope/main.py --cfg federatedscope/gfl/baseline/fedbn_gnn_minibatch_on_multi_task.yaml device ${cudaid} data.type ${dataset} data.splitter ${splitter} train.optimizer.lr ${lrs[$i]} train.local_update_steps ${local_updates[$j]} model.type ${gnns[$g]} model.out_channels ${out_channels} model.hidden ${hidden} seed $k finetune.local_update_steps 16 >>out_bn_finetune/${gnns[$g]}_${lrs[$i]}_${local_updates[$j]}_on_${dataset}_${splitter}.log 2>&1 + done + done + done +done + +echo "HPO ends." diff --git a/fgssl/gfl/baseline/repro_exp/graph_level/run_graph_level_opt.sh b/fgssl/gfl/baseline/repro_exp/graph_level/run_graph_level_opt.sh new file mode 100644 index 0000000..09bc2a0 --- /dev/null +++ b/fgssl/gfl/baseline/repro_exp/graph_level/run_graph_level_opt.sh @@ -0,0 +1,45 @@ +set -e + +cudaid=$1 +dataset=$2 +gnn=$3 +lr=$4 +local_update=$5 + +cd ../../../../.. + +if [ ! -d "out" ];then + mkdir out +fi + +if [[ $dataset = 'hiv' ]]; then + out_channels=2 + hidden=64 + splitter='scaffold' +elif [[ $dataset = 'proteins' ]]; then + out_channels=2 + hidden=64 + splitter='rand_chunk' +elif [[ $dataset = 'imdb-binary' ]]; then + out_channels=2 + hidden=64 + splitter='lda' +else + out_channels=4 + hidden=1024 +fi + + +echo "HPO starts..." + +lr_servers=(0.5 0.1) + +for (( s=0; s<${#lr_servers[@]}; s++ )) +do + for k in {1..5} + do + python federatedscope/main.py --cfg federatedscope/gfl/baseline/fedavg_gcn_minibatch_on_hiv.yaml device ${cudaid} data.type ${dataset} data.splitter ${splitter} train.optimizer.lr ${lr} train.local_update_steps ${local_update} model.type ${gnn} model.out_channels ${out_channels} model.hidden ${hidden} seed $k fedopt.use True fedopt.optimizer.lr ${lr_servers[$s]} >>out/${gnn}_${lr}_${local_update}_on_${dataset}_${splitter}_${lr_servers[$s]}_opt.log 2>&1 + done +done + +echo "HPO ends." diff --git a/fgssl/gfl/baseline/repro_exp/graph_level/run_graph_level_prox.sh b/fgssl/gfl/baseline/repro_exp/graph_level/run_graph_level_prox.sh new file mode 100644 index 0000000..4a17bea --- /dev/null +++ b/fgssl/gfl/baseline/repro_exp/graph_level/run_graph_level_prox.sh @@ -0,0 +1,45 @@ +set -e + +cudaid=$1 +dataset=$2 +gnn=$3 +lr=$4 +local_update=$5 + +cd ../../../../.. + +if [ ! -d "out" ];then + mkdir out +fi + +if [[ $dataset = 'hiv' ]]; then + out_channels=2 + hidden=64 + splitter='scaffold' +elif [[ $dataset = 'proteins' ]]; then + out_channels=2 + hidden=64 + splitter='rand_chunk' +elif [[ $dataset = 'imdb-binary' ]]; then + out_channels=2 + hidden=64 + splitter='lda' +else + out_channels=4 + hidden=1024 +fi + + +echo "HPO starts..." + +mu=(0.1 1.0 5.0) + +for (( s=0; s<${#mu[@]}; s++ )) +do + for k in {1..5} + do + python federatedscope/main.py --cfg federatedscope/gfl/baseline/fedavg_gcn_minibatch_on_hiv.yaml device ${cudaid} data.type ${dataset} data.splitter ${splitter} train.optimizer.lr ${lr} train.local_update_steps ${local_update} model.type ${gnn} model.out_channels ${out_channels} model.hidden ${hidden} seed $k fedprox.use True fedprox.mu ${mu[$s]} >>out/${gnn}_${lr}_${local_update}_on_${dataset}_${splitter}_${mu[$s]}_prox.log 2>&1 + done +done + +echo "HPO ends." diff --git a/fgssl/gfl/baseline/repro_exp/graph_level/run_multi_opt.sh b/fgssl/gfl/baseline/repro_exp/graph_level/run_multi_opt.sh new file mode 100644 index 0000000..17b28ee --- /dev/null +++ b/fgssl/gfl/baseline/repro_exp/graph_level/run_multi_opt.sh @@ -0,0 +1,41 @@ +set -e + +cudaid=$1 +dname=$2 +gnn=$3 +lr=$4 +local_update=$5 + +cd ../../../../.. + +if [[ $dname = 'mol' ]]; then + dataset='graph_multi_domain_mol' +elif [[ $dname = 'mix' ]]; then + dataset='graph_multi_domain_mix' +elif [[ $dname = 'biochem' ]]; then + dataset='graph_multi_domain_biochem' +else + dataset='graph_multi_domain_small' +fi + +if [ ! -d "out" ];then + mkdir out +fi + +out_channels=0 +hidden=64 +splitter='ooxx' + +echo "HPO starts..." + +lr_servers=(0.5 0.1) + +for (( s=0; s<${#lr_servers[@]}; s++ )) +do + for k in {1..5} + do + python federatedscope/main.py --cfg federatedscope/gfl/baseline/fedavg_gnn_minibatch_on_multi_task.yaml device ${cudaid} data.type ${dataset} data.splitter ${splitter} train.optimizer.lr ${lr} train.local_update_steps ${local_update} model.type ${gnn} model.out_channels ${out_channels} model.hidden ${hidden} seed $k fedopt.use True fedopt.optimizer.lr ${lr_servers[$s]} >>out/${gnn}_${lr}_${local_update}_on_${dataset}_${splitter}_${lr_servers[$s]}_opt.log 2>&1 + done +done + +echo "HPO ends." diff --git a/fgssl/gfl/baseline/repro_exp/graph_level/run_multi_prox.sh b/fgssl/gfl/baseline/repro_exp/graph_level/run_multi_prox.sh new file mode 100644 index 0000000..efffc92 --- /dev/null +++ b/fgssl/gfl/baseline/repro_exp/graph_level/run_multi_prox.sh @@ -0,0 +1,42 @@ +set -e + +cudaid=$1 +dname=$2 +gnn=$3 +lr=$4 +local_update=$5 + +cd ../../../../.. + +if [[ $dname = 'mol' ]]; then + dataset='graph_multi_domain_mol' +elif [[ $dname = 'mix' ]]; then + dataset='graph_multi_domain_mix' +elif [[ $dname = 'biochem' ]]; then + dataset='graph_multi_domain_biochem' +else + dataset='graph_multi_domain_small' +fi + +if [ ! -d "out" ];then + mkdir out +fi + +out_channels=0 +hidden=64 +splitter='ooxx' + + +echo "HPO starts..." + +mu=(0.1 1.0 5.0) + +for (( s=0; s<${#mu[@]}; s++ )) +do + for k in {1..5} + do + python federatedscope/main.py --cfg federatedscope/gfl/baseline/fedavg_gnn_minibatch_on_multi_task.yaml device ${cudaid} data.type ${dataset} data.splitter ${splitter} train.optimizer.lr ${lr} train.local_update_steps ${local_update} model.type ${gnn} model.out_channels ${out_channels} model.hidden ${hidden} seed $k fedprox.use True fedprox.mu ${mu[$s]} >>out/${gnn}_${lr}_${local_update}_on_${dataset}_${splitter}_${mu[$s]}_prox.log 2>&1 + done +done + +echo "HPO ends." diff --git a/fgssl/gfl/baseline/repro_exp/hpo/run_hpo.sh b/fgssl/gfl/baseline/repro_exp/hpo/run_hpo.sh new file mode 100644 index 0000000..b516e2b --- /dev/null +++ b/fgssl/gfl/baseline/repro_exp/hpo/run_hpo.sh @@ -0,0 +1,29 @@ +set -e + +cudaid=$1 +dataset=$2 + +cd ../../../../.. + +if [ ! -d "hpo_${dataset}" ];then + mkdir hpo_${dataset} +fi + +if [ ! -d "hpo" ];then + mkdir hpo +fi + +rs=(1 2 4 8) +samples=(1 2 4 5) + +for (( s=0; s<${#samples[@]}; s++ )) +do + for (( r=0; r<${#rs[@]}; r++ )) + do + for k in {1..5} + do + python federatedscope/hpo.py --cfg federatedscope/gfl/baseline/fedavg_gnn_node_fullbatch_citation.yaml federate.sample_client_num ${samples[$s]} device ${cudaid} data.type ${dataset} hpo.r ${rs[$r]} seed $k >>hpo/hpo_on_${dataset}_${rs[$r]}_sample${samples[$s]}.log 2>&1 + rm hpo_${dataset}/* + done + done +done \ No newline at end of file diff --git a/fgssl/gfl/baseline/repro_exp/hpo/run_node_level_hpo.sh b/fgssl/gfl/baseline/repro_exp/hpo/run_node_level_hpo.sh new file mode 100644 index 0000000..b02b8b4 --- /dev/null +++ b/fgssl/gfl/baseline/repro_exp/hpo/run_node_level_hpo.sh @@ -0,0 +1,112 @@ +set -e + +cudaid=$1 +dataset=$2 +gnn='gcn' + +cd ../../../../.. + +if [ ! -d "out" ];then + mkdir out +fi + +if [[ $dataset = 'cora' ]]; then + out_channels=7 + hidden=64 + + num=16 + arry1=(0.5 0.0 0.25 16) + arry2=(0.5 0.0 0.01 16) + arry3=(0.0 0.0 0.01 1) + arry4=(0.0 0.0005 0.01 1) + arry5=(0.5 0.0 0.01 4) + arry6=(0.5 0.0 0.01 1) + arry7=(0.0 0.0 0.25 16) + arry8=(0.0 0.0005 0.25 1) + arry9=(0.5 0.0 0.05 1) + arry10=(0.5 0.0 0.25 4) + arry11=(0.0 0.0005 0.25 4) + arry12=(0.0 0.0005 0.01 4) + arry13=(0.5 0.0 0.25 1) + arry14=(0.0 0.0005 0.25 16) + arry15=(0.5 0.0 0.05 4) + arry16=(0.0 0.0005 0.01 16) + +elif [[ $dataset = 'citeseer' ]]; then + out_channels=6 + hidden=64 + + num=20 + arry1=(0.5 0.0 0.01 4) + arry2=(0.0 0.0005 0.01 1) + arry3=(0.5 0.0 0.05 4) + arry4=(0.5 0.0 0.25 1) + arry5=(0.0 0.0 0.05 16) + arry6=(0.0 0.0005 0.01 4) + arry7=(0.0 0.0005 0.05 1) + arry8=(0.0 0.0005 0.25 4) + arry9=(0.0 0.0005 0.05 16) + arry10=(0.0 0.0005 0.25 1) + arry11=(0.0 0.0 0.25 4) + arry12=(0.0 0.0 0.25 16) + arry13=(0.5 0.0 0.05 16) + arry14=(0.5 0.0 0.01 16) + arry15=(0.0 0.0 0.01 1) + arry16=(0.5 0.0 0.01 1) + arry17=(0.0 0.0005 0.05 4) + arry18=(0.0 0.0 0.25 1) + arry19=(0.0 0.0005 0.01 16) + arry20=(0.0 0.0 0.05 4) + +elif [[ $dataset = 'pubmed' ]]; then + out_channels=5 + hidden=64 + + num=15 + arry1=(0.5 0.0 0.05 1) + arry2=(0.5 0.0 0.01 16) + arry3=(0.0 0.0005 0.25 16) + arry4=(0.0 0.0005 0.01 4) + arry5=(0.5 0.0 0.25 4) + arry6=(0.5 0.0 0.25 16) + arry7=(0.0 0.0 0.25 4) + arry8=(0.0 0.0 0.01 1) + arry9=(0.0 0.0 0.05 1) + arry10=(0.0 0.0005 0.05 1) + arry11=(0.5 0.0 0.01 4) + arry12=(0.0 0.0 0.01 4) + arry13=(0.5 0.0 0.25 1) + arry14=(0.0 0.0005 0.01 1) + arry15=(0.5 0.0 0.01 1) + +else + out_channels=4 + hidden=1024 +fi + +echo "HPO starts..." + + +for (( i=1; i>out/${gnn}_${dropout}_${wd}_${lr}_${local_update}_on_${dataset}.log 2>&1 + done +done + +for (( i=1; i