Skip to content
This repository has been archived by the owner on Aug 30, 2022. It is now read-only.

XP-357 make controller parametrisable #201

Merged
merged 4 commits into from
Jan 8, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions tests/controller_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,8 @@ def test_random_controller():
np.ceil(fraction * len(participant_ids)) for fraction in fractions
]
for fraction, expected_length in zip(fractions, expected_lengths):
controller = RandomController(
participant_ids, fraction_of_participants=fraction
)
ids = controller.select_ids()
controller = RandomController(fraction_of_participants=fraction)
ids = controller.select_ids(participant_ids)
set_ids = set(ids)

# check that length of set_ids is as expected
Expand Down
13 changes: 6 additions & 7 deletions xain_fl/coordinator/coordinator.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
logger = get_logger(__name__)


# TODO: raise exceptions for invalid attribute values: https://xainag.atlassian.net/browse/XP-387
class Coordinator:
"""Class implementing the main Coordinator logic. It is implemented as a
state machine that reacts to received messages.
Expand Down Expand Up @@ -60,6 +61,7 @@ class Coordinator:
fraction_of_participants (:obj:`float`, optional): The fraction of total
connected participants to be selected in a single round. Defaults to 1.0,
meaning that all connected participants will be selected.
It must be in the (0.0, 1.0] interval.
weights (:obj:`list` of :class:`~numpy.ndarray`, optional): The weights of
the global model. Defaults to [].
epochs (:obj:`int`, optional): Number of training iterations local to
Expand All @@ -76,7 +78,7 @@ class Coordinator:
# pylint: disable-msg=dangerous-default-value

DEFAULT_AGGREGATOR: Aggregator = FederatedAveragingAgg()
DEFAULT_CONTROLLER: Controller = RandomController(participant_ids=[])
DEFAULT_CONTROLLER: Controller = RandomController()

def __init__(
self,
Expand Down Expand Up @@ -187,13 +189,10 @@ def remove_participant(self, participant_id: str) -> None:
self.state = coordinator_pb2.State.STANDBY

def select_participant_ids_and_init_round(self) -> None:
"""Initiates the Controller, selects ids and initiates a Round.
"""Selects the participant ids and initiates a Round.
"""
self.controller = RandomController(
participant_ids=self.participants.ids(),
fraction_of_participants=self.fraction_of_participants,
)
selected_ids = self.controller.select_ids()
self.controller.fraction_of_participants = self.fraction_of_participants
selected_ids = self.controller.select_ids(self.participants.ids())
self.round = Round(selected_ids)

def _handle_rendezvous(
Expand Down
58 changes: 31 additions & 27 deletions xain_fl/fl/coordinator/controller.py
Original file line number Diff line number Diff line change
@@ -1,62 +1,66 @@
"""Provides an abstract base class Controller and multiple sub-classes
such as CycleRandomController.
"""
"""Provides an abstract base class Controller and the RandomController
currently used by the Coordinator."""

from abc import ABC, abstractmethod
from typing import List

import numpy as np


# TODO: raise exceptions for invalid attribute values: https://xainag.atlassian.net/browse/XP-387
class Controller(ABC):
"""Abstract base class which provides an interface to the coordinator that
enables different selection strategies.
janpetschexain marked this conversation as resolved.
Show resolved Hide resolved

Attributes:
participant_ids (:obj:`list` of :obj:`str`): The list of IDs of the
all the available participants, a subset of which will be selected.
fraction_of_participants (:obj:`float`, optional): The fraction of total
participant ids to be selected. Defaults to 1.0, meaning that
all participant ids will be selected.
participant IDs to be selected. Defaults to 1.0, meaning that
all participant IDs will be selected. It must be in the (0.0, 1.0] interval.
"""

def __init__(
self, participant_ids: List[str], fraction_of_participants: float = 1.0
) -> None:
self.participant_ids: List[str] = participant_ids
def __init__(self, fraction_of_participants: float = 1.0) -> None:
janpetschexain marked this conversation as resolved.
Show resolved Hide resolved
self.fraction_of_participants: float = fraction_of_participants
self.num_ids_to_select: int = self.get_num_ids_to_select()

def get_num_ids_to_select(self) -> int:
"""Calculates how many participant ids need to be selected.
def get_num_ids_to_select(self, len_participant_ids: int) -> int:
"""Calculates how many participant IDs need to be selected.

Args:
len_participant_ids (:obj:`int`): The length of the list of IDs of all the
available participants.

Returns:
:obj:`int`: Number of participant ids to be selected
:obj:`int`: Number of participant IDs to be selected
"""
raw_num_ids_to_select = (
len(self.participant_ids) * self.fraction_of_participants
)
raw_num_ids_to_select = len_participant_ids * self.fraction_of_participants
max_valid_value = max(1, np.ceil(raw_num_ids_to_select))
minimum_valid_value = min(len(self.participant_ids), max_valid_value)
minimum_valid_value = min(len_participant_ids, max_valid_value)
return int(minimum_valid_value)

@abstractmethod
def select_ids(self) -> List[str]:
"""Returns the selected indices of next round
def select_ids(self, participant_ids: List[str]) -> List[str]:
"""Returns the selected indices of next round.

Args:
participant_ids (:obj:`list` of :obj:`str`): The list of IDs of all the
available participants, a subset of which will be selected.

Returns:
:obj:`list` of :obj:`str`: Unordered list of selected ids
:obj:`list` of :obj:`str`: List of selected participant IDs
"""
raise NotImplementedError("not implemented")


class RandomController(Controller):
def select_ids(self) -> List[str]:
def select_ids(self, participant_ids: List[str]) -> List[str]:
"""Randomly samples self.num_ids_to_select from the population of participants_ids,
without replacement.

Args:
participant_ids (:obj:`list` of :obj:`str`): The list of IDs of all the
available participants, a subset of which will be selected.

Returns:
:obj:`list` of :obj:`str`: List of selected participant ID's
:obj:`list` of :obj:`str`: List of selected participant IDs
"""
return np.random.choice(
self.participant_ids, size=self.num_ids_to_select, replace=False
)
num_ids_to_select = self.get_num_ids_to_select(len(participant_ids))
return np.random.choice(participant_ids, size=num_ids_to_select, replace=False)