Skip to content

Commit

Permalink
Make flower-simulation accept ClientApp and ServerApp objects (#…
Browse files Browse the repository at this point in the history
…3024)

Co-authored-by: Daniel J. Beutel <daniel@flower.ai>
  • Loading branch information
jafermarq and danieljanes authored Mar 5, 2024
1 parent d16807a commit c05df2a
Show file tree
Hide file tree
Showing 7 changed files with 324 additions and 84 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ flower-fleet-api = "flwr.server:run_fleet_api"
flower-superlink = "flwr.server:run_superlink"
flower-client-app = "flwr.client:run_client_app"
flower-server-app = "flwr.server:run_server_app"
flower-simulation = "flwr.simulation:run_simulation"
flower-simulation = "flwr.simulation:run_simulation_from_cli"

[tool.poetry.dependencies]
python = "^3.8"
Expand Down
6 changes: 3 additions & 3 deletions src/py/flwr/server/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,7 +362,7 @@ def run_superlink() -> None:
f_stop = asyncio.Event() # Does nothing
_run_fleet_api_vce(
num_supernodes=args.num_supernodes,
client_app_module_name=args.client_app,
client_app_attr=args.client_app,
backend_name=args.backend,
backend_config_json_stream=args.backend_config,
working_dir=args.dir,
Expand Down Expand Up @@ -438,7 +438,7 @@ def _run_fleet_api_grpc_rere(
# pylint: disable=too-many-arguments
def _run_fleet_api_vce(
num_supernodes: int,
client_app_module_name: str,
client_app_attr: str,
backend_name: str,
backend_config_json_stream: str,
working_dir: str,
Expand All @@ -449,7 +449,7 @@ def _run_fleet_api_vce(

start_vce(
num_supernodes=num_supernodes,
client_app_module_name=client_app_module_name,
client_app_attr=client_app_attr,
backend_name=backend_name,
backend_config_json_stream=backend_config_json_stream,
state_factory=state_factory,
Expand Down
23 changes: 14 additions & 9 deletions src/py/flwr/server/run_serverapp.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@


import argparse
import asyncio
import sys
from logging import DEBUG, WARN
from pathlib import Path
Expand All @@ -30,17 +29,27 @@


def run(
server_app_attr: str,
driver: Driver,
server_app_dir: str,
stop_event: Optional[asyncio.Event] = None,
server_app_attr: Optional[str] = None,
loaded_server_app: Optional[ServerApp] = None,
) -> None:
"""Run ServerApp with a given Driver."""
if not (server_app_attr is None) ^ (loaded_server_app is None):
raise ValueError(
"Either `server_app_attr` or `loaded_server_app` should be set "
"but not both. "
)

if server_app_dir is not None:
sys.path.insert(0, server_app_dir)

# Load ServerApp if needed
def _load() -> ServerApp:
server_app: ServerApp = load_server_app(server_app_attr)
if server_app_attr:
server_app: ServerApp = load_server_app(server_app_attr)
if loaded_server_app:
server_app = loaded_server_app
return server_app

server_app = _load()
Expand All @@ -52,10 +61,6 @@ def _load() -> ServerApp:
server_app(driver=driver, context=context)

log(DEBUG, "ServerApp finished running.")
# Upon completion, trigger stop event if one was passed
if stop_event is not None:
log(DEBUG, "Triggering stop event.")
stop_event.set()


def run_server_app() -> None:
Expand Down Expand Up @@ -117,7 +122,7 @@ def run_server_app() -> None:
)

# Run the Server App with the Driver
run(server_app_attr, driver, server_app_dir)
run(driver=driver, server_app_dir=server_app_dir, server_app_attr=server_app_attr)

# Clean up
driver.__del__() # pylint: disable=unnecessary-dunder-call
Expand Down
18 changes: 15 additions & 3 deletions src/py/flwr/server/superlink/fleet/vce/vce_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,16 +219,23 @@ async def run(

# pylint: disable=too-many-arguments,unused-argument,too-many-locals
def start_vce(
client_app_module_name: str,
backend_name: str,
backend_config_json_stream: str,
working_dir: str,
f_stop: asyncio.Event,
client_app: Optional[ClientApp] = None,
client_app_attr: Optional[str] = None,
num_supernodes: Optional[int] = None,
state_factory: Optional[StateFactory] = None,
existing_nodes_mapping: Optional[NodeToPartitionMapping] = None,
) -> None:
"""Start Fleet API with the Simulation Engine."""
if client_app_attr is not None and client_app is not None:
raise ValueError(
"Both `client_app_attr` and `client_app` are provided, "
"but only one is allowed."
)

if num_supernodes is not None and existing_nodes_mapping is not None:
raise ValueError(
"Both `num_supernodes` and `existing_nodes_mapping` are provided, "
Expand Down Expand Up @@ -292,10 +299,15 @@ def backend_fn() -> Backend:
"""Instantiate a Backend."""
return backend_type(backend_config, work_dir=working_dir)

log(INFO, "client_app_module_name = %s", client_app_module_name)
log(INFO, "client_app_attr = %s", client_app_attr)

# Load ClientApp if needed
def _load() -> ClientApp:
app: ClientApp = load_client_app(client_app_module_name)

if client_app_attr:
app: ClientApp = load_client_app(client_app_attr)
if client_app:
app = client_app
return app

app_fn = _load
Expand Down
8 changes: 4 additions & 4 deletions src/py/flwr/server/superlink/fleet/vce/vce_api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def _autoresolve_working_dir(rel_client_app_dir: str = "backend") -> str:
# pylint: disable=too-many-arguments
def start_and_shutdown(
backend: str = "ray",
clientapp_module: str = "raybackend_test:client_app",
client_app_attr: str = "raybackend_test:client_app",
working_dir: str = "",
num_supernodes: Optional[int] = None,
state_factory: Optional[StateFactory] = None,
Expand Down Expand Up @@ -162,7 +162,7 @@ def start_and_shutdown(

start_vce(
num_supernodes=num_supernodes,
client_app_module_name=clientapp_module,
client_app_attr=client_app_attr,
backend_name=backend,
backend_config_json_stream=backend_config,
state_factory=state_factory,
Expand All @@ -183,7 +183,7 @@ def test_erroneous_no_supernodes_client_mapping(self) -> None:
with self.assertRaises(ValueError):
start_and_shutdown(duration=2)

def test_erroneous_clientapp_module_name(self) -> None:
def test_erroneous_client_app_attr(self) -> None:
"""Tests attempt to load a ClientApp that can't be found."""
num_messages = 7
num_nodes = 59
Expand All @@ -193,7 +193,7 @@ def test_erroneous_clientapp_module_name(self) -> None:
)
with self.assertRaises(RuntimeError):
start_and_shutdown(
clientapp_module="totally_fictitious_app:client",
client_app_attr="totally_fictitious_app:client",
state_factory=state_factory,
nodes_mapping=nodes_mapping,
)
Expand Down
7 changes: 2 additions & 5 deletions src/py/flwr/simulation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

import importlib

from flwr.simulation.run_simulation import run_simulation
from flwr.simulation.run_simulation import run_simulation, run_simulation_from_cli

is_ray_installed = importlib.util.find_spec("ray") is not None

Expand All @@ -36,7 +36,4 @@ def start_simulation(*args, **kwargs): # type: ignore
raise ImportError(RAY_IMPORT_ERROR)


__all__ = [
"start_simulation",
"run_simulation",
]
__all__ = ["start_simulation", "run_simulation_from_cli", "run_simulation"]
Loading

0 comments on commit c05df2a

Please sign in to comment.