Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

break(framework) Remove support for client_ids in start_simulation #3699

Merged
merged 19 commits into from
Jul 7, 2024
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 29 additions & 7 deletions src/py/flwr/simulation/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,14 @@

from flwr.client import ClientFn
from flwr.common import EventType, event
from flwr.common.constant import NODE_ID_NUM_BYTES
from flwr.common.logger import log, set_logger_propagation
from flwr.server.client_manager import ClientManager
from flwr.server.history import History
from flwr.server.server import Server, init_defaults, run_fl
from flwr.server.server_config import ServerConfig
from flwr.server.strategy import Strategy
from flwr.server.superlink.state.utils import generate_rand_int_from_bytes
from flwr.simulation.ray_transport.ray_actor import (
ClientAppActor,
VirtualClientEngineActor,
Expand Down Expand Up @@ -70,13 +72,29 @@

"""

NodeToPartitionMapping = Dict[int, int]


def _create_node_id_to_partition_mapping(
partition_ids: List[int],
) -> NodeToPartitionMapping:
"""Given a list of partition_ids, generate a node_id:partition_id mapping."""
nodes_mapping: NodeToPartitionMapping = {} # {node-id; partition-id}
for i in partition_ids:
while True:
node_id = generate_rand_int_from_bytes(NODE_ID_NUM_BYTES)
if node_id not in nodes_mapping:
break
nodes_mapping[node_id] = i
return nodes_mapping


# pylint: disable=too-many-arguments,too-many-statements,too-many-branches
def start_simulation(
*,
client_fn: ClientFn,
num_clients: Optional[int] = None,
clients_ids: Optional[List[str]] = None,
clients_ids: Optional[Union[List[int], List[str]]] = None,
jafermarq marked this conversation as resolved.
Show resolved Hide resolved
client_resources: Optional[Dict[str, float]] = None,
server: Optional[Server] = None,
config: Optional[ServerConfig] = None,
Expand Down Expand Up @@ -105,10 +123,11 @@ def start_simulation(
num_clients : Optional[int]
The total number of clients in this simulation. This must be set if
`clients_ids` is not set and vice-versa.
clients_ids : Optional[List[str]]
clients_ids : OptionalUnion[List[int],List[str]]]
List `client_id`s for each client. This is only required if
jafermarq marked this conversation as resolved.
Show resolved Hide resolved
`num_clients` is not set. Setting both `num_clients` and `clients_ids`
with `len(clients_ids)` not equal to `num_clients` generates an error.
If list contains `str` values, they will be converted to `int`.
jafermarq marked this conversation as resolved.
Show resolved Hide resolved
client_resources : Optional[Dict[str, float]] (default: `{"num_cpus": 1, "num_gpus": 0.0}`)
CPU and GPU resources for a single client. Supported keys
are `num_cpus` and `num_gpus`. To understand the GPU utilization caused by
Expand Down Expand Up @@ -197,19 +216,21 @@ def start_simulation(
)

# clients_ids takes precedence
cids: List[str]
if clients_ids is not None:
if (num_clients is not None) and (len(clients_ids) != num_clients):
log(ERROR, INVALID_ARGUMENTS_START_SIMULATION)
sys.exit()
else:
cids = clients_ids
partition_ids = [int(cid) for cid in clients_ids]
else:
if num_clients is None:
log(ERROR, INVALID_ARGUMENTS_START_SIMULATION)
sys.exit()
else:
cids = [str(x) for x in range(num_clients)]
partition_ids = list(range(num_clients))

# Create node-id to partition-id mapping
nodes_mapping = _create_node_id_to_partition_mapping(partition_ids)

# Default arguments for Ray initialization
if not ray_init_args:
Expand Down Expand Up @@ -308,10 +329,11 @@ def update_resources(f_stop: threading.Event) -> None:
)

# Register one RayClientProxy object for each client with the ClientManager
for cid in cids:
for node_id, partition_id in nodes_mapping.items():
client_proxy = RayActorClientProxy(
client_fn=client_fn,
cid=cid,
node_id=node_id,
partition_id=partition_id,
actor_pool=pool,
)
initialized_server.client_manager().register(client=client_proxy)
Expand Down
22 changes: 15 additions & 7 deletions src/py/flwr/simulation/ray_transport/ray_client_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,16 +44,22 @@ class RayActorClientProxy(ClientProxy):
"""Flower client proxy which delegates work using Ray."""

def __init__(
self, client_fn: ClientFn, cid: str, actor_pool: VirtualClientEngineActorPool
self,
client_fn: ClientFn,
node_id: int,
partition_id: int,
actor_pool: VirtualClientEngineActorPool,
):
super().__init__(cid)
super().__init__(cid=str(node_id))
self.node_id = node_id
self.partition_id = partition_id

def _load_app() -> ClientApp:
return ClientApp(client_fn=client_fn)

self.app_fn = _load_app
self.actor_pool = actor_pool
self.proxy_state = NodeState(partition_id=int(self.cid))
self.proxy_state = NodeState(partition_id=self.partition_id)

def _submit_job(self, message: Message, timeout: Optional[float]) -> Message:
"""Sumbit a message to the ActorPool."""
Expand All @@ -67,11 +73,13 @@ def _submit_job(self, message: Message, timeout: Optional[float]) -> Message:

try:
self.actor_pool.submit_client_job(
lambda a, a_fn, mssg, cid, state: a.run.remote(a_fn, mssg, cid, state),
(self.app_fn, message, self.cid, state),
lambda a, a_fn, mssg, partition_id, state: a.run.remote(
a_fn, mssg, partition_id, state
),
(self.app_fn, message, str(self.partition_id), state),
)
out_mssg, updated_context = self.actor_pool.get_client_result(
self.cid, timeout
str(self.partition_id), timeout
)

# Update state
Expand Down Expand Up @@ -103,7 +111,7 @@ def _wrap_recordset_in_message(
message_id="",
group_id=str(group_id) if group_id is not None else "",
src_node_id=0,
dst_node_id=int(self.cid),
dst_node_id=self.node_id,
reply_to_message="",
ttl=timeout if timeout else DEFAULT_TTL,
message_type=message_type,
Expand Down
17 changes: 10 additions & 7 deletions src/py/flwr/simulation/ray_transport/ray_client_proxy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
recordset_to_getpropertiesres,
)
from flwr.common.recordset_compat_test import _get_valid_getpropertiesins
from flwr.simulation.app import _create_node_id_to_partition_mapping
from flwr.simulation.ray_transport.ray_actor import (
ClientAppActor,
VirtualClientEngineActor,
Expand Down Expand Up @@ -87,13 +88,15 @@ def create_actor_fn() -> Type[VirtualClientEngineActor]:

# Create 373 client proxies
num_proxies = 373 # a prime number
mapping = _create_node_id_to_partition_mapping(list(range(num_proxies)))
proxies = [
RayActorClientProxy(
client_fn=get_dummy_client,
cid=str(cid),
node_id=node_id,
partition_id=partition_id,
actor_pool=pool,
)
for cid in range(num_proxies)
for node_id, partition_id in mapping.items()
]

return proxies, pool
Expand Down Expand Up @@ -123,7 +126,7 @@ def test_cid_consistency_one_at_a_time() -> None:

res = recordset_to_getpropertiesres(message_out.content)

assert int(prox.cid) * pi == res.properties["result"]
assert int(prox.partition_id) * pi == res.properties["result"]

ray.shutdown()

Expand Down Expand Up @@ -156,21 +159,21 @@ def test_cid_consistency_all_submit_first_run_consistency() -> None:
)
prox.actor_pool.submit_client_job(
lambda a, a_fn, mssg, cid, state: a.run.remote(a_fn, mssg, cid, state),
(prox.app_fn, message, prox.cid, state),
(prox.app_fn, message, str(prox.partition_id), state),
)

# fetch results one at a time
shuffle(proxies)
for prox in proxies:
message_out, updated_context = prox.actor_pool.get_client_result(
prox.cid, timeout=None
str(prox.partition_id), timeout=None
)
prox.proxy_state.update_context(run_id, context=updated_context)
res = recordset_to_getpropertiesres(message_out.content)

assert int(prox.cid) * pi == res.properties["result"]
assert prox.partition_id * pi == res.properties["result"]
assert (
str(int(prox.cid) * pi)
str(prox.partition_id * pi)
== prox.proxy_state.retrieve_context(run_id).state.configs_records[
"result"
]["result"]
Expand Down