Skip to content

Commit

Permalink
feat(framework) Add get TaskIns from run_id to LinkState (#4682)
Browse files Browse the repository at this point in the history
Co-authored-by: Javier <jafermarq@users.noreply.github.com>
  • Loading branch information
chongshenng and jafermarq authored Dec 13, 2024
1 parent 46c2372 commit 331ed9b
Show file tree
Hide file tree
Showing 4 changed files with 70 additions and 0 deletions.
10 changes: 10 additions & 0 deletions src/py/flwr/server/superlink/linkstate/in_memory_linkstate.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,16 @@ def delete_tasks(self, task_ids: set[UUID]) -> None:
for task_id in task_res_to_be_deleted:
del self.task_res_store[task_id]

def get_task_ids_from_run_id(self, run_id: int) -> set[UUID]:
"""Get all TaskIns IDs for the given run_id."""
task_id_list: set[UUID] = set()
with self.lock:
for task_id, task_ins in self.task_ins_store.items():
if task_ins.run_id == run_id:
task_id_list.add(task_id)

return task_id_list

def _force_delete_tasks_by_ids(self, task_ids: set[UUID]) -> None:
"""Delete tasks based on a set of TaskIns IDs."""
if not task_ids:
Expand Down
4 changes: 4 additions & 0 deletions src/py/flwr/server/superlink/linkstate/linkstate.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,10 @@ def num_task_res(self) -> int:
def delete_tasks(self, task_ids: set[UUID]) -> None:
"""Delete all delivered TaskIns/TaskRes pairs."""

@abc.abstractmethod
def get_task_ids_from_run_id(self, run_id: int) -> set[UUID]:
"""Get all TaskIns IDs for the given run_id."""

@abc.abstractmethod
def create_node(
self, ping_interval: float, public_key: Optional[bytes] = None
Expand Down
37 changes: 37 additions & 0 deletions src/py/flwr/server/superlink/linkstate/linkstate_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,6 +353,43 @@ def test_store_and_delete_tasks(self) -> None:
assert state.num_task_ins() == 2
assert state.num_task_res() == 1

def test_get_task_ids_from_run_id(self) -> None:
"""Test get_task_ids_from_run_id."""
# Prepare
state = self.state_factory()
node_id = state.create_node(1e3)
run_id_0 = state.create_run(None, None, "8g13kl7", {}, ConfigsRecord())
# Insert tasks with the same run_id
task_ins_0 = create_task_ins(
consumer_node_id=node_id, anonymous=False, run_id=run_id_0
)
task_ins_1 = create_task_ins(
consumer_node_id=node_id, anonymous=False, run_id=run_id_0
)
# Insert a task with a different run_id to ensure it does not appear in result
run_id_1 = state.create_run(None, None, "9f86d08", {}, ConfigsRecord())
task_ins_2 = create_task_ins(
consumer_node_id=node_id, anonymous=False, run_id=run_id_1
)

# Insert three TaskIns
task_id_0 = state.store_task_ins(task_ins=task_ins_0)
task_id_1 = state.store_task_ins(task_ins=task_ins_1)
task_id_2 = state.store_task_ins(task_ins=task_ins_2)

assert task_id_0
assert task_id_1
assert task_id_2

expected_task_ids = {task_id_0, task_id_1}

# Execute
result = state.get_task_ids_from_run_id(run_id_0)
bad_result = state.get_task_ids_from_run_id(15)

self.assertEqual(len(bad_result), 0)
self.assertSetEqual(result, expected_task_ids)

# Init tests
def test_init_state(self) -> None:
"""Test that state is initialized correctly."""
Expand Down
19 changes: 19 additions & 0 deletions src/py/flwr/server/superlink/linkstate/sqlite_linkstate.py
Original file line number Diff line number Diff line change
Expand Up @@ -629,6 +629,25 @@ def delete_tasks(self, task_ids: set[UUID]) -> None:

return None

def get_task_ids_from_run_id(self, run_id: int) -> set[UUID]:
"""Get all TaskIns IDs for the given run_id."""
if self.conn is None:
raise AttributeError("LinkState not initialized")

query = """
SELECT task_id
FROM task_ins
WHERE run_id = :run_id;
"""

sint64_run_id = convert_uint64_to_sint64(run_id)
data = {"run_id": sint64_run_id}

with self.conn:
rows = self.conn.execute(query, data).fetchall()

return {UUID(row["task_id"]) for row in rows}

def _force_delete_tasks_by_ids(self, task_ids: set[UUID]) -> None:
"""Delete tasks based on a set of TaskIns IDs."""
if not task_ids:
Expand Down

0 comments on commit 331ed9b

Please sign in to comment.