diff --git a/src/py/flwr/client/app.py b/src/py/flwr/client/app.py index 7acd7535382..e20eee78e63 100644 --- a/src/py/flwr/client/app.py +++ b/src/py/flwr/client/app.py @@ -398,7 +398,7 @@ def _on_backoff(retry_state: RetryState) -> None: runs[run_id] = get_run(run_id) # If get_run is None, i.e., in grpc-bidi mode else: - runs[run_id] = Run(run_id, "", "", {}) + runs[run_id] = Run(run_id, "", "", "", {}) # Register context for this run node_state.register_context( diff --git a/src/py/flwr/client/grpc_rere_client/connection.py b/src/py/flwr/client/grpc_rere_client/connection.py index af74125140e..155beb8a563 100644 --- a/src/py/flwr/client/grpc_rere_client/connection.py +++ b/src/py/flwr/client/grpc_rere_client/connection.py @@ -286,6 +286,7 @@ def get_run(run_id: int) -> Run: run_id, get_run_response.run.fab_id, get_run_response.run.fab_version, + get_run_response.run.fab_hash, user_config_from_proto(get_run_response.run.override_config), ) diff --git a/src/py/flwr/client/rest_client/connection.py b/src/py/flwr/client/rest_client/connection.py index 2da320622c1..3f9147304fd 100644 --- a/src/py/flwr/client/rest_client/connection.py +++ b/src/py/flwr/client/rest_client/connection.py @@ -358,12 +358,13 @@ def get_run(run_id: int) -> Run: # Send the request res = _request(req, GetRunResponse, PATH_GET_RUN) if res is None: - return Run(run_id, "", "", {}) + return Run(run_id, "", "", "", {}) return Run( run_id, res.run.fab_id, res.run.fab_version, + res.run.fab_hash, user_config_from_proto(res.run.override_config), ) diff --git a/src/py/flwr/common/serde.py b/src/py/flwr/common/serde.py index 819b113d041..76265b9836d 100644 --- a/src/py/flwr/common/serde.py +++ b/src/py/flwr/common/serde.py @@ -850,8 +850,8 @@ def run_to_proto(run: typing.Run) -> ProtoRun: run_id=run.run_id, fab_id=run.fab_id, fab_version=run.fab_version, + fab_hash=run.fab_hash, override_config=user_config_to_proto(run.override_config), - fab_hash="", ) return proto @@ -862,6 +862,7 @@ def run_from_proto(run_proto: ProtoRun) -> typing.Run: run_id=run_proto.run_id, fab_id=run_proto.fab_id, fab_version=run_proto.fab_version, + fab_hash=run_proto.fab_hash, override_config=user_config_from_proto(run_proto.override_config), ) return run diff --git a/src/py/flwr/common/serde_test.py b/src/py/flwr/common/serde_test.py index f3279c10e34..013d04a32fd 100644 --- a/src/py/flwr/common/serde_test.py +++ b/src/py/flwr/common/serde_test.py @@ -519,6 +519,7 @@ def test_run_serialization_deserialization() -> None: run_id=1, fab_id="lorem", fab_version="ipsum", + fab_hash="hash", override_config=maker.user_config(), ) diff --git a/src/py/flwr/common/typing.py b/src/py/flwr/common/typing.py index 68a2b5825f0..b1dec8d0420 100644 --- a/src/py/flwr/common/typing.py +++ b/src/py/flwr/common/typing.py @@ -214,6 +214,7 @@ class Run: run_id: int fab_id: str fab_version: str + fab_hash: str override_config: UserConfig diff --git a/src/py/flwr/server/app.py b/src/py/flwr/server/app.py index 822defdb5b1..fc60cb331b7 100644 --- a/src/py/flwr/server/app.py +++ b/src/py/flwr/server/app.py @@ -34,6 +34,7 @@ from flwr.common import GRPC_MAX_MESSAGE_LENGTH, EventType, event from flwr.common.address import parse_address +from flwr.common.config import get_flwr_dir from flwr.common.constant import ( MISSING_EXTRA_REST, TRANSPORT_TYPE_GRPC_ADAPTER, @@ -57,6 +58,7 @@ from .server_config import ServerConfig from .strategy import Strategy from .superlink.driver.driver_grpc import run_driver_api_grpc +from .superlink.ffs.ffs_factory import FfsFactory from .superlink.fleet.grpc_adapter.grpc_adapter_servicer import GrpcAdapterServicer from .superlink.fleet.grpc_bidi.grpc_server import ( generic_create_grpc_server, @@ -72,6 +74,7 @@ ADDRESS_FLEET_API_REST = "0.0.0.0:9093" DATABASE = ":flwr-in-memory-state:" +BASE_DIR = get_flwr_dir() / "superlink" / "ffs" def start_server( # pylint: disable=too-many-arguments,too-many-locals @@ -211,10 +214,14 @@ def run_superlink() -> None: # Initialize StateFactory state_factory = StateFactory(args.database) + # Initialize FfsFactory + ffs_factory = FfsFactory(args.storage_dir) + # Start Driver API driver_server: grpc.Server = run_driver_api_grpc( address=driver_address, state_factory=state_factory, + ffs_factory=ffs_factory, certificates=certificates, ) @@ -610,6 +617,11 @@ def _add_args_common(parser: argparse.ArgumentParser) -> None: "Flower will just create a state in memory.", default=DATABASE, ) + parser.add_argument( + "--storage-dir", + help="The base directory to store the objects for the Flower File System.", + default=BASE_DIR, + ) parser.add_argument( "--auth-list-public-keys", type=str, diff --git a/src/py/flwr/server/driver/grpc_driver.py b/src/py/flwr/server/driver/grpc_driver.py index 60439892d94..80ce9623ab3 100644 --- a/src/py/flwr/server/driver/grpc_driver.py +++ b/src/py/flwr/server/driver/grpc_driver.py @@ -131,6 +131,7 @@ def _init_run(self) -> None: run_id=res.run.run_id, fab_id=res.run.fab_id, fab_version=res.run.fab_version, + fab_hash=res.run.fab_hash, override_config=user_config_from_proto(res.run.override_config), ) diff --git a/src/py/flwr/server/driver/inmemory_driver_test.py b/src/py/flwr/server/driver/inmemory_driver_test.py index d0f32e830f7..610452e225b 100644 --- a/src/py/flwr/server/driver/inmemory_driver_test.py +++ b/src/py/flwr/server/driver/inmemory_driver_test.py @@ -89,6 +89,7 @@ def setUp(self) -> None: run_id=61016, fab_id="mock/mock", fab_version="v1.0.0", + fab_hash="hash", override_config={"test_key": "test_value"}, ) state_factory = MagicMock(state=lambda: self.state) @@ -101,6 +102,7 @@ def test_get_run(self) -> None: self.assertEqual(self.driver.run.run_id, 61016) self.assertEqual(self.driver.run.fab_id, "mock/mock") self.assertEqual(self.driver.run.fab_version, "v1.0.0") + self.assertEqual(self.driver.run.fab_hash, "hash") self.assertEqual(self.driver.run.override_config["test_key"], "test_value") def test_get_nodes(self) -> None: @@ -227,7 +229,7 @@ def test_task_store_consistency_after_push_pull_sqlitestate(self) -> None: # Prepare state = StateFactory("").state() self.driver = InMemoryDriver( - state.create_run("", "", {}), MagicMock(state=lambda: state) + state.create_run("", "", "", {}), MagicMock(state=lambda: state) ) msg_ids, node_id = push_messages(self.driver, self.num_nodes) assert isinstance(state, SqliteState) @@ -253,7 +255,7 @@ def test_task_store_consistency_after_push_pull_inmemory_state(self) -> None: # Prepare state_factory = StateFactory(":flwr-in-memory-state:") state = state_factory.state() - self.driver = InMemoryDriver(state.create_run("", "", {}), state_factory) + self.driver = InMemoryDriver(state.create_run("", "", "", {}), state_factory) msg_ids, node_id = push_messages(self.driver, self.num_nodes) assert isinstance(state, InMemoryState) diff --git a/src/py/flwr/server/superlink/driver/driver_grpc.py b/src/py/flwr/server/superlink/driver/driver_grpc.py index 78293548194..b7b914206f7 100644 --- a/src/py/flwr/server/superlink/driver/driver_grpc.py +++ b/src/py/flwr/server/superlink/driver/driver_grpc.py @@ -24,6 +24,7 @@ from flwr.proto.driver_pb2_grpc import ( # pylint: disable=E0611 add_DriverServicer_to_server, ) +from flwr.server.superlink.ffs.ffs_factory import FfsFactory from flwr.server.superlink.state import StateFactory from ..fleet.grpc_bidi.grpc_server import generic_create_grpc_server @@ -33,12 +34,14 @@ def run_driver_api_grpc( address: str, state_factory: StateFactory, + ffs_factory: FfsFactory, certificates: Optional[Tuple[bytes, bytes, bytes]], ) -> grpc.Server: """Run Driver API (gRPC, request-response).""" # Create Driver API gRPC server driver_servicer: grpc.Server = DriverServicer( state_factory=state_factory, + ffs_factory=ffs_factory, ) driver_add_servicer_to_server_fn = add_DriverServicer_to_server driver_grpc_server = generic_create_grpc_server( diff --git a/src/py/flwr/server/superlink/driver/driver_servicer.py b/src/py/flwr/server/superlink/driver/driver_servicer.py index 7819c587e85..8236f9b50d7 100644 --- a/src/py/flwr/server/superlink/driver/driver_servicer.py +++ b/src/py/flwr/server/superlink/driver/driver_servicer.py @@ -43,6 +43,8 @@ Run, ) from flwr.proto.task_pb2 import TaskRes # pylint: disable=E0611 +from flwr.server.superlink.ffs import Ffs +from flwr.server.superlink.ffs.ffs_factory import FfsFactory from flwr.server.superlink.state import State, StateFactory from flwr.server.utils.validator import validate_task_ins_or_res @@ -50,8 +52,9 @@ class DriverServicer(driver_pb2_grpc.DriverServicer): """Driver API servicer.""" - def __init__(self, state_factory: StateFactory) -> None: + def __init__(self, state_factory: StateFactory, ffs_factory: FfsFactory) -> None: self.state_factory = state_factory + self.ffs_factory = ffs_factory def GetNodes( self, request: GetNodesRequest, context: grpc.ServicerContext @@ -71,9 +74,19 @@ def CreateRun( """Create run ID.""" log(DEBUG, "DriverServicer.CreateRun") state: State = self.state_factory.state() + if request.HasField("fab") and request.fab.HasField("content"): + ffs: Ffs = self.ffs_factory.ffs() + fab_hash = ffs.put(request.fab.content, {}) + _raise_if( + fab_hash != request.fab.hash_str, + f"FAB ({request.fab}) hash from request doesn't match contents", + ) + else: + fab_hash = "" run_id = state.create_run( request.fab_id, request.fab_version, + fab_hash, user_config_from_proto(request.override_config), ) return CreateRunResponse(run_id=run_id) diff --git a/src/py/flwr/server/superlink/fleet/grpc_rere/server_interceptor_test.py b/src/py/flwr/server/superlink/fleet/grpc_rere/server_interceptor_test.py index 798e7143558..b87a4293a77 100644 --- a/src/py/flwr/server/superlink/fleet/grpc_rere/server_interceptor_test.py +++ b/src/py/flwr/server/superlink/fleet/grpc_rere/server_interceptor_test.py @@ -328,7 +328,7 @@ def test_successful_get_run_with_metadata(self) -> None: self.state.create_node( ping_interval=30, public_key=public_key_to_bytes(self._client_public_key) ) - run_id = self.state.create_run("", "", {}) + run_id = self.state.create_run("", "", "", {}) request = GetRunRequest(run_id=run_id) shared_secret = generate_shared_key( self._client_private_key, self._server_public_key @@ -359,7 +359,7 @@ def test_unsuccessful_get_run_with_metadata(self) -> None: self.state.create_node( ping_interval=30, public_key=public_key_to_bytes(self._client_public_key) ) - run_id = self.state.create_run("", "", {}) + run_id = self.state.create_run("", "", "", {}) request = GetRunRequest(run_id=run_id) client_private_key, _ = generate_key_pairs() shared_secret = generate_shared_key(client_private_key, self._server_public_key) diff --git a/src/py/flwr/server/superlink/fleet/vce/vce_api_test.py b/src/py/flwr/server/superlink/fleet/vce/vce_api_test.py index 70c8669ad88..28ed23cf650 100644 --- a/src/py/flwr/server/superlink/fleet/vce/vce_api_test.py +++ b/src/py/flwr/server/superlink/fleet/vce/vce_api_test.py @@ -109,7 +109,11 @@ def register_messages_into_state( """Register `num_messages` into the state factory.""" state: InMemoryState = state_factory.state() # type: ignore state.run_ids[run_id] = Run( - run_id=run_id, fab_id="Mock/mock", fab_version="v1.0.0", override_config={} + run_id=run_id, + fab_id="Mock/mock", + fab_version="v1.0.0", + fab_hash="hash", + override_config={}, ) # Artificially add TaskIns to state so they can be processed # by the Simulation Engine logic @@ -192,7 +196,7 @@ def start_and_shutdown( if not app_dir: app_dir = _autoresolve_app_dir() - run = Run(run_id=1234, fab_id="", fab_version="", override_config={}) + run = Run(run_id=1234, fab_id="", fab_version="", fab_hash="", override_config={}) start_vce( num_supernodes=num_supernodes, diff --git a/src/py/flwr/server/superlink/state/in_memory_state.py b/src/py/flwr/server/superlink/state/in_memory_state.py index beb25ba4e84..fde8fe41912 100644 --- a/src/py/flwr/server/superlink/state/in_memory_state.py +++ b/src/py/flwr/server/superlink/state/in_memory_state.py @@ -277,11 +277,12 @@ def get_node_id(self, client_public_key: bytes) -> Optional[int]: def create_run( self, - fab_id: str, - fab_version: str, + fab_id: Optional[str], + fab_version: Optional[str], + fab_hash: Optional[str], override_config: UserConfig, ) -> int: - """Create a new run for the specified `fab_id` and `fab_version`.""" + """Create a new run for the specified `fab_hash`.""" # Sample a random int64 as run_id with self.lock: run_id = generate_rand_int_from_bytes(RUN_ID_NUM_BYTES) @@ -289,8 +290,9 @@ def create_run( if run_id not in self.run_ids: self.run_ids[run_id] = Run( run_id=run_id, - fab_id=fab_id, - fab_version=fab_version, + fab_id=fab_id if fab_id else "", + fab_version=fab_version if fab_version else "", + fab_hash=fab_hash if fab_hash else "", override_config=override_config, ) return run_id diff --git a/src/py/flwr/server/superlink/state/sqlite_state.py b/src/py/flwr/server/superlink/state/sqlite_state.py index bd3b6ebabd8..93b3cd63ca7 100644 --- a/src/py/flwr/server/superlink/state/sqlite_state.py +++ b/src/py/flwr/server/superlink/state/sqlite_state.py @@ -65,6 +65,7 @@ run_id INTEGER UNIQUE, fab_id TEXT, fab_version TEXT, + fab_hash TEXT, override_config TEXT ); """ @@ -617,8 +618,9 @@ def get_node_id(self, client_public_key: bytes) -> Optional[int]: def create_run( self, - fab_id: str, - fab_version: str, + fab_id: Optional[str], + fab_version: Optional[str], + fab_hash: Optional[str], override_config: UserConfig, ) -> int: """Create a new run for the specified `fab_id` and `fab_version`.""" @@ -630,12 +632,19 @@ def create_run( # If run_id does not exist if self.query(query, (run_id,))[0]["COUNT(*)"] == 0: query = ( - "INSERT INTO run (run_id, fab_id, fab_version, override_config)" - "VALUES (?, ?, ?, ?);" - ) - self.query( - query, (run_id, fab_id, fab_version, json.dumps(override_config)) + "INSERT INTO run " + "(run_id, fab_id, fab_version, fab_hash, override_config)" + "VALUES (?, ?, ?, ?, ?);" ) + if fab_hash: + self.query( + query, (run_id, "", "", fab_hash, json.dumps(override_config)) + ) + else: + self.query( + query, + (run_id, fab_id, fab_version, "", json.dumps(override_config)), + ) return run_id log(ERROR, "Unexpected run creation failure.") return 0 @@ -702,6 +711,7 @@ def get_run(self, run_id: int) -> Optional[Run]: run_id=run_id, fab_id=row["fab_id"], fab_version=row["fab_version"], + fab_hash=row["fab_hash"], override_config=json.loads(row["override_config"]), ) except sqlite3.IntegrityError: diff --git a/src/py/flwr/server/superlink/state/state.py b/src/py/flwr/server/superlink/state/state.py index 23c95805948..80d3b799bce 100644 --- a/src/py/flwr/server/superlink/state/state.py +++ b/src/py/flwr/server/superlink/state/state.py @@ -159,11 +159,12 @@ def get_node_id(self, client_public_key: bytes) -> Optional[int]: @abc.abstractmethod def create_run( self, - fab_id: str, - fab_version: str, + fab_id: Optional[str], + fab_version: Optional[str], + fab_hash: Optional[str], override_config: UserConfig, ) -> int: - """Create a new run for the specified `fab_id` and `fab_version`.""" + """Create a new run for the specified `fab_hash`.""" @abc.abstractmethod def get_run(self, run_id: int) -> Optional[Run]: diff --git a/src/py/flwr/server/superlink/state/state_test.py b/src/py/flwr/server/superlink/state/state_test.py index 5f0d23ffc4d..3efce9ca0c8 100644 --- a/src/py/flwr/server/superlink/state/state_test.py +++ b/src/py/flwr/server/superlink/state/state_test.py @@ -52,7 +52,7 @@ def test_create_and_get_run(self) -> None: """Test if create_run and get_run work correctly.""" # Prepare state: State = self.state_factory() - run_id = state.create_run("Mock/mock", "v1.0.0", {"test_key": "test_value"}) + run_id = state.create_run(None, None, "9f86d08", {"test_key": "test_value"}) # Execute run = state.get_run(run_id) @@ -60,8 +60,7 @@ def test_create_and_get_run(self) -> None: # Assert assert run is not None assert run.run_id == run_id - assert run.fab_id == "Mock/mock" - assert run.fab_version == "v1.0.0" + assert run.fab_hash == "9f86d08" assert run.override_config["test_key"] == "test_value" def test_get_task_ins_empty(self) -> None: @@ -91,7 +90,7 @@ def test_store_task_ins_one(self) -> None: # Prepare consumer_node_id = 1 state = self.state_factory() - run_id = state.create_run("mock/mock", "v1.0.0", {}) + run_id = state.create_run(None, None, "9f86d08", {}) task_ins = create_task_ins( consumer_node_id=consumer_node_id, anonymous=False, run_id=run_id ) @@ -126,7 +125,7 @@ def test_store_and_delete_tasks(self) -> None: # Prepare consumer_node_id = 1 state = self.state_factory() - run_id = state.create_run("mock/mock", "v1.0.0", {}) + run_id = state.create_run(None, None, "9f86d08", {}) task_ins_0 = create_task_ins( consumer_node_id=consumer_node_id, anonymous=False, run_id=run_id ) @@ -200,7 +199,7 @@ def test_task_ins_store_anonymous_and_retrieve_anonymous(self) -> None: """ # Prepare state: State = self.state_factory() - run_id = state.create_run("mock/mock", "v1.0.0", {}) + run_id = state.create_run(None, None, "9f86d08", {}) task_ins = create_task_ins(consumer_node_id=0, anonymous=True, run_id=run_id) # Execute @@ -215,7 +214,7 @@ def test_task_ins_store_anonymous_and_fail_retrieving_identitiy(self) -> None: """Store anonymous TaskIns and fail to retrieve it.""" # Prepare state: State = self.state_factory() - run_id = state.create_run("mock/mock", "v1.0.0", {}) + run_id = state.create_run(None, None, "9f86d08", {}) task_ins = create_task_ins(consumer_node_id=0, anonymous=True, run_id=run_id) # Execute @@ -229,7 +228,7 @@ def test_task_ins_store_identity_and_fail_retrieving_anonymous(self) -> None: """Store identity TaskIns and fail retrieving it as anonymous.""" # Prepare state: State = self.state_factory() - run_id = state.create_run("mock/mock", "v1.0.0", {}) + run_id = state.create_run(None, None, "9f86d08", {}) task_ins = create_task_ins(consumer_node_id=1, anonymous=False, run_id=run_id) # Execute @@ -243,7 +242,7 @@ def test_task_ins_store_identity_and_retrieve_identity(self) -> None: """Store identity TaskIns and retrieve it.""" # Prepare state: State = self.state_factory() - run_id = state.create_run("mock/mock", "v1.0.0", {}) + run_id = state.create_run(None, None, "9f86d08", {}) task_ins = create_task_ins(consumer_node_id=1, anonymous=False, run_id=run_id) # Execute @@ -260,7 +259,7 @@ def test_task_ins_store_delivered_and_fail_retrieving(self) -> None: """Fail retrieving delivered task.""" # Prepare state: State = self.state_factory() - run_id = state.create_run("mock/mock", "v1.0.0", {}) + run_id = state.create_run(None, None, "9f86d08", {}) task_ins = create_task_ins(consumer_node_id=1, anonymous=False, run_id=run_id) # Execute @@ -303,7 +302,7 @@ def test_task_res_store_and_retrieve_by_task_ins_id(self) -> None: """Store TaskRes retrieve it by task_ins_id.""" # Prepare state: State = self.state_factory() - run_id = state.create_run("mock/mock", "v1.0.0", {}) + run_id = state.create_run(None, None, "9f86d08", {}) task_ins_id = uuid4() task_res = create_task_res( producer_node_id=0, @@ -324,7 +323,7 @@ def test_node_ids_initial_state(self) -> None: """Test retrieving all node_ids and empty initial state.""" # Prepare state: State = self.state_factory() - run_id = state.create_run("mock/mock", "v1.0.0", {}) + run_id = state.create_run(None, None, "9f86d08", {}) # Execute retrieved_node_ids = state.get_nodes(run_id) @@ -336,7 +335,7 @@ def test_create_node_and_get_nodes(self) -> None: """Test creating a client node.""" # Prepare state: State = self.state_factory() - run_id = state.create_run("mock/mock", "v1.0.0", {}) + run_id = state.create_run(None, None, "9f86d08", {}) node_ids = [] # Execute @@ -353,7 +352,7 @@ def test_create_node_public_key(self) -> None: # Prepare state: State = self.state_factory() public_key = b"mock" - run_id = state.create_run("mock/mock", "v1.0.0", {}) + run_id = state.create_run(None, None, "9f86d08", {}) # Execute node_id = state.create_node(ping_interval=10, public_key=public_key) @@ -369,7 +368,7 @@ def test_create_node_public_key_twice(self) -> None: # Prepare state: State = self.state_factory() public_key = b"mock" - run_id = state.create_run("mock/mock", "v1.0.0", {}) + run_id = state.create_run(None, None, "9f86d08", {}) node_id = state.create_node(ping_interval=10, public_key=public_key) # Execute @@ -391,7 +390,7 @@ def test_delete_node(self) -> None: """Test deleting a client node.""" # Prepare state: State = self.state_factory() - run_id = state.create_run("mock/mock", "v1.0.0", {}) + run_id = state.create_run(None, None, "9f86d08", {}) node_id = state.create_node(ping_interval=10) # Execute @@ -406,7 +405,7 @@ def test_delete_node_public_key(self) -> None: # Prepare state: State = self.state_factory() public_key = b"mock" - run_id = state.create_run("mock/mock", "v1.0.0", {}) + run_id = state.create_run(None, None, "9f86d08", {}) node_id = state.create_node(ping_interval=10, public_key=public_key) # Execute @@ -423,7 +422,7 @@ def test_delete_node_public_key_none(self) -> None: # Prepare state: State = self.state_factory() public_key = b"mock" - run_id = state.create_run("mock/mock", "v1.0.0", {}) + run_id = state.create_run(None, None, "9f86d08", {}) node_id = 0 # Execute & Assert @@ -442,7 +441,7 @@ def test_delete_node_wrong_public_key(self) -> None: state: State = self.state_factory() public_key = b"mock" wrong_public_key = b"mock_mock" - run_id = state.create_run("mock/mock", "v1.0.0", {}) + run_id = state.create_run(None, None, "9f86d08", {}) node_id = state.create_node(ping_interval=10, public_key=public_key) # Execute & Assert @@ -461,7 +460,7 @@ def test_get_node_id_wrong_public_key(self) -> None: state: State = self.state_factory() public_key = b"mock" wrong_public_key = b"mock_mock" - run_id = state.create_run("mock/mock", "v1.0.0", {}) + run_id = state.create_run(None, None, "9f86d08", {}) # Execute state.create_node(ping_interval=10, public_key=public_key) @@ -476,7 +475,7 @@ def test_get_nodes_invalid_run_id(self) -> None: """Test retrieving all node_ids with invalid run_id.""" # Prepare state: State = self.state_factory() - state.create_run("mock/mock", "v1.0.0", {}) + state.create_run(None, None, "9f86d08", {}) invalid_run_id = 61016 state.create_node(ping_interval=10) @@ -490,7 +489,7 @@ def test_num_task_ins(self) -> None: """Test if num_tasks returns correct number of not delivered task_ins.""" # Prepare state: State = self.state_factory() - run_id = state.create_run("mock/mock", "v1.0.0", {}) + run_id = state.create_run(None, None, "9f86d08", {}) task_0 = create_task_ins(consumer_node_id=0, anonymous=True, run_id=run_id) task_1 = create_task_ins(consumer_node_id=0, anonymous=True, run_id=run_id) @@ -508,7 +507,7 @@ def test_num_task_res(self) -> None: """Test if num_tasks returns correct number of not delivered task_res.""" # Prepare state: State = self.state_factory() - run_id = state.create_run("mock/mock", "v1.0.0", {}) + run_id = state.create_run(None, None, "9f86d08", {}) task_0 = create_task_res( producer_node_id=0, anonymous=True, ancestry=["1"], run_id=run_id ) @@ -609,7 +608,7 @@ def test_acknowledge_ping(self) -> None: """Test if acknowledge_ping works and if get_nodes return online nodes.""" # Prepare state: State = self.state_factory() - run_id = state.create_run("mock/mock", "v1.0.0", {}) + run_id = state.create_run(None, None, "9f86d08", {}) node_ids = [state.create_node(ping_interval=10) for _ in range(100)] for node_id in node_ids[:70]: state.acknowledge_ping(node_id, ping_interval=30) @@ -628,7 +627,7 @@ def test_node_unavailable_error(self) -> None: """Test if get_task_res return TaskRes containing node unavailable error.""" # Prepare state: State = self.state_factory() - run_id = state.create_run("mock/mock", "v1.0.0", {}) + run_id = state.create_run(None, None, "9f86d08", {}) node_id_0 = state.create_node(ping_interval=90) node_id_1 = state.create_node(ping_interval=30) # Create and store TaskIns diff --git a/src/py/flwr/simulation/run_simulation.py b/src/py/flwr/simulation/run_simulation.py index 51799074ef6..257a066433f 100644 --- a/src/py/flwr/simulation/run_simulation.py +++ b/src/py/flwr/simulation/run_simulation.py @@ -163,6 +163,7 @@ def run_simulation_from_cli() -> None: run_id=run_id, fab_id="", fab_version="", + fab_hash="", override_config=override_config, ) @@ -529,7 +530,9 @@ def _run_simulation( # If no `Run` object is set, create one if run is None: run_id = generate_rand_int_from_bytes(RUN_ID_NUM_BYTES) - run = Run(run_id=run_id, fab_id="", fab_version="", override_config={}) + run = Run( + run_id=run_id, fab_id="", fab_version="", fab_hash="", override_config={} + ) args = ( num_supernodes,