Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(framework) Generate error TaskRes for invalid cases in LinkState.get_task_res #4369

Merged
merged 36 commits into from
Nov 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
c20d229
remove limit arg from get_task_res
panh99 Oct 10, 2024
9cecf3a
Merge branch 'main' into rm-limit-gettaskres
panh99 Oct 10, 2024
d792eb4
add utility func
panh99 Oct 10, 2024
282eba9
Merge remote-tracking branch 'origin/main' into err-reply-get-task-res
panh99 Oct 10, 2024
e41e7ef
update get_task_res of in-mem-state
panh99 Oct 10, 2024
4cc7fc8
temp commit
panh99 Oct 10, 2024
5083ee8
temp
panh99 Oct 21, 2024
0e6df75
Merge branch 'main' into err-reply-get-task-res
panh99 Oct 22, 2024
38d10e8
complete get_task_res
panh99 Oct 23, 2024
441fe9b
tmp
panh99 Oct 23, 2024
0d364a4
use time.time
panh99 Oct 23, 2024
e79a9ba
rm node availibility check
panh99 Oct 24, 2024
668549f
format
panh99 Oct 24, 2024
e1cd5f3
Merge remote-tracking branch 'origin/main' into err-reply-get-task-res
panh99 Oct 24, 2024
466ca85
improve doc
panh99 Oct 24, 2024
988062b
refactor code
panh99 Oct 24, 2024
75bd174
Merge branch 'main' into err-reply-get-task-res
panh99 Oct 24, 2024
d7f1a6a
fix unit test
panh99 Oct 24, 2024
255085b
Merge remote-tracking branch 'refs/remotes/origin/err-reply-get-task-…
panh99 Oct 24, 2024
6a88ede
Merge remote-tracking branch 'origin/main' into err-reply-get-task-res
panh99 Oct 24, 2024
9d3263d
merge with main
panh99 Oct 24, 2024
7fcd794
rm node-unavailable error
panh99 Oct 24, 2024
8d3d5f6
Merge branch 'main' into err-reply-get-task-res
panh99 Oct 24, 2024
a0b1416
Merge branch 'main' into err-reply-get-task-res
panh99 Oct 25, 2024
450f405
Merge remote-tracking branch 'origin/main' into err-reply-get-task-res
panh99 Oct 25, 2024
11988f9
update docstring
panh99 Oct 25, 2024
143db6d
Merge branch 'main' into err-reply-get-task-res
panh99 Oct 25, 2024
22d9150
Merge branch 'main' into err-reply-get-task-res
panh99 Nov 5, 2024
7d08a02
format
panh99 Nov 5, 2024
61056b5
Merge branch 'main' into err-reply-get-task-res
panh99 Nov 5, 2024
6415337
Merge branch 'main' into err-reply-get-task-res
panh99 Nov 5, 2024
4702beb
Merge branch 'main' into err-reply-get-task-res
danieljanes Nov 5, 2024
8eece3e
Merge branch 'main' into err-reply-get-task-res
panh99 Nov 6, 2024
2b6a415
Merge remote-tracking branch 'origin/main' into err-reply-get-task-res
panh99 Nov 7, 2024
b346da5
Merge remote-tracking branch 'origin/main' into err-reply-get-task-res
panh99 Nov 15, 2024
6b390ec
update based on comments
panh99 Nov 15, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/py/flwr/common/constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,8 @@ class ErrorCode:
UNKNOWN = 0
LOAD_CLIENT_APP_EXCEPTION = 1
CLIENT_APP_RAISED_EXCEPTION = 2
MESSAGE_UNAVAILABLE = 3
REPLY_MESSAGE_UNAVAILABLE = 4

def __new__(cls) -> ErrorCode:
"""Prevent instantiation."""
Expand Down
2 changes: 1 addition & 1 deletion src/py/flwr/server/driver/inmemory_driver_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,6 @@ def test_task_store_consistency_after_push_pull_inmemory_state(self) -> None:
reply_tos = get_replies(self.driver, msg_ids, node_id)

# Assert
self.assertEqual(reply_tos, msg_ids)
self.assertEqual(set(reply_tos), set(msg_ids))
self.assertEqual(len(state.task_res_store), 0)
self.assertEqual(len(state.task_ins_store), 0)
83 changes: 55 additions & 28 deletions src/py/flwr/server/superlink/linkstate/in_memory_linkstate.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@
generate_rand_int_from_bytes,
has_valid_sub_status,
is_valid_transition,
verify_found_taskres,
verify_taskins_ids,
)


Expand Down Expand Up @@ -67,12 +69,13 @@ def __init__(self) -> None:
self.federation_options: dict[int, ConfigsRecord] = {}
self.task_ins_store: dict[UUID, TaskIns] = {}
self.task_res_store: dict[UUID, TaskRes] = {}
self.task_ins_id_to_task_res_id: dict[UUID, UUID] = {}

self.node_public_keys: set[bytes] = set()
self.server_public_key: Optional[bytes] = None
self.server_private_key: Optional[bytes] = None

self.lock = threading.Lock()
self.lock = threading.RLock()
panh99 marked this conversation as resolved.
Show resolved Hide resolved

def store_task_ins(self, task_ins: TaskIns) -> Optional[UUID]:
"""Store one TaskIns."""
Expand Down Expand Up @@ -222,42 +225,50 @@ def store_task_res(self, task_res: TaskRes) -> Optional[UUID]:
task_res.task_id = str(task_id)
with self.lock:
self.task_res_store[task_id] = task_res
self.task_ins_id_to_task_res_id[UUID(task_ins_id)] = task_id

# Return the new task_id
return task_id

def get_task_res(self, task_ids: set[UUID]) -> list[TaskRes]:
"""Get all TaskRes that have not been delivered yet."""
"""Get TaskRes for the given TaskIns IDs."""
ret: dict[UUID, TaskRes] = {}

with self.lock:
# Find TaskRes that were not delivered yet
task_res_list: list[TaskRes] = []
replied_task_ids: set[UUID] = set()
for _, task_res in self.task_res_store.items():
reply_to = UUID(task_res.task.ancestry[0])

# Check if corresponding TaskIns exists and is not expired
task_ins = self.task_ins_store.get(reply_to)
if task_ins is None:
log(WARNING, "TaskIns with task_id %s does not exist.", reply_to)
task_ids.remove(reply_to)
continue

if task_ins.task.created_at + task_ins.task.ttl <= time.time():
log(WARNING, "TaskIns with task_id %s is expired.", reply_to)
task_ids.remove(reply_to)
continue

if reply_to in task_ids and task_res.task.delivered_at == "":
task_res_list.append(task_res)
replied_task_ids.add(reply_to)

# Mark all of them as delivered
current = time.time()

# Verify TaskIns IDs
ret = verify_taskins_ids(
inquired_taskins_ids=task_ids,
found_taskins_dict=self.task_ins_store,
current_time=current,
)

# Find all TaskRes
task_res_found: list[TaskRes] = []
for task_id in task_ids:
# If TaskRes exists and is not delivered, add it to the list
if task_res_id := self.task_ins_id_to_task_res_id.get(task_id):
task_res = self.task_res_store[task_res_id]
if task_res.task.delivered_at == "":
task_res_found.append(task_res)
tmp_ret_dict = verify_found_taskres(
inquired_taskins_ids=task_ids,
found_taskins_dict=self.task_ins_store,
found_taskres_list=task_res_found,
current_time=current,
)
ret.update(tmp_ret_dict)

# Mark existing TaskRes to be returned as delivered
delivered_at = now().isoformat()
for task_res in task_res_list:
for task_res in task_res_found:
task_res.task.delivered_at = delivered_at

# Return TaskRes
return task_res_list
# Cleanup
self._force_delete_tasks_by_ids(set(ret.keys()))

return list(ret.values())

def delete_tasks(self, task_ids: set[UUID]) -> None:
"""Delete all delivered TaskIns/TaskRes pairs."""
Expand All @@ -278,9 +289,25 @@ def delete_tasks(self, task_ids: set[UUID]) -> None:

for task_id in task_ins_to_be_deleted:
del self.task_ins_store[task_id]
del self.task_ins_id_to_task_res_id[task_id]
for task_id in task_res_to_be_deleted:
del self.task_res_store[task_id]

def _force_delete_tasks_by_ids(self, task_ids: set[UUID]) -> None:
"""Delete tasks based on a set of TaskIns IDs."""
if not task_ids:
return

with self.lock:
for task_id in task_ids:
# Delete TaskIns
if task_id in self.task_ins_store:
del self.task_ins_store[task_id]
# Delete TaskRes
if task_id in self.task_ins_id_to_task_res_id:
task_res_id = self.task_ins_id_to_task_res_id.pop(task_id)
del self.task_res_store[task_res_id]

def num_task_ins(self) -> int:
"""Calculate the number of task_ins in store.

Expand Down
24 changes: 19 additions & 5 deletions src/py/flwr/server/superlink/linkstate/linkstate.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,13 +101,27 @@ def store_task_res(self, task_res: TaskRes) -> Optional[UUID]:

@abc.abstractmethod
def get_task_res(self, task_ids: set[UUID]) -> list[TaskRes]:
"""Get TaskRes for task_ids.
"""Get TaskRes for the given TaskIns IDs.

Usually, the ServerAppIo API calls this method to get results for instructions
it has previously scheduled.
This method is typically called by the ServerAppIo API to obtain
results (TaskRes) for previously scheduled instructions (TaskIns).
For each task_id provided, this method returns one of the following responses:

Retrieves all TaskRes for the given `task_ids` and returns and empty list of
none could be found.
- An error TaskRes if the corresponding TaskIns does not exist or has expired.
- An error TaskRes if the corresponding TaskRes exists but has expired.
- The valid TaskRes if the TaskIns has a corresponding valid TaskRes.
- Nothing if the TaskIns is still valid and waiting for a TaskRes.

Parameters
----------
task_ids : set[UUID]
A set of TaskIns IDs for which to retrieve results (TaskRes).

Returns
-------
list[TaskRes]
A list of TaskRes corresponding to the given task IDs. If no
TaskRes could be found for any of the task IDs, an empty list is returned.
"""

@abc.abstractmethod
Expand Down
19 changes: 8 additions & 11 deletions src/py/flwr/server/superlink/linkstate/linkstate_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,13 +350,6 @@ def test_store_and_delete_tasks(self) -> None:
# - State has three TaskIns, all of them delivered
# - State has two TaskRes, one of the delivered, the other not

assert state.num_task_ins() == 3
assert state.num_task_res() == 2

# Execute
state.delete_tasks(task_ids={task_id_0, task_id_1, task_id_2})

# Assert
assert state.num_task_ins() == 2
assert state.num_task_res() == 1

Expand Down Expand Up @@ -932,8 +925,8 @@ def test_get_task_ins_not_return_expired(self) -> None:
task_ins_list = state.get_task_ins(node_id=1, limit=None)
assert len(task_ins_list) == 0

def test_get_task_res_not_return_expired(self) -> None:
"""Test get_task_res not to return TaskRes if its TaskIns is expired."""
def test_get_task_res_expired_task_ins(self) -> None:
"""Test get_task_res to return error TaskRes if its TaskIns has expired."""
# Prepare
state = self.state_factory()
node_id = state.create_node(1e3)
Expand Down Expand Up @@ -961,7 +954,9 @@ def test_get_task_res_not_return_expired(self) -> None:
task_res_list = state.get_task_res(task_ids={task_id})

# Assert
assert len(task_res_list) == 0
assert len(task_res_list) == 1
assert task_res_list[0].task.HasField("error")
assert state.num_task_ins() == state.num_task_res() == 0

def test_get_task_res_returns_empty_for_missing_taskins(self) -> None:
"""Test that get_task_res returns an empty result when the corresponding TaskIns
Expand All @@ -983,7 +978,9 @@ def test_get_task_res_returns_empty_for_missing_taskins(self) -> None:
task_res_list = state.get_task_res(task_ids={UUID(task_ins_id)})

# Assert
assert len(task_res_list) == 0
assert len(task_res_list) == 1
assert task_res_list[0].task.HasField("error")
assert state.num_task_ins() == state.num_task_res() == 0

def test_get_task_res_return_if_not_expired(self) -> None:
"""Test get_task_res to return TaskRes if its TaskIns exists and is not
Expand Down
Loading