Skip to content

Commit

Permalink
Added metrics_writer back and fixed GRPC error reply
Browse files Browse the repository at this point in the history
  • Loading branch information
nvidianz committed Apr 19, 2024
1 parent 1a4ee4e commit 2b18bfa
Show file tree
Hide file tree
Showing 14 changed files with 100 additions and 36 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import grpc

import nvflare.app_opt.xgboost.histogram_based_v2.proto.federated_pb2 as pb2
from nvflare.apis.fl_constant import FLContextKey
from nvflare.apis.fl_context import FLContext
Expand Down Expand Up @@ -135,7 +137,9 @@ def Allgather(self, request: pb2.AllgatherRequest, context):
return pb2.AllgatherReply(receive_buffer=rcv_buf)
except Exception as ex:
self._abort(reason=f"send_all_gather exception: {secure_format_exception(ex)}")
return None
context.set_code(grpc.StatusCode.INTERNAL)
context.set_details(str(ex))
return pb2.AllgatherReply(receive_buffer=None)

def AllgatherV(self, request: pb2.AllgatherVRequest, context):
try:
Expand All @@ -147,7 +151,9 @@ def AllgatherV(self, request: pb2.AllgatherVRequest, context):
return pb2.AllgatherVReply(receive_buffer=rcv_buf)
except Exception as ex:
self._abort(reason=f"send_all_gather_v exception: {secure_format_exception(ex)}")
return None
context.set_code(grpc.StatusCode.INTERNAL)
context.set_details(str(ex))
return pb2.AllgatherVReply(receive_buffer=None)

def Allreduce(self, request: pb2.AllreduceRequest, context):
try:
Expand All @@ -161,7 +167,9 @@ def Allreduce(self, request: pb2.AllreduceRequest, context):
return pb2.AllreduceReply(receive_buffer=rcv_buf)
except Exception as ex:
self._abort(reason=f"send_all_reduce exception: {secure_format_exception(ex)}")
return None
context.set_code(grpc.StatusCode.INTERNAL)
context.set_details(str(ex))
return pb2.AllreduceReply(receive_buffer=None)

def Broadcast(self, request: pb2.BroadcastRequest, context):
try:
Expand All @@ -174,4 +182,6 @@ def Broadcast(self, request: pb2.BroadcastRequest, context):
return pb2.BroadcastReply(receive_buffer=rcv_buf)
except Exception as ex:
self._abort(reason=f"send_broadcast exception: {secure_format_exception(ex)}")
return None
context.set_code(grpc.StatusCode.INTERNAL)
context.set_details(str(ex))
return pb2.BroadcastReply(receive_buffer=None)
3 changes: 3 additions & 0 deletions nvflare/app_opt/xgboost/histogram_based_v2/fed_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def __init__(
int_server_grpc_options=None,
req_timeout=60.0,
model_file_name="model.json",
metrics_writer_id: str = None,
in_process=True,
):
XGBExecutor.__init__(
Expand All @@ -43,6 +44,7 @@ def __init__(
self.use_gpus = use_gpus
self.int_server_grpc_options = int_server_grpc_options
self.model_file_name = model_file_name
self.metrics_writer_id = metrics_writer_id
self.in_process = in_process

def get_adaptor(self, fl_ctx: FLContext):
Expand All @@ -53,6 +55,7 @@ def get_adaptor(self, fl_ctx: FLContext):
verbose_eval=self.verbose_eval,
use_gpus=self.use_gpus,
model_file_name=self.model_file_name,
metrics_writer_id=self.metrics_writer_id,
)
runner.initialize(fl_ctx)
adaptor = GrpcClientAdaptor(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@

import threading

import nvflare.app_common.xgb.proto.federated_pb2 as pb2
from nvflare.app_common.xgb.proto.federated_pb2_grpc import FederatedServicer
import nvflare.app_opt.xgboost.histogram_based_v2.proto.federated_pb2 as pb2
from nvflare.app_opt.xgboost.histogram_based_v2.proto.federated_pb2_grpc import FederatedServicer
from nvflare.fuel.utils.obj_utils import get_logger


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@
import os
import time

import nvflare.app_common.xgb.proto.federated_pb2 as pb2
import nvflare.app_opt.xgboost.histogram_based_v2.proto.federated_pb2 as pb2
from nvflare.apis.fl_component import FLComponent
from nvflare.app_common.xgb.defs import Constant
from nvflare.app_common.xgb.grpc_client import GrpcClient
from nvflare.app_common.xgb.runners.xgb_runner import AppRunner
from nvflare.app_opt.xgboost.histogram_based_v2.defs import Constant
from nvflare.app_opt.xgboost.histogram_based_v2.grpc_client import GrpcClient
from nvflare.app_opt.xgboost.histogram_based_v2.runners.xgb_runner import AppRunner


class MockClientRunner(AppRunner, FLComponent):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@
# limitations under the License.

from nvflare.apis.fl_context import FLContext
from nvflare.app_common.xgb.adaptors.grpc_server_adaptor import GrpcServerAdaptor
from nvflare.app_common.xgb.controller import XGBController
from nvflare.app_common.xgb.defs import Constant
from nvflare.app_common.xgb.mock.mock_server_runner import MockServerRunner
from nvflare.app_opt.xgboost.histogram_based_v2.adaptors.grpc_server_adaptor import GrpcServerAdaptor
from nvflare.app_opt.xgboost.histogram_based_v2.controller import XGBController
from nvflare.app_opt.xgboost.histogram_based_v2.defs import Constant
from nvflare.app_opt.xgboost.histogram_based_v2.mock.mock_server_runner import MockServerRunner


class MockXGBController(XGBController):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@
from typing import Dict, List, Tuple

from nvflare.apis.fl_context import FLContext
from nvflare.app_common.xgb.aggr import Aggregator
from nvflare.app_common.xgb.defs import Constant
from nvflare.app_common.xgb.sec.data_converter import (
from nvflare.app_opt.xgboost.histogram_based_v2.aggr import Aggregator
from nvflare.app_opt.xgboost.histogram_based_v2.defs import Constant
from nvflare.app_opt.xgboost.histogram_based_v2.sec.data_converter import (
AggregationContext,
DataConverter,
FeatureAggregationResult,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from nvflare.apis.fl_context import FLContext
from nvflare.app_common.xgb.adaptors.grpc_client_adaptor import GrpcClientAdaptor
from nvflare.app_common.xgb.executor import XGBExecutor
from nvflare.app_common.xgb.mock.mock_client_runner import MockClientRunner
from nvflare.app_opt.xgboost.histogram_based_v2.adaptors.grpc_client_adaptor import GrpcClientAdaptor
from nvflare.app_opt.xgboost.histogram_based_v2.executor import XGBExecutor
from nvflare.app_opt.xgboost.histogram_based_v2.mock.mock_client_runner import MockClientRunner


class MockXGBExecutor(XGBExecutor):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@
import json
import time

import nvflare.app_common.xgb.proto.federated_pb2 as pb2
import nvflare.app_opt.xgboost.histogram_based_v2.proto.federated_pb2 as pb2
from nvflare.apis.fl_component import FLComponent
from nvflare.app_common.xgb.defs import Constant
from nvflare.app_common.xgb.grpc_client import GrpcClient
from nvflare.app_common.xgb.runners.xgb_runner import AppRunner
from nvflare.app_opt.xgboost.histogram_based_v2.defs import Constant
from nvflare.app_opt.xgboost.histogram_based_v2.grpc_client import GrpcClient
from nvflare.app_opt.xgboost.histogram_based_v2.runners.xgb_runner import AppRunner


def encode_msg(msg: dict):
Expand Down Expand Up @@ -51,7 +51,7 @@ def run(self, ctx: dict):
self.training_stopped = True
return

#### fake bcst
# fake bcst
data = {
"op": "none",
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from nvflare.apis.fl_context import FLContext
from nvflare.app_common.xgb.adaptors.grpc_client_adaptor import GrpcClientAdaptor
from nvflare.app_common.xgb.executor import XGBExecutor
from nvflare.app_common.xgb.mock.mock_secure_client_runner import MockSecureClientRunner
from nvflare.app_opt.xgboost.histogram_based_v2.adaptors.grpc_client_adaptor import GrpcClientAdaptor
from nvflare.app_opt.xgboost.histogram_based_v2.executor import XGBExecutor
from nvflare.app_opt.xgboost.histogram_based_v2.mock.mock_secure_client_runner import MockSecureClientRunner


class MockSecureXGBExecutor(XGBExecutor):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from nvflare.app_common.xgb.defs import Constant
from nvflare.app_common.xgb.grpc_server import GrpcServer
from nvflare.app_common.xgb.mock.aggr_servicer import AggrServicer
from nvflare.app_common.xgb.runners.xgb_runner import AppRunner
from nvflare.app_opt.xgboost.histogram_based_v2.defs import Constant
from nvflare.app_opt.xgboost.histogram_based_v2.grpc_server import GrpcServer
from nvflare.app_opt.xgboost.histogram_based_v2.mock.aggr_servicer import AggrServicer
from nvflare.app_opt.xgboost.histogram_based_v2.runners.xgb_runner import AppRunner


class MockServerRunner(AppRunner):
Expand Down
4 changes: 2 additions & 2 deletions nvflare/app_opt/xgboost/histogram_based_v2/mock/run_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@
import os
import time

import nvflare.app_common.xgb.proto.federated_pb2 as pb2
from nvflare.app_common.xgb.grpc_client import GrpcClient
import nvflare.app_opt.xgboost.histogram_based_v2.proto.federated_pb2 as pb2
from nvflare.app_opt.xgboost.histogram_based_v2.grpc_client import GrpcClient


def main():
Expand Down
4 changes: 2 additions & 2 deletions nvflare/app_opt/xgboost/histogram_based_v2/mock/run_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
import argparse
import logging

from nvflare.app_common.xgb.grpc_server import GrpcServer
from nvflare.app_common.xgb.mock.aggr_servicer import AggrServicer
from nvflare.app_opt.xgboost.histogram_based_v2.grpc_server import GrpcServer
from nvflare.app_opt.xgboost.histogram_based_v2.mock.aggr_servicer import AggrServicer


def main():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,13 @@

from nvflare.apis.fl_component import FLComponent
from nvflare.apis.fl_context import FLContext
from nvflare.app_common.tracking.log_writer import LogWriter
from nvflare.app_opt.xgboost.data_loader import XGBDataLoader
from nvflare.app_opt.xgboost.histogram_based_v2.defs import Constant
from nvflare.app_opt.xgboost.histogram_based_v2.runners.xgb_runner import AppRunner
from nvflare.app_opt.xgboost.histogram_based_v2.tb import TensorBoardCallback
from nvflare.app_opt.xgboost.histogram_based_v2.xgb_params import XGBoostParams
from nvflare.app_opt.xgboost.metrics_cb import MetricsCallback
from nvflare.fuel.utils.import_utils import optional_import
from nvflare.fuel.utils.obj_utils import get_logger

Expand All @@ -36,6 +38,7 @@ def __init__(
verbose_eval,
use_gpus,
model_file_name,
metrics_writer_id: str = None,
):
FLComponent.__init__(self)
self.early_stopping_rounds = early_stopping_rounds
Expand All @@ -55,13 +58,20 @@ def __init__(
self._tb_dir = None
self._model_dir = None
self._stopped = False
self._metrics_writer_id = metrics_writer_id
self._metrics_writer = None

def initialize(self, fl_ctx: FLContext):
engine = fl_ctx.get_engine()
self._data_loader = engine.get_component(self.data_loader_id)
if not isinstance(self._data_loader, XGBDataLoader):
self.system_panic(f"data_loader should be type XGBDataLoader but got {type(self._data_loader)}", fl_ctx)

if self._metrics_writer_id:
self._metrics_writer = engine.get_component(self._metrics_writer_id)
if not isinstance(self._metrics_writer, LogWriter):
self.system_panic("writer should be type LogWriter", fl_ctx)

def _xgb_train(self, params: XGBoostParams, train_data, val_data) -> xgb.core.Booster:
"""XGBoost training logic.
Expand All @@ -75,6 +85,9 @@ def _xgb_train(self, params: XGBoostParams, train_data, val_data) -> xgb.core.Bo
watchlist = [(val_data, "eval"), (train_data, "train")]

callbacks = [callback.EvaluationMonitor(rank=self._rank)]
if self._metrics_writer:
callbacks.append(MetricsCallback(self._metrics_writer))

tensorboard, flag = optional_import(module="torch.utils.tensorboard")
if flag and self._tb_dir:
callbacks.append(TensorBoardCallback(self._tb_dir, tensorboard))
Expand Down
38 changes: 38 additions & 0 deletions nvflare/app_opt/xgboost/metrics_cb.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import xgboost.callback

from nvflare.app_common.tracking.log_writer import LogWriter


class MetricsCallback(xgboost.callback.TrainingCallback):
def __init__(self, writer: LogWriter):
xgboost.callback.TrainingCallback.__init__(self)
if not isinstance(writer, LogWriter):
raise RuntimeError("MetricsCallback: writer is not valid.")
self.writer = writer

def after_iteration(self, model, epoch: int, evals_log: xgboost.callback.TrainingCallback.EvalsLog):
if not evals_log:
return False

data_type = self.writer.get_default_metric_data_type()
for data, metric in evals_log.items():
record = {}
for metric_name, log in metric.items():
score = log[-1][0] if isinstance(log[-1], tuple) else log[-1]
record[metric_name] = score
self.writer.write(tag=f"{data}_metrics", value=record, data_type=data_type, global_step=epoch)
return False

0 comments on commit 2b18bfa

Please sign in to comment.