diff --git a/doc/source/how-to-run-simulations.rst b/doc/source/how-to-run-simulations.rst index 6e0520a79bf5..d1dcb511ed51 100644 --- a/doc/source/how-to-run-simulations.rst +++ b/doc/source/how-to-run-simulations.rst @@ -29,7 +29,7 @@ Running Flower simulations still require you to define your client class, a stra def client_fn(cid: str): # Return a standard Flower client - return MyFlowerClient() + return MyFlowerClient().to_client() # Launch the simulation hist = fl.simulation.start_simulation( diff --git a/e2e/pandas/simulation.py b/e2e/pandas/simulation.py index 91af84062712..b548b5ebb760 100644 --- a/e2e/pandas/simulation.py +++ b/e2e/pandas/simulation.py @@ -1,12 +1,8 @@ import flwr as fl -from client import FlowerClient +from client import client_fn from strategy import FedAnalytics -def client_fn(cid): - _ = cid - return FlowerClient() - hist = fl.simulation.start_simulation( client_fn=client_fn, num_clients=2, diff --git a/src/py/flwr/client/message_handler/message_handler.py b/src/py/flwr/client/message_handler/message_handler.py index 93de7d7d8821..f8c8a725aec7 100644 --- a/src/py/flwr/client/message_handler/message_handler.py +++ b/src/py/flwr/client/message_handler/message_handler.py @@ -109,7 +109,7 @@ def handle_legacy_message_from_msgtype( client_fn: ClientFn, message: Message, context: Context ) -> Message: """Handle legacy message in the inner most mod.""" - client = client_fn("-1") + client = client_fn(str(message.metadata.node_id)) client.set_context(context) diff --git a/src/py/flwr/simulation/app.py b/src/py/flwr/simulation/app.py index b159042588c9..9ee230942890 100644 --- a/src/py/flwr/simulation/app.py +++ b/src/py/flwr/simulation/app.py @@ -35,7 +35,7 @@ from flwr.server.server_config import ServerConfig from flwr.server.strategy import Strategy from flwr.simulation.ray_transport.ray_actor import ( - DefaultActor, + ClientAppActor, VirtualClientEngineActor, VirtualClientEngineActorPool, pool_size_from_resources, @@ -83,7 +83,7 @@ def start_simulation( client_manager: Optional[ClientManager] = None, ray_init_args: Optional[Dict[str, Any]] = None, keep_initialised: Optional[bool] = False, - actor_type: Type[VirtualClientEngineActor] = DefaultActor, + actor_type: Type[VirtualClientEngineActor] = ClientAppActor, actor_kwargs: Optional[Dict[str, Any]] = None, actor_scheduling: Union[str, NodeAffinitySchedulingStrategy] = "DEFAULT", ) -> History: @@ -139,10 +139,10 @@ def start_simulation( keep_initialised: Optional[bool] (default: False) Set to True to prevent `ray.shutdown()` in case `ray.is_initialized()=True`. - actor_type: VirtualClientEngineActor (default: DefaultActor) + actor_type: VirtualClientEngineActor (default: ClientAppActor) Optionally specify the type of actor to use. The actor object, which persists throughout the simulation, will be the process in charge of - running the clients' jobs (i.e. their `fit()` method). + executing a ClientApp wrapping input argument `client_fn`. actor_kwargs: Optional[Dict[str, Any]] (default: None) If you want to create your own Actor classes, you might need to pass diff --git a/src/py/flwr/simulation/ray_transport/ray_actor.py b/src/py/flwr/simulation/ray_transport/ray_actor.py index 853566a4cbeb..974773a3f577 100644 --- a/src/py/flwr/simulation/ray_transport/ray_actor.py +++ b/src/py/flwr/simulation/ray_transport/ray_actor.py @@ -25,18 +25,12 @@ from ray import ObjectRef from ray.util.actor_pool import ActorPool -from flwr import common -from flwr.client import Client, ClientFn +from flwr.client.clientapp import ClientApp from flwr.common.context import Context from flwr.common.logger import log -from flwr.simulation.ray_transport.utils import check_clientfn_returns_client +from flwr.common.message import Message -# All possible returns by a client -ClientRes = Union[ - common.GetPropertiesRes, common.GetParametersRes, common.FitRes, common.EvaluateRes -] -# A function to be executed by a client to obtain some results -JobFn = Callable[[Client], ClientRes] +ClientAppFn = Callable[[], ClientApp] class ClientException(Exception): @@ -58,27 +52,25 @@ def terminate(self) -> None: def run( self, - client_fn: ClientFn, - job_fn: JobFn, + client_app_fn: ClientAppFn, + message: Message, cid: str, context: Context, - ) -> Tuple[str, ClientRes, Context]: + ) -> Tuple[str, Message, Context]: """Run a client run.""" - # Execute tasks and return result + # Pass message through ClientApp and return a message # return also cid which is needed to ensure results # from the pool are correctly assigned to each ClientProxy try: - # Instantiate client (check 'Client' type is returned) - client = check_clientfn_returns_client(client_fn(cid)) - # Inject context - client.set_context(context) - # Run client job - job_results = job_fn(client) - # Retrieve context (potentially updated) - updated_context = client.get_context() + # Load app + app: ClientApp = client_app_fn() + + # Handle task message + out_message = app(message=message, context=context) + except Exception as ex: client_trace = traceback.format_exc() - message = ( + mssg = ( "\n\tSomething went wrong when running your client run." "\n\tClient " + cid @@ -87,13 +79,13 @@ def run( + " was running its run." "\n\tException triggered on the client side: " + client_trace, ) - raise ClientException(str(message)) from ex + raise ClientException(str(mssg)) from ex - return cid, job_results, updated_context + return cid, out_message, context @ray.remote -class DefaultActor(VirtualClientEngineActor): +class ClientAppActor(VirtualClientEngineActor): """A Ray Actor class that runs client runs. Parameters @@ -237,16 +229,16 @@ def add_actors_to_pool(self, num_actors: int) -> None: self._idle_actors.extend(new_actors) self.num_actors += num_actors - def submit(self, fn: Any, value: Tuple[ClientFn, JobFn, str, Context]) -> None: - """Take idle actor and assign it a client run. + def submit(self, fn: Any, value: Tuple[ClientAppFn, Message, str, Context]) -> None: + """Take an idle actor and assign it to run a client app and Message. Submit a job to an actor by first removing it from the list of idle actors, then - check if this actor was flagged to be removed from the pool + check if this actor was flagged to be removed from the pool. """ - client_fn, job_fn, cid, context = value + app_fn, mssg, cid, context = value actor = self._idle_actors.pop() if self._check_and_remove_actor_from_pool(actor): - future = fn(actor, client_fn, job_fn, cid, context) + future = fn(actor, app_fn, mssg, cid, context) future_key = tuple(future) if isinstance(future, List) else future self._future_to_actor[future_key] = (self._next_task_index, actor, cid) self._next_task_index += 1 @@ -255,7 +247,7 @@ def submit(self, fn: Any, value: Tuple[ClientFn, JobFn, str, Context]) -> None: self._cid_to_future[cid]["future"] = future_key def submit_client_job( - self, actor_fn: Any, job: Tuple[ClientFn, JobFn, str, Context] + self, actor_fn: Any, job: Tuple[ClientAppFn, Message, str, Context] ) -> None: """Submit a job while tracking client ids.""" _, _, cid, _ = job @@ -295,7 +287,7 @@ def _is_future_ready(self, cid: str) -> bool: return self._cid_to_future[cid]["ready"] # type: ignore - def _fetch_future_result(self, cid: str) -> Tuple[ClientRes, Context]: + def _fetch_future_result(self, cid: str) -> Tuple[Message, Context]: """Fetch result and updated context for a VirtualClient from Object Store. The job submitted by the ClientProxy interfacing with client with cid=cid is @@ -303,9 +295,9 @@ def _fetch_future_result(self, cid: str) -> Tuple[ClientRes, Context]: """ try: future: ObjectRef[Any] = self._cid_to_future[cid]["future"] # type: ignore - res_cid, res, updated_context = ray.get( + res_cid, out_mssg, updated_context = ray.get( future - ) # type: (str, ClientRes, Context) + ) # type: (str, Message, Context) except ray.exceptions.RayActorError as ex: log(ERROR, ex) if hasattr(ex, "actor_id"): @@ -322,7 +314,7 @@ def _fetch_future_result(self, cid: str) -> Tuple[ClientRes, Context]: # Reset mapping self._reset_cid_to_future_dict(cid) - return res, updated_context + return out_mssg, updated_context def _flag_actor_for_removal(self, actor_id_hex: str) -> None: """Flag actor that should be removed from pool.""" @@ -409,7 +401,7 @@ def process_unordered_future(self, timeout: Optional[float] = None) -> None: def get_client_result( self, cid: str, timeout: Optional[float] - ) -> Tuple[ClientRes, Context]: + ) -> Tuple[Message, Context]: """Get result from VirtualClient with specific cid.""" # Loop until all jobs submitted to the pool are completed. Break early # if the result for the ClientProxy calling this method is ready diff --git a/src/py/flwr/simulation/ray_transport/ray_client_proxy.py b/src/py/flwr/simulation/ray_transport/ray_client_proxy.py index 894012dc6d70..ddac030b2ef0 100644 --- a/src/py/flwr/simulation/ray_transport/ray_client_proxy.py +++ b/src/py/flwr/simulation/ray_transport/ray_client_proxy.py @@ -17,107 +17,33 @@ import traceback from logging import ERROR -from typing import Dict, Optional, cast - -import ray +from typing import Optional from flwr import common -from flwr.client import Client, ClientFn -from flwr.client.client import ( - maybe_call_evaluate, - maybe_call_fit, - maybe_call_get_parameters, - maybe_call_get_properties, -) +from flwr.client import ClientFn +from flwr.client.clientapp import ClientApp from flwr.client.node_state import NodeState +from flwr.common.constant import ( + MESSAGE_TYPE_EVALUATE, + MESSAGE_TYPE_FIT, + MESSAGE_TYPE_GET_PARAMETERS, + MESSAGE_TYPE_GET_PROPERTIES, +) from flwr.common.logger import log -from flwr.server.client_proxy import ClientProxy -from flwr.simulation.ray_transport.ray_actor import ( - ClientRes, - JobFn, - VirtualClientEngineActorPool, +from flwr.common.message import Message, Metadata +from flwr.common.recordset import RecordSet +from flwr.common.recordset_compat import ( + evaluateins_to_recordset, + fitins_to_recordset, + getparametersins_to_recordset, + getpropertiesins_to_recordset, + recordset_to_evaluateres, + recordset_to_fitres, + recordset_to_getparametersres, + recordset_to_getpropertiesres, ) - - -class RayClientProxy(ClientProxy): - """Flower client proxy which delegates work using Ray.""" - - def __init__(self, client_fn: ClientFn, cid: str, resources: Dict[str, float]): - super().__init__(cid) - self.client_fn = client_fn - self.resources = resources - - def get_properties( - self, ins: common.GetPropertiesIns, timeout: Optional[float] - ) -> common.GetPropertiesRes: - """Return client's properties.""" - future_get_properties_res = launch_and_get_properties.options( # type: ignore - **self.resources, - ).remote(self.client_fn, self.cid, ins) - try: - res = ray.get(future_get_properties_res, timeout=timeout) - except Exception as ex: - log(ERROR, ex) - raise ex - return cast( - common.GetPropertiesRes, - res, - ) - - def get_parameters( - self, ins: common.GetParametersIns, timeout: Optional[float] - ) -> common.GetParametersRes: - """Return the current local model parameters.""" - future_paramseters_res = launch_and_get_parameters.options( # type: ignore - **self.resources, - ).remote(self.client_fn, self.cid, ins) - try: - res = ray.get(future_paramseters_res, timeout=timeout) - except Exception as ex: - log(ERROR, ex) - raise ex - return cast( - common.GetParametersRes, - res, - ) - - def fit(self, ins: common.FitIns, timeout: Optional[float]) -> common.FitRes: - """Train model parameters on the locally held dataset.""" - future_fit_res = launch_and_fit.options( # type: ignore - **self.resources, - ).remote(self.client_fn, self.cid, ins) - try: - res = ray.get(future_fit_res, timeout=timeout) - except Exception as ex: - log(ERROR, ex) - raise ex - return cast( - common.FitRes, - res, - ) - - def evaluate( - self, ins: common.EvaluateIns, timeout: Optional[float] - ) -> common.EvaluateRes: - """Evaluate model parameters on the locally held dataset.""" - future_evaluate_res = launch_and_evaluate.options( # type: ignore - **self.resources, - ).remote(self.client_fn, self.cid, ins) - try: - res = ray.get(future_evaluate_res, timeout=timeout) - except Exception as ex: - log(ERROR, ex) - raise ex - return cast( - common.EvaluateRes, - res, - ) - - def reconnect( - self, ins: common.ReconnectIns, timeout: Optional[float] - ) -> common.DisconnectRes: - """Disconnect and (optionally) reconnect later.""" - return common.DisconnectRes(reason="") # Nothing to do here (yet) +from flwr.server.client_proxy import ClientProxy +from flwr.simulation.ray_transport.ray_actor import VirtualClientEngineActorPool class RayActorClientProxy(ClientProxy): @@ -127,15 +53,17 @@ def __init__( self, client_fn: ClientFn, cid: str, actor_pool: VirtualClientEngineActorPool ): super().__init__(cid) - self.client_fn = client_fn + + def _load_app() -> ClientApp: + return ClientApp(client_fn=client_fn) + + self.app_fn = _load_app self.actor_pool = actor_pool self.proxy_state = NodeState() - def _submit_job(self, job_fn: JobFn, timeout: Optional[float]) -> ClientRes: - # The VCE is not exposed to TaskIns, it won't handle multilple runs - # For the time being, fixing run_id is a small compromise - # This will be one of the first points to address integrating VCE + DriverAPI - run_id = 0 + def _submit_job(self, message: Message, timeout: Optional[float]) -> Message: + """Sumbit a message to the ActorPool.""" + run_id = message.metadata.run_id # Register state self.proxy_state.register_context(run_id=run_id) @@ -145,10 +73,12 @@ def _submit_job(self, job_fn: JobFn, timeout: Optional[float]) -> ClientRes: try: self.actor_pool.submit_client_job( - lambda a, c_fn, j_fn, cid, state: a.run.remote(c_fn, j_fn, cid, state), - (self.client_fn, job_fn, self.cid, state), + lambda a, a_fn, mssg, cid, state: a.run.remote(a_fn, mssg, cid, state), + (self.app_fn, message, self.cid, state), + ) + out_mssg, updated_context = self.actor_pool.get_client_result( + self.cid, timeout ) - res, updated_context = self.actor_pool.get_client_result(self.cid, timeout) # Update state self.proxy_state.update_context(run_id=run_id, context=updated_context) @@ -162,134 +92,87 @@ def _submit_job(self, job_fn: JobFn, timeout: Optional[float]) -> ClientRes: log(ERROR, ex) raise ex - return res + return out_mssg + + def _wrap_recordset_in_message( + self, + recordset: RecordSet, + message_type: str, + timeout: Optional[float], + ) -> Message: + """Wrap a RecordSet inside a Message.""" + return Message( + content=recordset, + metadata=Metadata( + run_id=0, + message_id="", + group_id="", + node_id=int(self.cid), + ttl=str(timeout) if timeout else "", + message_type=message_type, + ), + ) def get_properties( self, ins: common.GetPropertiesIns, timeout: Optional[float] ) -> common.GetPropertiesRes: """Return client's properties.""" + recordset = getpropertiesins_to_recordset(ins) + message = self._wrap_recordset_in_message( + recordset, + message_type=MESSAGE_TYPE_GET_PROPERTIES, + timeout=timeout, + ) - def get_properties(client: Client) -> common.GetPropertiesRes: - return maybe_call_get_properties( - client=client, - get_properties_ins=ins, - ) + message_out = self._submit_job(message, timeout) - res = self._submit_job(get_properties, timeout) - - return cast( - common.GetPropertiesRes, - res, - ) + return recordset_to_getpropertiesres(message_out.content) def get_parameters( self, ins: common.GetParametersIns, timeout: Optional[float] ) -> common.GetParametersRes: """Return the current local model parameters.""" + recordset = getparametersins_to_recordset(ins) + message = self._wrap_recordset_in_message( + recordset, + message_type=MESSAGE_TYPE_GET_PARAMETERS, + timeout=timeout, + ) - def get_parameters(client: Client) -> common.GetParametersRes: - return maybe_call_get_parameters( - client=client, - get_parameters_ins=ins, - ) - - res = self._submit_job(get_parameters, timeout) + message_out = self._submit_job(message, timeout) - return cast( - common.GetParametersRes, - res, - ) + return recordset_to_getparametersres(message_out.content, keep_input=False) def fit(self, ins: common.FitIns, timeout: Optional[float]) -> common.FitRes: """Train model parameters on the locally held dataset.""" + recordset = fitins_to_recordset( + ins, keep_input=True + ) # This must stay TRUE since ins are in-memory + message = self._wrap_recordset_in_message( + recordset, message_type=MESSAGE_TYPE_FIT, timeout=timeout + ) - def fit(client: Client) -> common.FitRes: - return maybe_call_fit( - client=client, - fit_ins=ins, - ) + message_out = self._submit_job(message, timeout) - res = self._submit_job(fit, timeout) - - return cast( - common.FitRes, - res, - ) + return recordset_to_fitres(message_out.content, keep_input=False) def evaluate( self, ins: common.EvaluateIns, timeout: Optional[float] ) -> common.EvaluateRes: """Evaluate model parameters on the locally held dataset.""" + recordset = evaluateins_to_recordset( + ins, keep_input=True + ) # This must stay TRUE since ins are in-memory + message = self._wrap_recordset_in_message( + recordset, message_type=MESSAGE_TYPE_EVALUATE, timeout=timeout + ) - def evaluate(client: Client) -> common.EvaluateRes: - return maybe_call_evaluate( - client=client, - evaluate_ins=ins, - ) - - res = self._submit_job(evaluate, timeout) + message_out = self._submit_job(message, timeout) - return cast( - common.EvaluateRes, - res, - ) + return recordset_to_evaluateres(message_out.content) def reconnect( self, ins: common.ReconnectIns, timeout: Optional[float] ) -> common.DisconnectRes: """Disconnect and (optionally) reconnect later.""" return common.DisconnectRes(reason="") # Nothing to do here (yet) - - -@ray.remote -def launch_and_get_properties( - client_fn: ClientFn, cid: str, get_properties_ins: common.GetPropertiesIns -) -> common.GetPropertiesRes: - """Exectue get_properties remotely.""" - client: Client = _create_client(client_fn, cid) - return maybe_call_get_properties( - client=client, - get_properties_ins=get_properties_ins, - ) - - -@ray.remote -def launch_and_get_parameters( - client_fn: ClientFn, cid: str, get_parameters_ins: common.GetParametersIns -) -> common.GetParametersRes: - """Exectue get_parameters remotely.""" - client: Client = _create_client(client_fn, cid) - return maybe_call_get_parameters( - client=client, - get_parameters_ins=get_parameters_ins, - ) - - -@ray.remote -def launch_and_fit( - client_fn: ClientFn, cid: str, fit_ins: common.FitIns -) -> common.FitRes: - """Exectue fit remotely.""" - client: Client = _create_client(client_fn, cid) - return maybe_call_fit( - client=client, - fit_ins=fit_ins, - ) - - -@ray.remote -def launch_and_evaluate( - client_fn: ClientFn, cid: str, evaluate_ins: common.EvaluateIns -) -> common.EvaluateRes: - """Exectue evaluate remotely.""" - client: Client = _create_client(client_fn, cid) - return maybe_call_evaluate( - client=client, - evaluate_ins=evaluate_ins, - ) - - -def _create_client(client_fn: ClientFn, cid: str) -> Client: - """Create a client instance.""" - # Materialize client - return client_fn(cid) diff --git a/src/py/flwr/simulation/ray_transport/ray_client_proxy_test.py b/src/py/flwr/simulation/ray_transport/ray_client_proxy_test.py index b380d37d01c8..9ade31c323d8 100644 --- a/src/py/flwr/simulation/ray_transport/ray_client_proxy_test.py +++ b/src/py/flwr/simulation/ray_transport/ray_client_proxy_test.py @@ -17,19 +17,25 @@ from math import pi from random import shuffle -from typing import List, Tuple, Type, cast +from typing import Dict, List, Tuple, Type import ray from flwr.client import Client, NumPyClient -from flwr.common import Code, GetPropertiesRes, Status +from flwr.client.clientapp import ClientApp +from flwr.common import Config, Scalar from flwr.common.configsrecord import ConfigsRecord +from flwr.common.constant import MESSAGE_TYPE_GET_PROPERTIES from flwr.common.context import Context +from flwr.common.message import Message, Metadata from flwr.common.recordset import RecordSet +from flwr.common.recordset_compat import ( + getpropertiesins_to_recordset, + recordset_to_getpropertiesres, +) +from flwr.common.recordset_compat_test import _get_valid_getpropertiesins from flwr.simulation.ray_transport.ray_actor import ( - ClientRes, - DefaultActor, - JobFn, + ClientAppActor, VirtualClientEngineActor, VirtualClientEngineActorPool, ) @@ -42,34 +48,24 @@ class DummyClient(NumPyClient): def __init__(self, cid: str) -> None: self.cid = int(cid) - -def get_dummy_client(cid: str) -> Client: - """Return a DummyClient converted to Client type.""" - return DummyClient(cid).to_client() - - -# A dummy run -def job_fn(cid: str) -> JobFn: # pragma: no cover - """Construct a simple job with cid dependency.""" - - def cid_times_pi(client: Client) -> ClientRes: # pylint: disable=unused-argument - result = int(cid) * pi + def get_properties(self, config: Config) -> Dict[str, Scalar]: + """Return properties by doing a simple calculation.""" + result = int(self.cid) * pi # store something in context - client.numpy_client.context.state.set_configs( # type: ignore + self.context.state.set_configs( "result", record=ConfigsRecord({"result": str(result)}) ) + return {"result": result} - # now let's convert it to a GetPropertiesRes response - return GetPropertiesRes( - status=Status(Code(0), message="test"), properties={"result": result} - ) - return cid_times_pi +def get_dummy_client(cid: str) -> Client: + """Return a DummyClient converted to Client type.""" + return DummyClient(cid).to_client() def prep( - actor_type: Type[VirtualClientEngineActor] = DefaultActor, + actor_type: Type[VirtualClientEngineActor] = ClientAppActor, ) -> Tuple[List[RayActorClientProxy], VirtualClientEngineActorPool]: # pragma: no cover """Prepare ClientProxies and pool for tests.""" client_resources = {"num_cpus": 1, "num_gpus": 0.0} @@ -104,13 +100,23 @@ def test_cid_consistency_one_at_a_time() -> None: Submit one job and waits for completion. Then submits the next and so on """ proxies, _ = prep() + + getproperties_ins = _get_valid_getpropertiesins() + recordset = getpropertiesins_to_recordset(getproperties_ins) + # submit jobs one at a time for prox in proxies: - res = prox._submit_job( # pylint: disable=protected-access - job_fn=job_fn(prox.cid), timeout=None + message = prox._wrap_recordset_in_message( # pylint: disable=protected-access + recordset, + MESSAGE_TYPE_GET_PROPERTIES, + timeout=None, ) + message_out = prox._submit_job( # pylint: disable=protected-access + message=message, timeout=None + ) + + res = recordset_to_getpropertiesres(message_out.content) - res = cast(GetPropertiesRes, res) assert int(prox.cid) * pi == res.properties["result"] ray.shutdown() @@ -125,6 +131,9 @@ def test_cid_consistency_all_submit_first_run_consistency() -> None: proxies, _ = prep() run_id = 0 + getproperties_ins = _get_valid_getpropertiesins() + recordset = getpropertiesins_to_recordset(getproperties_ins) + # submit all jobs (collect later) shuffle(proxies) for prox in proxies: @@ -133,18 +142,24 @@ def test_cid_consistency_all_submit_first_run_consistency() -> None: # Retrieve state state = prox.proxy_state.retrieve_context(run_id=run_id) - job = job_fn(prox.cid) + message = prox._wrap_recordset_in_message( # pylint: disable=protected-access + recordset, + message_type=MESSAGE_TYPE_GET_PROPERTIES, + timeout=None, + ) prox.actor_pool.submit_client_job( - lambda a, c_fn, j_fn, cid, state: a.run.remote(c_fn, j_fn, cid, state), - (prox.client_fn, job, prox.cid, state), + lambda a, a_fn, mssg, cid, state: a.run.remote(a_fn, mssg, cid, state), + (prox.app_fn, message, prox.cid, state), ) # fetch results one at a time shuffle(proxies) for prox in proxies: - res, updated_context = prox.actor_pool.get_client_result(prox.cid, timeout=None) + message_out, updated_context = prox.actor_pool.get_client_result( + prox.cid, timeout=None + ) prox.proxy_state.update_context(run_id, context=updated_context) - res = cast(GetPropertiesRes, res) + res = recordset_to_getpropertiesres(message_out.content) assert int(prox.cid) * pi == res.properties["result"] assert ( @@ -163,20 +178,36 @@ def test_cid_consistency_without_proxies() -> None: num_clients = len(proxies) cids = [str(cid) for cid in range(num_clients)] + getproperties_ins = _get_valid_getpropertiesins() + recordset = getpropertiesins_to_recordset(getproperties_ins) + + def _load_app() -> ClientApp: + return ClientApp(client_fn=get_dummy_client) + # submit all jobs (collect later) shuffle(cids) for cid in cids: - job = job_fn(cid) + message = Message( + content=recordset, + metadata=Metadata( + run_id=0, + message_id="", + group_id="", + ttl="", + node_id=int(cid), + message_type=MESSAGE_TYPE_GET_PROPERTIES, + ), + ) pool.submit_client_job( lambda a, c_fn, j_fn, cid_, state: a.run.remote(c_fn, j_fn, cid_, state), - (get_dummy_client, job, cid, Context(state=RecordSet())), + (_load_app, message, cid, Context(state=RecordSet())), ) # fetch results one at a time shuffle(cids) for cid in cids: - res, _ = pool.get_client_result(cid, timeout=None) - res = cast(GetPropertiesRes, res) + message_out, _ = pool.get_client_result(cid, timeout=None) + res = recordset_to_getpropertiesres(message_out.content) assert int(cid) * pi == res.properties["result"] ray.shutdown() diff --git a/src/py/flwr/simulation/ray_transport/utils.py b/src/py/flwr/simulation/ray_transport/utils.py index dd9fb6b2aa85..3861164998a4 100644 --- a/src/py/flwr/simulation/ray_transport/utils.py +++ b/src/py/flwr/simulation/ray_transport/utils.py @@ -18,7 +18,6 @@ import warnings from logging import ERROR -from flwr.client import Client from flwr.common.logger import log try: @@ -60,25 +59,3 @@ def enable_tf_gpu_growth() -> None: log(ERROR, traceback.format_exc()) log(ERROR, ex) raise ex - - -def check_clientfn_returns_client(client: Client) -> Client: - """Warn once that clients returned in `clinet_fn` should be of type Client. - - This is here for backwards compatibility. If a ClientFn is provided returning - a different type of client (e.g. NumPyClient) we'll warn the user but convert - the client internally to `Client` by calling `.to_client()`. - """ - if not isinstance(client, Client): - mssg = ( - " Ensure your client is of type `flwr.client.Client`. Please convert it" - " using the `.to_client()` method before returning it" - " in the `client_fn` you pass to `start_simulation`." - " We have applied this conversion on your behalf." - " Not returning a `Client` might trigger an error in future" - " versions of Flower." - ) - - warnings.warn(mssg, DeprecationWarning, stacklevel=2) - client = client.to_client() - return client