Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

break(framework) Remove support for client_ids in start_simulation #3699

Merged
merged 19 commits into from
Jul 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions src/py/flwr/common/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,19 @@ def warn_deprecated_feature(name: str) -> None:
)


def warn_unsupported_feature(name: str) -> None:
"""Warn the user when they use an unsupported feature."""
log(
WARN,
"""UNSUPPORTED FEATURE: %s

This is an unsupported feature. It will be removed
entirely in future versions of Flower.
""",
name,
)


def set_logger_propagation(
child_logger: logging.Logger, value: bool = True
) -> logging.Logger:
Expand Down
63 changes: 39 additions & 24 deletions src/py/flwr/simulation/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,14 @@

from flwr.client import ClientFnExt
from flwr.common import EventType, event
from flwr.common.logger import log, set_logger_propagation
from flwr.common.constant import NODE_ID_NUM_BYTES
from flwr.common.logger import log, set_logger_propagation, warn_unsupported_feature
from flwr.server.client_manager import ClientManager
from flwr.server.history import History
from flwr.server.server import Server, init_defaults, run_fl
from flwr.server.server_config import ServerConfig
from flwr.server.strategy import Strategy
from flwr.server.superlink.state.utils import generate_rand_int_from_bytes
from flwr.simulation.ray_transport.ray_actor import (
ClientAppActor,
VirtualClientEngineActor,
Expand All @@ -51,7 +53,7 @@
`start_simulation(
*,
client_fn: ClientFn,
num_clients: Optional[int] = None,
num_clients: int,
clients_ids: Optional[List[str]] = None,
client_resources: Optional[Dict[str, float]] = None,
server: Optional[Server] = None,
Expand All @@ -70,13 +72,29 @@

"""

NodeToPartitionMapping = Dict[int, int]


def _create_node_id_to_partition_mapping(
num_clients: int,
) -> NodeToPartitionMapping:
"""Generate a node_id:partition_id mapping."""
nodes_mapping: NodeToPartitionMapping = {} # {node-id; partition-id}
for i in range(num_clients):
while True:
node_id = generate_rand_int_from_bytes(NODE_ID_NUM_BYTES)
if node_id not in nodes_mapping:
break
nodes_mapping[node_id] = i
return nodes_mapping


# pylint: disable=too-many-arguments,too-many-statements,too-many-branches
def start_simulation(
*,
client_fn: ClientFnExt,
num_clients: Optional[int] = None,
clients_ids: Optional[List[str]] = None,
num_clients: int,
clients_ids: Optional[List[str]] = None, # UNSUPPORTED, WILL BE REMOVED
client_resources: Optional[Dict[str, float]] = None,
server: Optional[Server] = None,
config: Optional[ServerConfig] = None,
Expand All @@ -102,13 +120,14 @@ def start_simulation(
(model, dataset, hyperparameters, ...) should be (re-)created in either the
call to `client_fn` or the call to any of the client methods (e.g., load
evaluation data in the `evaluate` method itself).
num_clients : Optional[int]
The total number of clients in this simulation. This must be set if
`clients_ids` is not set and vice-versa.
num_clients : int
The total number of clients in this simulation.
clients_ids : Optional[List[str]]
UNSUPPORTED, WILL BE REMOVED. USE `num_clients` INSTEAD.
List `client_id`s for each client. This is only required if
jafermarq marked this conversation as resolved.
Show resolved Hide resolved
`num_clients` is not set. Setting both `num_clients` and `clients_ids`
with `len(clients_ids)` not equal to `num_clients` generates an error.
Using this argument will raise an error.
client_resources : Optional[Dict[str, float]] (default: `{"num_cpus": 1, "num_gpus": 0.0}`)
CPU and GPU resources for a single client. Supported keys
are `num_cpus` and `num_gpus`. To understand the GPU utilization caused by
Expand Down Expand Up @@ -158,7 +177,6 @@ def start_simulation(
is an advanced feature. For all details, please refer to the Ray documentation:
https://docs.ray.io/en/latest/ray-core/scheduling/index.html


Returns
-------
hist : flwr.server.history.History
Expand All @@ -170,6 +188,14 @@ def start_simulation(
{"num_clients": len(clients_ids) if clients_ids is not None else num_clients},
)

if clients_ids is not None:
warn_unsupported_feature(
"Passing `clients_ids` to `start_simulation` is deprecated and not longer "
"used by `start_simulation`. Use `num_clients` exclusively instead."
)
log(ERROR, "`clients_ids` argument used.")
sys.exit()

# Set logger propagation
loop: Optional[asyncio.AbstractEventLoop] = None
try:
Expand All @@ -196,20 +222,8 @@ def start_simulation(
initialized_config,
)

# clients_ids takes precedence
cids: List[str]
if clients_ids is not None:
if (num_clients is not None) and (len(clients_ids) != num_clients):
log(ERROR, INVALID_ARGUMENTS_START_SIMULATION)
sys.exit()
else:
cids = clients_ids
else:
if num_clients is None:
log(ERROR, INVALID_ARGUMENTS_START_SIMULATION)
sys.exit()
else:
cids = [str(x) for x in range(num_clients)]
# Create node-id to partition-id mapping
nodes_mapping = _create_node_id_to_partition_mapping(num_clients)

# Default arguments for Ray initialization
if not ray_init_args:
Expand Down Expand Up @@ -308,10 +322,11 @@ def update_resources(f_stop: threading.Event) -> None:
)

# Register one RayClientProxy object for each client with the ClientManager
for cid in cids:
for node_id, partition_id in nodes_mapping.items():
client_proxy = RayActorClientProxy(
client_fn=client_fn,
cid=cid,
node_id=node_id,
partition_id=partition_id,
actor_pool=pool,
)
initialized_server.client_manager().register(client=client_proxy)
Expand Down
22 changes: 15 additions & 7 deletions src/py/flwr/simulation/ray_transport/ray_client_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,16 +44,22 @@ class RayActorClientProxy(ClientProxy):
"""Flower client proxy which delegates work using Ray."""

def __init__(
self, client_fn: ClientFnExt, cid: str, actor_pool: VirtualClientEngineActorPool
self,
client_fn: ClientFnExt,
node_id: int,
partition_id: int,
actor_pool: VirtualClientEngineActorPool,
):
super().__init__(cid)
super().__init__(cid=str(node_id))
self.node_id = node_id
self.partition_id = partition_id

def _load_app() -> ClientApp:
return ClientApp(client_fn=client_fn)

self.app_fn = _load_app
self.actor_pool = actor_pool
self.proxy_state = NodeState(partition_id=int(self.cid))
self.proxy_state = NodeState(partition_id=self.partition_id)

def _submit_job(self, message: Message, timeout: Optional[float]) -> Message:
"""Sumbit a message to the ActorPool."""
Expand All @@ -67,11 +73,13 @@ def _submit_job(self, message: Message, timeout: Optional[float]) -> Message:

try:
self.actor_pool.submit_client_job(
lambda a, a_fn, mssg, cid, state: a.run.remote(a_fn, mssg, cid, state),
(self.app_fn, message, self.cid, state),
lambda a, a_fn, mssg, partition_id, state: a.run.remote(
a_fn, mssg, partition_id, state
),
(self.app_fn, message, str(self.partition_id), state),
)
out_mssg, updated_context = self.actor_pool.get_client_result(
self.cid, timeout
str(self.partition_id), timeout
)

# Update state
Expand Down Expand Up @@ -103,7 +111,7 @@ def _wrap_recordset_in_message(
message_id="",
group_id=str(group_id) if group_id is not None else "",
src_node_id=0,
dst_node_id=int(self.cid),
dst_node_id=self.node_id,
reply_to_message="",
ttl=timeout if timeout else DEFAULT_TTL,
message_type=message_type,
Expand Down
43 changes: 22 additions & 21 deletions src/py/flwr/simulation/ray_transport/ray_client_proxy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
recordset_to_getpropertiesres,
)
from flwr.common.recordset_compat_test import _get_valid_getpropertiesins
from flwr.simulation.app import _create_node_id_to_partition_mapping
from flwr.simulation.ray_transport.ray_actor import (
ClientAppActor,
VirtualClientEngineActor,
Expand Down Expand Up @@ -68,9 +69,7 @@ def get_dummy_client(
node_id: int, partition_id: Optional[int] # pylint: disable=unused-argument
) -> Client:
"""Return a DummyClient converted to Client type."""
if partition_id is None:
raise ValueError("`partition_id` is not set.")
return DummyClient(partition_id).to_client()
return DummyClient(node_id).to_client()


def prep(
Expand All @@ -91,13 +90,15 @@ def create_actor_fn() -> Type[VirtualClientEngineActor]:

# Create 373 client proxies
num_proxies = 373 # a prime number
mapping = _create_node_id_to_partition_mapping(num_proxies)
proxies = [
RayActorClientProxy(
client_fn=get_dummy_client,
cid=str(cid),
node_id=node_id,
partition_id=partition_id,
actor_pool=pool,
)
for cid in range(num_proxies)
for node_id, partition_id in mapping.items()
]

return proxies, pool
Expand Down Expand Up @@ -127,7 +128,7 @@ def test_cid_consistency_one_at_a_time() -> None:

res = recordset_to_getpropertiesres(message_out.content)

assert int(prox.cid) * pi == res.properties["result"]
assert int(prox.node_id) * pi == res.properties["result"]

ray.shutdown()

Expand Down Expand Up @@ -160,21 +161,21 @@ def test_cid_consistency_all_submit_first_run_consistency() -> None:
)
prox.actor_pool.submit_client_job(
lambda a, a_fn, mssg, cid, state: a.run.remote(a_fn, mssg, cid, state),
(prox.app_fn, message, prox.cid, state),
(prox.app_fn, message, str(prox.node_id), state),
)

# fetch results one at a time
shuffle(proxies)
for prox in proxies:
message_out, updated_context = prox.actor_pool.get_client_result(
prox.cid, timeout=None
str(prox.node_id), timeout=None
)
prox.proxy_state.update_context(run_id, context=updated_context)
res = recordset_to_getpropertiesres(message_out.content)

assert int(prox.cid) * pi == res.properties["result"]
assert prox.node_id * pi == res.properties["result"]
assert (
str(int(prox.cid) * pi)
str(prox.node_id * pi)
== prox.proxy_state.retrieve_context(run_id).state.configs_records[
"result"
]["result"]
Expand All @@ -187,7 +188,7 @@ def test_cid_consistency_without_proxies() -> None:
"""Test cid consistency of jobs submitted/retrieved to/from pool w/o ClientProxy."""
proxies, pool = prep()
num_clients = len(proxies)
cids = [str(cid) for cid in range(num_clients)]
node_ids = list(range(num_clients))

getproperties_ins = _get_valid_getpropertiesins()
recordset = getpropertiesins_to_recordset(getproperties_ins)
Expand All @@ -196,36 +197,36 @@ def _load_app() -> ClientApp:
return ClientApp(client_fn=get_dummy_client)

# submit all jobs (collect later)
shuffle(cids)
for cid in cids:
shuffle(node_ids)
for node_id in node_ids:
message = Message(
content=recordset,
metadata=Metadata(
run_id=0,
message_id="",
group_id=str(0),
src_node_id=0,
dst_node_id=12345,
dst_node_id=node_id,
reply_to_message="",
ttl=DEFAULT_TTL,
message_type=MessageTypeLegacy.GET_PROPERTIES,
),
)
pool.submit_client_job(
lambda a, c_fn, j_fn, cid_, state: a.run.remote(c_fn, j_fn, cid_, state),
lambda a, c_fn, j_fn, nid_, state: a.run.remote(c_fn, j_fn, nid_, state),
(
_load_app,
message,
cid,
Context(state=RecordSet(), partition_id=int(cid)),
str(node_id),
Context(state=RecordSet(), partition_id=node_id),
),
)

# fetch results one at a time
shuffle(cids)
for cid in cids:
message_out, _ = pool.get_client_result(cid, timeout=None)
shuffle(node_ids)
for node_id in node_ids:
message_out, _ = pool.get_client_result(str(node_id), timeout=None)
res = recordset_to_getpropertiesres(message_out.content)
assert int(cid) * pi == res.properties["result"]
assert node_id * pi == res.properties["result"]

ray.shutdown()