Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(framework) Add fab_hash to Run #4006

Merged
merged 12 commits into from
Aug 15, 2024
2 changes: 1 addition & 1 deletion src/py/flwr/client/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,7 +398,7 @@ def _on_backoff(retry_state: RetryState) -> None:
runs[run_id] = get_run(run_id)
# If get_run is None, i.e., in grpc-bidi mode
else:
runs[run_id] = Run(run_id, "", "", {})
runs[run_id] = Run(run_id, "", "", "", {})

# Register context for this run
node_state.register_context(
Expand Down
1 change: 1 addition & 0 deletions src/py/flwr/client/grpc_rere_client/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,7 @@ def get_run(run_id: int) -> Run:
run_id,
get_run_response.run.fab_id,
get_run_response.run.fab_version,
get_run_response.run.fab_hash,
user_config_from_proto(get_run_response.run.override_config),
)

Expand Down
3 changes: 2 additions & 1 deletion src/py/flwr/client/rest_client/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,12 +358,13 @@ def get_run(run_id: int) -> Run:
# Send the request
res = _request(req, GetRunResponse, PATH_GET_RUN)
if res is None:
return Run(run_id, "", "", {})
return Run(run_id, "", "", "", {})

return Run(
run_id,
res.run.fab_id,
res.run.fab_version,
res.run.fab_hash,
user_config_from_proto(res.run.override_config),
)

Expand Down
3 changes: 2 additions & 1 deletion src/py/flwr/common/serde.py
Original file line number Diff line number Diff line change
Expand Up @@ -850,8 +850,8 @@ def run_to_proto(run: typing.Run) -> ProtoRun:
run_id=run.run_id,
fab_id=run.fab_id,
fab_version=run.fab_version,
fab_hash=run.fab_hash,
override_config=user_config_to_proto(run.override_config),
fab_hash="",
)
return proto

Expand All @@ -862,6 +862,7 @@ def run_from_proto(run_proto: ProtoRun) -> typing.Run:
run_id=run_proto.run_id,
fab_id=run_proto.fab_id,
fab_version=run_proto.fab_version,
fab_hash=run_proto.fab_hash,
override_config=user_config_from_proto(run_proto.override_config),
)
return run
Expand Down
1 change: 1 addition & 0 deletions src/py/flwr/common/serde_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -519,6 +519,7 @@ def test_run_serialization_deserialization() -> None:
run_id=1,
fab_id="lorem",
fab_version="ipsum",
fab_hash="hash",
override_config=maker.user_config(),
)

Expand Down
1 change: 1 addition & 0 deletions src/py/flwr/common/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,7 @@ class Run:
run_id: int
fab_id: str
fab_version: str
fab_hash: str
override_config: UserConfig


Expand Down
12 changes: 12 additions & 0 deletions src/py/flwr/server/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@

from flwr.common import GRPC_MAX_MESSAGE_LENGTH, EventType, event
from flwr.common.address import parse_address
from flwr.common.config import get_flwr_dir
from flwr.common.constant import (
MISSING_EXTRA_REST,
TRANSPORT_TYPE_GRPC_ADAPTER,
Expand All @@ -57,6 +58,7 @@
from .server_config import ServerConfig
from .strategy import Strategy
from .superlink.driver.driver_grpc import run_driver_api_grpc
from .superlink.ffs.ffs_factory import FfsFactory
from .superlink.fleet.grpc_adapter.grpc_adapter_servicer import GrpcAdapterServicer
from .superlink.fleet.grpc_bidi.grpc_server import (
generic_create_grpc_server,
Expand All @@ -72,6 +74,7 @@
ADDRESS_FLEET_API_REST = "0.0.0.0:9093"

DATABASE = ":flwr-in-memory-state:"
BASE_DIR = get_flwr_dir() / "superlink" / "ffs"


def start_server( # pylint: disable=too-many-arguments,too-many-locals
Expand Down Expand Up @@ -211,10 +214,14 @@ def run_superlink() -> None:
# Initialize StateFactory
state_factory = StateFactory(args.database)

# Initialize FfsFactory
ffs_factory = FfsFactory(args.storage_dir)

# Start Driver API
driver_server: grpc.Server = run_driver_api_grpc(
address=driver_address,
state_factory=state_factory,
ffs_factory=ffs_factory,
certificates=certificates,
)

Expand Down Expand Up @@ -610,6 +617,11 @@ def _add_args_common(parser: argparse.ArgumentParser) -> None:
"Flower will just create a state in memory.",
default=DATABASE,
)
parser.add_argument(
"--storage-dir",
help="The base directory to store the objects for the Flower File System.",
default=BASE_DIR,
)
parser.add_argument(
"--auth-list-public-keys",
type=str,
Expand Down
1 change: 1 addition & 0 deletions src/py/flwr/server/driver/grpc_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ def _init_run(self) -> None:
run_id=res.run.run_id,
fab_id=res.run.fab_id,
fab_version=res.run.fab_version,
fab_hash=res.run.fab_hash,
override_config=user_config_from_proto(res.run.override_config),
)

Expand Down
6 changes: 4 additions & 2 deletions src/py/flwr/server/driver/inmemory_driver_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ def setUp(self) -> None:
run_id=61016,
fab_id="mock/mock",
fab_version="v1.0.0",
fab_hash="hash",
override_config={"test_key": "test_value"},
)
state_factory = MagicMock(state=lambda: self.state)
Expand All @@ -101,6 +102,7 @@ def test_get_run(self) -> None:
self.assertEqual(self.driver.run.run_id, 61016)
self.assertEqual(self.driver.run.fab_id, "mock/mock")
self.assertEqual(self.driver.run.fab_version, "v1.0.0")
self.assertEqual(self.driver.run.fab_hash, "hash")
self.assertEqual(self.driver.run.override_config["test_key"], "test_value")

def test_get_nodes(self) -> None:
Expand Down Expand Up @@ -227,7 +229,7 @@ def test_task_store_consistency_after_push_pull_sqlitestate(self) -> None:
# Prepare
state = StateFactory("").state()
self.driver = InMemoryDriver(
state.create_run("", "", {}), MagicMock(state=lambda: state)
state.create_run("", "", "", {}), MagicMock(state=lambda: state)
)
msg_ids, node_id = push_messages(self.driver, self.num_nodes)
assert isinstance(state, SqliteState)
Expand All @@ -253,7 +255,7 @@ def test_task_store_consistency_after_push_pull_inmemory_state(self) -> None:
# Prepare
state_factory = StateFactory(":flwr-in-memory-state:")
state = state_factory.state()
self.driver = InMemoryDriver(state.create_run("", "", {}), state_factory)
self.driver = InMemoryDriver(state.create_run("", "", "", {}), state_factory)
msg_ids, node_id = push_messages(self.driver, self.num_nodes)
assert isinstance(state, InMemoryState)

Expand Down
3 changes: 3 additions & 0 deletions src/py/flwr/server/superlink/driver/driver_grpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from flwr.proto.driver_pb2_grpc import ( # pylint: disable=E0611
add_DriverServicer_to_server,
)
from flwr.server.superlink.ffs.ffs_factory import FfsFactory
from flwr.server.superlink.state import StateFactory

from ..fleet.grpc_bidi.grpc_server import generic_create_grpc_server
Expand All @@ -33,12 +34,14 @@
def run_driver_api_grpc(
address: str,
state_factory: StateFactory,
ffs_factory: FfsFactory,
certificates: Optional[Tuple[bytes, bytes, bytes]],
) -> grpc.Server:
"""Run Driver API (gRPC, request-response)."""
# Create Driver API gRPC server
driver_servicer: grpc.Server = DriverServicer(
state_factory=state_factory,
ffs_factory=ffs_factory,
)
driver_add_servicer_to_server_fn = add_DriverServicer_to_server
driver_grpc_server = generic_create_grpc_server(
Expand Down
15 changes: 14 additions & 1 deletion src/py/flwr/server/superlink/driver/driver_servicer.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,15 +43,18 @@
Run,
)
from flwr.proto.task_pb2 import TaskRes # pylint: disable=E0611
from flwr.server.superlink.ffs import Ffs
from flwr.server.superlink.ffs.ffs_factory import FfsFactory
from flwr.server.superlink.state import State, StateFactory
from flwr.server.utils.validator import validate_task_ins_or_res


class DriverServicer(driver_pb2_grpc.DriverServicer):
"""Driver API servicer."""

def __init__(self, state_factory: StateFactory) -> None:
def __init__(self, state_factory: StateFactory, ffs_factory: FfsFactory) -> None:
self.state_factory = state_factory
self.ffs_factory = ffs_factory

def GetNodes(
self, request: GetNodesRequest, context: grpc.ServicerContext
Expand All @@ -71,9 +74,19 @@ def CreateRun(
"""Create run ID."""
log(DEBUG, "DriverServicer.CreateRun")
state: State = self.state_factory.state()
if request.HasField("fab") and request.fab.HasField("content"):
ffs: Ffs = self.ffs_factory.ffs()
fab_hash = ffs.put(request.fab.content, {})
_raise_if(
fab_hash != request.fab.hash_str,
f"FAB ({request.fab}) hash from request doesn't match contents",
)
else:
fab_hash = ""
run_id = state.create_run(
request.fab_id,
request.fab_version,
fab_hash,
user_config_from_proto(request.override_config),
)
return CreateRunResponse(run_id=run_id)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -328,7 +328,7 @@ def test_successful_get_run_with_metadata(self) -> None:
self.state.create_node(
ping_interval=30, public_key=public_key_to_bytes(self._client_public_key)
)
run_id = self.state.create_run("", "", {})
run_id = self.state.create_run("", "", "", {})
request = GetRunRequest(run_id=run_id)
shared_secret = generate_shared_key(
self._client_private_key, self._server_public_key
Expand Down Expand Up @@ -359,7 +359,7 @@ def test_unsuccessful_get_run_with_metadata(self) -> None:
self.state.create_node(
ping_interval=30, public_key=public_key_to_bytes(self._client_public_key)
)
run_id = self.state.create_run("", "", {})
run_id = self.state.create_run("", "", "", {})
request = GetRunRequest(run_id=run_id)
client_private_key, _ = generate_key_pairs()
shared_secret = generate_shared_key(client_private_key, self._server_public_key)
Expand Down
8 changes: 6 additions & 2 deletions src/py/flwr/server/superlink/fleet/vce/vce_api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,11 @@ def register_messages_into_state(
"""Register `num_messages` into the state factory."""
state: InMemoryState = state_factory.state() # type: ignore
state.run_ids[run_id] = Run(
run_id=run_id, fab_id="Mock/mock", fab_version="v1.0.0", override_config={}
run_id=run_id,
fab_id="Mock/mock",
fab_version="v1.0.0",
fab_hash="hash",
override_config={},
)
# Artificially add TaskIns to state so they can be processed
# by the Simulation Engine logic
Expand Down Expand Up @@ -192,7 +196,7 @@ def start_and_shutdown(
if not app_dir:
app_dir = _autoresolve_app_dir()

run = Run(run_id=1234, fab_id="", fab_version="", override_config={})
run = Run(run_id=1234, fab_id="", fab_version="", fab_hash="", override_config={})

start_vce(
num_supernodes=num_supernodes,
Expand Down
12 changes: 7 additions & 5 deletions src/py/flwr/server/superlink/state/in_memory_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,20 +277,22 @@ def get_node_id(self, client_public_key: bytes) -> Optional[int]:

def create_run(
self,
fab_id: str,
fab_version: str,
fab_id: Optional[str],
fab_version: Optional[str],
fab_hash: Optional[str],
override_config: UserConfig,
) -> int:
"""Create a new run for the specified `fab_id` and `fab_version`."""
"""Create a new run for the specified `fab_hash`."""
# Sample a random int64 as run_id
with self.lock:
run_id = generate_rand_int_from_bytes(RUN_ID_NUM_BYTES)

if run_id not in self.run_ids:
self.run_ids[run_id] = Run(
run_id=run_id,
fab_id=fab_id,
fab_version=fab_version,
fab_id=fab_id if fab_id else "",
fab_version=fab_version if fab_version else "",
fab_hash=fab_hash if fab_hash else "",
override_config=override_config,
)
return run_id
Expand Down
24 changes: 17 additions & 7 deletions src/py/flwr/server/superlink/state/sqlite_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@
run_id INTEGER UNIQUE,
fab_id TEXT,
fab_version TEXT,
fab_hash TEXT,
override_config TEXT
);
"""
Expand Down Expand Up @@ -617,8 +618,9 @@ def get_node_id(self, client_public_key: bytes) -> Optional[int]:

def create_run(
self,
fab_id: str,
fab_version: str,
fab_id: Optional[str],
fab_version: Optional[str],
fab_hash: Optional[str],
override_config: UserConfig,
) -> int:
"""Create a new run for the specified `fab_id` and `fab_version`."""
Expand All @@ -630,12 +632,19 @@ def create_run(
# If run_id does not exist
if self.query(query, (run_id,))[0]["COUNT(*)"] == 0:
query = (
"INSERT INTO run (run_id, fab_id, fab_version, override_config)"
"VALUES (?, ?, ?, ?);"
)
self.query(
query, (run_id, fab_id, fab_version, json.dumps(override_config))
"INSERT INTO run "
"(run_id, fab_id, fab_version, fab_hash, override_config)"
"VALUES (?, ?, ?, ?, ?);"
)
if fab_hash:
self.query(
query, (run_id, "", "", fab_hash, json.dumps(override_config))
)
else:
self.query(
query,
(run_id, fab_id, fab_version, "", json.dumps(override_config)),
)
return run_id
log(ERROR, "Unexpected run creation failure.")
return 0
Expand Down Expand Up @@ -702,6 +711,7 @@ def get_run(self, run_id: int) -> Optional[Run]:
run_id=run_id,
fab_id=row["fab_id"],
fab_version=row["fab_version"],
fab_hash=row["fab_hash"],
override_config=json.loads(row["override_config"]),
)
except sqlite3.IntegrityError:
Expand Down
7 changes: 4 additions & 3 deletions src/py/flwr/server/superlink/state/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,11 +159,12 @@ def get_node_id(self, client_public_key: bytes) -> Optional[int]:
@abc.abstractmethod
def create_run(
self,
fab_id: str,
fab_version: str,
fab_id: Optional[str],
fab_version: Optional[str],
fab_hash: Optional[str],
override_config: UserConfig,
) -> int:
"""Create a new run for the specified `fab_id` and `fab_version`."""
"""Create a new run for the specified `fab_hash`."""

@abc.abstractmethod
def get_run(self, run_id: int) -> Optional[Run]:
Expand Down
Loading