Skip to content

Commit

Permalink
refactor(framework) Enable passing an existing Context to `ServerAp…
Browse files Browse the repository at this point in the history
…p` startup function (#4364)
  • Loading branch information
jafermarq authored Oct 24, 2024
1 parent eb6d9be commit 8c449f5
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 13 deletions.
22 changes: 13 additions & 9 deletions src/py/flwr/server/run_serverapp.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@
from flwr.common.constant import DRIVER_API_DEFAULT_ADDRESS
from flwr.common.logger import log, update_console_handler, warn_deprecated_feature
from flwr.common.object_ref import load_app
from flwr.common.typing import UserConfig
from flwr.proto.fab_pb2 import GetFabRequest, GetFabResponse # pylint: disable=E0611
from flwr.proto.run_pb2 import ( # pylint: disable=E0611
CreateRunRequest,
Expand All @@ -46,13 +45,14 @@
from .server_app import LoadServerAppError, ServerApp


# pylint: disable-next=too-many-arguments,too-many-positional-arguments
def run(
driver: Driver,
context: Context,
server_app_dir: str,
server_app_run_config: UserConfig,
server_app_attr: Optional[str] = None,
loaded_server_app: Optional[ServerApp] = None,
) -> None:
) -> Context:
"""Run ServerApp with a given Driver."""
if not (server_app_attr is None) ^ (loaded_server_app is None):
raise ValueError(
Expand All @@ -78,15 +78,11 @@ def _load() -> ServerApp:

server_app = _load()

# Initialize Context
context = Context(
node_id=0, node_config={}, state=RecordSet(), run_config=server_app_run_config
)

# Call ServerApp
server_app(driver=driver, context=context)

log(DEBUG, "ServerApp finished running.")
return context


# pylint: disable-next=too-many-branches,too-many-statements,too-many-locals
Expand Down Expand Up @@ -225,11 +221,19 @@ def run_server_app() -> None:
root_certificates,
)

# Initialize Context
context = Context(
node_id=0,
node_config={},
state=RecordSet(),
run_config=server_app_run_config,
)

# Run the ServerApp with the Driver
run(
driver=driver,
context=context,
server_app_dir=app_path,
server_app_run_config=server_app_run_config,
server_app_attr=server_app_attr,
)

Expand Down
16 changes: 12 additions & 4 deletions src/py/flwr/simulation/run_simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@

from flwr.cli.config_utils import load_and_validate
from flwr.client import ClientApp
from flwr.common import EventType, event, log, now
from flwr.common import Context, EventType, RecordSet, event, log, now
from flwr.common.config import get_fused_config_from_dir, parse_config_args
from flwr.common.constant import RUN_ID_NUM_BYTES, Status
from flwr.common.logger import (
Expand All @@ -40,7 +40,7 @@
)
from flwr.common.typing import Run, RunStatus, UserConfig
from flwr.server.driver import Driver, InMemoryDriver
from flwr.server.run_serverapp import run as run_server_app
from flwr.server.run_serverapp import run as _run
from flwr.server.server_app import ServerApp
from flwr.server.superlink.fleet import vce
from flwr.server.superlink.fleet.vce.backend.backend import BackendConfig
Expand Down Expand Up @@ -333,11 +333,19 @@ def server_th_with_start_checks(
log(INFO, "Enabling GPU growth for Tensorflow on the server thread.")
enable_gpu_growth()

# Initialize Context
context = Context(
node_id=0,
node_config={},
state=RecordSet(),
run_config=_server_app_run_config,
)

# Run ServerApp
run_server_app(
_run(
driver=_driver,
context=context,
server_app_dir=_server_app_dir,
server_app_run_config=_server_app_run_config,
server_app_attr=_server_app_attr,
loaded_server_app=_server_app,
)
Expand Down

0 comments on commit 8c449f5

Please sign in to comment.