diff --git a/src/py/flwr/server/superlink/linkstate/in_memory_linkstate.py b/src/py/flwr/server/superlink/linkstate/in_memory_linkstate.py index c2273a36a5db..6d719a7dd377 100644 --- a/src/py/flwr/server/superlink/linkstate/in_memory_linkstate.py +++ b/src/py/flwr/server/superlink/linkstate/in_memory_linkstate.py @@ -293,6 +293,16 @@ def delete_tasks(self, task_ids: set[UUID]) -> None: for task_id in task_res_to_be_deleted: del self.task_res_store[task_id] + def get_task_ids_from_run_id(self, run_id: int) -> set[UUID]: + """Get all TaskIns IDs for the given run_id.""" + task_id_list: set[UUID] = set() + with self.lock: + for task_id, task_ins in self.task_ins_store.items(): + if task_ins.run_id == run_id: + task_id_list.add(task_id) + + return task_id_list + def _force_delete_tasks_by_ids(self, task_ids: set[UUID]) -> None: """Delete tasks based on a set of TaskIns IDs.""" if not task_ids: diff --git a/src/py/flwr/server/superlink/linkstate/linkstate.py b/src/py/flwr/server/superlink/linkstate/linkstate.py index 05fb3c2f0cc6..c5ab7efa8cf2 100644 --- a/src/py/flwr/server/superlink/linkstate/linkstate.py +++ b/src/py/flwr/server/superlink/linkstate/linkstate.py @@ -142,6 +142,10 @@ def num_task_res(self) -> int: def delete_tasks(self, task_ids: set[UUID]) -> None: """Delete all delivered TaskIns/TaskRes pairs.""" + @abc.abstractmethod + def get_task_ids_from_run_id(self, run_id: int) -> set[UUID]: + """Get all TaskIns IDs for the given run_id.""" + @abc.abstractmethod def create_node( self, ping_interval: float, public_key: Optional[bytes] = None diff --git a/src/py/flwr/server/superlink/linkstate/linkstate_test.py b/src/py/flwr/server/superlink/linkstate/linkstate_test.py index 202fdf387277..15b97ee1a0a1 100644 --- a/src/py/flwr/server/superlink/linkstate/linkstate_test.py +++ b/src/py/flwr/server/superlink/linkstate/linkstate_test.py @@ -353,6 +353,43 @@ def test_store_and_delete_tasks(self) -> None: assert state.num_task_ins() == 2 assert state.num_task_res() == 1 + def test_get_task_ids_from_run_id(self) -> None: + """Test get_task_ids_from_run_id.""" + # Prepare + state = self.state_factory() + node_id = state.create_node(1e3) + run_id_0 = state.create_run(None, None, "8g13kl7", {}, ConfigsRecord()) + # Insert tasks with the same run_id + task_ins_0 = create_task_ins( + consumer_node_id=node_id, anonymous=False, run_id=run_id_0 + ) + task_ins_1 = create_task_ins( + consumer_node_id=node_id, anonymous=False, run_id=run_id_0 + ) + # Insert a task with a different run_id to ensure it does not appear in result + run_id_1 = state.create_run(None, None, "9f86d08", {}, ConfigsRecord()) + task_ins_2 = create_task_ins( + consumer_node_id=node_id, anonymous=False, run_id=run_id_1 + ) + + # Insert three TaskIns + task_id_0 = state.store_task_ins(task_ins=task_ins_0) + task_id_1 = state.store_task_ins(task_ins=task_ins_1) + task_id_2 = state.store_task_ins(task_ins=task_ins_2) + + assert task_id_0 + assert task_id_1 + assert task_id_2 + + expected_task_ids = {task_id_0, task_id_1} + + # Execute + result = state.get_task_ids_from_run_id(run_id_0) + bad_result = state.get_task_ids_from_run_id(15) + + self.assertEqual(len(bad_result), 0) + self.assertSetEqual(result, expected_task_ids) + # Init tests def test_init_state(self) -> None: """Test that state is initialized correctly.""" diff --git a/src/py/flwr/server/superlink/linkstate/sqlite_linkstate.py b/src/py/flwr/server/superlink/linkstate/sqlite_linkstate.py index 54df4685bf9a..8e4043582d14 100644 --- a/src/py/flwr/server/superlink/linkstate/sqlite_linkstate.py +++ b/src/py/flwr/server/superlink/linkstate/sqlite_linkstate.py @@ -629,6 +629,25 @@ def delete_tasks(self, task_ids: set[UUID]) -> None: return None + def get_task_ids_from_run_id(self, run_id: int) -> set[UUID]: + """Get all TaskIns IDs for the given run_id.""" + if self.conn is None: + raise AttributeError("LinkState not initialized") + + query = """ + SELECT task_id + FROM task_ins + WHERE run_id = :run_id; + """ + + sint64_run_id = convert_uint64_to_sint64(run_id) + data = {"run_id": sint64_run_id} + + with self.conn: + rows = self.conn.execute(query, data).fetchall() + + return {UUID(row["task_id"]) for row in rows} + def _force_delete_tasks_by_ids(self, task_ids: set[UUID]) -> None: """Delete tasks based on a set of TaskIns IDs.""" if not task_ids: