From 12b963ab487fca9d8291978a575b1c3f07c7e908 Mon Sep 17 00:00:00 2001 From: Dmitry Razdoburdin Date: Wed, 12 Jul 2023 19:11:57 +0200 Subject: [PATCH] Model builders API update (#1320) --------- Co-authored-by: Dmitry Razdoburdin <> Co-authored-by: Nikolay Petrov Co-authored-by: Alexander Andreev --- daal4py/__init__.py | 5 + daal4py/mb/__init__.py | 20 ++ daal4py/mb/model_builders.py | 222 ++++++++++++++++++ daal4py/sklearn/ensemble/GBTDAAL.py | 74 +++--- daal4py/sklearn/ensemble/__init__.py | 4 +- doc/daal4py/examples.rst | 6 +- doc/daal4py/model-builders.rst | 6 +- ...st_batch.py => model_builders_catboost.py} | 21 +- ...bm_batch.py => model_builders_lightgbm.py} | 20 +- ...ost_batch.py => model_builders_xgboost.py} | 20 +- setup.py | 1 + src/gbt_convertors.pyx | 46 ++-- tests/run_examples.py | 6 +- tests/test_examples.py | 8 +- 14 files changed, 352 insertions(+), 107 deletions(-) create mode 100644 daal4py/mb/__init__.py create mode 100644 daal4py/mb/model_builders.py rename examples/daal4py/{gbt_cls_model_create_from_catboost_batch.py => model_builders_catboost.py} (82%) rename examples/daal4py/{gbt_cls_model_create_from_lightgbm_batch.py => model_builders_lightgbm.py} (83%) rename examples/daal4py/{gbt_cls_model_create_from_xgboost_batch.py => model_builders_xgboost.py} (83%) diff --git a/daal4py/__init__.py b/daal4py/__init__.py index b011a14514..f3b7181c27 100644 --- a/daal4py/__init__.py +++ b/daal4py/__init__.py @@ -49,3 +49,8 @@ '[c]sh/psxevars.[c]sh may solve the issue.\n') raise + +from . import mb +from . import sklearn + +__all__ = ['mb', 'sklearn'] diff --git a/daal4py/mb/__init__.py b/daal4py/mb/__init__.py new file mode 100644 index 0000000000..279681ca07 --- /dev/null +++ b/daal4py/mb/__init__.py @@ -0,0 +1,20 @@ +#!/usr/bin/env python +#=============================================================================== +# Copyright 2023 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +#=============================================================================== + +from .model_builders import GBTDAALBaseModel, convert_model + +__all__ = ['GBTDAALBaseModel', 'convert_model'] diff --git a/daal4py/mb/model_builders.py b/daal4py/mb/model_builders.py new file mode 100644 index 0000000000..c3f2a99be7 --- /dev/null +++ b/daal4py/mb/model_builders.py @@ -0,0 +1,222 @@ +#=============================================================================== +# Copyright 2023 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +#=============================================================================== + +# daal4py Model builders API + +import numpy as np +import daal4py as d4p + +try: + from pandas import DataFrame + from pandas.core.dtypes.cast import find_common_type + pandas_is_imported = True +except (ImportError, ModuleNotFoundError): + pandas_is_imported = False + + +def parse_dtype(dt): + if dt == np.double: + return "double" + if dt == np.single: + return "float" + raise ValueError(f"Input array has unexpected dtype = {dt}") + + +def getFPType(X): + if pandas_is_imported: + if isinstance(X, DataFrame): + dt = find_common_type(X.dtypes.tolist()) + return parse_dtype(dt) + + dt = getattr(X, 'dtype', None) + return parse_dtype(dt) + + +class GBTDAALBaseModel: + def _get_params_from_lightgbm(self, params): + self.n_classes_ = params["num_tree_per_iteration"] + objective_fun = params["objective"] + if self.n_classes_ <= 2: + if "binary" in objective_fun: # nClasses == 1 + self.n_classes_ = 2 + + self.n_features_in_ = params["max_feature_idx"] + 1 + + def _get_params_from_xgboost(self, params): + self.n_classes_ = int(params["learner"]["learner_model_param"]["num_class"]) + objective_fun = params["learner"]["learner_train_param"]["objective"] + if self.n_classes_ <= 2: + if objective_fun in ["binary:logistic", "binary:logitraw"]: + self.n_classes_ = 2 + + self.n_features_in_ = int(params["learner"]["learner_model_param"]["num_feature"]) + + def _get_params_from_catboost(self, params): + if 'class_params' in params['model_info']: + self.n_classes_ = len(params['model_info']['class_params']['class_to_label']) + self.n_features_in_ = len(params['features_info']['float_features']) + + def _convert_model_from_lightgbm(self, booster): + lgbm_params = d4p.get_lightgbm_params(booster) + self.daal_model_ = d4p.get_gbt_model_from_lightgbm(booster, lgbm_params) + self._get_params_from_lightgbm(lgbm_params) + + def _convert_model_from_xgboost(self, booster): + xgb_params = d4p.get_xgboost_params(booster) + self.daal_model_ = d4p.get_gbt_model_from_xgboost(booster, xgb_params) + self._get_params_from_xgboost(xgb_params) + + def _convert_model_from_catboost(self, booster): + catboost_params = d4p.get_catboost_params(booster) + self.daal_model_ = d4p.get_gbt_model_from_catboost(booster) + self._get_params_from_catboost(catboost_params) + + def _convert_model(self, model): + (submodule_name, class_name) = (model.__class__.__module__, + model.__class__.__name__) + self_class_name = self.__class__.__name__ + + # Build GBTDAALClassifier from LightGBM + if (submodule_name, class_name) == ("lightgbm.sklearn", "LGBMClassifier"): + if self_class_name == "GBTDAALClassifier": + self._convert_model_from_lightgbm(model.booster_) + else: + raise TypeError(f"Only GBTDAALClassifier can be created from\ + {submodule_name}.{class_name} (got {self_class_name})") + # Build GBTDAALClassifier from XGBoost + elif (submodule_name, class_name) == ("xgboost.sklearn", "XGBClassifier"): + if self_class_name == "GBTDAALClassifier": + self._convert_model_from_xgboost(model.get_booster()) + else: + raise TypeError(f"Only GBTDAALClassifier can be created from\ + {submodule_name}.{class_name} (got {self_class_name})") + # Build GBTDAALClassifier from CatBoost + elif (submodule_name, class_name) == ("catboost.core", "CatBoostClassifier"): + if self_class_name == "GBTDAALClassifier": + self._convert_model_from_catboost(model) + else: + raise TypeError(f"Only GBTDAALClassifier can be created from\ + {submodule_name}.{class_name} (got {self_class_name})") + # Build GBTDAALRegressor from LightGBM + elif (submodule_name, class_name) == ("lightgbm.sklearn", "LGBMRegressor"): + if self_class_name == "GBTDAALRegressor": + self._convert_model_from_lightgbm(model.booster_) + else: + raise TypeError(f"Only GBTDAALRegressor can be created from\ + {submodule_name}.{class_name} (got {self_class_name})") + # Build GBTDAALRegressor from XGBoost + elif (submodule_name, class_name) == ("xgboost.sklearn", "XGBRegressor"): + if self_class_name == "GBTDAALRegressor": + self._convert_model_from_xgboost(model.get_booster()) + else: + raise TypeError(f"Only GBTDAALRegressor can be created from\ + {submodule_name}.{class_name} (got {self_class_name})") + # Build GBTDAALRegressor from CatBoost + elif (submodule_name, class_name) == ("catboost.core", "CatBoostRegressor"): + if self_class_name == "GBTDAALRegressor": + self._convert_model_from_catboost(model) + else: + raise TypeError(f"Only GBTDAALRegressor can be created from\ + {submodule_name}.{class_name} (got {self_class_name})") + # Build GBTDAALModel from LightGBM + elif (submodule_name, class_name) == ("lightgbm.basic", "Booster"): + if self_class_name == "GBTDAALModel": + self._convert_model_from_lightgbm(model) + else: + raise TypeError(f"Only GBTDAALModel can be created from\ + {submodule_name}.{class_name} (got {self_class_name})") + # Build GBTDAALModel from XGBoost + elif (submodule_name, class_name) == ("xgboost.core", "Booster"): + if self_class_name == "GBTDAALModel": + self._convert_model_from_xgboost(model) + else: + raise TypeError(f"Only GBTDAALModel can be created from\ + {submodule_name}.{class_name} (got {self_class_name})") + # Build GBTDAALModel from CatBoost + elif (submodule_name, class_name) == ("catboost.core", "CatBoost"): + if self_class_name == "GBTDAALModel": + self._convert_model_from_catboost(model) + else: + raise TypeError(f"Only GBTDAALModel can be created from\ + {submodule_name}.{class_name} (got {self_class_name})") + else: + raise TypeError(f"Unknown model format {submodule_name}.{class_name}") + + def _predict_classification(self, X, fptype, resultsToEvaluate): + if X.shape[1] != self.n_features_in_: + raise ValueError('Shape of input is different from what was seen in `fit`') + + if not hasattr(self, 'daal_model_'): + raise ValueError(( + "The class {} instance does not have 'daal_model_' attribute set. " + "Call 'fit' with appropriate arguments before using this method.") + .format(type(self).__name__)) + + # Prediction + predict_algo = d4p.gbt_classification_prediction( + fptype=fptype, + nClasses=self.n_classes_, + resultsToEvaluate=resultsToEvaluate) + predict_result = predict_algo.compute(X, self.daal_model_) + + if resultsToEvaluate == "computeClassLabels": + return predict_result.prediction.ravel().astype(np.int64, copy=False) + else: + return predict_result.probabilities + + def _predict_regression(self, X, fptype): + if X.shape[1] != self.n_features_in_: + raise ValueError('Shape of input is different from what was seen in `fit`') + + if not hasattr(self, 'daal_model_'): + raise ValueError(( + "The class {} instance does not have 'daal_model_' attribute set. " + "Call 'fit' with appropriate arguments before using this method.").format( + type(self).__name__)) + + # Prediction + predict_algo = d4p.gbt_regression_prediction(fptype=fptype) + predict_result = predict_algo.compute(X, self.daal_model_) + + return predict_result.prediction.ravel() + + +class GBTDAALModel(GBTDAALBaseModel): + def __init__(self): + pass + + def predict(self, X): + fptype = getFPType(X) + if self._is_regression: + return self._predict_regression(X, fptype) + else: + return self._predict_classification(X, fptype, "computeClassLabels") + + def predict_proba(self, X): + fptype = getFPType(X) + if self._is_regression: + raise NotImplementedError("Can't predict probabilities for regression task") + else: + return self._predict_classification(X, fptype, "computeClassProbabilities") + + +def convert_model(model): + gbm = GBTDAALModel() + gbm._convert_model(model) + + gbm._is_regression = isinstance(gbm.daal_model_, d4p.gbt_regression_model) + + return gbm diff --git a/daal4py/sklearn/ensemble/GBTDAAL.py b/daal4py/sklearn/ensemble/GBTDAAL.py index 5b9c7a2e26..1ea795b2cd 100644 --- a/daal4py/sklearn/ensemble/GBTDAAL.py +++ b/daal4py/sklearn/ensemble/GBTDAAL.py @@ -27,7 +27,7 @@ from .._utils import getFPType -class GBTDAALBase(BaseEstimator): +class GBTDAALBase(BaseEstimator, d4p.mb.GBTDAALBaseModel): def __init__(self, split_method='inexact', max_iterations=50, @@ -101,6 +101,11 @@ def _check_params(self): raise ValueError('Parameter "min_bin_size" must be ' 'non-zero positive integer value.') + allow_nan_ = False + + def _more_tags(self): + return {"allow_nan": self.allow_nan_} + class GBTDAALClassifier(GBTDAALBase, ClassifierMixin): def fit(self, X, y): @@ -165,41 +170,28 @@ def fit(self, X, y): return self def _predict(self, X, resultsToEvaluate): + # Input validation + if not self.allow_nan_: + X = check_array(X, dtype=[np.single, np.double]) + else: + X = check_array(X, dtype=[np.single, np.double], force_all_finite='allow-nan') + # Check is fit had been called check_is_fitted(self, ['n_features_in_', 'n_classes_']) - # Input validation - X = check_array(X, dtype=[np.single, np.double]) - if X.shape[1] != self.n_features_in_: - raise ValueError('Shape of input is different from what was seen in `fit`') - # Trivial case if self.n_classes_ == 1: return np.full(X.shape[0], self.classes_[0]) - if not hasattr(self, 'daal_model_'): - raise ValueError(( - "The class {} instance does not have 'daal_model_' attribute set. " - "Call 'fit' with appropriate arguments before using this method.").format( - type(self).__name__)) - - # Define type of data fptype = getFPType(X) - - # Prediction - predict_algo = d4p.gbt_classification_prediction( - fptype=fptype, - nClasses=self.n_classes_, - resultsToEvaluate=resultsToEvaluate) - predict_result = predict_algo.compute(X, self.daal_model_) + predict_result = self._predict_classification(X, fptype, resultsToEvaluate) if resultsToEvaluate == "computeClassLabels": # Decode labels le = preprocessing.LabelEncoder() le.classes_ = self.classes_ - return le.inverse_transform( - predict_result.prediction.ravel().astype(np.int64, copy=False)) - return predict_result.probabilities + return le.inverse_transform(predict_result) + return predict_result def predict(self, X): return self._predict(X, "computeClassLabels") @@ -218,6 +210,14 @@ def predict_log_proba(self, X): return proba + def convert_model(model): + gbm = GBTDAALClassifier() + gbm._convert_model(model) + + gbm.classes_ = model.classes_ + gbm.allow_nan_ = True + return gbm + class GBTDAALRegressor(GBTDAALBase, RegressorMixin): def fit(self, X, y): @@ -264,25 +264,21 @@ def fit(self, X, y): return self def predict(self, X): - # Check is fit had been called - check_is_fitted(self, ['n_features_in_']) - # Input validation - X = check_array(X, dtype=[np.single, np.double]) - if X.shape[1] != self.n_features_in_: - raise ValueError('Shape of input is different from what was seen in `fit`') + if not self.allow_nan_: + X = check_array(X, dtype=[np.single, np.double]) + else: + X = check_array(X, dtype=[np.single, np.double], force_all_finite='allow-nan') - if not hasattr(self, 'daal_model_'): - raise ValueError(( - "The class {} instance does not have 'daal_model_' attribute set. " - "Call 'fit' with appropriate arguments before using this method.").format( - type(self).__name__)) + # Check is fit had been called + check_is_fitted(self, ['n_features_in_']) - # Define type of data fptype = getFPType(X) + return self._predict_regression(X, fptype) - # Prediction - predict_algo = d4p.gbt_regression_prediction(fptype=fptype) - predict_result = predict_algo.compute(X, self.daal_model_) + def convert_model(model): + gbm = GBTDAALRegressor() + gbm._convert_model(model) - return predict_result.prediction.ravel() + gbm.allow_nan_ = True + return gbm diff --git a/daal4py/sklearn/ensemble/__init__.py b/daal4py/sklearn/ensemble/__init__.py index ada8550520..15e97b423c 100644 --- a/daal4py/sklearn/ensemble/__init__.py +++ b/daal4py/sklearn/ensemble/__init__.py @@ -19,5 +19,5 @@ from .GBTDAAL import (GBTDAALClassifier, GBTDAALRegressor) from .AdaBoostClassifier import AdaBoostClassifier -__all__ = ['RandomForestClassifier', 'RandomForestRegressor', 'GBTDAALClassifier', - 'GBTDAALRegressor', 'AdaBoostClassifier'] +__all__ = ['RandomForestClassifier', 'RandomForestRegressor', + 'GBTDAALClassifier', 'GBTDAALRegressor', 'AdaBoostClassifier'] diff --git a/doc/daal4py/examples.rst b/doc/daal4py/examples.rst index 49b70d3e6e..792c15a138 100755 --- a/doc/daal4py/examples.rst +++ b/doc/daal4py/examples.rst @@ -25,9 +25,9 @@ General usage Building models from Gradient Boosting frameworks -- `XGBoost* model conversion `_ -- `LightGBM* model conversion `_ -- `CatBoost* model conversion `_ +- `XGBoost* model conversion `_ +- `LightGBM* model conversion `_ +- `CatBoost* model conversion `_ Principal Component Analysis (PCA) Transform diff --git a/doc/daal4py/model-builders.rst b/doc/daal4py/model-builders.rst index 90d4a5e93f..146e191787 100644 --- a/doc/daal4py/model-builders.rst +++ b/doc/daal4py/model-builders.rst @@ -90,9 +90,9 @@ Examples --------------------------------- Model Builders models conversion -- `XGBoost model conversion `_ -- `LightGBM model conversion `_ -- `CatBoost model conversion `_ +- `XGBoost model conversion `_ +- `LightGBM model conversion `_ +- `CatBoost model conversion `_ Articles and Blog Posts --------------------------------- diff --git a/examples/daal4py/gbt_cls_model_create_from_catboost_batch.py b/examples/daal4py/model_builders_catboost.py similarity index 82% rename from examples/daal4py/gbt_cls_model_create_from_catboost_batch.py rename to examples/daal4py/model_builders_catboost.py index 19a90bc487..ba497fef87 100644 --- a/examples/daal4py/gbt_cls_model_create_from_catboost_batch.py +++ b/examples/daal4py/model_builders_catboost.py @@ -26,7 +26,7 @@ def pd_read_csv(f, c=None, t=np.float64): return pd.read_csv(f, usecols=c, delimiter=',', header=None, dtype=t) -def main(readcsv=pd_read_csv, method='defaultDense'): +def main(readcsv=pd_read_csv): # Path to data train_file = "./data/batch/df_classification_train.csv" test_file = "./data/batch/df_classification_test.csv" @@ -44,12 +44,12 @@ def main(readcsv=pd_read_csv, method='defaultDense'): # training parameters setting params = { 'reg_lambda': 1, - 'max_depth': 8, - 'num_leaves': 2**8, + 'max_depth': 6, + 'num_leaves': 2**6, 'verbose': 0, 'objective': 'MultiClass', 'learning_rate': 0.3, - 'n_estimators': 100, + 'n_estimators': 25, 'classes_count': 5, } @@ -62,19 +62,14 @@ def main(readcsv=pd_read_csv, method='defaultDense'): cb_errors_count = np.count_nonzero(cb_prediction - np.ravel(y_test)) # Conversion to daal4py - daal_model = d4p.get_gbt_model_from_catboost(cb_model) + daal_model = d4p.mb.convert_model(cb_model) # daal4py prediction - daal_predict_algo = d4p.gbt_classification_prediction( - nClasses=params['classes_count'], - resultsToEvaluate="computeClassLabels", - fptype='float' - ) - daal_prediction = daal_predict_algo.compute(X_test, daal_model) - daal_errors_count = np.count_nonzero(daal_prediction.prediction - y_test) + daal_prediction = daal_model.predict(X_test) + daal_errors_count = np.count_nonzero(daal_prediction - np.ravel(y_test)) assert np.absolute(cb_errors_count - daal_errors_count) == 0 - return (cb_prediction, cb_errors_count, np.ravel(daal_prediction.prediction), + return (cb_prediction, cb_errors_count, daal_prediction, daal_errors_count, np.ravel(y_test)) diff --git a/examples/daal4py/gbt_cls_model_create_from_lightgbm_batch.py b/examples/daal4py/model_builders_lightgbm.py similarity index 83% rename from examples/daal4py/gbt_cls_model_create_from_lightgbm_batch.py rename to examples/daal4py/model_builders_lightgbm.py index 5bc7cbeb5c..b8b5279496 100644 --- a/examples/daal4py/gbt_cls_model_create_from_lightgbm_batch.py +++ b/examples/daal4py/model_builders_lightgbm.py @@ -26,7 +26,7 @@ def pd_read_csv(f, c=None, t=np.float64): return pd.read_csv(f, usecols=c, delimiter=',', header=None, dtype=t) -def main(readcsv=pd_read_csv, method='defaultDense'): +def main(readcsv=pd_read_csv): # Path to data train_file = "./data/batch/df_classification_train.csv" test_file = "./data/batch/df_classification_test.csv" @@ -50,12 +50,13 @@ def main(readcsv=pd_read_csv, method='defaultDense'): 'scale_pos_weight': 2, 'lambda_l2': 1, 'alpha': 0.9, - 'max_depth': 8, - 'num_leaves': 2**8, + 'max_depth': 6, + 'num_leaves': 2**6, 'verbose': -1, 'objective': 'multiclass', 'learning_rate': 0.3, 'num_class': 5, + 'n_estimators': 25 } # Training @@ -66,19 +67,14 @@ def main(readcsv=pd_read_csv, method='defaultDense'): lgb_errors_count = np.count_nonzero(lgb_prediction - np.ravel(y_test)) # Conversion to daal4py - daal_model = d4p.get_gbt_model_from_lightgbm(lgb_model) + daal_model = d4p.mb.convert_model(lgb_model) # daal4py prediction - daal_predict_algo = d4p.gbt_classification_prediction( - nClasses=params["num_class"], - resultsToEvaluate="computeClassLabels", - fptype='float' - ) - daal_prediction = daal_predict_algo.compute(X_test, daal_model) - daal_errors_count = np.count_nonzero(daal_prediction.prediction - y_test) + daal_prediction = daal_model.predict(X_test) + daal_errors_count = np.count_nonzero(daal_prediction - np.ravel(y_test)) assert np.absolute(lgb_errors_count - daal_errors_count) == 0 - return (lgb_prediction, lgb_errors_count, np.ravel(daal_prediction.prediction), + return (lgb_prediction, lgb_errors_count, daal_prediction, daal_errors_count, np.ravel(y_test)) diff --git a/examples/daal4py/gbt_cls_model_create_from_xgboost_batch.py b/examples/daal4py/model_builders_xgboost.py similarity index 83% rename from examples/daal4py/gbt_cls_model_create_from_xgboost_batch.py rename to examples/daal4py/model_builders_xgboost.py index 80557d1e86..2c29327df5 100644 --- a/examples/daal4py/gbt_cls_model_create_from_xgboost_batch.py +++ b/examples/daal4py/model_builders_xgboost.py @@ -26,7 +26,7 @@ def pd_read_csv(f, c=None, t=np.float64): return pd.read_csv(f, usecols=c, delimiter=',', header=None, dtype=t) -def main(readcsv=pd_read_csv, method='defaultDense'): +def main(readcsv=pd_read_csv): # Path to data train_file = "./data/batch/df_classification_train.csv" test_file = "./data/batch/df_classification_test.csv" @@ -47,12 +47,13 @@ def main(readcsv=pd_read_csv, method='defaultDense'): 'scale_pos_weight': 2, 'lambda_l2': 1, 'alpha': 0.9, - 'max_depth': 8, - 'num_leaves': 2**8, + 'max_depth': 6, + 'num_leaves': 2**6, 'verbosity': 0, 'objective': 'multi:softmax', 'learning_rate': 0.3, 'num_class': 5, + 'n_estimators': 25, } # Training @@ -63,19 +64,14 @@ def main(readcsv=pd_read_csv, method='defaultDense'): xgb_errors_count = np.count_nonzero(xgb_prediction - np.ravel(y_test)) # Conversion to daal4py - daal_model = d4p.get_gbt_model_from_xgboost(xgb_model) + daal_model = d4p.mb.convert_model(xgb_model) # daal4py prediction - daal_predict_algo = d4p.gbt_classification_prediction( - nClasses=params["num_class"], - resultsToEvaluate="computeClassLabels", - fptype='float' - ) - daal_prediction = daal_predict_algo.compute(X_test, daal_model) - daal_errors_count = np.count_nonzero(daal_prediction.prediction - y_test) + daal_prediction = daal_model.predict(X_test) + daal_errors_count = np.count_nonzero(daal_prediction - np.ravel(y_test)) assert np.absolute(xgb_errors_count - daal_errors_count) == 0 - return (xgb_prediction, xgb_errors_count, np.ravel(daal_prediction.prediction), + return (xgb_prediction, xgb_errors_count, daal_prediction, daal_errors_count, np.ravel(y_test)) diff --git a/setup.py b/setup.py index c48e019f06..aaac0ab582 100644 --- a/setup.py +++ b/setup.py @@ -414,6 +414,7 @@ def run(self): packages_with_tests = [ 'daal4py', 'daal4py.oneapi', + 'daal4py.mb', 'daal4py.sklearn', 'daal4py.sklearn.cluster', 'daal4py.sklearn.decomposition', diff --git a/src/gbt_convertors.pyx b/src/gbt_convertors.pyx index f6c10b96bf..5a330bc2c3 100755 --- a/src/gbt_convertors.pyx +++ b/src/gbt_convertors.pyx @@ -21,14 +21,35 @@ import json import re from time import time -def get_gbt_model_from_lightgbm(model: Any) -> Any: +def get_lightgbm_params(booster): + return booster.dump_model() + +def get_xgboost_params(booster): + return json.loads(booster.save_config()) + +def get_catboost_params(booster): + dump_filename = f"catboost_model_{getpid()}_{time()}" + + # Dump model in file + booster.save_model(dump_filename, 'json') + + # Read json with model + with open(dump_filename) as file: + model_data = json.load(file) + + # Delete dump file + remove(dump_filename) + return model_data + +def get_gbt_model_from_lightgbm(model: Any, lgb_model = None) -> Any: class Node: def __init__(self, tree: Dict[str, Any], parent_id: int, position: int): self.tree = tree self.parent_id = parent_id self.position = position - lgb_model = model.dump_model() + if lgb_model is None: + lgb_model = get_lightgbm_params(model) n_features = lgb_model["max_feature_idx"] + 1 n_iterations = len(lgb_model["tree_info"]) / lgb_model["num_tree_per_iteration"] @@ -117,7 +138,7 @@ def get_gbt_model_from_lightgbm(model: Any) -> Any: return mb.model() -def get_gbt_model_from_xgboost(booster: Any) -> Any: +def get_gbt_model_from_xgboost(booster: Any, xgb_config=None) -> Any: class Node: def __init__(self, tree: Dict, parent_id: int, position: int): self.tree = tree @@ -131,7 +152,9 @@ def get_gbt_model_from_xgboost(booster: Any) -> Any: booster.feature_names = [str(i) for i in lst] trees_arr = booster.get_dump(dump_format="json") - xgb_config = json.loads(booster.save_config()) + if xgb_config is None: + xgb_config = get_xgboost_params(booster) + n_features = int(xgb_config["learner"]["learner_model_param"]["num_feature"]) n_classes = int(xgb_config["learner"]["learner_model_param"]["num_class"]) @@ -238,22 +261,13 @@ def get_gbt_model_from_xgboost(booster: Any) -> Any: return mb.model() -def get_gbt_model_from_catboost(model: Any) -> Any: +def get_gbt_model_from_catboost(model: Any, model_data=None) -> Any: if not model.is_fitted(): raise RuntimeError( "Model should be fitted before exporting to daal4py.") - dump_filename = f"catboost_model_{getpid()}_{time()}" - - # Dump model in file - model.save_model(dump_filename, 'json') - - # Read json with model - with open(dump_filename) as file: - model_data = json.load(file) - - # Delete dump file - remove(dump_filename) + if model_data is None: + model_data = get_catboost_params(model) if 'categorical_features' in model_data['features_info']: raise NotImplementedError( diff --git a/tests/run_examples.py b/tests/run_examples.py index f9c77abd8e..7f74deb23f 100755 --- a/tests/run_examples.py +++ b/tests/run_examples.py @@ -157,9 +157,9 @@ def check_library(rule): req_library = defaultdict(lambda: []) req_library['basic_statistics_spmd.py'] = ['dpctl', 'mpi4py'] -req_library['gbt_cls_model_create_from_lightgbm_batch.py'] = ['lightgbm'] -req_library['gbt_cls_model_create_from_xgboost_batch.py'] = ['xgboost'] -req_library['gbt_cls_model_create_from_catboost_batch.py'] = ['catboost'] +req_library['model_builders_lightgbm.py'] = ['lightgbm'] +req_library['model_builders_xgboost.py'] = ['xgboost'] +req_library['model_builders_catboost.py'] = ['catboost'] req_library['basic_statistics_spmd.py'] = ['dpctl', 'mpi4py'] req_library['kmeans_spmd.py'] = ['dpctl', 'mpi4py'] req_library['knn_bf_classification_spmd.py'] = ['dpctl', 'mpi4py'] diff --git a/tests/test_examples.py b/tests/test_examples.py index 458ae7dc2c..d35dfdd9f2 100755 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -199,11 +199,11 @@ def test_svm_batch(self): ('distributions_normal_batch',), ('distributions_uniform_batch',), ('em_gmm_batch', 'em_gmm.csv', lambda r: r.covariances[0]), - ('gbt_cls_model_create_from_lightgbm_batch', None, None, + ('model_builders_lightgbm', None, None, ((2020, 'P', 2), (2021, 'B', 109)), ['lightgbm']), - ('gbt_cls_model_create_from_xgboost_batch', None, None, + ('model_builders_xgboost', None, None, ((2020, 'P', 2), (2021, 'B', 109)), ['xgboost']), - ('gbt_cls_model_create_from_catboost_batch', None, None, + ('model_builders_catboost', None, None, (2021, 'P', 4), ['catboost']), ('gradient_boosted_classification_batch',), ('gradient_boosted_regression_batch',), @@ -291,7 +291,7 @@ def call(self, ex): if any(ex.__name__.startswith(x) for x in ['adaboost', 'brownboost', 'stump_classification', - 'gbt_cls_model_create', + 'model_builders', 'decision_forest']): self.skipTest("not supporting CSR") method = \