Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(framework) Consolidate run_id and node_id creation logic #3569

Merged
merged 2 commits into from
Jun 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/py/flwr/common/constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@
PING_RANDOM_RANGE = (-0.1, 0.1)
PING_MAX_INTERVAL = 1e300

# IDs
RUN_ID_NUM_BYTES = 8
NODE_ID_NUM_BYTES = 8
GRPC_ADAPTER_METADATA_FLOWER_VERSION_KEY = "flower-version"
GRPC_ADAPTER_METADATA_SHOULD_EXIT_KEY = "should-exit"

Expand Down
6 changes: 3 additions & 3 deletions src/py/flwr/server/driver/inmemory_driver_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,14 @@
"""Tests for in-memory driver."""


import os
import time
import unittest
from typing import Iterable, List, Tuple
from unittest.mock import MagicMock, patch
from uuid import uuid4

from flwr.common import RecordSet
from flwr.common.constant import PING_MAX_INTERVAL
from flwr.common.constant import NODE_ID_NUM_BYTES, PING_MAX_INTERVAL
from flwr.common.message import Error
from flwr.common.serde import (
error_to_proto,
Expand All @@ -34,6 +33,7 @@
from flwr.common.typing import Run
from flwr.proto.task_pb2 import Task, TaskRes # pylint: disable=E0611
from flwr.server.superlink.state import InMemoryState, SqliteState, StateFactory
from flwr.server.superlink.state.utils import generate_rand_int_from_bytes

from .inmemory_driver import InMemoryDriver

Expand Down Expand Up @@ -82,7 +82,7 @@ def setUp(self) -> None:
self.num_nodes = 42
self.state = MagicMock()
self.state.get_nodes.return_value = [
int.from_bytes(os.urandom(8), "little", signed=True)
generate_rand_int_from_bytes(NODE_ID_NUM_BYTES)
for _ in range(self.num_nodes)
]
self.state.get_run.return_value = Run(
Expand Down
8 changes: 4 additions & 4 deletions src/py/flwr/server/superlink/state/in_memory_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,20 +15,20 @@
"""In-memory State implementation."""


import os
import threading
import time
from logging import ERROR
from typing import Dict, List, Optional, Set, Tuple
from uuid import UUID, uuid4

from flwr.common import log, now
from flwr.common.constant import NODE_ID_NUM_BYTES, RUN_ID_NUM_BYTES
from flwr.common.typing import Run
from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611
from flwr.server.superlink.state.state import State
from flwr.server.utils import validate_task_ins_or_res

from .utils import make_node_unavailable_taskres
from .utils import generate_rand_int_from_bytes, make_node_unavailable_taskres


class InMemoryState(State): # pylint: disable=R0902,R0904
Expand Down Expand Up @@ -216,7 +216,7 @@ def create_node(
) -> int:
"""Create, store in state, and return `node_id`."""
# Sample a random int64 as node_id
node_id: int = int.from_bytes(os.urandom(8), "little", signed=True)
node_id = generate_rand_int_from_bytes(NODE_ID_NUM_BYTES)

with self.lock:
if node_id in self.node_ids:
Expand Down Expand Up @@ -279,7 +279,7 @@ def create_run(self, fab_id: str, fab_version: str) -> int:
"""Create a new run for the specified `fab_id` and `fab_version`."""
# Sample a random int64 as run_id
with self.lock:
run_id: int = int.from_bytes(os.urandom(8), "little", signed=True)
run_id = generate_rand_int_from_bytes(RUN_ID_NUM_BYTES)

if run_id not in self.run_ids:
self.run_ids[run_id] = Run(
Expand Down
8 changes: 4 additions & 4 deletions src/py/flwr/server/superlink/state/sqlite_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
"""SQLite based implemenation of server state."""


import os
import re
import sqlite3
import time
Expand All @@ -24,14 +23,15 @@
from uuid import UUID, uuid4

from flwr.common import log, now
from flwr.common.constant import NODE_ID_NUM_BYTES, RUN_ID_NUM_BYTES
from flwr.common.typing import Run
from flwr.proto.node_pb2 import Node # pylint: disable=E0611
from flwr.proto.recordset_pb2 import RecordSet # pylint: disable=E0611
from flwr.proto.task_pb2 import Task, TaskIns, TaskRes # pylint: disable=E0611
from flwr.server.utils.validator import validate_task_ins_or_res

from .state import State
from .utils import make_node_unavailable_taskres
from .utils import generate_rand_int_from_bytes, make_node_unavailable_taskres

SQL_CREATE_TABLE_NODE = """
CREATE TABLE IF NOT EXISTS node(
Expand Down Expand Up @@ -541,7 +541,7 @@ def create_node(
) -> int:
"""Create, store in state, and return `node_id`."""
# Sample a random int64 as node_id
node_id: int = int.from_bytes(os.urandom(8), "little", signed=True)
node_id = generate_rand_int_from_bytes(NODE_ID_NUM_BYTES)

query = "SELECT node_id FROM node WHERE public_key = :public_key;"
row = self.query(query, {"public_key": public_key})
Expand Down Expand Up @@ -616,7 +616,7 @@ def get_node_id(self, client_public_key: bytes) -> Optional[int]:
def create_run(self, fab_id: str, fab_version: str) -> int:
"""Create a new run for the specified `fab_id` and `fab_version`."""
# Sample a random int64 as run_id
run_id: int = int.from_bytes(os.urandom(8), "little", signed=True)
run_id = generate_rand_int_from_bytes(RUN_ID_NUM_BYTES)

# Check conflicts
query = "SELECT COUNT(*) FROM run WHERE run_id = ?;"
Expand Down
6 changes: 6 additions & 0 deletions src/py/flwr/server/superlink/state/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import time
from logging import ERROR
from os import urandom
from uuid import uuid4

from flwr.common import log
Expand All @@ -31,6 +32,11 @@
)


def generate_rand_int_from_bytes(num_bytes: int) -> int:
"""Generate a random `num_bytes` integer."""
return int.from_bytes(urandom(num_bytes), "little", signed=True)


def make_node_unavailable_taskres(ref_taskins: TaskIns) -> TaskRes:
"""Generate a TaskRes with a node unavailable error from a TaskIns."""
current_time = time.time()
Expand Down