diff --git a/nvflare/app_opt/xgboost/tree_based/executor.py b/nvflare/app_opt/xgboost/tree_based/executor.py index 95b228c7ec..f7c8c69f15 100644 --- a/nvflare/app_opt/xgboost/tree_based/executor.py +++ b/nvflare/app_opt/xgboost/tree_based/executor.py @@ -26,7 +26,7 @@ from nvflare.apis.signal import Signal from nvflare.app_common.app_constant import AppConstants from nvflare.app_opt.xgboost.data_loader import XGBDataLoader -from nvflare.app_opt.xgboost.tree_based.shareable_generator import update_model +from nvflare.app_opt.xgboost.tree_based.utils import update_model from nvflare.fuel.utils.import_utils import optional_import from nvflare.security.logging import secure_format_exception diff --git a/nvflare/app_opt/xgboost/tree_based/shareable_generator.py b/nvflare/app_opt/xgboost/tree_based/shareable_generator.py index 85da2e3e3e..7fdb4f705b 100644 --- a/nvflare/app_opt/xgboost/tree_based/shareable_generator.py +++ b/nvflare/app_opt/xgboost/tree_based/shareable_generator.py @@ -20,44 +20,7 @@ from nvflare.app_common.abstract.model import ModelLearnable, ModelLearnableKey, model_learnable_to_dxo from nvflare.app_common.abstract.shareable_generator import ShareableGenerator from nvflare.app_common.app_constant import AppConstants - - -def _get_xgboost_model_attr(xgb_model): - num_parallel_tree = int( - xgb_model["learner"]["gradient_booster"]["model"]["gbtree_model_param"]["num_parallel_tree"] - ) - num_trees = int(xgb_model["learner"]["gradient_booster"]["model"]["gbtree_model_param"]["num_trees"]) - return num_parallel_tree, num_trees - - -def update_model(prev_model, model_update): - if not prev_model: - return model_update - else: - # Append all trees - # get the parameters - pre_num_parallel_tree, pre_num_trees = _get_xgboost_model_attr(prev_model) - cur_num_parallel_tree, add_num_trees = _get_xgboost_model_attr(model_update) - - # check num_parallel_tree, should be consistent - if cur_num_parallel_tree != pre_num_parallel_tree: - raise ValueError( - f"add_num_parallel_tree should not change, previous {pre_num_parallel_tree}, current {add_num_parallel_tree}" - ) - prev_model["learner"]["gradient_booster"]["model"]["gbtree_model_param"]["num_trees"] = str( - pre_num_trees + cur_num_parallel_tree - ) - # append the new trees - append_info = model_update["learner"]["gradient_booster"]["model"]["trees"] - for tree_ct in range(cur_num_parallel_tree): - append_info[tree_ct]["id"] = pre_num_trees + tree_ct - prev_model["learner"]["gradient_booster"]["model"]["trees"].append(append_info[tree_ct]) - prev_model["learner"]["gradient_booster"]["model"]["tree_info"].append(0) - # append iteration_indptr - prev_model["learner"]["gradient_booster"]["model"]["iteration_indptr"].append( - pre_num_trees + cur_num_parallel_tree - ) - return prev_model +from nvflare.app_opt.xgboost.tree_based.utils import update_model class XGBModelShareableGenerator(ShareableGenerator): diff --git a/nvflare/app_opt/xgboost/tree_based/utils.py b/nvflare/app_opt/xgboost/tree_based/utils.py new file mode 100644 index 0000000000..9e68c2fd26 --- /dev/null +++ b/nvflare/app_opt/xgboost/tree_based/utils.py @@ -0,0 +1,51 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +def _get_xgboost_model_attr(xgb_model): + num_parallel_tree = int( + xgb_model["learner"]["gradient_booster"]["model"]["gbtree_model_param"]["num_parallel_tree"] + ) + num_trees = int(xgb_model["learner"]["gradient_booster"]["model"]["gbtree_model_param"]["num_trees"]) + return num_parallel_tree, num_trees + + +def update_model(prev_model, model_update): + if not prev_model: + return model_update + else: + # Append all trees + # get the parameters + pre_num_parallel_tree, pre_num_trees = _get_xgboost_model_attr(prev_model) + cur_num_parallel_tree, add_num_trees = _get_xgboost_model_attr(model_update) + + # check num_parallel_tree, should be consistent + if cur_num_parallel_tree != pre_num_parallel_tree: + raise ValueError( + f"add_num_parallel_tree should not change, previous {pre_num_parallel_tree}, current {cur_num_parallel_tree}" + ) + prev_model["learner"]["gradient_booster"]["model"]["gbtree_model_param"]["num_trees"] = str( + pre_num_trees + cur_num_parallel_tree + ) + # append the new trees + append_info = model_update["learner"]["gradient_booster"]["model"]["trees"] + for tree_ct in range(cur_num_parallel_tree): + append_info[tree_ct]["id"] = pre_num_trees + tree_ct + prev_model["learner"]["gradient_booster"]["model"]["trees"].append(append_info[tree_ct]) + prev_model["learner"]["gradient_booster"]["model"]["tree_info"].append(0) + # append iteration_indptr + prev_model["learner"]["gradient_booster"]["model"]["iteration_indptr"].append( + pre_num_trees + cur_num_parallel_tree + ) + return prev_model