diff --git a/research/fed-bpt/job_templates/fedbpt/config_fed_server.conf b/research/fed-bpt/job_templates/fedbpt/config_fed_server.conf index a142fbcf33..66422969c4 100644 --- a/research/fed-bpt/job_templates/fedbpt/config_fed_server.conf +++ b/research/fed-bpt/job_templates/fedbpt/config_fed_server.conf @@ -33,32 +33,22 @@ # seed for CMA-ES algorithm seed = 42 - - # The GlobalES controller use an persistor to load the model and save the model. - # The persistent component can be identified by component ID specified here. - # persistor_id = "persistor" } } ] # List of components used in the server side workflow. components = [ -# { -# # This is the persistence component used in above workflow. -# # PTFileModelPersistor is a Pytorch persistor which save/read the model to/from file. - -# id = "persistor" -# path = "nvflare.app_opt.pt.file_model_persistor.PTFileModelPersistor" -# -# # the persitor class take model class as argument -# # This imply that the model is initialized from the server-side. -# # The initialized model will be broadcast to all the clients to start the training. -# args.model.path = "{model_class_path}" -#}, { id = "receiver" path = "nvflare.app_opt.tracking.tb.tb_receiver.TBAnalyticsReceiver" args.events = ["fed.analytix_log_stats"] + }, + { + # we use this component so the client api `flare.init()` can get required information + id = "register_decomposer" + path = "decomposer_widget.RegisterDecomposer" + args {} } ] diff --git a/research/fed-bpt/src/global_es.py b/research/fed-bpt/src/global_es.py index bee3bd2256..572922e582 100644 --- a/research/fed-bpt/src/global_es.py +++ b/research/fed-bpt/src/global_es.py @@ -16,7 +16,6 @@ import cma import numpy as np -from cma_decomposer import register_decomposers from nvflare.app_common.abstract.fl_model import FLModel from nvflare.app_common.workflows.fedavg import FedAvg @@ -63,9 +62,6 @@ def __init__(self, *args, frac=1, sigma=1, intrinsic_dim=500, seed=42, bound=0, self.bound = bound def run(self) -> None: - # We serialize CMAEvolutionStrategy object directly. This requires registering custom decomposers. - register_decomposers() - local_cma_mu = 0.0 m = max(int(self.frac * self._min_clients), 1)