Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add custom order and early termination to CyclicController #2387

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions nvflare/apis/fl_constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ class ReturnCode(object):
VALIDATE_TYPE_UNKNOWN = "VALIDATE_TYPE_UNKNOWN"
EMPTY_RESULT = "EMPTY_RESULT"
UNSAFE_JOB = "UNSAFE_JOB"
EARLY_TERMINATION = "EARLY_TERMINATION"
SERVER_NOT_READY = "SERVER_NOT_READY"
SERVICE_UNAVAILABLE = "SERVICE_UNAVAILABLE"

Expand Down
114 changes: 79 additions & 35 deletions nvflare/app_common/workflows/cyclic_ctl.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import gc
import random
from typing import List, Union

from nvflare.apis.client import Client
from nvflare.apis.controller_spec import ClientTask, Task
Expand Down Expand Up @@ -50,7 +51,8 @@ def __init__(
task_check_period: float = 0.5,
persist_every_n_rounds: int = 1,
snapshot_every_n_rounds: int = 1,
order: str = RelayOrder.FIXED,
order: Union[str, List[str]] = RelayOrder.FIXED,
allow_early_termination=False,
):
"""A sample implementation to demonstrate how to use relay method for Cyclic Federated Learning.

Expand All @@ -67,11 +69,13 @@ def __init__(
If n is 0 then no persist.
snapshot_every_n_rounds (int, optional): persist the server state every n rounds. Defaults to 1.
If n is 0 then no persist.
order (str, optional): the order of relay.
If FIXED means the same order for every round.
If RANDOM means random order for every round.
If RANDOM_WITHOUT_SAME_IN_A_ROW means every round the order gets shuffled but a client will never be
run twice in a row (in different round).
order (Union[str, List[str]], optional): The order of relay.
- If a string is provided:
- "FIXED": Same order for every round.
- "RANDOM": Random order for every round.
- "RANDOM_WITHOUT_SAME_IN_A_ROW": Shuffled order, no repetition in consecutive rounds.
- If a list of strings is provided, it represents a custom order for relay.
allow_early_termination: whether to allow early workflow termination from clients

Raises:
TypeError: when any of input arguments does not have correct type
Expand All @@ -90,13 +94,14 @@ def __init__(
if not isinstance(task_name, str):
raise TypeError("task_name must be a string but got {}".format(type(task_name)))

if order not in SUPPORTED_ORDERS:
raise ValueError(f"order must be in {SUPPORTED_ORDERS}")
if order not in SUPPORTED_ORDERS and not isinstance(order, list):
raise ValueError(f"order must be in {SUPPORTED_ORDERS} or a list")

self._num_rounds = num_rounds
self._start_round = 0
self._end_round = self._start_round + self._num_rounds
self._current_round = 0
self._is_done = False
self._last_learnable = None
self.persistor_id = persistor_id
self.shareable_generator_id = shareable_generator_id
Expand All @@ -109,6 +114,7 @@ def __init__(
self._participating_clients = None
self._last_client = None
self._order = order
self._allow_early_termination = allow_early_termination

def start_controller(self, fl_ctx: FLContext):
self.log_debug(fl_ctx, "starting controller")
Expand All @@ -129,46 +135,79 @@ def start_controller(self, fl_ctx: FLContext):
fl_ctx.set_prop(AppConstants.NUM_ROUNDS, self._num_rounds, private=True, sticky=True)
self.fire_event(AppEventType.INITIAL_MODEL_LOADED, fl_ctx)

self._participating_clients = self._engine.get_clients()
self._participating_clients: List[Client] = self._engine.get_clients()
if len(self._participating_clients) <= 1:
self.system_panic("Not enough client sites.", fl_ctx)
self._last_client = None

def _get_relay_orders(self, fl_ctx: FLContext):
targets = list(self._participating_clients)
if len(targets) <= 1:
self.system_panic("Not enough client sites.", fl_ctx)
if self._order == RelayOrder.RANDOM:
random.shuffle(targets)
elif self._order == RelayOrder.RANDOM_WITHOUT_SAME_IN_A_ROW:
random.shuffle(targets)
if self._last_client == targets[0]:
targets = targets.append(targets.pop(0))
def _get_relay_orders(self, fl_ctx: FLContext) -> Union[List[Client], None]:
if len(self._participating_clients) <= 1:
self.system_panic(f"Not enough client sites ({len(self._participating_clients)}).", fl_ctx)
return None

if isinstance(self._order, list):
targets = []
active_clients_map = {t.name: t for t in self._participating_clients}
for c_name in self._order:
if c_name not in active_clients_map:
self.system_panic(f"Required client site ({c_name}) is not in active clients.", fl_ctx)
return None
targets.append(active_clients_map[c_name])
else:
targets = list(self._participating_clients)
if self._order == RelayOrder.RANDOM or self._order == RelayOrder.RANDOM_WITHOUT_SAME_IN_A_ROW:
random.shuffle(targets)
if self._order == RelayOrder.RANDOM_WITHOUT_SAME_IN_A_ROW and self._last_client == targets[0]:
targets.append(targets.pop(0))
self._last_client = targets[-1]
return targets

def _process_result(self, client_task: ClientTask, fl_ctx: FLContext):
result = client_task.result
rc = result.get_return_code()
client_name = client_task.client.name

# Raise errors if ReturnCode is not OK.
if rc and rc != ReturnCode.OK:
self.system_panic(
f"Result from {client_name} is bad, error code: {rc}. "
f"{self.__class__.__name__} exiting at round {self._current_round}.",
fl_ctx=fl_ctx,
)
return False
def _stop_workflow(self, task: Task):
self.cancel_task(task)
self._is_done = True

def _process_result(self, client_task: ClientTask, fl_ctx: FLContext):
# submitted shareable is stored in client_task.result
# we need to update task.data with that shareable so the next target
# will get the updated shareable
task = client_task.task

# update the global learnable with the received result (shareable)
# e.g. the received result could be weight_diffs, the learnable could be full weights.
self._last_learnable = self.shareable_generator.shareable_to_learnable(client_task.result, fl_ctx)
result = client_task.result
if isinstance(result, Shareable):
YuanTingHsieh marked this conversation as resolved.
Show resolved Hide resolved
# update the global learnable with the received result (shareable)
# e.g. the received result could be weight_diffs, the learnable could be full weights.
rc = result.get_return_code()
try:
self._last_learnable = self.shareable_generator.shareable_to_learnable(result, fl_ctx)
except Exception as ex:
if rc != ReturnCode.EARLY_TERMINATION:
self._stop_workflow(task)
self.log_error(fl_ctx, f"exception {secure_format_exception(ex)} from shareable_to_learnable")
return
else:
self.log_warning(
fl_ctx,
f"ignored {secure_format_exception(ex)} from shareable_to_learnable in early termination",
)

if rc == ReturnCode.EARLY_TERMINATION:
if self._allow_early_termination:
# the workflow is done
self._stop_workflow(task)
self.log_info(fl_ctx, f"Stopping workflow due to {rc} from client {client_task.client.name}")
return
else:
self.log_warning(
fl_ctx,
f"Ignored {rc} from client {client_task.client.name} because early termination is not allowed",
)
else:
self._stop_workflow(task)
self.log_error(
fl_ctx,
f"Stopping workflow due to result from client {client_task.client.name} is not a Shareable",
)
return

# prepare task shareable data for next client
task.data = self.shareable_generator.learnable_to_shareable(self._last_learnable, fl_ctx)
Expand All @@ -183,6 +222,9 @@ def control_flow(self, abort_signal: Signal, fl_ctx: FLContext):
self.log_debug(fl_ctx, "Cyclic starting.")

for self._current_round in range(self._start_round, self._end_round):
if self._is_done:
return

if abort_signal.triggered:
return

Expand All @@ -191,6 +233,8 @@ def control_flow(self, abort_signal: Signal, fl_ctx: FLContext):

# Task for one cyclic
targets = self._get_relay_orders(fl_ctx)
if targets is None:
return
targets_names = [t.name for t in targets]
self.log_debug(fl_ctx, f"Relay on {targets_names}")

Expand Down
134 changes: 134 additions & 0 deletions tests/unit_test/app_common/workflow/cyclic_ctl_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import uuid
from unittest.mock import Mock, patch

import pytest

from nvflare.apis.client import Client
from nvflare.apis.controller_spec import ClientTask, Task
from nvflare.apis.fl_constant import ReturnCode
from nvflare.apis.fl_context import FLContext
from nvflare.apis.shareable import Shareable
from nvflare.apis.signal import Signal
from nvflare.app_common.abstract.learnable import Learnable
from nvflare.app_common.workflows.cyclic_ctl import CyclicController, RelayOrder

SITE_1_ID = uuid.uuid4()
SITE_2_ID = uuid.uuid4()
SITE_3_ID = uuid.uuid4()

ORDER_TEST_CASES = [
(
RelayOrder.FIXED,
[Client("site-1", SITE_1_ID), Client("site-2", SITE_2_ID)],
[Client("site-1", SITE_1_ID), Client("site-2", SITE_2_ID)],
),
(
["site-1", "site-2"],
[Client("site-1", SITE_1_ID), Client("site-2", SITE_2_ID)],
[Client("site-1", SITE_1_ID), Client("site-2", SITE_2_ID)],
),
(
["site-2", "site-1"],
[Client("site-1", SITE_1_ID), Client("site-2", SITE_2_ID)],
[Client("site-2", SITE_2_ID), Client("site-1", SITE_1_ID)],
),
(
["site-2", "site-1", "site-3"],
[Client("site-3", SITE_3_ID), Client("site-1", SITE_1_ID), Client("site-2", SITE_2_ID)],
[Client("site-2", SITE_2_ID), Client("site-1", SITE_1_ID), Client("site-3", SITE_3_ID)],
),
]


def gen_shareable(is_early_termination: bool = False, is_not_shareable: bool = False):
if is_not_shareable:
return [1, 2, 3]
return_result = Shareable()
if is_early_termination:
return_result.set_return_code(ReturnCode.EARLY_TERMINATION)
return return_result


PROCESS_RESULT_TEST_CASES = [gen_shareable(is_early_termination=True), gen_shareable(is_not_shareable=True)]


class TestCyclicController:
@pytest.mark.parametrize("order,active_clients,expected_result", ORDER_TEST_CASES)
def test_get_relay_orders(self, order, active_clients, expected_result):
ctl = CyclicController(order=order)
ctx = FLContext()
ctl._participating_clients = active_clients
targets = ctl._get_relay_orders(ctx)
for c, e_c in zip(targets, expected_result):
assert c.name == e_c.name
assert c.token == e_c.token

def test_control_flow_call_relay_and_wait(self):

with patch("nvflare.app_common.workflows.cyclic_ctl.CyclicController.relay_and_wait") as mock_method:
ctl = CyclicController(persist_every_n_rounds=0, snapshot_every_n_rounds=0, num_rounds=1)
ctl.shareable_generator = Mock()
ctl._participating_clients = [
Client("site-3", SITE_3_ID),
Client("site-1", SITE_1_ID),
Client("site-2", SITE_2_ID),
]

abort_signal = Signal()
fl_ctx = FLContext()

with patch.object(ctl.shareable_generator, "learnable_to_shareable") as mock_method1, patch.object(
ctl.shareable_generator, "shareable_to_learnable"
) as mock_method2:
mock_method1.return_value = Shareable()
mock_method2.return_value = Learnable()

ctl.control_flow(abort_signal, fl_ctx)

mock_method.assert_called_once()

@pytest.mark.parametrize("return_result", PROCESS_RESULT_TEST_CASES)
def test_process_result(self, return_result):
ctl = CyclicController(
persist_every_n_rounds=0, snapshot_every_n_rounds=0, num_rounds=1, allow_early_termination=True
)
ctl.shareable_generator = Mock()
ctl._participating_clients = [
Client("site-3", SITE_3_ID),
Client("site-1", SITE_1_ID),
Client("site-2", SITE_2_ID),
]

fl_ctx = FLContext()
with patch.object(ctl, "cancel_task") as mock_method, patch.object(
ctl.shareable_generator, "learnable_to_shareable"
) as mock_method1, patch.object(ctl.shareable_generator, "shareable_to_learnable") as mock_method2:
mock_method1.return_value = Shareable()
mock_method2.return_value = Learnable()

client_task = ClientTask(
client=Mock(),
task=Task(
name="__test_task",
data=Shareable(),
),
)
client_task.result = return_result
ctl._process_result(client_task, fl_ctx)
mock_method.assert_called_once()
assert ctl._is_done is True
Loading