diff --git a/src/py/flwr/common/constant.py b/src/py/flwr/common/constant.py index ce29b3edb30e..f14959589458 100644 --- a/src/py/flwr/common/constant.py +++ b/src/py/flwr/common/constant.py @@ -46,6 +46,9 @@ PING_RANDOM_RANGE = (-0.1, 0.1) PING_MAX_INTERVAL = 1e300 +# IDs +RUN_ID_NUM_BYTES = 8 +NODE_ID_NUM_BYTES = 8 GRPC_ADAPTER_METADATA_FLOWER_VERSION_KEY = "flower-version" GRPC_ADAPTER_METADATA_SHOULD_EXIT_KEY = "should-exit" diff --git a/src/py/flwr/server/driver/inmemory_driver_test.py b/src/py/flwr/server/driver/inmemory_driver_test.py index eff38f548826..0cc1c5a53e13 100644 --- a/src/py/flwr/server/driver/inmemory_driver_test.py +++ b/src/py/flwr/server/driver/inmemory_driver_test.py @@ -15,7 +15,6 @@ """Tests for in-memory driver.""" -import os import time import unittest from typing import Iterable, List, Tuple @@ -23,7 +22,7 @@ from uuid import uuid4 from flwr.common import RecordSet -from flwr.common.constant import PING_MAX_INTERVAL +from flwr.common.constant import NODE_ID_NUM_BYTES, PING_MAX_INTERVAL from flwr.common.message import Error from flwr.common.serde import ( error_to_proto, @@ -34,6 +33,7 @@ from flwr.common.typing import Run from flwr.proto.task_pb2 import Task, TaskRes # pylint: disable=E0611 from flwr.server.superlink.state import InMemoryState, SqliteState, StateFactory +from flwr.server.superlink.state.utils import generate_rand_int_from_bytes from .inmemory_driver import InMemoryDriver @@ -82,7 +82,7 @@ def setUp(self) -> None: self.num_nodes = 42 self.state = MagicMock() self.state.get_nodes.return_value = [ - int.from_bytes(os.urandom(8), "little", signed=True) + generate_rand_int_from_bytes(NODE_ID_NUM_BYTES) for _ in range(self.num_nodes) ] self.state.get_run.return_value = Run( diff --git a/src/py/flwr/server/superlink/state/in_memory_state.py b/src/py/flwr/server/superlink/state/in_memory_state.py index da9c754c3115..5a4e4eb0fd9a 100644 --- a/src/py/flwr/server/superlink/state/in_memory_state.py +++ b/src/py/flwr/server/superlink/state/in_memory_state.py @@ -15,7 +15,6 @@ """In-memory State implementation.""" -import os import threading import time from logging import ERROR @@ -23,12 +22,13 @@ from uuid import UUID, uuid4 from flwr.common import log, now +from flwr.common.constant import NODE_ID_NUM_BYTES, RUN_ID_NUM_BYTES from flwr.common.typing import Run from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611 from flwr.server.superlink.state.state import State from flwr.server.utils import validate_task_ins_or_res -from .utils import make_node_unavailable_taskres +from .utils import generate_rand_int_from_bytes, make_node_unavailable_taskres class InMemoryState(State): # pylint: disable=R0902,R0904 @@ -216,7 +216,7 @@ def create_node( ) -> int: """Create, store in state, and return `node_id`.""" # Sample a random int64 as node_id - node_id: int = int.from_bytes(os.urandom(8), "little", signed=True) + node_id = generate_rand_int_from_bytes(NODE_ID_NUM_BYTES) with self.lock: if node_id in self.node_ids: @@ -279,7 +279,7 @@ def create_run(self, fab_id: str, fab_version: str) -> int: """Create a new run for the specified `fab_id` and `fab_version`.""" # Sample a random int64 as run_id with self.lock: - run_id: int = int.from_bytes(os.urandom(8), "little", signed=True) + run_id = generate_rand_int_from_bytes(RUN_ID_NUM_BYTES) if run_id not in self.run_ids: self.run_ids[run_id] = Run( diff --git a/src/py/flwr/server/superlink/state/sqlite_state.py b/src/py/flwr/server/superlink/state/sqlite_state.py index 4df9470ded62..725f7c2dff4b 100644 --- a/src/py/flwr/server/superlink/state/sqlite_state.py +++ b/src/py/flwr/server/superlink/state/sqlite_state.py @@ -15,7 +15,6 @@ """SQLite based implemenation of server state.""" -import os import re import sqlite3 import time @@ -24,6 +23,7 @@ from uuid import UUID, uuid4 from flwr.common import log, now +from flwr.common.constant import NODE_ID_NUM_BYTES, RUN_ID_NUM_BYTES from flwr.common.typing import Run from flwr.proto.node_pb2 import Node # pylint: disable=E0611 from flwr.proto.recordset_pb2 import RecordSet # pylint: disable=E0611 @@ -31,7 +31,7 @@ from flwr.server.utils.validator import validate_task_ins_or_res from .state import State -from .utils import make_node_unavailable_taskres +from .utils import generate_rand_int_from_bytes, make_node_unavailable_taskres SQL_CREATE_TABLE_NODE = """ CREATE TABLE IF NOT EXISTS node( @@ -541,7 +541,7 @@ def create_node( ) -> int: """Create, store in state, and return `node_id`.""" # Sample a random int64 as node_id - node_id: int = int.from_bytes(os.urandom(8), "little", signed=True) + node_id = generate_rand_int_from_bytes(NODE_ID_NUM_BYTES) query = "SELECT node_id FROM node WHERE public_key = :public_key;" row = self.query(query, {"public_key": public_key}) @@ -616,7 +616,7 @@ def get_node_id(self, client_public_key: bytes) -> Optional[int]: def create_run(self, fab_id: str, fab_version: str) -> int: """Create a new run for the specified `fab_id` and `fab_version`.""" # Sample a random int64 as run_id - run_id: int = int.from_bytes(os.urandom(8), "little", signed=True) + run_id = generate_rand_int_from_bytes(RUN_ID_NUM_BYTES) # Check conflicts query = "SELECT COUNT(*) FROM run WHERE run_id = ?;" diff --git a/src/py/flwr/server/superlink/state/utils.py b/src/py/flwr/server/superlink/state/utils.py index 233a90946cc7..b12a87ac998d 100644 --- a/src/py/flwr/server/superlink/state/utils.py +++ b/src/py/flwr/server/superlink/state/utils.py @@ -17,6 +17,7 @@ import time from logging import ERROR +from os import urandom from uuid import uuid4 from flwr.common import log @@ -31,6 +32,11 @@ ) +def generate_rand_int_from_bytes(num_bytes: int) -> int: + """Generate a random `num_bytes` integer.""" + return int.from_bytes(urandom(num_bytes), "little", signed=True) + + def make_node_unavailable_taskres(ref_taskins: TaskIns) -> TaskRes: """Generate a TaskRes with a node unavailable error from a TaskIns.""" current_time = time.time()