Skip to content

Commit

Permalink
design change, broken commit
Browse files Browse the repository at this point in the history
  • Loading branch information
chesterxgchen committed Jan 18, 2024
1 parent 84082f1 commit 90301d5
Show file tree
Hide file tree
Showing 23 changed files with 210 additions and 421 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@

from nvflare.app_common.abstract.fl_model import FLModel, ParamsType
from nvflare.app_common.utils.fl_model_utils import FLModelUtils
from nvflare.app_common.workflows import wf_comm as flare
from nvflare.app_common.workflows.wf_comm.wf_comm_api_spec import (
from nvflare.app_common import wf_comm as flare
from nvflare.app_common.wf_comm.wf_comm_api_spec import (
CURRENT_ROUND,
DATA,
MIN_RESPONSES,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,23 +5,25 @@
task_result_filters = []

workflows = [
{
{
id = "fed_avg"
path = "nvflare.app_opt.pt.wf_controller.PTWFController"
args {
result_pull_interval = 5
task_name = "train"
wf_class_path = "fedavg_pt.PTFedAvg",
wf_args {
min_clients = 2
num_rounds = 2
output_path = "/tmp/nvflare/fedavg/mode.pth"
stop_cond = "accuracy >= 55"
}
communicator {
path = "nvflare.app_common.wf_comm.wf_communicator.WFCommunicator"
args = {}
}
strategy {
path = "fedavg_pt.PTFedAvg"
args {
min_clients = 2
num_rounds = 2
output_path = "/tmp/nvflare/fedavg/mode.pth"
stop_cond = "accuracy >= 55"
}
serializers = ["nvflare.app_opt.pt.decomposers.TensorDecomposer"]

}
}
]

components = []

}
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@
from nvflare.app_common.aggregators.weighted_aggregation_helper import WeightedAggregationHelper
from nvflare.app_common.utils.fl_model_utils import FLModelUtils
from nvflare.app_common.utils.math_utils import parse_compare_criteria
from nvflare.app_common.workflows import wf_comm as flare
from nvflare.app_common.workflows.wf_comm.wf_comm_api_spec import (
from nvflare.app_common import wf_comm as flare
from nvflare.app_common.wf_comm.wf_comm_api_spec import (
CURRENT_ROUND,
DATA,
MIN_RESPONSES,
Expand Down Expand Up @@ -64,10 +64,11 @@ def __init__(
else:
self.stop_criteria = None

self.flare_comm = flare.get_wf_comm_api()
self.flare_comm = None

def run(self):
self.logger.info("start Fed Avg Workflow\n \n")
self.flare_comm = flare.get_wf_comm_api()

start = self.start_round
end = self.start_round + self.num_rounds
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,9 @@

import torch

# from fedavg import FedAvg
from fedavg_intime import FedAvg

from fedavg import FedAvg
from nvflare.app_common.abstract.fl_model import FLModel

# to use in_time aggregate version of FedAvg
# you change the import to 'from fedavg_intime import FedAvg'


class PTFedAvg(FedAvg):
def __init__(
Expand All @@ -32,9 +27,8 @@ def __init__(
output_path: str,
start_round: int = 1,
stop_cond: str = None,
model_selection_rule: str = None,
):
super().__init__(min_clients, num_rounds, output_path, start_round, stop_cond, model_selection_rule)
super().__init__(min_clients, num_rounds, output_path, start_round, stop_cond)

def save_model(self, model: FLModel, file_path: str):
if not file_path:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@
from km_analysis import kaplan_meier_analysis

from nvflare.app_common.abstract.fl_model import FLModel
from nvflare.app_common.workflows import wf_comm as flare
from nvflare.app_common.workflows.wf_comm.wf_comm_api_spec import (
from nvflare.app_common import wf_comm as flare
from nvflare.app_common.wf_comm.wf_comm_api_spec import (
CURRENT_ROUND,
DATA,
MIN_RESPONSES,
Expand Down
2 changes: 1 addition & 1 deletion nvflare/apis/utils/fl_context_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def generate_log_message(fl_ctx: FLContext, msg: str):
_task_name = "task_name"
_task_id = "task_id"
_rc = "peer_rc"
_wf = "wf"
_wf = "strategy"

all_kvs = {_identity_: fl_ctx.get_identity_name()}
my_run = fl_ctx.get_job_id()
Expand Down
2 changes: 1 addition & 1 deletion nvflare/app_common/ccwf/server_ctl.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def __init__(
self,
num_rounds: int,
start_round: int = 0,
task_name_prefix: str = "wf",
task_name_prefix: str = "strategy",
configure_task_timeout=Constant.CONFIG_TASK_TIMEOUT,
end_workflow_timeout=Constant.END_WORKFLOW_TIMEOUT,
start_task_timeout=Constant.START_TASK_TIMEOUT,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from nvflare.app_common.wf_comm.wf_comm_api_spec import WFCommAPISpec
from nvflare.fuel.message.data_bus import DataBus

class EventManager:
def __init__(self, message_bus):
self.message_bus = message_bus
data_bus = DataBus()

def fire_event(self, event_name, event_data=None):
self.message_bus.publish(event_name, event_data)

def get_wf_comm_api() -> WFCommAPISpec:
return data_bus.receive_messages("wf_comm_api")
Loading

0 comments on commit 90301d5

Please sign in to comment.