From f271ac748c8b52a763f90a557f414df8db1bfe9c Mon Sep 17 00:00:00 2001 From: jafermarq Date: Wed, 10 Jan 2024 16:22:14 +0000 Subject: [PATCH 01/25] removed old (unused) clientproxy --- .../ray_transport/ray_client_proxy.py | 139 +----------------- 1 file changed, 1 insertion(+), 138 deletions(-) diff --git a/src/py/flwr/simulation/ray_transport/ray_client_proxy.py b/src/py/flwr/simulation/ray_transport/ray_client_proxy.py index 5c05850dfd2f..7e7678837b14 100644 --- a/src/py/flwr/simulation/ray_transport/ray_client_proxy.py +++ b/src/py/flwr/simulation/ray_transport/ray_client_proxy.py @@ -17,9 +17,7 @@ import traceback from logging import ERROR -from typing import Dict, Optional, cast - -import ray +from typing import Optional, cast from flwr import common from flwr.client import Client, ClientFn @@ -39,87 +37,6 @@ ) -class RayClientProxy(ClientProxy): - """Flower client proxy which delegates work using Ray.""" - - def __init__(self, client_fn: ClientFn, cid: str, resources: Dict[str, float]): - super().__init__(cid) - self.client_fn = client_fn - self.resources = resources - - def get_properties( - self, ins: common.GetPropertiesIns, timeout: Optional[float] - ) -> common.GetPropertiesRes: - """Return client's properties.""" - future_get_properties_res = launch_and_get_properties.options( # type: ignore - **self.resources, - ).remote(self.client_fn, self.cid, ins) - try: - res = ray.get(future_get_properties_res, timeout=timeout) - except Exception as ex: - log(ERROR, ex) - raise ex - return cast( - common.GetPropertiesRes, - res, - ) - - def get_parameters( - self, ins: common.GetParametersIns, timeout: Optional[float] - ) -> common.GetParametersRes: - """Return the current local model parameters.""" - future_paramseters_res = launch_and_get_parameters.options( # type: ignore - **self.resources, - ).remote(self.client_fn, self.cid, ins) - try: - res = ray.get(future_paramseters_res, timeout=timeout) - except Exception as ex: - log(ERROR, ex) - raise ex - return cast( - common.GetParametersRes, - res, - ) - - def fit(self, ins: common.FitIns, timeout: Optional[float]) -> common.FitRes: - """Train model parameters on the locally held dataset.""" - future_fit_res = launch_and_fit.options( # type: ignore - **self.resources, - ).remote(self.client_fn, self.cid, ins) - try: - res = ray.get(future_fit_res, timeout=timeout) - except Exception as ex: - log(ERROR, ex) - raise ex - return cast( - common.FitRes, - res, - ) - - def evaluate( - self, ins: common.EvaluateIns, timeout: Optional[float] - ) -> common.EvaluateRes: - """Evaluate model parameters on the locally held dataset.""" - future_evaluate_res = launch_and_evaluate.options( # type: ignore - **self.resources, - ).remote(self.client_fn, self.cid, ins) - try: - res = ray.get(future_evaluate_res, timeout=timeout) - except Exception as ex: - log(ERROR, ex) - raise ex - return cast( - common.EvaluateRes, - res, - ) - - def reconnect( - self, ins: common.ReconnectIns, timeout: Optional[float] - ) -> common.DisconnectRes: - """Disconnect and (optionally) reconnect later.""" - return common.DisconnectRes(reason="") # Nothing to do here (yet) - - class RayActorClientProxy(ClientProxy): """Flower client proxy which delegates work using Ray.""" @@ -239,57 +156,3 @@ def reconnect( ) -> common.DisconnectRes: """Disconnect and (optionally) reconnect later.""" return common.DisconnectRes(reason="") # Nothing to do here (yet) - - -@ray.remote -def launch_and_get_properties( - client_fn: ClientFn, cid: str, get_properties_ins: common.GetPropertiesIns -) -> common.GetPropertiesRes: - """Exectue get_properties remotely.""" - client: Client = _create_client(client_fn, cid) - return maybe_call_get_properties( - client=client, - get_properties_ins=get_properties_ins, - ) - - -@ray.remote -def launch_and_get_parameters( - client_fn: ClientFn, cid: str, get_parameters_ins: common.GetParametersIns -) -> common.GetParametersRes: - """Exectue get_parameters remotely.""" - client: Client = _create_client(client_fn, cid) - return maybe_call_get_parameters( - client=client, - get_parameters_ins=get_parameters_ins, - ) - - -@ray.remote -def launch_and_fit( - client_fn: ClientFn, cid: str, fit_ins: common.FitIns -) -> common.FitRes: - """Exectue fit remotely.""" - client: Client = _create_client(client_fn, cid) - return maybe_call_fit( - client=client, - fit_ins=fit_ins, - ) - - -@ray.remote -def launch_and_evaluate( - client_fn: ClientFn, cid: str, evaluate_ins: common.EvaluateIns -) -> common.EvaluateRes: - """Exectue evaluate remotely.""" - client: Client = _create_client(client_fn, cid) - return maybe_call_evaluate( - client=client, - evaluate_ins=evaluate_ins, - ) - - -def _create_client(client_fn: ClientFn, cid: str) -> Client: - """Create a client instance.""" - # Materialize client - return client_fn(cid) From bfdf6d471be6ff21549ef55b15ae44ef002b4ba9 Mon Sep 17 00:00:00 2001 From: jafermarq Date: Wed, 10 Jan 2024 17:30:07 +0000 Subject: [PATCH 02/25] works --- .../simulation/ray_transport/ray_actor.py | 85 +++++++++----- .../ray_transport/ray_client_proxy.py | 109 +++++++++--------- 2 files changed, 108 insertions(+), 86 deletions(-) diff --git a/src/py/flwr/simulation/ray_transport/ray_actor.py b/src/py/flwr/simulation/ray_transport/ray_actor.py index 38af3f08daa2..69f87932e4f3 100644 --- a/src/py/flwr/simulation/ray_transport/ray_actor.py +++ b/src/py/flwr/simulation/ray_transport/ray_actor.py @@ -25,18 +25,20 @@ from ray import ObjectRef from ray.util.actor_pool import ActorPool -from flwr import common -from flwr.client import Client, ClientFn +from flwr.client import ClientFn +from flwr.client.message_handler.message_handler import ( + UnexpectedServerMessage, + UnknownServerMessage, + _evaluate, + _fit, + _get_parameters, + _get_properties, + get_server_message_from_task_ins, + wrap_client_message_in_task_res, +) from flwr.client.run_state import RunState from flwr.common.logger import log -from flwr.simulation.ray_transport.utils import check_clientfn_returns_client - -# All possible returns by a client -ClientRes = Union[ - common.GetPropertiesRes, common.GetParametersRes, common.FitRes, common.EvaluateRes -] -# A function to be executed by a client to obtain some results -JobFn = Callable[[Client], ClientRes] +from flwr.proto.task_pb2 import TaskIns, TaskRes class ClientException(Exception): @@ -59,23 +61,42 @@ def terminate(self) -> None: def run( self, client_fn: ClientFn, - job_fn: JobFn, + task_ins: TaskIns, cid: str, state: RunState, - ) -> Tuple[str, ClientRes, RunState]: - """Run a client run.""" - # Execute tasks and return result - # return also cid which is needed to ensure results - # from the pool are correctly assigned to each ClientProxy + ) -> Tuple[str, TaskRes, RunState]: + """Instantiate client and run TaskIns.""" try: - # Instantiate client (check 'Client' type is returned) - client = check_clientfn_returns_client(client_fn(cid)) - # Inject state + # Ideally we would be simply call `handle()` but we can't + # this is because we need to pass `cid` to `client_fn` + # Still, most of the code below is borrowed from handle() and the functions it calls internally + server_msg = server_msg = get_server_message_from_task_ins( + task_ins, exclude_reconnect_ins=False + ) + field = server_msg.WhichOneof("msg") + + # Must be handled elsewhere + if field == "reconnect_ins": + raise UnexpectedServerMessage() + + # Instantiate the client + client = client_fn(cid) client.set_state(state) - # Run client job - job_results = job_fn(client) - # Retrieve state (potentially updated) - updated_state = client.get_state() + # Execute task + message = None + if field == "get_properties_ins": + message = _get_properties(client, server_msg.get_properties_ins) + elif field == "get_parameters_ins": + message = _get_parameters(client, server_msg.get_parameters_ins) + elif field == "fit_ins": + message = _fit(client, server_msg.fit_ins) + elif field == "evaluate_ins": + message = _evaluate(client, server_msg.evaluate_ins) + else: + raise UnknownServerMessage() + + task_res = wrap_client_message_in_task_res(message) + except Exception as ex: client_trace = traceback.format_exc() message = ( @@ -89,7 +110,7 @@ def run( ) raise ClientException(str(message)) from ex - return cid, job_results, updated_state + return cid, task_res, client.get_state() @ray.remote @@ -237,16 +258,16 @@ def add_actors_to_pool(self, num_actors: int) -> None: self._idle_actors.extend(new_actors) self.num_actors += num_actors - def submit(self, fn: Any, value: Tuple[ClientFn, JobFn, str, RunState]) -> None: + def submit(self, fn: Any, value: Tuple[ClientFn, TaskIns, str, RunState]) -> None: """Take idle actor and assign it a client run. Submit a job to an actor by first removing it from the list of idle actors, then check if this actor was flagged to be removed from the pool """ - client_fn, job_fn, cid, state = value + client_fn, task_ins, cid, state = value actor = self._idle_actors.pop() if self._check_and_remove_actor_from_pool(actor): - future = fn(actor, client_fn, job_fn, cid, state) + future = fn(actor, client_fn, task_ins, cid, state) future_key = tuple(future) if isinstance(future, List) else future self._future_to_actor[future_key] = (self._next_task_index, actor, cid) self._next_task_index += 1 @@ -254,8 +275,8 @@ def submit(self, fn: Any, value: Tuple[ClientFn, JobFn, str, RunState]) -> None: # Update with future self._cid_to_future[cid]["future"] = future_key - def submit_client_job( - self, actor_fn: Any, job: Tuple[ClientFn, JobFn, str, RunState] + def submit_task_ins( + self, actor_fn: Any, job: Tuple[ClientFn, TaskIns, str, RunState] ) -> None: """Submit a job while tracking client ids.""" _, _, cid, _ = job @@ -295,7 +316,7 @@ def _is_future_ready(self, cid: str) -> bool: return self._cid_to_future[cid]["ready"] # type: ignore - def _fetch_future_result(self, cid: str) -> Tuple[ClientRes, RunState]: + def _fetch_future_result(self, cid: str) -> Tuple[TaskRes, RunState]: """Fetch result and updated state for a VirtualClient from Object Store. The job submitted by the ClientProxy interfacing with client with cid=cid is @@ -305,7 +326,7 @@ def _fetch_future_result(self, cid: str) -> Tuple[ClientRes, RunState]: future: ObjectRef[Any] = self._cid_to_future[cid]["future"] # type: ignore res_cid, res, updated_state = ray.get( future - ) # type: (str, ClientRes, RunState) + ) # type: (str, TaskRes, RunState) except ray.exceptions.RayActorError as ex: log(ERROR, ex) if hasattr(ex, "actor_id"): @@ -409,7 +430,7 @@ def process_unordered_future(self, timeout: Optional[float] = None) -> None: def get_client_result( self, cid: str, timeout: Optional[float] - ) -> Tuple[ClientRes, RunState]: + ) -> Tuple[TaskRes, RunState]: """Get result from VirtualClient with specific cid.""" # Loop until all jobs submitted to the pool are completed. Break early # if the result for the ClientProxy calling this method is ready diff --git a/src/py/flwr/simulation/ray_transport/ray_client_proxy.py b/src/py/flwr/simulation/ray_transport/ray_client_proxy.py index 7e7678837b14..4e51bb24715d 100644 --- a/src/py/flwr/simulation/ray_transport/ray_client_proxy.py +++ b/src/py/flwr/simulation/ray_transport/ray_client_proxy.py @@ -17,24 +17,22 @@ import traceback from logging import ERROR -from typing import Optional, cast +from typing import Optional, Union, cast from flwr import common from flwr.client import Client, ClientFn from flwr.client.client import ( maybe_call_evaluate, maybe_call_fit, - maybe_call_get_parameters, maybe_call_get_properties, ) from flwr.client.node_state import NodeState +from flwr.common import serde from flwr.common.logger import log +from flwr.proto.task_pb2 import Task, TaskIns, TaskRes +from flwr.proto.transport_pb2 import ClientMessage, ServerMessage from flwr.server.client_proxy import ClientProxy -from flwr.simulation.ray_transport.ray_actor import ( - ClientRes, - JobFn, - VirtualClientEngineActorPool, -) +from flwr.simulation.ray_transport.ray_actor import VirtualClientEngineActorPool class RayActorClientProxy(ClientProxy): @@ -48,7 +46,7 @@ def __init__( self.actor_pool = actor_pool self.proxy_state = NodeState() - def _submit_job(self, job_fn: JobFn, timeout: Optional[float]) -> ClientRes: + def _submit_taskins(self, task_ins: TaskIns, timeout: Optional[float]) -> TaskRes: # The VCE is not exposed to TaskIns, it won't handle multilple runs # For the time being, fixing run_id is a small compromise # This will be one of the first points to address integrating VCE + DriverAPI @@ -61,11 +59,15 @@ def _submit_job(self, job_fn: JobFn, timeout: Optional[float]) -> ClientRes: state = self.proxy_state.retrieve_runstate(run_id=run_id) try: - self.actor_pool.submit_client_job( - lambda a, c_fn, j_fn, cid, state: a.run.remote(c_fn, j_fn, cid, state), - (self.client_fn, job_fn, self.cid, state), + self.actor_pool.submit_task_ins( + lambda a, c_fn, t_ins, cid, state: a.run.remote( + c_fn, t_ins, cid, state + ), + (self.client_fn, task_ins, self.cid, state), + ) + task_res, updated_state = self.actor_pool.get_client_result( + self.cid, timeout ) - res, updated_state = self.actor_pool.get_client_result(self.cid, timeout) # Update state self.proxy_state.update_runstate(run_id=run_id, run_state=updated_state) @@ -79,77 +81,76 @@ def _submit_job(self, job_fn: JobFn, timeout: Optional[float]) -> ClientRes: log(ERROR, ex) raise ex - return res + return task_res + + def _submit_server_message_to_pool( + self, server_msg: ServerMessage, timeout + ) -> ClientMessage: + task_ins = TaskIns( + task_id="", + group_id="", + run_id=0, + task=Task(ancestry=[], legacy_server_message=server_msg), + ) + + # Submit + task_res = self._submit_taskins(task_ins, timeout) + + # To client message + return serde.client_message_from_proto(task_res.task.legacy_client_message) def get_properties( self, ins: common.GetPropertiesIns, timeout: Optional[float] ) -> common.GetPropertiesRes: """Return client's properties.""" - def get_properties(client: Client) -> common.GetPropertiesRes: - return maybe_call_get_properties( - client=client, - get_properties_ins=ins, - ) + ins_proto = serde.get_properties_ins_to_proto(ins) + server_msg = ServerMessage(get_properties_ins=ins_proto) - res = self._submit_job(get_properties, timeout) + # Submit (block until completed) + client_msg = self._submit_server_message_to_pool(server_msg, timeout) - return cast( - common.GetPropertiesRes, - res, - ) + # Return as legacy type + return client_msg.get_properties_res def get_parameters( self, ins: common.GetParametersIns, timeout: Optional[float] ) -> common.GetParametersRes: """Return the current local model parameters.""" + ins_proto = serde.get_parameters_ins_to_proto(ins) + server_msg = ServerMessage(get_parameters_ins=ins_proto) - def get_parameters(client: Client) -> common.GetParametersRes: - return maybe_call_get_parameters( - client=client, - get_parameters_ins=ins, - ) - - res = self._submit_job(get_parameters, timeout) + # Submit (block until completed) + client_msg = self._submit_server_message_to_pool(server_msg, timeout) - return cast( - common.GetParametersRes, - res, - ) + # Return as legacy type + return client_msg.get_parameters_res def fit(self, ins: common.FitIns, timeout: Optional[float]) -> common.FitRes: """Train model parameters on the locally held dataset.""" - def fit(client: Client) -> common.FitRes: - return maybe_call_fit( - client=client, - fit_ins=ins, - ) + ins_proto = serde.fit_ins_to_proto(ins) + server_msg = ServerMessage(fit_ins=ins_proto) - res = self._submit_job(fit, timeout) + # Submit (block until completed) + client_msg = self._submit_server_message_to_pool(server_msg, timeout) - return cast( - common.FitRes, - res, - ) + # Return as legacy type + return client_msg.fit_res def evaluate( self, ins: common.EvaluateIns, timeout: Optional[float] ) -> common.EvaluateRes: """Evaluate model parameters on the locally held dataset.""" - def evaluate(client: Client) -> common.EvaluateRes: - return maybe_call_evaluate( - client=client, - evaluate_ins=ins, - ) + ins_proto = serde.evaluate_ins_to_proto(ins) + server_msg = ServerMessage(evaluate_ins=ins_proto) - res = self._submit_job(evaluate, timeout) + # Submit (block until completed) + client_msg = self._submit_server_message_to_pool(server_msg, timeout) - return cast( - common.EvaluateRes, - res, - ) + # Return as legacy type + return client_msg.evaluate_res def reconnect( self, ins: common.ReconnectIns, timeout: Optional[float] From fbad2d6538dfe182df2d4d17085fdbbedbb8f3c9 Mon Sep 17 00:00:00 2001 From: jafermarq Date: Wed, 10 Jan 2024 17:34:12 +0000 Subject: [PATCH 03/25] format --- src/py/flwr/simulation/ray_transport/ray_actor.py | 2 +- .../simulation/ray_transport/ray_client_proxy.py | 12 ++---------- 2 files changed, 3 insertions(+), 11 deletions(-) diff --git a/src/py/flwr/simulation/ray_transport/ray_actor.py b/src/py/flwr/simulation/ray_transport/ray_actor.py index 69f87932e4f3..ab8db9c1dcfb 100644 --- a/src/py/flwr/simulation/ray_transport/ray_actor.py +++ b/src/py/flwr/simulation/ray_transport/ray_actor.py @@ -80,7 +80,7 @@ def run( raise UnexpectedServerMessage() # Instantiate the client - client = client_fn(cid) + client = client_fn(cid) # client_fn must return Client type client.set_state(state) # Execute task message = None diff --git a/src/py/flwr/simulation/ray_transport/ray_client_proxy.py b/src/py/flwr/simulation/ray_transport/ray_client_proxy.py index 4e51bb24715d..576815914a6c 100644 --- a/src/py/flwr/simulation/ray_transport/ray_client_proxy.py +++ b/src/py/flwr/simulation/ray_transport/ray_client_proxy.py @@ -17,15 +17,10 @@ import traceback from logging import ERROR -from typing import Optional, Union, cast +from typing import Optional from flwr import common -from flwr.client import Client, ClientFn -from flwr.client.client import ( - maybe_call_evaluate, - maybe_call_fit, - maybe_call_get_properties, -) +from flwr.client import ClientFn from flwr.client.node_state import NodeState from flwr.common import serde from flwr.common.logger import log @@ -103,7 +98,6 @@ def get_properties( self, ins: common.GetPropertiesIns, timeout: Optional[float] ) -> common.GetPropertiesRes: """Return client's properties.""" - ins_proto = serde.get_properties_ins_to_proto(ins) server_msg = ServerMessage(get_properties_ins=ins_proto) @@ -128,7 +122,6 @@ def get_parameters( def fit(self, ins: common.FitIns, timeout: Optional[float]) -> common.FitRes: """Train model parameters on the locally held dataset.""" - ins_proto = serde.fit_ins_to_proto(ins) server_msg = ServerMessage(fit_ins=ins_proto) @@ -142,7 +135,6 @@ def evaluate( self, ins: common.EvaluateIns, timeout: Optional[float] ) -> common.EvaluateRes: """Evaluate model parameters on the locally held dataset.""" - ins_proto = serde.evaluate_ins_to_proto(ins) server_msg = ServerMessage(evaluate_ins=ins_proto) From ae07103447a15d24b6359a377df029e7ad59abae Mon Sep 17 00:00:00 2001 From: jafermarq Date: Wed, 10 Jan 2024 17:44:45 +0000 Subject: [PATCH 04/25] ensure Client type returned --- src/py/flwr/simulation/ray_transport/ray_actor.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/py/flwr/simulation/ray_transport/ray_actor.py b/src/py/flwr/simulation/ray_transport/ray_actor.py index ab8db9c1dcfb..617e7d1db732 100644 --- a/src/py/flwr/simulation/ray_transport/ray_actor.py +++ b/src/py/flwr/simulation/ray_transport/ray_actor.py @@ -40,6 +40,8 @@ from flwr.common.logger import log from flwr.proto.task_pb2 import TaskIns, TaskRes +from .utils import check_clientfn_returns_client + class ClientException(Exception): """Raised when client side logic crashes with an exception.""" @@ -80,7 +82,9 @@ def run( raise UnexpectedServerMessage() # Instantiate the client - client = client_fn(cid) # client_fn must return Client type + client = check_clientfn_returns_client( + client_fn(cid) + ) # client_fn must return Client type client.set_state(state) # Execute task message = None From cc9f52698dc31ece94e95f9e0a5fe233aefb94e1 Mon Sep 17 00:00:00 2001 From: jafermarq Date: Mon, 29 Jan 2024 17:03:33 +0000 Subject: [PATCH 05/25] remove unused --- .../ray_transport/ray_client_proxy.py | 138 +----------------- 1 file changed, 1 insertion(+), 137 deletions(-) diff --git a/src/py/flwr/simulation/ray_transport/ray_client_proxy.py b/src/py/flwr/simulation/ray_transport/ray_client_proxy.py index 894012dc6d70..4c8df6c25396 100644 --- a/src/py/flwr/simulation/ray_transport/ray_client_proxy.py +++ b/src/py/flwr/simulation/ray_transport/ray_client_proxy.py @@ -17,9 +17,8 @@ import traceback from logging import ERROR -from typing import Dict, Optional, cast +from typing import Optional, cast -import ray from flwr import common from flwr.client import Client, ClientFn @@ -39,87 +38,6 @@ ) -class RayClientProxy(ClientProxy): - """Flower client proxy which delegates work using Ray.""" - - def __init__(self, client_fn: ClientFn, cid: str, resources: Dict[str, float]): - super().__init__(cid) - self.client_fn = client_fn - self.resources = resources - - def get_properties( - self, ins: common.GetPropertiesIns, timeout: Optional[float] - ) -> common.GetPropertiesRes: - """Return client's properties.""" - future_get_properties_res = launch_and_get_properties.options( # type: ignore - **self.resources, - ).remote(self.client_fn, self.cid, ins) - try: - res = ray.get(future_get_properties_res, timeout=timeout) - except Exception as ex: - log(ERROR, ex) - raise ex - return cast( - common.GetPropertiesRes, - res, - ) - - def get_parameters( - self, ins: common.GetParametersIns, timeout: Optional[float] - ) -> common.GetParametersRes: - """Return the current local model parameters.""" - future_paramseters_res = launch_and_get_parameters.options( # type: ignore - **self.resources, - ).remote(self.client_fn, self.cid, ins) - try: - res = ray.get(future_paramseters_res, timeout=timeout) - except Exception as ex: - log(ERROR, ex) - raise ex - return cast( - common.GetParametersRes, - res, - ) - - def fit(self, ins: common.FitIns, timeout: Optional[float]) -> common.FitRes: - """Train model parameters on the locally held dataset.""" - future_fit_res = launch_and_fit.options( # type: ignore - **self.resources, - ).remote(self.client_fn, self.cid, ins) - try: - res = ray.get(future_fit_res, timeout=timeout) - except Exception as ex: - log(ERROR, ex) - raise ex - return cast( - common.FitRes, - res, - ) - - def evaluate( - self, ins: common.EvaluateIns, timeout: Optional[float] - ) -> common.EvaluateRes: - """Evaluate model parameters on the locally held dataset.""" - future_evaluate_res = launch_and_evaluate.options( # type: ignore - **self.resources, - ).remote(self.client_fn, self.cid, ins) - try: - res = ray.get(future_evaluate_res, timeout=timeout) - except Exception as ex: - log(ERROR, ex) - raise ex - return cast( - common.EvaluateRes, - res, - ) - - def reconnect( - self, ins: common.ReconnectIns, timeout: Optional[float] - ) -> common.DisconnectRes: - """Disconnect and (optionally) reconnect later.""" - return common.DisconnectRes(reason="") # Nothing to do here (yet) - - class RayActorClientProxy(ClientProxy): """Flower client proxy which delegates work using Ray.""" @@ -239,57 +157,3 @@ def reconnect( ) -> common.DisconnectRes: """Disconnect and (optionally) reconnect later.""" return common.DisconnectRes(reason="") # Nothing to do here (yet) - - -@ray.remote -def launch_and_get_properties( - client_fn: ClientFn, cid: str, get_properties_ins: common.GetPropertiesIns -) -> common.GetPropertiesRes: - """Exectue get_properties remotely.""" - client: Client = _create_client(client_fn, cid) - return maybe_call_get_properties( - client=client, - get_properties_ins=get_properties_ins, - ) - - -@ray.remote -def launch_and_get_parameters( - client_fn: ClientFn, cid: str, get_parameters_ins: common.GetParametersIns -) -> common.GetParametersRes: - """Exectue get_parameters remotely.""" - client: Client = _create_client(client_fn, cid) - return maybe_call_get_parameters( - client=client, - get_parameters_ins=get_parameters_ins, - ) - - -@ray.remote -def launch_and_fit( - client_fn: ClientFn, cid: str, fit_ins: common.FitIns -) -> common.FitRes: - """Exectue fit remotely.""" - client: Client = _create_client(client_fn, cid) - return maybe_call_fit( - client=client, - fit_ins=fit_ins, - ) - - -@ray.remote -def launch_and_evaluate( - client_fn: ClientFn, cid: str, evaluate_ins: common.EvaluateIns -) -> common.EvaluateRes: - """Exectue evaluate remotely.""" - client: Client = _create_client(client_fn, cid) - return maybe_call_evaluate( - client=client, - evaluate_ins=evaluate_ins, - ) - - -def _create_client(client_fn: ClientFn, cid: str) -> Client: - """Create a client instance.""" - # Materialize client - return client_fn(cid) From 36a73dc980d9180da00dd767904965b6c9c90751 Mon Sep 17 00:00:00 2001 From: jafermarq Date: Mon, 29 Jan 2024 18:42:08 +0000 Subject: [PATCH 06/25] w/ Message and Flower callable --- .../simulation/ray_transport/ray_actor.py | 47 +++--- .../ray_transport/ray_client_proxy.py | 153 ++++++++++-------- .../ray_transport/ray_client_proxy_test.py | 2 +- 3 files changed, 113 insertions(+), 89 deletions(-) diff --git a/src/py/flwr/simulation/ray_transport/ray_actor.py b/src/py/flwr/simulation/ray_transport/ray_actor.py index 853566a4cbeb..b5d2e3bc3660 100644 --- a/src/py/flwr/simulation/ray_transport/ray_actor.py +++ b/src/py/flwr/simulation/ray_transport/ray_actor.py @@ -26,10 +26,11 @@ from ray.util.actor_pool import ActorPool from flwr import common -from flwr.client import Client, ClientFn +from flwr.client import Client +from flwr.client.flower import Flower from flwr.common.context import Context from flwr.common.logger import log -from flwr.simulation.ray_transport.utils import check_clientfn_returns_client +from flwr.common.message import Message # All possible returns by a client ClientRes = Union[ @@ -38,6 +39,8 @@ # A function to be executed by a client to obtain some results JobFn = Callable[[Client], ClientRes] +FlowerFn = Callable[[], Flower] + class ClientException(Exception): """Raised when client side logic crashes with an exception.""" @@ -58,27 +61,25 @@ def terminate(self) -> None: def run( self, - client_fn: ClientFn, - job_fn: JobFn, + app_fn: FlowerFn, + message: Message, cid: str, context: Context, - ) -> Tuple[str, ClientRes, Context]: + ) -> Tuple[str, Message, Context]: """Run a client run.""" # Execute tasks and return result # return also cid which is needed to ensure results # from the pool are correctly assigned to each ClientProxy try: - # Instantiate client (check 'Client' type is returned) - client = check_clientfn_returns_client(client_fn(cid)) - # Inject context - client.set_context(context) - # Run client job - job_results = job_fn(client) - # Retrieve context (potentially updated) - updated_context = client.get_context() + # Load app + app: Flower = app_fn() + + # Handle task message + out_message = app(message=message, context=context) + except Exception as ex: client_trace = traceback.format_exc() - message = ( + mssg = ( "\n\tSomething went wrong when running your client run." "\n\tClient " + cid @@ -87,9 +88,9 @@ def run( + " was running its run." "\n\tException triggered on the client side: " + client_trace, ) - raise ClientException(str(message)) from ex + raise ClientException(str(mssg)) from ex - return cid, job_results, updated_context + return cid, out_message, context @ray.remote @@ -237,16 +238,16 @@ def add_actors_to_pool(self, num_actors: int) -> None: self._idle_actors.extend(new_actors) self.num_actors += num_actors - def submit(self, fn: Any, value: Tuple[ClientFn, JobFn, str, Context]) -> None: + def submit(self, fn: Any, value: Tuple[FlowerFn, Message, str, Context]) -> None: """Take idle actor and assign it a client run. Submit a job to an actor by first removing it from the list of idle actors, then check if this actor was flagged to be removed from the pool """ - client_fn, job_fn, cid, context = value + app_fn, mssg, cid, context = value actor = self._idle_actors.pop() if self._check_and_remove_actor_from_pool(actor): - future = fn(actor, client_fn, job_fn, cid, context) + future = fn(actor, app_fn, mssg, cid, context) future_key = tuple(future) if isinstance(future, List) else future self._future_to_actor[future_key] = (self._next_task_index, actor, cid) self._next_task_index += 1 @@ -255,7 +256,7 @@ def submit(self, fn: Any, value: Tuple[ClientFn, JobFn, str, Context]) -> None: self._cid_to_future[cid]["future"] = future_key def submit_client_job( - self, actor_fn: Any, job: Tuple[ClientFn, JobFn, str, Context] + self, actor_fn: Any, job: Tuple[FlowerFn, Message, str, Context] ) -> None: """Submit a job while tracking client ids.""" _, _, cid, _ = job @@ -295,7 +296,7 @@ def _is_future_ready(self, cid: str) -> bool: return self._cid_to_future[cid]["ready"] # type: ignore - def _fetch_future_result(self, cid: str) -> Tuple[ClientRes, Context]: + def _fetch_future_result(self, cid: str) -> Tuple[Message, Context]: """Fetch result and updated context for a VirtualClient from Object Store. The job submitted by the ClientProxy interfacing with client with cid=cid is @@ -305,7 +306,7 @@ def _fetch_future_result(self, cid: str) -> Tuple[ClientRes, Context]: future: ObjectRef[Any] = self._cid_to_future[cid]["future"] # type: ignore res_cid, res, updated_context = ray.get( future - ) # type: (str, ClientRes, Context) + ) # type: (str, Message, Context) except ray.exceptions.RayActorError as ex: log(ERROR, ex) if hasattr(ex, "actor_id"): @@ -409,7 +410,7 @@ def process_unordered_future(self, timeout: Optional[float] = None) -> None: def get_client_result( self, cid: str, timeout: Optional[float] - ) -> Tuple[ClientRes, Context]: + ) -> Tuple[Message, Context]: """Get result from VirtualClient with specific cid.""" # Loop until all jobs submitted to the pool are completed. Break early # if the result for the ClientProxy calling this method is ready diff --git a/src/py/flwr/simulation/ray_transport/ray_client_proxy.py b/src/py/flwr/simulation/ray_transport/ray_client_proxy.py index 4c8df6c25396..64d73e24b4a1 100644 --- a/src/py/flwr/simulation/ray_transport/ray_client_proxy.py +++ b/src/py/flwr/simulation/ray_transport/ray_client_proxy.py @@ -17,25 +17,32 @@ import traceback from logging import ERROR -from typing import Optional, cast - +from typing import Optional from flwr import common -from flwr.client import Client, ClientFn -from flwr.client.client import ( - maybe_call_evaluate, - maybe_call_fit, - maybe_call_get_parameters, - maybe_call_get_properties, -) +from flwr.client import ClientFn +from flwr.client.flower import Flower from flwr.client.node_state import NodeState +from flwr.common.constant import ( + TASK_TYPE_EVALUATE, + TASK_TYPE_FIT, + TASK_TYPE_GET_PARAMETERS, + TASK_TYPE_GET_PROPERTIES, +) from flwr.common.logger import log -from flwr.server.client_proxy import ClientProxy -from flwr.simulation.ray_transport.ray_actor import ( - ClientRes, - JobFn, - VirtualClientEngineActorPool, +from flwr.common.message import Message, Metadata +from flwr.common.recordset_compat import ( + evaluateins_to_recordset, + fitins_to_recordset, + getparametersins_to_recordset, + getpropertiesins_to_recordset, + recordset_to_evaluateres, + recordset_to_fitres, + recordset_to_getparametersres, + recordset_to_getpropertiesres, ) +from flwr.server.client_proxy import ClientProxy +from flwr.simulation.ray_transport.ray_actor import VirtualClientEngineActorPool class RayActorClientProxy(ClientProxy): @@ -45,15 +52,19 @@ def __init__( self, client_fn: ClientFn, cid: str, actor_pool: VirtualClientEngineActorPool ): super().__init__(cid) - self.client_fn = client_fn + + def _load_app() -> Flower: + return Flower(client_fn=client_fn) + + self.app_fn = _load_app self.actor_pool = actor_pool self.proxy_state = NodeState() - def _submit_job(self, job_fn: JobFn, timeout: Optional[float]) -> ClientRes: + def _submit_job(self, message: Message, timeout: Optional[float]) -> Message: # The VCE is not exposed to TaskIns, it won't handle multilple runs # For the time being, fixing run_id is a small compromise # This will be one of the first points to address integrating VCE + DriverAPI - run_id = 0 + run_id = message.metadata.run_id # Register state self.proxy_state.register_context(run_id=run_id) @@ -63,10 +74,12 @@ def _submit_job(self, job_fn: JobFn, timeout: Optional[float]) -> ClientRes: try: self.actor_pool.submit_client_job( - lambda a, c_fn, j_fn, cid, state: a.run.remote(c_fn, j_fn, cid, state), - (self.client_fn, job_fn, self.cid, state), + lambda a, a_fn, mssg, cid, state: a.run.remote(a_fn, mssg, cid, state), + (self.app_fn, message, self.cid, state), + ) + out_mssg, updated_context = self.actor_pool.get_client_result( + self.cid, timeout ) - res, updated_context = self.actor_pool.get_client_result(self.cid, timeout) # Update state self.proxy_state.update_context(run_id=run_id, context=updated_context) @@ -80,77 +93,87 @@ def _submit_job(self, job_fn: JobFn, timeout: Optional[float]) -> ClientRes: log(ERROR, ex) raise ex - return res + return out_mssg def get_properties( self, ins: common.GetPropertiesIns, timeout: Optional[float] ) -> common.GetPropertiesRes: """Return client's properties.""" + recordset = getpropertiesins_to_recordset(ins) + message = Message( + message=recordset, + metadata=Metadata( + run_id=0, + task_id="", + group_id="", + ttl="", + task_type=TASK_TYPE_GET_PROPERTIES, + ), + ) - def get_properties(client: Client) -> common.GetPropertiesRes: - return maybe_call_get_properties( - client=client, - get_properties_ins=ins, - ) - - res = self._submit_job(get_properties, timeout) + message_out = self._submit_job(message, timeout) - return cast( - common.GetPropertiesRes, - res, - ) + return recordset_to_getpropertiesres(message_out.message) def get_parameters( self, ins: common.GetParametersIns, timeout: Optional[float] ) -> common.GetParametersRes: """Return the current local model parameters.""" + recordset = getparametersins_to_recordset(ins) + message = Message( + message=recordset, + metadata=Metadata( + run_id=0, + task_id="", + group_id="", + ttl="", + task_type=TASK_TYPE_GET_PARAMETERS, + ), + ) - def get_parameters(client: Client) -> common.GetParametersRes: - return maybe_call_get_parameters( - client=client, - get_parameters_ins=ins, - ) - - res = self._submit_job(get_parameters, timeout) + message_out = self._submit_job(message, timeout) - return cast( - common.GetParametersRes, - res, - ) + return recordset_to_getparametersres(message_out.message, keep_input=False) def fit(self, ins: common.FitIns, timeout: Optional[float]) -> common.FitRes: """Train model parameters on the locally held dataset.""" + recordset = fitins_to_recordset( + ins, keep_input=True + ) # This must stay TRUE since ins are in-memory + message = Message( + message=recordset, + metadata=Metadata( + run_id=0, + task_id="", + group_id="", + ttl="", + task_type=TASK_TYPE_FIT, + ), + ) - def fit(client: Client) -> common.FitRes: - return maybe_call_fit( - client=client, - fit_ins=ins, - ) - - res = self._submit_job(fit, timeout) + message_out = self._submit_job(message, timeout) - return cast( - common.FitRes, - res, - ) + return recordset_to_fitres(message_out.message, keep_input=False) def evaluate( self, ins: common.EvaluateIns, timeout: Optional[float] ) -> common.EvaluateRes: """Evaluate model parameters on the locally held dataset.""" + recordset = evaluateins_to_recordset(ins, keep_input=True) + message = Message( + message=recordset, + metadata=Metadata( + run_id=0, + task_id="", + group_id="", + ttl="", + task_type=TASK_TYPE_EVALUATE, + ), + ) - def evaluate(client: Client) -> common.EvaluateRes: - return maybe_call_evaluate( - client=client, - evaluate_ins=ins, - ) - - res = self._submit_job(evaluate, timeout) + message_out = self._submit_job(message, timeout) - return cast( - common.EvaluateRes, - res, - ) + return recordset_to_evaluateres(message_out.message) def reconnect( self, ins: common.ReconnectIns, timeout: Optional[float] diff --git a/src/py/flwr/simulation/ray_transport/ray_client_proxy_test.py b/src/py/flwr/simulation/ray_transport/ray_client_proxy_test.py index b380d37d01c8..55abad6ef38a 100644 --- a/src/py/flwr/simulation/ray_transport/ray_client_proxy_test.py +++ b/src/py/flwr/simulation/ray_transport/ray_client_proxy_test.py @@ -107,7 +107,7 @@ def test_cid_consistency_one_at_a_time() -> None: # submit jobs one at a time for prox in proxies: res = prox._submit_job( # pylint: disable=protected-access - job_fn=job_fn(prox.cid), timeout=None + message=job_fn(prox.cid), timeout=None ) res = cast(GetPropertiesRes, res) From 9a3c23c40bd11c41a3874ddfe5bd9cb0a540f566 Mon Sep 17 00:00:00 2001 From: jafermarq Date: Mon, 29 Jan 2024 18:48:48 +0000 Subject: [PATCH 07/25] w/ previous --- src/py/flwr/simulation/ray_transport/ray_client_proxy.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/py/flwr/simulation/ray_transport/ray_client_proxy.py b/src/py/flwr/simulation/ray_transport/ray_client_proxy.py index 64d73e24b4a1..cf4238baa871 100644 --- a/src/py/flwr/simulation/ray_transport/ray_client_proxy.py +++ b/src/py/flwr/simulation/ray_transport/ray_client_proxy.py @@ -159,7 +159,9 @@ def evaluate( self, ins: common.EvaluateIns, timeout: Optional[float] ) -> common.EvaluateRes: """Evaluate model parameters on the locally held dataset.""" - recordset = evaluateins_to_recordset(ins, keep_input=True) + recordset = evaluateins_to_recordset( + ins, keep_input=True + ) # This must stay TRUE since ins are in-memory message = Message( message=recordset, metadata=Metadata( From adb4c982599b1a77fbe627a9265c210f797338d7 Mon Sep 17 00:00:00 2001 From: jafermarq Date: Tue, 6 Feb 2024 15:34:07 +0000 Subject: [PATCH 08/25] rewrapping client_fn --- src/py/flwr/simulation/ray_transport/ray_client_proxy.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/py/flwr/simulation/ray_transport/ray_client_proxy.py b/src/py/flwr/simulation/ray_transport/ray_client_proxy.py index cf4238baa871..94399c979a0f 100644 --- a/src/py/flwr/simulation/ray_transport/ray_client_proxy.py +++ b/src/py/flwr/simulation/ray_transport/ray_client_proxy.py @@ -54,7 +54,9 @@ def __init__( super().__init__(cid) def _load_app() -> Flower: - return Flower(client_fn=client_fn) + def wrap(cid: str): + return client_fn(self.cid) + return Flower(client_fn=wrap) self.app_fn = _load_app self.actor_pool = actor_pool From 982857069dbad765679692856eee919b43e952a1 Mon Sep 17 00:00:00 2001 From: jafermarq Date: Tue, 6 Feb 2024 20:49:15 +0000 Subject: [PATCH 09/25] wip --- src/py/flwr/client/grpc_client/connection.py | 1 + src/py/flwr/client/grpc_client/connection_test.py | 2 ++ src/py/flwr/client/message_handler/message_handler.py | 4 +++- src/py/flwr/client/message_handler/message_handler_test.py | 2 ++ .../middleware/secure_aggregation/secaggplus_middleware.py | 2 +- .../secure_aggregation/secaggplus_middleware_test.py | 4 ++-- src/py/flwr/client/middleware/utils_test.py | 4 +++- src/py/flwr/common/message.py | 3 +++ src/py/flwr/common/serde.py | 2 ++ src/py/flwr/common/serde_test.py | 3 +++ 10 files changed, 22 insertions(+), 5 deletions(-) diff --git a/src/py/flwr/client/grpc_client/connection.py b/src/py/flwr/client/grpc_client/connection.py index bd1ea5fab307..4578b88a57b2 100644 --- a/src/py/flwr/client/grpc_client/connection.py +++ b/src/py/flwr/client/grpc_client/connection.py @@ -172,6 +172,7 @@ def receive() -> Message: run_id=0, task_id=str(uuid.uuid4()), group_id="", + node_id="", ttl="", task_type=task_type, ), diff --git a/src/py/flwr/client/grpc_client/connection_test.py b/src/py/flwr/client/grpc_client/connection_test.py index 8631c4a9b12b..0acb8c1ac86c 100644 --- a/src/py/flwr/client/grpc_client/connection_test.py +++ b/src/py/flwr/client/grpc_client/connection_test.py @@ -48,6 +48,7 @@ run_id=0, task_id="", group_id="", + node_id="", ttl="", task_type=TASK_TYPE_GET_PROPERTIES, ), @@ -60,6 +61,7 @@ run_id=0, task_id="", group_id="", + node_id="", ttl="", task_type="reconnect", ), diff --git a/src/py/flwr/client/message_handler/message_handler.py b/src/py/flwr/client/message_handler/message_handler.py index e61e151da5e9..cecaf51b542b 100644 --- a/src/py/flwr/client/message_handler/message_handler.py +++ b/src/py/flwr/client/message_handler/message_handler.py @@ -92,6 +92,7 @@ def handle_control_message(message: Message) -> Tuple[Optional[Message], int]: run_id=0, task_id="", group_id="", + node_id="", ttl="", task_type="reconnect", ), @@ -108,7 +109,7 @@ def handle_legacy_message_from_tasktype( client_fn: ClientFn, message: Message, context: Context ) -> Message: """Handle legacy message in the inner most middleware layer.""" - client = client_fn("-1") + client = client_fn(message.metadata.node_id) client.set_context(context) @@ -153,6 +154,7 @@ def handle_legacy_message_from_tasktype( run_id=0, # Non-user defined task_id="", # Non-user defined group_id="", # Non-user defined + node_id="", ttl="", task_type=task_type, ), diff --git a/src/py/flwr/client/message_handler/message_handler_test.py b/src/py/flwr/client/message_handler/message_handler_test.py index ad09ca95abc7..d992ae1eba21 100644 --- a/src/py/flwr/client/message_handler/message_handler_test.py +++ b/src/py/flwr/client/message_handler/message_handler_test.py @@ -124,6 +124,7 @@ def test_client_without_get_properties() -> None: run_id=0, task_id=str(uuid.uuid4()), group_id="", + node_id="", ttl="", task_type=TASK_TYPE_GET_PROPERTIES, ), @@ -162,6 +163,7 @@ def test_client_with_get_properties() -> None: run_id=0, task_id=str(uuid.uuid4()), group_id="", + node_id="", ttl="", task_type=TASK_TYPE_GET_PROPERTIES, ), diff --git a/src/py/flwr/client/middleware/secure_aggregation/secaggplus_middleware.py b/src/py/flwr/client/middleware/secure_aggregation/secaggplus_middleware.py index 885dc4d9cbf5..e352e111ddf7 100644 --- a/src/py/flwr/client/middleware/secure_aggregation/secaggplus_middleware.py +++ b/src/py/flwr/client/middleware/secure_aggregation/secaggplus_middleware.py @@ -207,7 +207,7 @@ def secaggplus_middleware( # Return message return Message( - metadata=Metadata(0, "", "", "", TASK_TYPE_FIT), + metadata=Metadata(0, "", "", "", "", TASK_TYPE_FIT), message=RecordSet(configs={RECORD_KEY_CONFIGS: ConfigsRecord(res, False)}), ) diff --git a/src/py/flwr/client/middleware/secure_aggregation/secaggplus_middleware_test.py b/src/py/flwr/client/middleware/secure_aggregation/secaggplus_middleware_test.py index 8ec52d71cbdd..3a64a0da1325 100644 --- a/src/py/flwr/client/middleware/secure_aggregation/secaggplus_middleware_test.py +++ b/src/py/flwr/client/middleware/secure_aggregation/secaggplus_middleware_test.py @@ -57,7 +57,7 @@ def get_test_handler( def empty_ffn(_: Message, _2: Context) -> Message: return Message( - metadata=Metadata(0, "", "", "", TASK_TYPE_FIT), + metadata=Metadata(0, "", "", "", "", TASK_TYPE_FIT), message=RecordSet(), ) @@ -65,7 +65,7 @@ def empty_ffn(_: Message, _2: Context) -> Message: def func(configs: Dict[str, ConfigsRecordValues]) -> Dict[str, ConfigsRecordValues]: in_msg = Message( - metadata=Metadata(0, "", "", "", TASK_TYPE_FIT), + metadata=Metadata(0, "", "", "", "", TASK_TYPE_FIT), message=RecordSet(configs={RECORD_KEY_CONFIGS: ConfigsRecord(configs)}), ) out_msg = app(in_msg, ctxt) diff --git a/src/py/flwr/client/middleware/utils_test.py b/src/py/flwr/client/middleware/utils_test.py index a4e1f6e87599..c10c3134e0db 100644 --- a/src/py/flwr/client/middleware/utils_test.py +++ b/src/py/flwr/client/middleware/utils_test.py @@ -74,7 +74,9 @@ def app(message: Message, context: Context) -> Message: def _get_dummy_flower_message() -> Message: return Message( message=RecordSet(), - metadata=Metadata(run_id=0, task_id="", group_id="", ttl="", task_type="mock"), + metadata=Metadata( + run_id=0, task_id="", group_id="", node_id="", ttl="", task_type="mock" + ), ) diff --git a/src/py/flwr/common/message.py b/src/py/flwr/common/message.py index f693d8e27bc3..d21a3a39232b 100644 --- a/src/py/flwr/common/message.py +++ b/src/py/flwr/common/message.py @@ -33,6 +33,8 @@ class Metadata: group_id : str An identifier for grouping tasks. In some settings this is used as the FL round. + node_id : str + An identifier for the node running a task. ttl : str Time-to-live for this task. task_type : str @@ -43,6 +45,7 @@ class Metadata: run_id: int task_id: str group_id: str + node_id: str ttl: str task_type: str diff --git a/src/py/flwr/common/serde.py b/src/py/flwr/common/serde.py index 92c4e2cdad00..c4bbc490ac8f 100644 --- a/src/py/flwr/common/serde.py +++ b/src/py/flwr/common/serde.py @@ -563,6 +563,7 @@ def message_from_taskins(taskins: TaskIns) -> Message: run_id=taskins.run_id, task_id=taskins.task_id, group_id=taskins.group_id, + node_id="", ttl=taskins.task.ttl, task_type=taskins.task.task_type, ) @@ -592,6 +593,7 @@ def message_from_taskres(taskres: TaskRes) -> Message: run_id=taskres.run_id, task_id=taskres.task_id, group_id=taskres.group_id, + node_id="", ttl=taskres.task.ttl, task_type=taskres.task.task_type, ) diff --git a/src/py/flwr/common/serde_test.py b/src/py/flwr/common/serde_test.py index 2a229a87e399..3c1c3d8ca55b 100644 --- a/src/py/flwr/common/serde_test.py +++ b/src/py/flwr/common/serde_test.py @@ -219,6 +219,7 @@ def metadata(self) -> Metadata: run_id=self.rng.randint(0, 1 << 30), task_id=self.get_str(64), group_id=self.get_str(30), + node_id=self.get_str(30), ttl=self.get_str(10), task_type=self.get_str(10), ) @@ -309,6 +310,7 @@ def test_message_to_and_from_taskins() -> None: run_id=0, task_id="", group_id="", + node_id="", ttl=metadata.ttl, task_type=metadata.task_type, ), @@ -337,6 +339,7 @@ def test_message_to_and_from_taskres() -> None: run_id=0, task_id="", group_id="", + node_id="", ttl=metadata.ttl, task_type=metadata.task_type, ), From 046ac41db095103c986f2eb0fb36c1f3d87fe447 Mon Sep 17 00:00:00 2001 From: jafermarq Date: Wed, 7 Feb 2024 20:16:18 +0000 Subject: [PATCH 10/25] node_id is of type int --- src/py/flwr/client/grpc_client/connection.py | 2 +- src/py/flwr/client/grpc_client/connection_test.py | 4 ++-- .../flwr/client/message_handler/message_handler.py | 6 +++--- .../client/message_handler/message_handler_test.py | 4 ++-- .../secure_aggregation/secaggplus_middleware.py | 2 +- .../secure_aggregation/secaggplus_middleware_test.py | 4 ++-- src/py/flwr/client/middleware/utils_test.py | 2 +- src/py/flwr/common/message.py | 4 ++-- src/py/flwr/common/serde.py | 4 ++-- src/py/flwr/common/serde_test.py | 12 +++++++++--- 10 files changed, 25 insertions(+), 19 deletions(-) diff --git a/src/py/flwr/client/grpc_client/connection.py b/src/py/flwr/client/grpc_client/connection.py index 4578b88a57b2..6e4edf21ec9e 100644 --- a/src/py/flwr/client/grpc_client/connection.py +++ b/src/py/flwr/client/grpc_client/connection.py @@ -172,7 +172,7 @@ def receive() -> Message: run_id=0, task_id=str(uuid.uuid4()), group_id="", - node_id="", + node_id=0, ttl="", task_type=task_type, ), diff --git a/src/py/flwr/client/grpc_client/connection_test.py b/src/py/flwr/client/grpc_client/connection_test.py index 0acb8c1ac86c..9f61a930dc34 100644 --- a/src/py/flwr/client/grpc_client/connection_test.py +++ b/src/py/flwr/client/grpc_client/connection_test.py @@ -48,7 +48,7 @@ run_id=0, task_id="", group_id="", - node_id="", + node_id=0, ttl="", task_type=TASK_TYPE_GET_PROPERTIES, ), @@ -61,7 +61,7 @@ run_id=0, task_id="", group_id="", - node_id="", + node_id=0, ttl="", task_type="reconnect", ), diff --git a/src/py/flwr/client/message_handler/message_handler.py b/src/py/flwr/client/message_handler/message_handler.py index cecaf51b542b..ef5c14f74b5a 100644 --- a/src/py/flwr/client/message_handler/message_handler.py +++ b/src/py/flwr/client/message_handler/message_handler.py @@ -92,7 +92,7 @@ def handle_control_message(message: Message) -> Tuple[Optional[Message], int]: run_id=0, task_id="", group_id="", - node_id="", + node_id=0, ttl="", task_type="reconnect", ), @@ -109,7 +109,7 @@ def handle_legacy_message_from_tasktype( client_fn: ClientFn, message: Message, context: Context ) -> Message: """Handle legacy message in the inner most middleware layer.""" - client = client_fn(message.metadata.node_id) + client = client_fn(str(message.metadata.node_id)) client.set_context(context) @@ -154,7 +154,7 @@ def handle_legacy_message_from_tasktype( run_id=0, # Non-user defined task_id="", # Non-user defined group_id="", # Non-user defined - node_id="", + node_id=0, ttl="", task_type=task_type, ), diff --git a/src/py/flwr/client/message_handler/message_handler_test.py b/src/py/flwr/client/message_handler/message_handler_test.py index d992ae1eba21..5889e4f6fb98 100644 --- a/src/py/flwr/client/message_handler/message_handler_test.py +++ b/src/py/flwr/client/message_handler/message_handler_test.py @@ -124,7 +124,7 @@ def test_client_without_get_properties() -> None: run_id=0, task_id=str(uuid.uuid4()), group_id="", - node_id="", + node_id=0, ttl="", task_type=TASK_TYPE_GET_PROPERTIES, ), @@ -163,7 +163,7 @@ def test_client_with_get_properties() -> None: run_id=0, task_id=str(uuid.uuid4()), group_id="", - node_id="", + node_id=0, ttl="", task_type=TASK_TYPE_GET_PROPERTIES, ), diff --git a/src/py/flwr/client/middleware/secure_aggregation/secaggplus_middleware.py b/src/py/flwr/client/middleware/secure_aggregation/secaggplus_middleware.py index e352e111ddf7..fb9538f0116e 100644 --- a/src/py/flwr/client/middleware/secure_aggregation/secaggplus_middleware.py +++ b/src/py/flwr/client/middleware/secure_aggregation/secaggplus_middleware.py @@ -207,7 +207,7 @@ def secaggplus_middleware( # Return message return Message( - metadata=Metadata(0, "", "", "", "", TASK_TYPE_FIT), + metadata=Metadata(0, "", "", 0, "", TASK_TYPE_FIT), message=RecordSet(configs={RECORD_KEY_CONFIGS: ConfigsRecord(res, False)}), ) diff --git a/src/py/flwr/client/middleware/secure_aggregation/secaggplus_middleware_test.py b/src/py/flwr/client/middleware/secure_aggregation/secaggplus_middleware_test.py index 3a64a0da1325..b5082c3a14e0 100644 --- a/src/py/flwr/client/middleware/secure_aggregation/secaggplus_middleware_test.py +++ b/src/py/flwr/client/middleware/secure_aggregation/secaggplus_middleware_test.py @@ -57,7 +57,7 @@ def get_test_handler( def empty_ffn(_: Message, _2: Context) -> Message: return Message( - metadata=Metadata(0, "", "", "", "", TASK_TYPE_FIT), + metadata=Metadata(0, "", "", 0, "", TASK_TYPE_FIT), message=RecordSet(), ) @@ -65,7 +65,7 @@ def empty_ffn(_: Message, _2: Context) -> Message: def func(configs: Dict[str, ConfigsRecordValues]) -> Dict[str, ConfigsRecordValues]: in_msg = Message( - metadata=Metadata(0, "", "", "", "", TASK_TYPE_FIT), + metadata=Metadata(0, "", "", 0, "", TASK_TYPE_FIT), message=RecordSet(configs={RECORD_KEY_CONFIGS: ConfigsRecord(configs)}), ) out_msg = app(in_msg, ctxt) diff --git a/src/py/flwr/client/middleware/utils_test.py b/src/py/flwr/client/middleware/utils_test.py index c10c3134e0db..4fb243f2b282 100644 --- a/src/py/flwr/client/middleware/utils_test.py +++ b/src/py/flwr/client/middleware/utils_test.py @@ -75,7 +75,7 @@ def _get_dummy_flower_message() -> Message: return Message( message=RecordSet(), metadata=Metadata( - run_id=0, task_id="", group_id="", node_id="", ttl="", task_type="mock" + run_id=0, task_id="", group_id="", node_id=0, ttl="", task_type="mock" ), ) diff --git a/src/py/flwr/common/message.py b/src/py/flwr/common/message.py index d21a3a39232b..ce389d56ffb0 100644 --- a/src/py/flwr/common/message.py +++ b/src/py/flwr/common/message.py @@ -33,7 +33,7 @@ class Metadata: group_id : str An identifier for grouping tasks. In some settings this is used as the FL round. - node_id : str + node_id : int An identifier for the node running a task. ttl : str Time-to-live for this task. @@ -45,7 +45,7 @@ class Metadata: run_id: int task_id: str group_id: str - node_id: str + node_id: int ttl: str task_type: str diff --git a/src/py/flwr/common/serde.py b/src/py/flwr/common/serde.py index c4bbc490ac8f..63f44701852c 100644 --- a/src/py/flwr/common/serde.py +++ b/src/py/flwr/common/serde.py @@ -563,7 +563,7 @@ def message_from_taskins(taskins: TaskIns) -> Message: run_id=taskins.run_id, task_id=taskins.task_id, group_id=taskins.group_id, - node_id="", + node_id=0, ttl=taskins.task.ttl, task_type=taskins.task.task_type, ) @@ -593,7 +593,7 @@ def message_from_taskres(taskres: TaskRes) -> Message: run_id=taskres.run_id, task_id=taskres.task_id, group_id=taskres.group_id, - node_id="", + node_id=0, ttl=taskres.task.ttl, task_type=taskres.task.task_type, ) diff --git a/src/py/flwr/common/serde_test.py b/src/py/flwr/common/serde_test.py index 3c1c3d8ca55b..a99fad25fc17 100644 --- a/src/py/flwr/common/serde_test.py +++ b/src/py/flwr/common/serde_test.py @@ -219,7 +219,7 @@ def metadata(self) -> Metadata: run_id=self.rng.randint(0, 1 << 30), task_id=self.get_str(64), group_id=self.get_str(30), - node_id=self.get_str(30), + node_id=self.rng.randint(0, 1 << 30), ttl=self.get_str(10), task_type=self.get_str(10), ) @@ -310,7 +310,7 @@ def test_message_to_and_from_taskins() -> None: run_id=0, task_id="", group_id="", - node_id="", + node_id=metadata.node_id, ttl=metadata.ttl, task_type=metadata.task_type, ), @@ -324,6 +324,9 @@ def test_message_to_and_from_taskins() -> None: taskins.group_id = metadata.group_id deserialized = message_from_taskins(taskins) + # update node_id + deserialized.metadata.node_id = metadata.node_id + # Assert assert original.message == deserialized.message assert metadata == deserialized.metadata @@ -339,7 +342,7 @@ def test_message_to_and_from_taskres() -> None: run_id=0, task_id="", group_id="", - node_id="", + node_id=metadata.node_id, ttl=metadata.ttl, task_type=metadata.task_type, ), @@ -353,6 +356,9 @@ def test_message_to_and_from_taskres() -> None: taskres.group_id = metadata.group_id deserialized = message_from_taskres(taskres) + # update node_id + deserialized.metadata.node_id = metadata.node_id + # Assert assert original.message == deserialized.message assert metadata == deserialized.metadata From b15ba3591c5b5ac578733454d5dc0dd75c557098 Mon Sep 17 00:00:00 2001 From: jafermarq Date: Wed, 7 Feb 2024 20:19:23 +0000 Subject: [PATCH 11/25] leave `node_id` unused --- src/py/flwr/client/message_handler/message_handler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/py/flwr/client/message_handler/message_handler.py b/src/py/flwr/client/message_handler/message_handler.py index ef5c14f74b5a..e6d127fe573d 100644 --- a/src/py/flwr/client/message_handler/message_handler.py +++ b/src/py/flwr/client/message_handler/message_handler.py @@ -109,7 +109,7 @@ def handle_legacy_message_from_tasktype( client_fn: ClientFn, message: Message, context: Context ) -> Message: """Handle legacy message in the inner most middleware layer.""" - client = client_fn(str(message.metadata.node_id)) + client = client_fn("-1") client.set_context(context) From 48914d846293d25b8985dba7807382eebbb04142 Mon Sep 17 00:00:00 2001 From: jafermarq Date: Wed, 7 Feb 2024 20:38:03 +0000 Subject: [PATCH 12/25] integration --- .../client/message_handler/message_handler.py | 2 +- .../ray_transport/ray_client_proxy.py | 60 +++++++------------ 2 files changed, 23 insertions(+), 39 deletions(-) diff --git a/src/py/flwr/client/message_handler/message_handler.py b/src/py/flwr/client/message_handler/message_handler.py index e6d127fe573d..ef5c14f74b5a 100644 --- a/src/py/flwr/client/message_handler/message_handler.py +++ b/src/py/flwr/client/message_handler/message_handler.py @@ -109,7 +109,7 @@ def handle_legacy_message_from_tasktype( client_fn: ClientFn, message: Message, context: Context ) -> Message: """Handle legacy message in the inner most middleware layer.""" - client = client_fn("-1") + client = client_fn(str(message.metadata.node_id)) client.set_context(context) diff --git a/src/py/flwr/simulation/ray_transport/ray_client_proxy.py b/src/py/flwr/simulation/ray_transport/ray_client_proxy.py index 94399c979a0f..8bd799e01f03 100644 --- a/src/py/flwr/simulation/ray_transport/ray_client_proxy.py +++ b/src/py/flwr/simulation/ray_transport/ray_client_proxy.py @@ -31,6 +31,7 @@ ) from flwr.common.logger import log from flwr.common.message import Message, Metadata +from flwr.common.recordset import RecordSet from flwr.common.recordset_compat import ( evaluateins_to_recordset, fitins_to_recordset, @@ -54,9 +55,7 @@ def __init__( super().__init__(cid) def _load_app() -> Flower: - def wrap(cid: str): - return client_fn(self.cid) - return Flower(client_fn=wrap) + return Flower(client_fn=client_fn) self.app_fn = _load_app self.actor_pool = actor_pool @@ -97,22 +96,30 @@ def _submit_job(self, message: Message, timeout: Optional[float]) -> Message: return out_mssg - def get_properties( - self, ins: common.GetPropertiesIns, timeout: Optional[float] - ) -> common.GetPropertiesRes: - """Return client's properties.""" - recordset = getpropertiesins_to_recordset(ins) - message = Message( + def _wrap_recordset_in_message( + self, recordset: RecordSet, task_type: str + ) -> Message: + return Message( message=recordset, metadata=Metadata( run_id=0, task_id="", group_id="", + node_id=int(self.cid), ttl="", - task_type=TASK_TYPE_GET_PROPERTIES, + task_type=task_type, ), ) + def get_properties( + self, ins: common.GetPropertiesIns, timeout: Optional[float] + ) -> common.GetPropertiesRes: + """Return client's properties.""" + recordset = getpropertiesins_to_recordset(ins) + message = self._wrap_recordset_in_message( + recordset, task_type=TASK_TYPE_GET_PROPERTIES + ) + message_out = self._submit_job(message, timeout) return recordset_to_getpropertiesres(message_out.message) @@ -122,15 +129,8 @@ def get_parameters( ) -> common.GetParametersRes: """Return the current local model parameters.""" recordset = getparametersins_to_recordset(ins) - message = Message( - message=recordset, - metadata=Metadata( - run_id=0, - task_id="", - group_id="", - ttl="", - task_type=TASK_TYPE_GET_PARAMETERS, - ), + message = self._wrap_recordset_in_message( + recordset, task_type=TASK_TYPE_GET_PARAMETERS ) message_out = self._submit_job(message, timeout) @@ -142,16 +142,7 @@ def fit(self, ins: common.FitIns, timeout: Optional[float]) -> common.FitRes: recordset = fitins_to_recordset( ins, keep_input=True ) # This must stay TRUE since ins are in-memory - message = Message( - message=recordset, - metadata=Metadata( - run_id=0, - task_id="", - group_id="", - ttl="", - task_type=TASK_TYPE_FIT, - ), - ) + message = self._wrap_recordset_in_message(recordset, task_type=TASK_TYPE_FIT) message_out = self._submit_job(message, timeout) @@ -164,15 +155,8 @@ def evaluate( recordset = evaluateins_to_recordset( ins, keep_input=True ) # This must stay TRUE since ins are in-memory - message = Message( - message=recordset, - metadata=Metadata( - run_id=0, - task_id="", - group_id="", - ttl="", - task_type=TASK_TYPE_EVALUATE, - ), + message = self._wrap_recordset_in_message( + recordset, task_type=TASK_TYPE_EVALUATE ) message_out = self._submit_job(message, timeout) From 5e8ed70514456b1444493d92f0cfe9c8e2cd181e Mon Sep 17 00:00:00 2001 From: jafermarq Date: Wed, 7 Feb 2024 21:35:05 +0000 Subject: [PATCH 13/25] updated docstrings and tests --- .../simulation/ray_transport/ray_actor.py | 19 +--- .../ray_transport/ray_client_proxy.py | 3 +- .../ray_transport/ray_client_proxy_test.py | 95 ++++++++++++------- 3 files changed, 68 insertions(+), 49 deletions(-) diff --git a/src/py/flwr/simulation/ray_transport/ray_actor.py b/src/py/flwr/simulation/ray_transport/ray_actor.py index b5d2e3bc3660..3dcd54009700 100644 --- a/src/py/flwr/simulation/ray_transport/ray_actor.py +++ b/src/py/flwr/simulation/ray_transport/ray_actor.py @@ -25,20 +25,11 @@ from ray import ObjectRef from ray.util.actor_pool import ActorPool -from flwr import common -from flwr.client import Client from flwr.client.flower import Flower from flwr.common.context import Context from flwr.common.logger import log from flwr.common.message import Message -# All possible returns by a client -ClientRes = Union[ - common.GetPropertiesRes, common.GetParametersRes, common.FitRes, common.EvaluateRes -] -# A function to be executed by a client to obtain some results -JobFn = Callable[[Client], ClientRes] - FlowerFn = Callable[[], Flower] @@ -67,7 +58,7 @@ def run( context: Context, ) -> Tuple[str, Message, Context]: """Run a client run.""" - # Execute tasks and return result + # Pass message through app and return a message # return also cid which is needed to ensure results # from the pool are correctly assigned to each ClientProxy try: @@ -239,10 +230,10 @@ def add_actors_to_pool(self, num_actors: int) -> None: self.num_actors += num_actors def submit(self, fn: Any, value: Tuple[FlowerFn, Message, str, Context]) -> None: - """Take idle actor and assign it a client run. + """Take an idle actor and assign it to run a client app and Message. Submit a job to an actor by first removing it from the list of idle actors, then - check if this actor was flagged to be removed from the pool + check if this actor was flagged to be removed from the pool. """ app_fn, mssg, cid, context = value actor = self._idle_actors.pop() @@ -304,7 +295,7 @@ def _fetch_future_result(self, cid: str) -> Tuple[Message, Context]: """ try: future: ObjectRef[Any] = self._cid_to_future[cid]["future"] # type: ignore - res_cid, res, updated_context = ray.get( + res_cid, out_mssg, updated_context = ray.get( future ) # type: (str, Message, Context) except ray.exceptions.RayActorError as ex: @@ -323,7 +314,7 @@ def _fetch_future_result(self, cid: str) -> Tuple[Message, Context]: # Reset mapping self._reset_cid_to_future_dict(cid) - return res, updated_context + return out_mssg, updated_context def _flag_actor_for_removal(self, actor_id_hex: str) -> None: """Flag actor that should be removed from pool.""" diff --git a/src/py/flwr/simulation/ray_transport/ray_client_proxy.py b/src/py/flwr/simulation/ray_transport/ray_client_proxy.py index 8bd799e01f03..33b67aea6d5d 100644 --- a/src/py/flwr/simulation/ray_transport/ray_client_proxy.py +++ b/src/py/flwr/simulation/ray_transport/ray_client_proxy.py @@ -62,7 +62,7 @@ def _load_app() -> Flower: self.proxy_state = NodeState() def _submit_job(self, message: Message, timeout: Optional[float]) -> Message: - # The VCE is not exposed to TaskIns, it won't handle multilple runs + """Sumbit a message to the AcotrPool.""" # For the time being, fixing run_id is a small compromise # This will be one of the first points to address integrating VCE + DriverAPI run_id = message.metadata.run_id @@ -99,6 +99,7 @@ def _submit_job(self, message: Message, timeout: Optional[float]) -> Message: def _wrap_recordset_in_message( self, recordset: RecordSet, task_type: str ) -> Message: + """Wrap a RecordSet inside a Message.""" return Message( message=recordset, metadata=Metadata( diff --git a/src/py/flwr/simulation/ray_transport/ray_client_proxy_test.py b/src/py/flwr/simulation/ray_transport/ray_client_proxy_test.py index 55abad6ef38a..3a687f800c70 100644 --- a/src/py/flwr/simulation/ray_transport/ray_client_proxy_test.py +++ b/src/py/flwr/simulation/ray_transport/ray_client_proxy_test.py @@ -17,19 +17,25 @@ from math import pi from random import shuffle -from typing import List, Tuple, Type, cast +from typing import Dict, List, Tuple, Type import ray from flwr.client import Client, NumPyClient -from flwr.common import Code, GetPropertiesRes, Status +from flwr.client.flower import Flower +from flwr.common import Config, Scalar from flwr.common.configsrecord import ConfigsRecord +from flwr.common.constant import TASK_TYPE_GET_PROPERTIES from flwr.common.context import Context +from flwr.common.message import Message, Metadata from flwr.common.recordset import RecordSet +from flwr.common.recordset_compat import ( + getpropertiesins_to_recordset, + recordset_to_getpropertiesres, +) +from flwr.common.recordset_compat_test import _get_valid_getpropertiesins from flwr.simulation.ray_transport.ray_actor import ( - ClientRes, DefaultActor, - JobFn, VirtualClientEngineActor, VirtualClientEngineActorPool, ) @@ -42,30 +48,20 @@ class DummyClient(NumPyClient): def __init__(self, cid: str) -> None: self.cid = int(cid) - -def get_dummy_client(cid: str) -> Client: - """Return a DummyClient converted to Client type.""" - return DummyClient(cid).to_client() - - -# A dummy run -def job_fn(cid: str) -> JobFn: # pragma: no cover - """Construct a simple job with cid dependency.""" - - def cid_times_pi(client: Client) -> ClientRes: # pylint: disable=unused-argument - result = int(cid) * pi + def get_properties(self, config: Config) -> Dict[str, Scalar]: + """Return properties by doing a simple calculation.""" + result = int(self.cid) * pi # store something in context - client.numpy_client.context.state.set_configs( # type: ignore + self.context.state.set_configs( "result", record=ConfigsRecord({"result": str(result)}) ) + return {"result": result} - # now let's convert it to a GetPropertiesRes response - return GetPropertiesRes( - status=Status(Code(0), message="test"), properties={"result": result} - ) - return cid_times_pi +def get_dummy_client(cid: str) -> Client: + """Return a DummyClient converted to Client type.""" + return DummyClient(cid).to_client() def prep( @@ -104,13 +100,21 @@ def test_cid_consistency_one_at_a_time() -> None: Submit one job and waits for completion. Then submits the next and so on """ proxies, _ = prep() + + getproperties_ins = _get_valid_getpropertiesins() + recordset = getpropertiesins_to_recordset(getproperties_ins) + # submit jobs one at a time for prox in proxies: - res = prox._submit_job( # pylint: disable=protected-access - message=job_fn(prox.cid), timeout=None + message = prox._wrap_recordset_in_message( # pylint: disable=protected-access + recordset, TASK_TYPE_GET_PROPERTIES ) + message_out = prox._submit_job( # pylint: disable=protected-access + message=message, timeout=None + ) + + res = recordset_to_getpropertiesres(message_out.message) - res = cast(GetPropertiesRes, res) assert int(prox.cid) * pi == res.properties["result"] ray.shutdown() @@ -125,6 +129,9 @@ def test_cid_consistency_all_submit_first_run_consistency() -> None: proxies, _ = prep() run_id = 0 + getproperties_ins = _get_valid_getpropertiesins() + recordset = getpropertiesins_to_recordset(getproperties_ins) + # submit all jobs (collect later) shuffle(proxies) for prox in proxies: @@ -133,18 +140,22 @@ def test_cid_consistency_all_submit_first_run_consistency() -> None: # Retrieve state state = prox.proxy_state.retrieve_context(run_id=run_id) - job = job_fn(prox.cid) + message = prox._wrap_recordset_in_message( # pylint: disable=protected-access + recordset, TASK_TYPE_GET_PROPERTIES + ) prox.actor_pool.submit_client_job( - lambda a, c_fn, j_fn, cid, state: a.run.remote(c_fn, j_fn, cid, state), - (prox.client_fn, job, prox.cid, state), + lambda a, a_fn, mssg, cid, state: a.run.remote(a_fn, mssg, cid, state), + (prox.app_fn, message, prox.cid, state), ) # fetch results one at a time shuffle(proxies) for prox in proxies: - res, updated_context = prox.actor_pool.get_client_result(prox.cid, timeout=None) + message_out, updated_context = prox.actor_pool.get_client_result( + prox.cid, timeout=None + ) prox.proxy_state.update_context(run_id, context=updated_context) - res = cast(GetPropertiesRes, res) + res = recordset_to_getpropertiesres(message_out.message) assert int(prox.cid) * pi == res.properties["result"] assert ( @@ -163,20 +174,36 @@ def test_cid_consistency_without_proxies() -> None: num_clients = len(proxies) cids = [str(cid) for cid in range(num_clients)] + getproperties_ins = _get_valid_getpropertiesins() + recordset = getpropertiesins_to_recordset(getproperties_ins) + + def _load_app() -> Flower: + return Flower(client_fn=get_dummy_client) + # submit all jobs (collect later) shuffle(cids) for cid in cids: - job = job_fn(cid) + message = Message( + message=recordset, + metadata=Metadata( + run_id=0, + task_id="", + group_id="", + node_id=int(cid), + ttl="", + task_type=TASK_TYPE_GET_PROPERTIES, + ), + ) pool.submit_client_job( lambda a, c_fn, j_fn, cid_, state: a.run.remote(c_fn, j_fn, cid_, state), - (get_dummy_client, job, cid, Context(state=RecordSet())), + (_load_app, message, cid, Context(state=RecordSet())), ) # fetch results one at a time shuffle(cids) for cid in cids: - res, _ = pool.get_client_result(cid, timeout=None) - res = cast(GetPropertiesRes, res) + message_out, _ = pool.get_client_result(cid, timeout=None) + res = recordset_to_getpropertiesres(message_out.message) assert int(cid) * pi == res.properties["result"] ray.shutdown() From 3eb115fb2c62ea56d214a70f543b59349e7f638d Mon Sep 17 00:00:00 2001 From: jafermarq Date: Thu, 8 Feb 2024 15:26:36 +0000 Subject: [PATCH 14/25] wip --- src/py/flwr/common/serde.py | 4 ++-- src/py/flwr/common/serde_test.py | 6 ++++-- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/src/py/flwr/common/serde.py b/src/py/flwr/common/serde.py index 63f44701852c..1348346beb32 100644 --- a/src/py/flwr/common/serde.py +++ b/src/py/flwr/common/serde.py @@ -563,7 +563,7 @@ def message_from_taskins(taskins: TaskIns) -> Message: run_id=taskins.run_id, task_id=taskins.task_id, group_id=taskins.group_id, - node_id=0, + node_id=taskins.task.consumer.node_id, ttl=taskins.task.ttl, task_type=taskins.task.task_type, ) @@ -593,7 +593,7 @@ def message_from_taskres(taskres: TaskRes) -> Message: run_id=taskres.run_id, task_id=taskres.task_id, group_id=taskres.group_id, - node_id=0, + node_id=taskres.task.consumer.node_id, ttl=taskres.task.ttl, task_type=taskres.task.task_type, ) diff --git a/src/py/flwr/common/serde_test.py b/src/py/flwr/common/serde_test.py index a99fad25fc17..926d5d6286ef 100644 --- a/src/py/flwr/common/serde_test.py +++ b/src/py/flwr/common/serde_test.py @@ -322,10 +322,11 @@ def test_message_to_and_from_taskins() -> None: taskins.run_id = metadata.run_id taskins.task_id = metadata.task_id taskins.group_id = metadata.group_id + taskins.task.consumer = metadata.node_id deserialized = message_from_taskins(taskins) # update node_id - deserialized.metadata.node_id = metadata.node_id + # deserialized.metadata.node_id = metadata.node_id # Assert assert original.message == deserialized.message @@ -354,10 +355,11 @@ def test_message_to_and_from_taskres() -> None: taskres.run_id = metadata.run_id taskres.task_id = metadata.task_id taskres.group_id = metadata.group_id + # taskres.task.consumer = metadata.node_id # <------------ can't be done (it's read only) deserialized = message_from_taskres(taskres) # update node_id - deserialized.metadata.node_id = metadata.node_id + # deserialized.metadata.node_id = metadata.node_id # Assert assert original.message == deserialized.message From 87fb2ba8764b061130c2c51cb3e180621b1513d5 Mon Sep 17 00:00:00 2001 From: jafermarq Date: Thu, 8 Feb 2024 15:42:23 +0000 Subject: [PATCH 15/25] fix --- src/py/flwr/common/serde_test.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/src/py/flwr/common/serde_test.py b/src/py/flwr/common/serde_test.py index 926d5d6286ef..4fbf9151a78a 100644 --- a/src/py/flwr/common/serde_test.py +++ b/src/py/flwr/common/serde_test.py @@ -322,12 +322,9 @@ def test_message_to_and_from_taskins() -> None: taskins.run_id = metadata.run_id taskins.task_id = metadata.task_id taskins.group_id = metadata.group_id - taskins.task.consumer = metadata.node_id + taskins.task.consumer.node_id = metadata.node_id deserialized = message_from_taskins(taskins) - # update node_id - # deserialized.metadata.node_id = metadata.node_id - # Assert assert original.message == deserialized.message assert metadata == deserialized.metadata @@ -355,12 +352,9 @@ def test_message_to_and_from_taskres() -> None: taskres.run_id = metadata.run_id taskres.task_id = metadata.task_id taskres.group_id = metadata.group_id - # taskres.task.consumer = metadata.node_id # <------------ can't be done (it's read only) + taskres.task.consumer.node_id = metadata.node_id deserialized = message_from_taskres(taskres) - # update node_id - # deserialized.metadata.node_id = metadata.node_id - # Assert assert original.message == deserialized.message assert metadata == deserialized.metadata From 89f232c0d55cb484b4b195998d180fd040ad1766 Mon Sep 17 00:00:00 2001 From: Heng Pan Date: Mon, 12 Feb 2024 19:06:03 +0000 Subject: [PATCH 16/25] add default values for Metadata and Message --- src/py/flwr/client/grpc_client/connection.py | 3 -- .../client/grpc_client/connection_test.py | 10 ------ .../client/message_handler/message_handler.py | 10 ------ .../message_handler/message_handler_test.py | 8 ----- .../mod/secure_aggregation/secaggplus_mod.py | 2 +- .../secure_aggregation/secaggplus_mod_test.py | 4 +-- src/py/flwr/client/mod/utils_test.py | 4 +-- src/py/flwr/common/message.py | 31 ++++++++++--------- src/py/flwr/common/serde_test.py | 12 ++----- 9 files changed, 23 insertions(+), 61 deletions(-) diff --git a/src/py/flwr/client/grpc_client/connection.py b/src/py/flwr/client/grpc_client/connection.py index 6e4edf21ec9e..3956968f4668 100644 --- a/src/py/flwr/client/grpc_client/connection.py +++ b/src/py/flwr/client/grpc_client/connection.py @@ -169,11 +169,8 @@ def receive() -> Message: # Construct Message return Message( metadata=Metadata( - run_id=0, task_id=str(uuid.uuid4()), - group_id="", node_id=0, - ttl="", task_type=task_type, ), message=recordset, diff --git a/src/py/flwr/client/grpc_client/connection_test.py b/src/py/flwr/client/grpc_client/connection_test.py index c5ef4f83ab8d..d086bc329c24 100644 --- a/src/py/flwr/client/grpc_client/connection_test.py +++ b/src/py/flwr/client/grpc_client/connection_test.py @@ -45,11 +45,6 @@ MESSAGE_GET_PROPERTIES = Message( metadata=Metadata( - run_id=0, - task_id="", - group_id="", - node_id=0, - ttl="", task_type=TASK_TYPE_GET_PROPERTIES, ), message=compat.getpropertiesres_to_recordset( @@ -58,11 +53,6 @@ ) MESSAGE_DISCONNECT = Message( metadata=Metadata( - run_id=0, - task_id="", - group_id="", - node_id=0, - ttl="", task_type="reconnect", ), message=RecordSet(configs={"config": ConfigsRecord({"reason": 0})}), diff --git a/src/py/flwr/client/message_handler/message_handler.py b/src/py/flwr/client/message_handler/message_handler.py index 4b01bb74593e..cd12b795a2ba 100644 --- a/src/py/flwr/client/message_handler/message_handler.py +++ b/src/py/flwr/client/message_handler/message_handler.py @@ -89,11 +89,6 @@ def handle_control_message(message: Message) -> Tuple[Optional[Message], int]: recordset.set_configs("config", ConfigsRecord({"reason": reason})) out_message = Message( metadata=Metadata( - run_id=0, - task_id="", - group_id="", - node_id=0, - ttl="", task_type="reconnect", ), message=recordset, @@ -151,11 +146,6 @@ def handle_legacy_message_from_tasktype( # Return Message out_message = Message( metadata=Metadata( - run_id=0, # Non-user defined - task_id="", # Non-user defined - group_id="", # Non-user defined - node_id=0, - ttl="", task_type=task_type, ), message=out_recordset, diff --git a/src/py/flwr/client/message_handler/message_handler_test.py b/src/py/flwr/client/message_handler/message_handler_test.py index 5889e4f6fb98..e6980eb17f6a 100644 --- a/src/py/flwr/client/message_handler/message_handler_test.py +++ b/src/py/flwr/client/message_handler/message_handler_test.py @@ -121,11 +121,7 @@ def test_client_without_get_properties() -> None: recordset = compat.getpropertiesins_to_recordset(GetPropertiesIns({})) message = Message( metadata=Metadata( - run_id=0, task_id=str(uuid.uuid4()), - group_id="", - node_id=0, - ttl="", task_type=TASK_TYPE_GET_PROPERTIES, ), message=recordset, @@ -160,11 +156,7 @@ def test_client_with_get_properties() -> None: recordset = compat.getpropertiesins_to_recordset(GetPropertiesIns({})) message = Message( metadata=Metadata( - run_id=0, task_id=str(uuid.uuid4()), - group_id="", - node_id=0, - ttl="", task_type=TASK_TYPE_GET_PROPERTIES, ), message=recordset, diff --git a/src/py/flwr/client/mod/secure_aggregation/secaggplus_mod.py b/src/py/flwr/client/mod/secure_aggregation/secaggplus_mod.py index 0cb56bbbaa04..f00efec8e32e 100644 --- a/src/py/flwr/client/mod/secure_aggregation/secaggplus_mod.py +++ b/src/py/flwr/client/mod/secure_aggregation/secaggplus_mod.py @@ -207,7 +207,7 @@ def secaggplus_mod( # Return message return Message( - metadata=Metadata(0, "", "", 0, "", TASK_TYPE_FIT), + metadata=Metadata(task_type=TASK_TYPE_FIT), message=RecordSet(configs={RECORD_KEY_CONFIGS: ConfigsRecord(res, False)}), ) diff --git a/src/py/flwr/client/mod/secure_aggregation/secaggplus_mod_test.py b/src/py/flwr/client/mod/secure_aggregation/secaggplus_mod_test.py index 197c24573c6d..e2a60e18f57f 100644 --- a/src/py/flwr/client/mod/secure_aggregation/secaggplus_mod_test.py +++ b/src/py/flwr/client/mod/secure_aggregation/secaggplus_mod_test.py @@ -57,7 +57,7 @@ def get_test_handler( def empty_ffn(_: Message, _2: Context) -> Message: return Message( - metadata=Metadata(0, "", "", 0, "", TASK_TYPE_FIT), + metadata=Metadata(task_type=TASK_TYPE_FIT), message=RecordSet(), ) @@ -65,7 +65,7 @@ def empty_ffn(_: Message, _2: Context) -> Message: def func(configs: Dict[str, ConfigsRecordValues]) -> Dict[str, ConfigsRecordValues]: in_msg = Message( - metadata=Metadata(0, "", "", 0, "", TASK_TYPE_FIT), + metadata=Metadata(task_type=TASK_TYPE_FIT), message=RecordSet(configs={RECORD_KEY_CONFIGS: ConfigsRecord(configs)}), ) out_msg = app(in_msg, ctxt) diff --git a/src/py/flwr/client/mod/utils_test.py b/src/py/flwr/client/mod/utils_test.py index 5d457b418de6..49ea7b4e4cfd 100644 --- a/src/py/flwr/client/mod/utils_test.py +++ b/src/py/flwr/client/mod/utils_test.py @@ -74,9 +74,7 @@ def app(message: Message, context: Context) -> Message: def _get_dummy_flower_message() -> Message: return Message( message=RecordSet(), - metadata=Metadata( - run_id=0, task_id="", group_id="", node_id=0, ttl="", task_type="mock" - ), + metadata=Metadata(task_type="mock"), ) diff --git a/src/py/flwr/common/message.py b/src/py/flwr/common/message.py index ce389d56ffb0..a6727e1951b6 100644 --- a/src/py/flwr/common/message.py +++ b/src/py/flwr/common/message.py @@ -15,7 +15,8 @@ """Message.""" -from dataclasses import dataclass +from dataclasses import dataclass, field +from typing import Optional from .recordset import RecordSet @@ -26,28 +27,28 @@ class Metadata: Parameters ---------- - run_id : int + run_id : int (default: 0) An identifier for the current run. - task_id : str + task_id : str (default: "") An identifier for the current task. - group_id : str + group_id : str (default: "") An identifier for grouping tasks. In some settings this is used as the FL round. - node_id : int + node_id : Optional[int] (default: None) An identifier for the node running a task. - ttl : str + ttl : str (default: "") Time-to-live for this task. - task_type : str + task_type : str (default: "") A string that encodes the action to be executed on the receiving end. """ - run_id: int - task_id: str - group_id: str - node_id: int - ttl: str - task_type: str + run_id: int = 0 + task_id: str = "" + group_id: str = "" + node_id: Optional[int] = None + ttl: str = "" + task_type: str = "" @dataclass @@ -63,5 +64,5 @@ class Message: logic to a client, or vice-versa) or that will be sent to it. """ - metadata: Metadata - message: RecordSet + metadata: Metadata = field(default_factory=Metadata) + message: RecordSet = field(default_factory=RecordSet) diff --git a/src/py/flwr/common/serde_test.py b/src/py/flwr/common/serde_test.py index 4fbf9151a78a..82c2c5b2158b 100644 --- a/src/py/flwr/common/serde_test.py +++ b/src/py/flwr/common/serde_test.py @@ -219,7 +219,7 @@ def metadata(self) -> Metadata: run_id=self.rng.randint(0, 1 << 30), task_id=self.get_str(64), group_id=self.get_str(30), - node_id=self.rng.randint(0, 1 << 30), + node_id=self.rng.randint(0, 1 << 63), ttl=self.get_str(10), task_type=self.get_str(10), ) @@ -307,9 +307,6 @@ def test_message_to_and_from_taskins() -> None: metadata = maker.metadata() original = Message( metadata=Metadata( - run_id=0, - task_id="", - group_id="", node_id=metadata.node_id, ttl=metadata.ttl, task_type=metadata.task_type, @@ -322,7 +319,7 @@ def test_message_to_and_from_taskins() -> None: taskins.run_id = metadata.run_id taskins.task_id = metadata.task_id taskins.group_id = metadata.group_id - taskins.task.consumer.node_id = metadata.node_id + taskins.task.consumer.node_id = cast(int, metadata.node_id) deserialized = message_from_taskins(taskins) # Assert @@ -337,9 +334,6 @@ def test_message_to_and_from_taskres() -> None: metadata = maker.metadata() original = Message( metadata=Metadata( - run_id=0, - task_id="", - group_id="", node_id=metadata.node_id, ttl=metadata.ttl, task_type=metadata.task_type, @@ -352,7 +346,7 @@ def test_message_to_and_from_taskres() -> None: taskres.run_id = metadata.run_id taskres.task_id = metadata.task_id taskres.group_id = metadata.group_id - taskres.task.consumer.node_id = metadata.node_id + taskres.task.consumer.node_id = cast(int, metadata.node_id) deserialized = message_from_taskres(taskres) # Assert From 38f116517908067a7c4d67e9d0cc1f42384e3be5 Mon Sep 17 00:00:00 2001 From: jafermarq Date: Mon, 12 Feb 2024 21:45:53 +0100 Subject: [PATCH 17/25] fix --- src/py/flwr/common/message.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/py/flwr/common/message.py b/src/py/flwr/common/message.py index d66fb48c0dd0..3f068843b3b3 100644 --- a/src/py/flwr/common/message.py +++ b/src/py/flwr/common/message.py @@ -15,7 +15,7 @@ """Message.""" -from dataclasses import dataclass, field +from dataclasses import dataclass from typing import Optional from .recordset import RecordSet From 30f745350c59d57453c29cf4176295fd1545d58a Mon Sep 17 00:00:00 2001 From: jafermarq Date: Mon, 12 Feb 2024 21:57:20 +0100 Subject: [PATCH 18/25] update w/ `content` and `flowerapp` --- .../simulation/ray_transport/ray_actor.py | 6 +++--- .../ray_transport/ray_client_proxy.py | 20 ++++++++----------- .../ray_transport/ray_client_proxy_test.py | 18 +++++++---------- 3 files changed, 18 insertions(+), 26 deletions(-) diff --git a/src/py/flwr/simulation/ray_transport/ray_actor.py b/src/py/flwr/simulation/ray_transport/ray_actor.py index 3dcd54009700..631b9260874d 100644 --- a/src/py/flwr/simulation/ray_transport/ray_actor.py +++ b/src/py/flwr/simulation/ray_transport/ray_actor.py @@ -25,12 +25,12 @@ from ray import ObjectRef from ray.util.actor_pool import ActorPool -from flwr.client.flower import Flower +from flwr.client.clientapp import ClientApp from flwr.common.context import Context from flwr.common.logger import log from flwr.common.message import Message -FlowerFn = Callable[[], Flower] +FlowerFn = Callable[[], ClientApp] class ClientException(Exception): @@ -63,7 +63,7 @@ def run( # from the pool are correctly assigned to each ClientProxy try: # Load app - app: Flower = app_fn() + app: ClientApp = app_fn() # Handle task message out_message = app(message=message, context=context) diff --git a/src/py/flwr/simulation/ray_transport/ray_client_proxy.py b/src/py/flwr/simulation/ray_transport/ray_client_proxy.py index 33b67aea6d5d..b949372a32b1 100644 --- a/src/py/flwr/simulation/ray_transport/ray_client_proxy.py +++ b/src/py/flwr/simulation/ray_transport/ray_client_proxy.py @@ -21,7 +21,7 @@ from flwr import common from flwr.client import ClientFn -from flwr.client.flower import Flower +from flwr.client.clientapp import ClientApp from flwr.client.node_state import NodeState from flwr.common.constant import ( TASK_TYPE_EVALUATE, @@ -54,8 +54,8 @@ def __init__( ): super().__init__(cid) - def _load_app() -> Flower: - return Flower(client_fn=client_fn) + def _load_app() -> ClientApp: + return ClientApp(client_fn=client_fn) self.app_fn = _load_app self.actor_pool = actor_pool @@ -101,13 +101,9 @@ def _wrap_recordset_in_message( ) -> Message: """Wrap a RecordSet inside a Message.""" return Message( - message=recordset, + content=recordset, metadata=Metadata( - run_id=0, - task_id="", - group_id="", node_id=int(self.cid), - ttl="", task_type=task_type, ), ) @@ -123,7 +119,7 @@ def get_properties( message_out = self._submit_job(message, timeout) - return recordset_to_getpropertiesres(message_out.message) + return recordset_to_getpropertiesres(message_out.content) def get_parameters( self, ins: common.GetParametersIns, timeout: Optional[float] @@ -136,7 +132,7 @@ def get_parameters( message_out = self._submit_job(message, timeout) - return recordset_to_getparametersres(message_out.message, keep_input=False) + return recordset_to_getparametersres(message_out.content, keep_input=False) def fit(self, ins: common.FitIns, timeout: Optional[float]) -> common.FitRes: """Train model parameters on the locally held dataset.""" @@ -147,7 +143,7 @@ def fit(self, ins: common.FitIns, timeout: Optional[float]) -> common.FitRes: message_out = self._submit_job(message, timeout) - return recordset_to_fitres(message_out.message, keep_input=False) + return recordset_to_fitres(message_out.content, keep_input=False) def evaluate( self, ins: common.EvaluateIns, timeout: Optional[float] @@ -162,7 +158,7 @@ def evaluate( message_out = self._submit_job(message, timeout) - return recordset_to_evaluateres(message_out.message) + return recordset_to_evaluateres(message_out.content) def reconnect( self, ins: common.ReconnectIns, timeout: Optional[float] diff --git a/src/py/flwr/simulation/ray_transport/ray_client_proxy_test.py b/src/py/flwr/simulation/ray_transport/ray_client_proxy_test.py index 3a687f800c70..fb8a3df96be9 100644 --- a/src/py/flwr/simulation/ray_transport/ray_client_proxy_test.py +++ b/src/py/flwr/simulation/ray_transport/ray_client_proxy_test.py @@ -22,7 +22,7 @@ import ray from flwr.client import Client, NumPyClient -from flwr.client.flower import Flower +from flwr.client.clientapp import ClientApp from flwr.common import Config, Scalar from flwr.common.configsrecord import ConfigsRecord from flwr.common.constant import TASK_TYPE_GET_PROPERTIES @@ -113,7 +113,7 @@ def test_cid_consistency_one_at_a_time() -> None: message=message, timeout=None ) - res = recordset_to_getpropertiesres(message_out.message) + res = recordset_to_getpropertiesres(message_out.content) assert int(prox.cid) * pi == res.properties["result"] @@ -155,7 +155,7 @@ def test_cid_consistency_all_submit_first_run_consistency() -> None: prox.cid, timeout=None ) prox.proxy_state.update_context(run_id, context=updated_context) - res = recordset_to_getpropertiesres(message_out.message) + res = recordset_to_getpropertiesres(message_out.content) assert int(prox.cid) * pi == res.properties["result"] assert ( @@ -177,20 +177,16 @@ def test_cid_consistency_without_proxies() -> None: getproperties_ins = _get_valid_getpropertiesins() recordset = getpropertiesins_to_recordset(getproperties_ins) - def _load_app() -> Flower: - return Flower(client_fn=get_dummy_client) + def _load_app() -> ClientApp: + return ClientApp(client_fn=get_dummy_client) # submit all jobs (collect later) shuffle(cids) for cid in cids: message = Message( - message=recordset, + content=recordset, metadata=Metadata( - run_id=0, - task_id="", - group_id="", node_id=int(cid), - ttl="", task_type=TASK_TYPE_GET_PROPERTIES, ), ) @@ -203,7 +199,7 @@ def _load_app() -> Flower: shuffle(cids) for cid in cids: message_out, _ = pool.get_client_result(cid, timeout=None) - res = recordset_to_getpropertiesres(message_out.message) + res = recordset_to_getpropertiesres(message_out.content) assert int(cid) * pi == res.properties["result"] ray.shutdown() From 5122793d6df8fade75f0b66a407d74e5081c67cc Mon Sep 17 00:00:00 2001 From: jafermarq Date: Wed, 14 Feb 2024 15:24:04 +0100 Subject: [PATCH 19/25] fix pandas e2e --- e2e/pandas/simulation.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/e2e/pandas/simulation.py b/e2e/pandas/simulation.py index 91af84062712..b548b5ebb760 100644 --- a/e2e/pandas/simulation.py +++ b/e2e/pandas/simulation.py @@ -1,12 +1,8 @@ import flwr as fl -from client import FlowerClient +from client import client_fn from strategy import FedAnalytics -def client_fn(cid): - _ = cid - return FlowerClient() - hist = fl.simulation.start_simulation( client_fn=client_fn, num_clients=2, From d1e71b1d1f12adefbaf7148c6ca32ce6704852e0 Mon Sep 17 00:00:00 2001 From: jafermarq Date: Wed, 14 Feb 2024 15:45:22 +0100 Subject: [PATCH 20/25] . --- src/py/flwr/common/message.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/py/flwr/common/message.py b/src/py/flwr/common/message.py index db2ace3bbdd7..9258edccbcd5 100644 --- a/src/py/flwr/common/message.py +++ b/src/py/flwr/common/message.py @@ -26,7 +26,7 @@ class Metadata: Parameters ---------- - run_id : int (default: 0) + run_id : int An identifier for the current run. message_id : str An identifier for the current message. From 31f4ded8a1cb552092a9ca1c293da24555333a0d Mon Sep 17 00:00:00 2001 From: jafermarq Date: Wed, 14 Feb 2024 16:00:22 +0100 Subject: [PATCH 21/25] updated docs; removed unused utility function --- doc/source/how-to-run-simulations.rst | 2 +- src/py/flwr/simulation/ray_transport/utils.py | 22 ------------------- 2 files changed, 1 insertion(+), 23 deletions(-) diff --git a/doc/source/how-to-run-simulations.rst b/doc/source/how-to-run-simulations.rst index 6e0520a79bf5..d1dcb511ed51 100644 --- a/doc/source/how-to-run-simulations.rst +++ b/doc/source/how-to-run-simulations.rst @@ -29,7 +29,7 @@ Running Flower simulations still require you to define your client class, a stra def client_fn(cid: str): # Return a standard Flower client - return MyFlowerClient() + return MyFlowerClient().to_client() # Launch the simulation hist = fl.simulation.start_simulation( diff --git a/src/py/flwr/simulation/ray_transport/utils.py b/src/py/flwr/simulation/ray_transport/utils.py index dd9fb6b2aa85..be6a609f0916 100644 --- a/src/py/flwr/simulation/ray_transport/utils.py +++ b/src/py/flwr/simulation/ray_transport/utils.py @@ -60,25 +60,3 @@ def enable_tf_gpu_growth() -> None: log(ERROR, traceback.format_exc()) log(ERROR, ex) raise ex - - -def check_clientfn_returns_client(client: Client) -> Client: - """Warn once that clients returned in `clinet_fn` should be of type Client. - - This is here for backwards compatibility. If a ClientFn is provided returning - a different type of client (e.g. NumPyClient) we'll warn the user but convert - the client internally to `Client` by calling `.to_client()`. - """ - if not isinstance(client, Client): - mssg = ( - " Ensure your client is of type `flwr.client.Client`. Please convert it" - " using the `.to_client()` method before returning it" - " in the `client_fn` you pass to `start_simulation`." - " We have applied this conversion on your behalf." - " Not returning a `Client` might trigger an error in future" - " versions of Flower." - ) - - warnings.warn(mssg, DeprecationWarning, stacklevel=2) - client = client.to_client() - return client From 87f0041abb383c1b15a34baf821c9f42c895f2d5 Mon Sep 17 00:00:00 2001 From: jafermarq Date: Wed, 14 Feb 2024 16:02:37 +0100 Subject: [PATCH 22/25] format --- src/py/flwr/simulation/ray_transport/utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/py/flwr/simulation/ray_transport/utils.py b/src/py/flwr/simulation/ray_transport/utils.py index be6a609f0916..3861164998a4 100644 --- a/src/py/flwr/simulation/ray_transport/utils.py +++ b/src/py/flwr/simulation/ray_transport/utils.py @@ -18,7 +18,6 @@ import warnings from logging import ERROR -from flwr.client import Client from flwr.common.logger import log try: From 794cd249db632e9c498de5762234f7d5262c078f Mon Sep 17 00:00:00 2001 From: Javier Date: Wed, 14 Feb 2024 19:00:24 +0000 Subject: [PATCH 23/25] Apply suggestions from code review Co-authored-by: Daniel J. Beutel --- src/py/flwr/simulation/ray_transport/ray_actor.py | 4 ++-- src/py/flwr/simulation/ray_transport/ray_client_proxy.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/py/flwr/simulation/ray_transport/ray_actor.py b/src/py/flwr/simulation/ray_transport/ray_actor.py index 631b9260874d..fd1fc8a34d3d 100644 --- a/src/py/flwr/simulation/ray_transport/ray_actor.py +++ b/src/py/flwr/simulation/ray_transport/ray_actor.py @@ -52,13 +52,13 @@ def terminate(self) -> None: def run( self, - app_fn: FlowerFn, + client_app_fn: FlowerFn, message: Message, cid: str, context: Context, ) -> Tuple[str, Message, Context]: """Run a client run.""" - # Pass message through app and return a message + # Pass message through ClientApp and return a message # return also cid which is needed to ensure results # from the pool are correctly assigned to each ClientProxy try: diff --git a/src/py/flwr/simulation/ray_transport/ray_client_proxy.py b/src/py/flwr/simulation/ray_transport/ray_client_proxy.py index b5ebb3eaf306..99355606f335 100644 --- a/src/py/flwr/simulation/ray_transport/ray_client_proxy.py +++ b/src/py/flwr/simulation/ray_transport/ray_client_proxy.py @@ -62,7 +62,7 @@ def _load_app() -> ClientApp: self.proxy_state = NodeState() def _submit_job(self, message: Message, timeout: Optional[float]) -> Message: - """Sumbit a message to the AcotrPool.""" + """Sumbit a message to the ActorPool.""" # For the time being, fixing run_id is a small compromise # This will be one of the first points to address integrating VCE + DriverAPI run_id = message.metadata.run_id From 812fe5456e8882110d3d2dbf9c3320eea2e9c821 Mon Sep 17 00:00:00 2001 From: jafermarq Date: Wed, 14 Feb 2024 20:18:25 +0100 Subject: [PATCH 24/25] timout->ttl; renaming; other minor fixes --- .../simulation/ray_transport/ray_actor.py | 10 ++++---- .../ray_transport/ray_client_proxy.py | 25 ++++++++++++------- .../ray_transport/ray_client_proxy_test.py | 8 ++++-- 3 files changed, 27 insertions(+), 16 deletions(-) diff --git a/src/py/flwr/simulation/ray_transport/ray_actor.py b/src/py/flwr/simulation/ray_transport/ray_actor.py index fd1fc8a34d3d..c9fa43bcb3ca 100644 --- a/src/py/flwr/simulation/ray_transport/ray_actor.py +++ b/src/py/flwr/simulation/ray_transport/ray_actor.py @@ -30,7 +30,7 @@ from flwr.common.logger import log from flwr.common.message import Message -FlowerFn = Callable[[], ClientApp] +ClientAppFn = Callable[[], ClientApp] class ClientException(Exception): @@ -52,7 +52,7 @@ def terminate(self) -> None: def run( self, - client_app_fn: FlowerFn, + client_app_fn: ClientAppFn, message: Message, cid: str, context: Context, @@ -63,7 +63,7 @@ def run( # from the pool are correctly assigned to each ClientProxy try: # Load app - app: ClientApp = app_fn() + app: ClientApp = client_app_fn() # Handle task message out_message = app(message=message, context=context) @@ -229,7 +229,7 @@ def add_actors_to_pool(self, num_actors: int) -> None: self._idle_actors.extend(new_actors) self.num_actors += num_actors - def submit(self, fn: Any, value: Tuple[FlowerFn, Message, str, Context]) -> None: + def submit(self, fn: Any, value: Tuple[ClientAppFn, Message, str, Context]) -> None: """Take an idle actor and assign it to run a client app and Message. Submit a job to an actor by first removing it from the list of idle actors, then @@ -247,7 +247,7 @@ def submit(self, fn: Any, value: Tuple[FlowerFn, Message, str, Context]) -> None self._cid_to_future[cid]["future"] = future_key def submit_client_job( - self, actor_fn: Any, job: Tuple[FlowerFn, Message, str, Context] + self, actor_fn: Any, job: Tuple[ClientAppFn, Message, str, Context] ) -> None: """Submit a job while tracking client ids.""" _, _, cid, _ = job diff --git a/src/py/flwr/simulation/ray_transport/ray_client_proxy.py b/src/py/flwr/simulation/ray_transport/ray_client_proxy.py index 99355606f335..ddac030b2ef0 100644 --- a/src/py/flwr/simulation/ray_transport/ray_client_proxy.py +++ b/src/py/flwr/simulation/ray_transport/ray_client_proxy.py @@ -63,8 +63,6 @@ def _load_app() -> ClientApp: def _submit_job(self, message: Message, timeout: Optional[float]) -> Message: """Sumbit a message to the ActorPool.""" - # For the time being, fixing run_id is a small compromise - # This will be one of the first points to address integrating VCE + DriverAPI run_id = message.metadata.run_id # Register state @@ -97,7 +95,10 @@ def _submit_job(self, message: Message, timeout: Optional[float]) -> Message: return out_mssg def _wrap_recordset_in_message( - self, recordset: RecordSet, task_type: str + self, + recordset: RecordSet, + message_type: str, + timeout: Optional[float], ) -> Message: """Wrap a RecordSet inside a Message.""" return Message( @@ -107,8 +108,8 @@ def _wrap_recordset_in_message( message_id="", group_id="", node_id=int(self.cid), - ttl="", - message_type=task_type, + ttl=str(timeout) if timeout else "", + message_type=message_type, ), ) @@ -118,7 +119,9 @@ def get_properties( """Return client's properties.""" recordset = getpropertiesins_to_recordset(ins) message = self._wrap_recordset_in_message( - recordset, task_type=MESSAGE_TYPE_GET_PROPERTIES + recordset, + message_type=MESSAGE_TYPE_GET_PROPERTIES, + timeout=timeout, ) message_out = self._submit_job(message, timeout) @@ -131,7 +134,9 @@ def get_parameters( """Return the current local model parameters.""" recordset = getparametersins_to_recordset(ins) message = self._wrap_recordset_in_message( - recordset, task_type=MESSAGE_TYPE_GET_PARAMETERS + recordset, + message_type=MESSAGE_TYPE_GET_PARAMETERS, + timeout=timeout, ) message_out = self._submit_job(message, timeout) @@ -143,7 +148,9 @@ def fit(self, ins: common.FitIns, timeout: Optional[float]) -> common.FitRes: recordset = fitins_to_recordset( ins, keep_input=True ) # This must stay TRUE since ins are in-memory - message = self._wrap_recordset_in_message(recordset, task_type=MESSAGE_TYPE_FIT) + message = self._wrap_recordset_in_message( + recordset, message_type=MESSAGE_TYPE_FIT, timeout=timeout + ) message_out = self._submit_job(message, timeout) @@ -157,7 +164,7 @@ def evaluate( ins, keep_input=True ) # This must stay TRUE since ins are in-memory message = self._wrap_recordset_in_message( - recordset, task_type=MESSAGE_TYPE_EVALUATE + recordset, message_type=MESSAGE_TYPE_EVALUATE, timeout=timeout ) message_out = self._submit_job(message, timeout) diff --git a/src/py/flwr/simulation/ray_transport/ray_client_proxy_test.py b/src/py/flwr/simulation/ray_transport/ray_client_proxy_test.py index a48dbf7ae455..f3049667ad0a 100644 --- a/src/py/flwr/simulation/ray_transport/ray_client_proxy_test.py +++ b/src/py/flwr/simulation/ray_transport/ray_client_proxy_test.py @@ -107,7 +107,9 @@ def test_cid_consistency_one_at_a_time() -> None: # submit jobs one at a time for prox in proxies: message = prox._wrap_recordset_in_message( # pylint: disable=protected-access - recordset, MESSAGE_TYPE_GET_PROPERTIES + recordset, + MESSAGE_TYPE_GET_PROPERTIES, + timeout=None, ) message_out = prox._submit_job( # pylint: disable=protected-access message=message, timeout=None @@ -141,7 +143,9 @@ def test_cid_consistency_all_submit_first_run_consistency() -> None: state = prox.proxy_state.retrieve_context(run_id=run_id) message = prox._wrap_recordset_in_message( # pylint: disable=protected-access - recordset, MESSAGE_TYPE_GET_PROPERTIES + recordset, + message_type=MESSAGE_TYPE_GET_PROPERTIES, + timeout=None, ) prox.actor_pool.submit_client_job( lambda a, a_fn, mssg, cid, state: a.run.remote(a_fn, mssg, cid, state), From a977a62c64c60c8b8473c2c7d8712a6dad2a3ff6 Mon Sep 17 00:00:00 2001 From: jafermarq Date: Wed, 14 Feb 2024 21:54:31 +0100 Subject: [PATCH 25/25] renamed DefaultActor --- src/py/flwr/simulation/app.py | 8 ++++---- src/py/flwr/simulation/ray_transport/ray_actor.py | 2 +- .../simulation/ray_transport/ray_client_proxy_test.py | 4 ++-- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/src/py/flwr/simulation/app.py b/src/py/flwr/simulation/app.py index b159042588c9..9ee230942890 100644 --- a/src/py/flwr/simulation/app.py +++ b/src/py/flwr/simulation/app.py @@ -35,7 +35,7 @@ from flwr.server.server_config import ServerConfig from flwr.server.strategy import Strategy from flwr.simulation.ray_transport.ray_actor import ( - DefaultActor, + ClientAppActor, VirtualClientEngineActor, VirtualClientEngineActorPool, pool_size_from_resources, @@ -83,7 +83,7 @@ def start_simulation( client_manager: Optional[ClientManager] = None, ray_init_args: Optional[Dict[str, Any]] = None, keep_initialised: Optional[bool] = False, - actor_type: Type[VirtualClientEngineActor] = DefaultActor, + actor_type: Type[VirtualClientEngineActor] = ClientAppActor, actor_kwargs: Optional[Dict[str, Any]] = None, actor_scheduling: Union[str, NodeAffinitySchedulingStrategy] = "DEFAULT", ) -> History: @@ -139,10 +139,10 @@ def start_simulation( keep_initialised: Optional[bool] (default: False) Set to True to prevent `ray.shutdown()` in case `ray.is_initialized()=True`. - actor_type: VirtualClientEngineActor (default: DefaultActor) + actor_type: VirtualClientEngineActor (default: ClientAppActor) Optionally specify the type of actor to use. The actor object, which persists throughout the simulation, will be the process in charge of - running the clients' jobs (i.e. their `fit()` method). + executing a ClientApp wrapping input argument `client_fn`. actor_kwargs: Optional[Dict[str, Any]] (default: None) If you want to create your own Actor classes, you might need to pass diff --git a/src/py/flwr/simulation/ray_transport/ray_actor.py b/src/py/flwr/simulation/ray_transport/ray_actor.py index c9fa43bcb3ca..974773a3f577 100644 --- a/src/py/flwr/simulation/ray_transport/ray_actor.py +++ b/src/py/flwr/simulation/ray_transport/ray_actor.py @@ -85,7 +85,7 @@ def run( @ray.remote -class DefaultActor(VirtualClientEngineActor): +class ClientAppActor(VirtualClientEngineActor): """A Ray Actor class that runs client runs. Parameters diff --git a/src/py/flwr/simulation/ray_transport/ray_client_proxy_test.py b/src/py/flwr/simulation/ray_transport/ray_client_proxy_test.py index f3049667ad0a..9ade31c323d8 100644 --- a/src/py/flwr/simulation/ray_transport/ray_client_proxy_test.py +++ b/src/py/flwr/simulation/ray_transport/ray_client_proxy_test.py @@ -35,7 +35,7 @@ ) from flwr.common.recordset_compat_test import _get_valid_getpropertiesins from flwr.simulation.ray_transport.ray_actor import ( - DefaultActor, + ClientAppActor, VirtualClientEngineActor, VirtualClientEngineActorPool, ) @@ -65,7 +65,7 @@ def get_dummy_client(cid: str) -> Client: def prep( - actor_type: Type[VirtualClientEngineActor] = DefaultActor, + actor_type: Type[VirtualClientEngineActor] = ClientAppActor, ) -> Tuple[List[RayActorClientProxy], VirtualClientEngineActorPool]: # pragma: no cover """Prepare ClientProxies and pool for tests.""" client_resources = {"num_cpus": 1, "num_gpus": 0.0}