diff --git a/src/py/flwr/superexec/deployment.py b/src/py/flwr/superexec/deployment.py index 6f931e81eefa..d117f280b38d 100644 --- a/src/py/flwr/superexec/deployment.py +++ b/src/py/flwr/superexec/deployment.py @@ -17,7 +17,7 @@ import subprocess import sys from logging import ERROR, INFO -from typing import Optional +from typing import Dict, Optional from typing_extensions import override @@ -53,18 +53,29 @@ def _connect(self) -> None: ) self.stub = DriverStub(channel) - def _create_run(self, fab_id: str, fab_version: str) -> int: + def _create_run( + self, + fab_id: str, + fab_version: str, + override_config: Dict[str, str], + ) -> int: if self.stub is None: self._connect() assert self.stub is not None - req = CreateRunRequest(fab_id=fab_id, fab_version=fab_version) + req = CreateRunRequest( + fab_id=fab_id, + fab_version=fab_version, + override_config=override_config, + ) res = self.stub.CreateRun(request=req) return int(res.run_id) @override - def start_run(self, fab_file: bytes) -> Optional[RunTracker]: + def start_run( + self, fab_file: bytes, override_config: Dict[str, str] + ) -> Optional[RunTracker]: """Start run using the Flower Deployment Engine.""" try: # Install FAB to flwr dir @@ -79,7 +90,7 @@ def start_run(self, fab_file: bytes) -> Optional[RunTracker]: ) # Call SuperLink to create run - run_id: int = self._create_run(fab_id, fab_version) + run_id: int = self._create_run(fab_id, fab_version, override_config) log(INFO, "Created run %s", str(run_id)) # Start ServerApp diff --git a/src/py/flwr/superexec/exec_servicer.py b/src/py/flwr/superexec/exec_servicer.py index e5ef2bd59a79..61a7bc289af3 100644 --- a/src/py/flwr/superexec/exec_servicer.py +++ b/src/py/flwr/superexec/exec_servicer.py @@ -45,7 +45,10 @@ def StartRun( """Create run ID.""" log(INFO, "ExecServicer.StartRun") - run = self.executor.start_run(request.fab_file) + run = self.executor.start_run( + request.fab_file, + dict(request.override_config.items()), + ) if run is None: log(ERROR, "Executor failed to start run") diff --git a/src/py/flwr/superexec/exec_servicer_test.py b/src/py/flwr/superexec/exec_servicer_test.py index 41f67b74c48b..edc91df4530e 100644 --- a/src/py/flwr/superexec/exec_servicer_test.py +++ b/src/py/flwr/superexec/exec_servicer_test.py @@ -36,7 +36,7 @@ def test_start_run() -> None: run_res.proc = proc executor = MagicMock() - executor.start_run = lambda _: run_res + executor.start_run = lambda _, __: run_res context_mock = MagicMock() diff --git a/src/py/flwr/superexec/executor.py b/src/py/flwr/superexec/executor.py index f85ac4c157fc..85b6e5c3e095 100644 --- a/src/py/flwr/superexec/executor.py +++ b/src/py/flwr/superexec/executor.py @@ -17,7 +17,7 @@ from abc import ABC, abstractmethod from dataclasses import dataclass from subprocess import Popen -from typing import Optional +from typing import Dict, Optional @dataclass @@ -33,8 +33,7 @@ class Executor(ABC): @abstractmethod def start_run( - self, - fab_file: bytes, + self, fab_file: bytes, override_config: Dict[str, str] ) -> Optional[RunTracker]: """Start a run using the given Flower FAB ID and version.