Skip to content

Commit

Permalink
- implement a new task_group filtering decorator in Assigner class
Browse files Browse the repository at this point in the history
- update all the sub-classes that use task_groups to use the decorator
- update fedeval sample workspace to use default assigner, tasks and aggregator
- use of federated-evaluation/aggregator.yaml for FedEval specific workspace example to use round_number as 1
- removed assigner and tasks yaml from defaults/federated-evaluation, superseded by default assigner/tasks
- Rebase 21-Jan-2025.2
- added additional checks for assigner sub-classes that might not have task_groups
- Addressing review comments
- Updated existing test cases for Assigner sub-classes
- Remove hard-coded setting in assigner for torch_cnn_mnist ws, refer to default as in other Workspaces
- Use aggregator supplied --task_group to override the assinger selected_task_group
- update existing test cases of aggregator cli
- add test cases for the decorator
Signed-off-by: Shailesh Pant <shailesh.pant@intel.com>
  • Loading branch information
ishaileshpant committed Jan 22, 2025
1 parent 8104144 commit d2b46f0
Show file tree
Hide file tree
Showing 17 changed files with 115 additions and 44 deletions.
12 changes: 2 additions & 10 deletions openfl-workspace/torch_cnn_mnist/plan/plan.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,8 @@ aggregator:
rounds_to_train: 2
write_logs: false
template: openfl.component.aggregator.Aggregator
assigner:
settings:
task_groups:
- name: learning
percentage: 1.0
tasks:
- aggregated_model_validation
- train
- locally_tuned_model_validation
template: openfl.component.RandomGroupedAssigner
assigner :
defaults : plan/defaults/assigner.yaml
collaborator:
settings:
db_store_rounds: 1
Expand Down
8 changes: 5 additions & 3 deletions openfl-workspace/torch_cnn_mnist_fed_eval/plan/plan.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,12 @@ network :
defaults : plan/defaults/network.yaml

assigner :
defaults : plan/defaults/federated-evaluation/assigner.yaml

defaults : plan/defaults/assigner.yaml
settings :
selected_task_group : evaluation

tasks :
defaults : plan/defaults/federated-evaluation/tasks_torch.yaml
defaults : plan/defaults/tasks_torch.yaml

compression_pipeline :
defaults : plan/defaults/compression_pipeline.yaml
1 change: 1 addition & 0 deletions openfl-workspace/workspace/plan/defaults/aggregator.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@ settings :
db_store_rounds : 2
persist_checkpoint: True
persistent_db_path: local_state/tensor.db
task_group: learning
4 changes: 4 additions & 0 deletions openfl-workspace/workspace/plan/defaults/assigner.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,7 @@ settings :
- aggregated_model_validation
- train
- locally_tuned_model_validation
- name : evaluation
percentage : 1.0
tasks :
- aggregated_model_validation

This file was deleted.

This file was deleted.

1 change: 1 addition & 0 deletions openfl/component/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright 2020-2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

"""OpenFL Component Module."""

from openfl.component.aggregator.aggregator import Aggregator
from openfl.component.assigner.assigner import Assigner
Expand Down
14 changes: 12 additions & 2 deletions openfl/component/aggregator/aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,15 +130,25 @@ def __init__(
self.straggler_handling_policy = (
straggler_handling_policy or CutoffTimeBasedStragglerHandling()
)
self._end_of_round_check_done = [False] * rounds_to_train
self.stragglers = []

self.rounds_to_train = rounds_to_train
if self.task_group == "evaluation":
self.rounds_to_train = 1
logger.info(
f"task_group is {self.task_group}, setting rounds_to_train = {self.rounds_to_train}"
)

self._end_of_round_check_done = [False] * rounds_to_train
self.stragglers = []

# if the collaborator requests a delta, this value is set to true
self.authorized_cols = authorized_cols
self.uuid = aggregator_uuid
self.federation_uuid = federation_uuid
# # override the assigner selected_task_group
# # FIXME check the case of CustomAssigner as base class Assigner is redefined
# # and doesn't have selected_task_group as attribute
# assigner.selected_task_group = task_group
self.assigner = assigner
self.quit_job_sent_to = []

Expand Down
1 change: 1 addition & 0 deletions openfl/component/assigner/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright 2020-2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

"""OpenFL Assigner Module."""

from openfl.component.assigner.assigner import Assigner
from openfl.component.assigner.random_grouped_assigner import RandomGroupedAssigner
Expand Down
50 changes: 49 additions & 1 deletion openfl/component/assigner/assigner.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,11 @@

"""Assigner module."""

import logging
from functools import wraps

logger = logging.getLogger(__name__)


class Assigner:
r"""
Expand Down Expand Up @@ -35,18 +40,27 @@ class Assigner:
\* - ``tasks`` argument is taken from ``tasks`` section of FL plan YAML file.
"""

def __init__(self, tasks, authorized_cols, rounds_to_train, **kwargs):
def __init__(
self,
tasks,
authorized_cols,
rounds_to_train,
selected_task_group: str = "learning",
**kwargs,
):
"""Initializes the Assigner.
Args:
tasks (list of object): List of tasks to assign.
authorized_cols (list of str): Collaborators.
rounds_to_train (int): Number of training rounds.
selected_task_group (str, optional): Selected task_group. Defaults to "learning".
**kwargs: Additional keyword arguments.
"""
self.tasks = tasks
self.authorized_cols = authorized_cols
self.rounds = rounds_to_train
self.selected_task_group = selected_task_group
self.all_tasks_in_groups = []

self.task_group_collaborators = {}
Expand Down Expand Up @@ -93,3 +107,37 @@ def get_aggregation_type_for_task(self, task_name):
if "aggregation_type" not in self.tasks[task_name]:
return None
return self.tasks[task_name]["aggregation_type"]

@classmethod
def task_group_filtering(cls, func):
"""Decorator to filter task groups based on selected_task_group.
This decorator should be applied to define_task_assignments() method
in Assigner subclasses to handle task_group filtering.
"""

@wraps(func)
def wrapper(self, *args, **kwargs):
# First check if selection of task_group is applicable
if hasattr(self, "selected_task_group"):
# Verify task_groups exists before attempting filtering
if not hasattr(self, "task_groups"):
logger.warning(
"Task group specified for selection but no task_groups found. "
"Skipping filtering. This might be intentional for custom assigners."
)
return func(self, *args, **kwargs)

assert self.task_groups, "No task_groups defined in assigner."

# Perform the filtering
self.task_groups = [
group for group in self.task_groups if group["name"] == self.selected_task_group
]

assert self.task_groups, f"No task groups found for : {self.selected_task_group}"

# Call the original method
return func(self, *args, **kwargs)

return wrapper
7 changes: 5 additions & 2 deletions openfl/component/assigner/random_grouped_assigner.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import numpy as np

from openfl.component.assigner.assigner import Assigner
from openfl.component.assigner import Assigner


class RandomGroupedAssigner(Assigner):
Expand All @@ -33,16 +33,19 @@ class RandomGroupedAssigner(Assigner):
\* - Plan setting.
"""

task_group_filtering = Assigner.task_group_filtering

def __init__(self, task_groups, **kwargs):
"""Initializes the RandomGroupedAssigner.
Args:
task_groups (list of object): Task groups to assign.
**kwargs: Additional keyword arguments.
**kwargs: Additional keyword arguments, including mode.
"""
self.task_groups = task_groups
super().__init__(**kwargs)

@task_group_filtering
def define_task_assignments(self):
"""Define task assignments for each round and collaborator.
Expand Down
3 changes: 3 additions & 0 deletions openfl/component/assigner/static_grouped_assigner.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ class StaticGroupedAssigner(Assigner):
\* - Plan setting.
"""

task_group_filtering = Assigner.task_group_filtering

def __init__(self, task_groups, **kwargs):
"""Initializes the StaticGroupedAssigner.
Expand All @@ -42,6 +44,7 @@ def __init__(self, task_groups, **kwargs):
self.task_groups = task_groups
super().__init__(**kwargs)

@task_group_filtering
def define_task_assignments(self):
"""Define task assignments for each round and collaborator.
Expand Down
1 change: 1 addition & 0 deletions openfl/interface/aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ def start_(plan, authorized_cols, task_group):
if "settings" not in parsed_plan.config["aggregator"]:
parsed_plan.config["aggregator"]["settings"] = {}
parsed_plan.config["aggregator"]["settings"]["task_group"] = task_group
parsed_plan.config["assigner"]["settings"]["selected_task_group"] = task_group
logger.info(f"Setting aggregator to assign: {task_group} task_group")

logger.info("🧿 Starting the Aggregator Service.")
Expand Down
24 changes: 14 additions & 10 deletions tests/openfl/component/assigner/test_assigner.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def assigner():

def test_get_aggregation_type_for_task_none(assigner):
"""Assert that aggregation type of custom task is None."""
task_name = 'test_name'
task_name = "test_name"
tasks = {task_name: {}}

assigner = assigner(tasks, None, None)
Expand All @@ -31,11 +31,9 @@ def test_get_aggregation_type_for_task_none(assigner):

def test_get_aggregation_type_for_task(assigner):
"""Assert that aggregation type of task is getting correctly."""
task_name = 'test_name'
test_aggregation_type = 'test_aggregation_type'
tasks = {task_name: {
'aggregation_type': test_aggregation_type
}}
task_name = "test_name"
test_aggregation_type = "test_aggregation_type"
tasks = {task_name: {"aggregation_type": test_aggregation_type}}
assigner = assigner(tasks, None, None)

aggregation_type = assigner.get_aggregation_type_for_task(task_name)
Expand All @@ -46,13 +44,19 @@ def test_get_aggregation_type_for_task(assigner):
def test_get_all_tasks_for_round(assigner):
"""Assert that assigner tasks object is list."""
assigner = Assigner(None, None, None)
tasks = assigner.get_all_tasks_for_round('test')
tasks = assigner.get_all_tasks_for_round("test")

assert isinstance(tasks, list)


class TestNotImplError(TestCase):
def test_task_group_filtering_no_task_groups(assigner):
"""Assert that task_group_filtering does not filter when no task_groups are defined."""
assigner = Assigner(None,None,None)
assigner.selected_task_group = "test_group"
assigner.define_task_assignments()
assert not hasattr(assigner, "task_groups")

Check notice

Code scanning / Bandit

Use of assert detected. The enclosed code will be removed when compiling to optimised byte code. Note test

Use of assert detected. The enclosed code will be removed when compiling to optimised byte code.

class TestNotImplError(TestCase):
def test_define_task_assignments(self):
# TODO: define_task_assignments is defined as a mock in multiple fixtures,
# which leads the function to behave as a mock here and other tests.
Expand All @@ -61,9 +65,9 @@ def test_define_task_assignments(self):
def test_get_tasks_for_collaborator(self):
with self.assertRaises(NotImplementedError):
assigner = Assigner(None, None, None)
assigner.get_tasks_for_collaborator('col1', 0)
assigner.get_tasks_for_collaborator("col1", 0)

def test_get_collaborators_for_task(self):
with self.assertRaises(NotImplementedError):
assigner = Assigner(None, None, None)
assigner.get_collaborators_for_task('task_name', 0)
assigner.get_collaborators_for_task("task_name", 0)
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def task_groups():
"""Initialize task groups."""
task_groups = [
{
'name': 'train_and_validate',
'name': 'learning',
'percentage': 1.0,
'tasks': [
'aggregated_model_validation',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def task_groups(authorized_cols):
"""Initialize task groups."""
task_groups = [
{
'name': 'train_and_validate',
'name': 'learning',
'percentage': 1.0,
'collaborators': authorized_cols,
'tasks': [
Expand Down
15 changes: 15 additions & 0 deletions tests/openfl/interface/test_aggregator_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,11 @@ def test_aggregator_start(mock_parse):
'settings': {
'task_group': 'learning'
}
},
'assigner': {
'settings': {
'selected_task_group': 'learning'
}
}
}
mock_parse.return_value = mock_plan
Expand Down Expand Up @@ -54,6 +59,11 @@ def test_aggregator_start_illegal_plan(mock_parse, mock_is_directory_traversal):
'settings': {
'task_group': 'learning'
}
},
'assigner': {
'settings': {
'selected_task_group': 'learning'
}
}
}
mock_parse.return_value = mock_plan
Expand Down Expand Up @@ -83,6 +93,11 @@ def test_aggregator_start_illegal_cols(mock_parse, mock_is_directory_traversal):
'settings': {
'task_group': 'learning'
}
},
'assigner': {
'settings': {
'selected_task_group': 'learning'
}
}
}
mock_parse.return_value = mock_plan
Expand Down

0 comments on commit d2b46f0

Please sign in to comment.