Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(framework) Enable passing an existing Context to ServerApp startup function #4364

Merged
merged 9 commits into from
Oct 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 13 additions & 9 deletions src/py/flwr/server/run_serverapp.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@
from flwr.common.constant import DRIVER_API_DEFAULT_ADDRESS
from flwr.common.logger import log, update_console_handler, warn_deprecated_feature
from flwr.common.object_ref import load_app
from flwr.common.typing import UserConfig
from flwr.proto.fab_pb2 import GetFabRequest, GetFabResponse # pylint: disable=E0611
from flwr.proto.run_pb2 import ( # pylint: disable=E0611
CreateRunRequest,
Expand All @@ -46,13 +45,14 @@
from .server_app import LoadServerAppError, ServerApp


# pylint: disable-next=too-many-arguments,too-many-positional-arguments
def run(
driver: Driver,
context: Context,
server_app_dir: str,
server_app_run_config: UserConfig,
server_app_attr: Optional[str] = None,
loaded_server_app: Optional[ServerApp] = None,
) -> None:
) -> Context:
"""Run ServerApp with a given Driver."""
if not (server_app_attr is None) ^ (loaded_server_app is None):
raise ValueError(
Expand All @@ -78,15 +78,11 @@ def _load() -> ServerApp:

server_app = _load()

# Initialize Context
context = Context(
node_id=0, node_config={}, state=RecordSet(), run_config=server_app_run_config
)

# Call ServerApp
server_app(driver=driver, context=context)

log(DEBUG, "ServerApp finished running.")
return context


# pylint: disable-next=too-many-branches,too-many-statements,too-many-locals
Expand Down Expand Up @@ -225,11 +221,19 @@ def run_server_app() -> None:
root_certificates,
)

# Initialize Context
context = Context(
node_id=0,
node_config={},
state=RecordSet(),
run_config=server_app_run_config,
)

# Run the ServerApp with the Driver
run(
driver=driver,
context=context,
server_app_dir=app_path,
server_app_run_config=server_app_run_config,
server_app_attr=server_app_attr,
)

Expand Down
16 changes: 12 additions & 4 deletions src/py/flwr/simulation/run_simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@

from flwr.cli.config_utils import load_and_validate
from flwr.client import ClientApp
from flwr.common import EventType, event, log, now
from flwr.common import Context, EventType, RecordSet, event, log, now
from flwr.common.config import get_fused_config_from_dir, parse_config_args
from flwr.common.constant import RUN_ID_NUM_BYTES, Status
from flwr.common.logger import (
Expand All @@ -40,7 +40,7 @@
)
from flwr.common.typing import Run, RunStatus, UserConfig
from flwr.server.driver import Driver, InMemoryDriver
from flwr.server.run_serverapp import run as run_server_app
from flwr.server.run_serverapp import run as _run
from flwr.server.server_app import ServerApp
from flwr.server.superlink.fleet import vce
from flwr.server.superlink.fleet.vce.backend.backend import BackendConfig
Expand Down Expand Up @@ -333,11 +333,19 @@ def server_th_with_start_checks(
log(INFO, "Enabling GPU growth for Tensorflow on the server thread.")
enable_gpu_growth()

# Initialize Context
context = Context(
node_id=0,
node_config={},
state=RecordSet(),
run_config=_server_app_run_config,
)

# Run ServerApp
run_server_app(
_run(
driver=_driver,
context=context,
server_app_dir=_server_app_dir,
server_app_run_config=_server_app_run_config,
server_app_attr=_server_app_attr,
loaded_server_app=_server_app,
)
Expand Down