Skip to content

Commit

Permalink
[2.4] Improve reliable message (#2452)
Browse files Browse the repository at this point in the history
* in dev

* refactor reliable msg

* integarte RM with XGB

* fix format

* make arg names consistent

* added arg check

* address pr comments; disable adaprot_test

---------

Co-authored-by: Yuan-Ting Hsieh (謝沅廷) <yuantingh@nvidia.com>
  • Loading branch information
yanchengnv and YuanTingHsieh authored Apr 1, 2024
1 parent a5af581 commit 84a1a98
Show file tree
Hide file tree
Showing 15 changed files with 329 additions and 300 deletions.
6 changes: 6 additions & 0 deletions nvflare/apis/fl_constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,6 +454,12 @@ class ConfigVarName:
# client: timeout for submitTaskResult requests
SUBMIT_TASK_RESULT_TIMEOUT = "submit_task_result_timeout"

# client and server: max number of request workers for reliable message
RM_MAX_REQUEST_WORKERS = "rm_max_request_workers"

# client and server: query interval for reliable message
RM_QUERY_INTERVAL = "rm_query_interval"


class SystemVarName:
"""
Expand Down
292 changes: 233 additions & 59 deletions nvflare/apis/utils/reliable_message.py

Large diffs are not rendered by default.

54 changes: 0 additions & 54 deletions nvflare/apis/utils/reliable_sender.py

This file was deleted.

86 changes: 0 additions & 86 deletions nvflare/apis/utils/sender.py

This file was deleted.

32 changes: 13 additions & 19 deletions nvflare/app_opt/xgboost/histogram_based_v2/adaptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,10 @@
from nvflare.apis.fl_context import FLContext
from nvflare.apis.shareable import Shareable
from nvflare.apis.signal import Signal
from nvflare.apis.utils.sender import Sender
from nvflare.apis.utils.reliable_message import ReliableMessage
from nvflare.app_opt.xgboost.histogram_based_v2.defs import Constant
from nvflare.app_opt.xgboost.histogram_based_v2.runner import XGBRunner
from nvflare.fuel.f3.cellnet.fqcn import FQCN
from nvflare.fuel.utils.validation_utils import check_non_negative_int, check_object_type, check_positive_int


Expand Down Expand Up @@ -280,29 +281,16 @@ class XGBClientAdaptor(XGBAdaptor, ABC):
XGBClientAdaptor specifies commonly required methods for client adaptor implementations.
"""

def __init__(self, req_timeout: float):
def __init__(self, per_msg_timeout: float, tx_timeout: float):
"""Constructor of XGBClientAdaptor"""
XGBAdaptor.__init__(self)
self.engine = None
self.sender = None
self.stopped = False
self.rank = None
self.num_rounds = None
self.world_size = None
self.req_timeout = req_timeout

def set_sender(self, sender: Sender):
"""Set the sender to be used to send XGB operation requests to the server.
Args:
sender: the sender to be set
Returns: None
"""
if not isinstance(sender, Sender):
raise TypeError(f"sender must be Sender but got {type(sender)}")
self.sender = sender
self.per_msg_timeout = per_msg_timeout
self.tx_timeout = tx_timeout

def configure(self, config: dict, fl_ctx: FLContext):
"""Called by XGB Executor to configure the target.
Expand Down Expand Up @@ -352,8 +340,14 @@ def _send_request(self, op: str, req: Shareable) -> bytes:
req.set_header(Constant.MSG_KEY_XGB_OP, op)

with self.engine.new_context() as fl_ctx:
reply = self.sender.send_to_server(
Constant.TOPIC_XGB_REQUEST, req, self.req_timeout, fl_ctx, self.abort_signal
reply = ReliableMessage.send_request(
target=FQCN.ROOT_SERVER,
topic=Constant.TOPIC_XGB_REQUEST,
request=req,
per_msg_timeout=self.per_msg_timeout,
tx_timeout=self.tx_timeout,
abort_signal=self.abort_signal,
fl_ctx=fl_ctx,
)

if isinstance(reply, Shareable):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,6 @@ def start_controller(self, fl_ctx: FLContext):
message_handle_func=self._process_client_done,
)

ReliableMessage.enable(fl_ctx)
ReliableMessage.register_request_handler(
topic=Constant.TOPIC_XGB_REQUEST,
handler_f=self._process_xgb_request,
Expand Down
43 changes: 0 additions & 43 deletions nvflare/app_opt/xgboost/histogram_based_v2/adaptor_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,39 +17,29 @@
from nvflare.apis.fl_context import FLContext
from nvflare.apis.shareable import Shareable, make_reply
from nvflare.apis.signal import Signal
from nvflare.apis.utils.sender import Sender, SimpleSender
from nvflare.app_opt.xgboost.histogram_based_v2.adaptor import XGBClientAdaptor
from nvflare.app_opt.xgboost.histogram_based_v2.defs import Constant
from nvflare.fuel.f3.cellnet.fqcn import FQCN
from nvflare.fuel.utils.validation_utils import check_str
from nvflare.security.logging import secure_format_exception


class XGBExecutor(Executor):
def __init__(
self,
adaptor_component_id: str,
sender_id: str = None,
configure_task_name=Constant.CONFIG_TASK_NAME,
start_task_name=Constant.START_TASK_NAME,
req_timeout=100.0,
):
"""Executor for XGB.
Args:
adaptor_component_id: the component ID of client target adaptor
sender_id: The sender component id
configure_task_name: name of the config task
start_task_name: name of the start task
"""
Executor.__init__(self)
self.adaptor_component_id = adaptor_component_id

if sender_id:
check_str("sender_id", sender_id)
self.sender_id = sender_id

self.req_timeout = req_timeout
self.configure_task_name = configure_task_name
self.start_task_name = start_task_name
self.adaptor = None
Expand Down Expand Up @@ -85,12 +75,7 @@ def handle_event(self, event_type: str, fl_ctx: FLContext):
)
return

sender = self._get_sender(fl_ctx)
if not sender:
return

adaptor.set_abort_signal(self.abort_signal)
adaptor.set_sender(sender)
adaptor.initialize(fl_ctx)
self.adaptor = adaptor
elif event_type == EventType.END_RUN:
Expand Down Expand Up @@ -178,31 +163,3 @@ def _notify_client_done(self, rc, fl_ctx: FLContext):
fl_ctx=fl_ctx,
optional=True,
)

def _get_sender(self, fl_ctx: FLContext) -> Sender:
"""Get request sender to be used by this executor.
Args:
fl_ctx: the FL context
Returns:
A sender object
"""

if self.sender_id:
engine = fl_ctx.get_engine()
sender = engine.get_component(self.sender_id)
if not sender:
self.system_panic(f"cannot get component for {self.sender_id}", fl_ctx)
else:
if not isinstance(sender, Sender):
self.system_panic(
f"invalid component '{self.sender_id}': expect {Sender.__name__} but got {type(sender)}",
fl_ctx,
)
sender = None

else:
sender = SimpleSender()

return sender
Original file line number Diff line number Diff line change
Expand Up @@ -87,21 +87,17 @@ class GrpcClientAdaptor(XGBClientAdaptor, FederatedServicer):
federated gRPC client.
"""

def __init__(
self,
int_server_grpc_options=None,
in_process=False,
req_timeout=100,
):
def __init__(self, int_server_grpc_options=None, in_process=False, per_msg_timeout=10.0, tx_timeout=100.0):
"""Constructor method to initialize the object.
Args:
int_server_grpc_options: An optional list of key-value pairs (`channel_arguments`
in gRPC Core runtime) to configure the gRPC channel of internal `GrpcServer`.
in_process (bool): Specifies whether to start the `XGBRunner` in the same process or not.
req_timeout: Request timeout
per_msg_timeout: Request per-msg timeout
tx_timeout: timeout for the whole req transaction
"""
XGBClientAdaptor.__init__(self, req_timeout)
XGBClientAdaptor.__init__(self, per_msg_timeout, tx_timeout)
self.int_server_grpc_options = int_server_grpc_options
self.in_process = in_process
self.internal_xgb_server = None
Expand Down Expand Up @@ -210,7 +206,7 @@ def start(self, fl_ctx: FLContext):
if not port:
raise RuntimeError("failed to get a port for XGB server")

self.internal_server_addr = f"localhost:{port}"
self.internal_server_addr = f"127.0.0.1:{port}"
self.logger.info(f"Start internal server at {self.internal_server_addr}")
self.internal_xgb_server = GrpcServer(
addr=self.internal_server_addr,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def start(self, fl_ctx: FLContext):
if not port:
raise RuntimeError("failed to get a port for XGB server")

server_addr = f"localhost:{port}"
server_addr = f"127.0.0.1:{port}"
self._start_server(addr=server_addr, port=port)

# start XGB client
Expand Down
12 changes: 12 additions & 0 deletions nvflare/app_opt/xgboost/histogram_based_v2/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from nvflare.app_opt.xgboost.histogram_based_v2.adaptors.grpc_server_adaptor import GrpcServerAdaptor
from nvflare.app_opt.xgboost.histogram_based_v2.defs import Constant
from nvflare.app_opt.xgboost.histogram_based_v2.runners.server_runner import XGBServerRunner
from nvflare.fuel.utils.validation_utils import check_object_type, check_positive_int, check_positive_number, check_str


class XGBFedController(XGBController):
Expand All @@ -33,6 +34,17 @@ def __init__(
client_ranks=None,
in_process=True,
):
check_positive_int("num_rounds", num_rounds)
check_str("configure_task_name", configure_task_name)
check_positive_number("configure_task_timeout", configure_task_timeout)
check_str("start_task_name", start_task_name)
check_positive_number("start_task_timeout", start_task_timeout)
check_positive_number("job_status_check_interval", job_status_check_interval)
check_positive_number("max_client_op_interval", max_client_op_interval)
check_positive_number("progress_timeout", progress_timeout)
if client_ranks is not None:
check_object_type("client_ranks", client_ranks, dict)

XGBController.__init__(
self,
adaptor_component_id="",
Expand Down
Loading

0 comments on commit 84a1a98

Please sign in to comment.