diff --git a/demo/guide-python/multioutput_regression.py b/demo/guide-python/multioutput_regression.py index 375377e4e4b5..078ec6b7dbaf 100644 --- a/demo/guide-python/multioutput_regression.py +++ b/demo/guide-python/multioutput_regression.py @@ -7,6 +7,12 @@ https://scikit-learn.org/stable/auto_examples/ensemble/plot_random_forest_regression_multioutput.html#sphx-glr-auto-examples-ensemble-plot-random-forest-regression-multioutput-py See :doc:`/tutorials/multioutput` for more information. + +.. note:: + + The feature is experimental. For the `multi_output_tree` strategy, many features are + missing. + """ import argparse @@ -40,11 +46,18 @@ def gen_circle() -> Tuple[np.ndarray, np.ndarray]: return X, y -def rmse_model(plot_result: bool): +def rmse_model(plot_result: bool, strategy: str): """Draw a circle with 2-dim coordinate as target variables.""" X, y = gen_circle() # Train a regressor on it - reg = xgb.XGBRegressor(tree_method="hist", n_estimators=64) + reg = xgb.XGBRegressor( + tree_method="hist", + n_estimators=128, + n_jobs=16, + max_depth=8, + multi_strategy=strategy, + subsample=0.6, + ) reg.fit(X, y, eval_set=[(X, y)]) y_predt = reg.predict(X) @@ -52,7 +65,7 @@ def rmse_model(plot_result: bool): plot_predt(y, y_predt, "multi") -def custom_rmse_model(plot_result: bool) -> None: +def custom_rmse_model(plot_result: bool, strategy: str) -> None: """Train using Python implementation of Squared Error.""" # As the experimental support status, custom objective doesn't support matrix as @@ -88,9 +101,10 @@ def rmse(predt: np.ndarray, dtrain: xgb.DMatrix) -> Tuple[str, float]: { "tree_method": "hist", "num_target": y.shape[1], + "multi_strategy": strategy, }, dtrain=Xy, - num_boost_round=100, + num_boost_round=128, obj=squared_log, evals=[(Xy, "Train")], evals_result=results, @@ -107,6 +121,16 @@ def rmse(predt: np.ndarray, dtrain: xgb.DMatrix) -> Tuple[str, float]: parser.add_argument("--plot", choices=[0, 1], type=int, default=1) args = parser.parse_args() # Train with builtin RMSE objective - rmse_model(args.plot == 1) + # - One model per output. + rmse_model(args.plot == 1, "one_output_per_tree") + + # - One model for all outputs, this is still working in progress, many features are + # missing. + rmse_model(args.plot == 1, "multi_output_tree") + # Train with custom objective. - custom_rmse_model(args.plot == 1) + # - One model per output. + custom_rmse_model(args.plot == 1, "one_output_per_tree") + # - One model for all outputs, this is still working in progress, many features are + # missing. + custom_rmse_model(args.plot == 1, "multi_output_tree") diff --git a/doc/parameter.rst b/doc/parameter.rst index ac566af749f9..1e703dacd9a8 100644 --- a/doc/parameter.rst +++ b/doc/parameter.rst @@ -226,6 +226,18 @@ Parameters for Tree Booster list is a group of indices of features that are allowed to interact with each other. See :doc:`/tutorials/feature_interaction_constraint` for more information. +* ``multi_strategy``, [default = ``one_output_per_tree``] + + .. versionadded:: 2.0.0 + + .. note:: This parameter is working-in-progress. + + - The strategy used for training multi-target models, including multi-target regression + and multi-class classification. See :doc:`/tutorials/multioutput` for more information. + + - ``one_output_per_tree``: One model for each target. + - ``multi_output_tree``: Use multi-target trees. + .. _cat-param: Parameters for Categorical Feature diff --git a/doc/tutorials/multioutput.rst b/doc/tutorials/multioutput.rst index 280fb106f247..983002aed499 100644 --- a/doc/tutorials/multioutput.rst +++ b/doc/tutorials/multioutput.rst @@ -11,7 +11,11 @@ can be simultaneously classified as both sci-fi and comedy. For detailed explan terminologies related to different multi-output models please refer to the :doc:`scikit-learn user guide `. -Internally, XGBoost builds one model for each target similar to sklearn meta estimators, +********************************** +Training with One-Model-Per-Target +********************************** + +By default, XGBoost builds one model for each target similar to sklearn meta estimators, with the added benefit of reusing data and other integrated features like SHAP. For a worked example of regression, see :ref:`sphx_glr_python_examples_multioutput_regression.py`. For multi-label classification, @@ -36,3 +40,26 @@ dense matrix for labels. The feature is still under development with limited support from objectives and metrics. + +************************* +Training with Vector Leaf +************************* + +.. versionadded:: 2.0 + +.. note:: + + This is still working-in-progress, and many features are missing. + +XGBoost can optionally build multi-output trees with the size of leaf equals to the number +of targets when the tree method `hist` is used. The behavior can be controlled by the +``multi_strategy`` training parameter, which can take the value `one_output_per_tree` (the +default) for building one model per-target or `multi_output_tree` for building +multi-output trees. + +.. code-block:: python + + clf = xgb.XGBClassifier(tree_method="hist", multi_strategy="multi_output_tree") + +See :ref:`sphx_glr_python_examples_multioutput_regression.py` for a worked example with +regression. diff --git a/include/xgboost/learner.h b/include/xgboost/learner.h index 1d4e35a94be5..08e1ded09e95 100644 --- a/include/xgboost/learner.h +++ b/include/xgboost/learner.h @@ -286,8 +286,8 @@ struct LearnerModelParamLegacy; * \brief Strategy for building multi-target models. */ enum class MultiStrategy : std::int32_t { - kComposite = 0, - kMonolithic = 1, + kOneOutputPerTree = 0, + kMultiOutputTree = 1, }; /** @@ -317,7 +317,7 @@ struct LearnerModelParam { /** * \brief Strategy for building multi-target models. */ - MultiStrategy multi_strategy{MultiStrategy::kComposite}; + MultiStrategy multi_strategy{MultiStrategy::kOneOutputPerTree}; LearnerModelParam() = default; // As the old `LearnerModelParamLegacy` is still used by binary IO, we keep @@ -338,7 +338,7 @@ struct LearnerModelParam { void Copy(LearnerModelParam const& that); [[nodiscard]] bool IsVectorLeaf() const noexcept { - return multi_strategy == MultiStrategy::kMonolithic; + return multi_strategy == MultiStrategy::kMultiOutputTree; } [[nodiscard]] bst_target_t OutputLength() const noexcept { return this->num_output_group; } [[nodiscard]] bst_target_t LeafLength() const noexcept { diff --git a/include/xgboost/linalg.h b/include/xgboost/linalg.h index 3d6bcc962017..65e9de6ba8b4 100644 --- a/include/xgboost/linalg.h +++ b/include/xgboost/linalg.h @@ -530,17 +530,17 @@ class TensorView { /** * \brief Number of items in the tensor. */ - LINALG_HD [[nodiscard]] std::size_t Size() const { return size_; } + [[nodiscard]] LINALG_HD std::size_t Size() const { return size_; } /** * \brief Whether this is a contiguous array, both C and F contiguous returns true. */ - LINALG_HD [[nodiscard]] bool Contiguous() const { + [[nodiscard]] LINALG_HD bool Contiguous() const { return data_.size() == this->Size() || this->CContiguous() || this->FContiguous(); } /** * \brief Whether it's a c-contiguous array. */ - LINALG_HD [[nodiscard]] bool CContiguous() const { + [[nodiscard]] LINALG_HD bool CContiguous() const { StrideT stride; static_assert(std::is_same::value); // It's contiguous if the stride can be calculated from shape. @@ -550,7 +550,7 @@ class TensorView { /** * \brief Whether it's a f-contiguous array. */ - LINALG_HD [[nodiscard]] bool FContiguous() const { + [[nodiscard]] LINALG_HD bool FContiguous() const { StrideT stride; static_assert(std::is_same::value); // It's contiguous if the stride can be calculated from shape. diff --git a/python-package/xgboost/sklearn.py b/python-package/xgboost/sklearn.py index 3204f5a2a61e..805eb75b36fc 100644 --- a/python-package/xgboost/sklearn.py +++ b/python-package/xgboost/sklearn.py @@ -312,6 +312,19 @@ def task(i: int) -> float: needs to be set to have categorical feature support. See :doc:`Categorical Data ` and :ref:`cat-param` for details. + multi_strategy : Optional[str] + + .. versionadded:: 2.0.0 + + .. note:: This parameter is working-in-progress. + + The strategy used for training multi-target models, including multi-target + regression and multi-class classification. See :doc:`/tutorials/multioutput` for + more information. + + - ``one_output_per_tree``: One model for each target. + - ``multi_output_tree``: Use multi-target trees. + eval_metric : Optional[Union[str, List[str], Callable]] .. versionadded:: 1.6.0 @@ -624,6 +637,7 @@ def __init__( feature_types: Optional[FeatureTypes] = None, max_cat_to_onehot: Optional[int] = None, max_cat_threshold: Optional[int] = None, + multi_strategy: Optional[str] = None, eval_metric: Optional[Union[str, List[str], Callable]] = None, early_stopping_rounds: Optional[int] = None, callbacks: Optional[List[TrainingCallback]] = None, @@ -670,6 +684,7 @@ def __init__( self.feature_types = feature_types self.max_cat_to_onehot = max_cat_to_onehot self.max_cat_threshold = max_cat_threshold + self.multi_strategy = multi_strategy self.eval_metric = eval_metric self.early_stopping_rounds = early_stopping_rounds self.callbacks = callbacks diff --git a/python-package/xgboost/testing/__init__.py b/python-package/xgboost/testing/__init__.py index bb13b5523ed2..20a4c681e142 100644 --- a/python-package/xgboost/testing/__init__.py +++ b/python-package/xgboost/testing/__init__.py @@ -10,11 +10,9 @@ import platform import socket import sys -import zipfile from concurrent.futures import ThreadPoolExecutor from contextlib import contextmanager from io import StringIO -from pathlib import Path from platform import system from typing import ( Any, @@ -29,7 +27,6 @@ TypedDict, Union, ) -from urllib import request import numpy as np import pytest @@ -38,6 +35,13 @@ import xgboost as xgb from xgboost.core import ArrayLike from xgboost.sklearn import SklObjective +from xgboost.testing.data import ( + get_california_housing, + get_cancer, + get_digits, + get_sparse, + memory, +) hypothesis = pytest.importorskip("hypothesis") @@ -45,13 +49,8 @@ from hypothesis import strategies from hypothesis.extra.numpy import arrays -joblib = pytest.importorskip("joblib") datasets = pytest.importorskip("sklearn.datasets") -Memory = joblib.Memory - -memory = Memory("./cachedir", verbose=0) - PytestSkip = TypedDict("PytestSkip", {"condition": bool, "reason": str}) @@ -353,137 +352,6 @@ def __repr__(self) -> str: return self.name -@memory.cache -def get_california_housing() -> Tuple[np.ndarray, np.ndarray]: - data = datasets.fetch_california_housing() - return data.data, data.target - - -@memory.cache -def get_digits() -> Tuple[np.ndarray, np.ndarray]: - data = datasets.load_digits() - return data.data, data.target - - -@memory.cache -def get_cancer() -> Tuple[np.ndarray, np.ndarray]: - return datasets.load_breast_cancer(return_X_y=True) - - -@memory.cache -def get_sparse() -> Tuple[np.ndarray, np.ndarray]: - rng = np.random.RandomState(199) - n = 2000 - sparsity = 0.75 - X, y = datasets.make_regression(n, random_state=rng) - flag = rng.binomial(1, sparsity, X.shape) - for i in range(X.shape[0]): - for j in range(X.shape[1]): - if flag[i, j]: - X[i, j] = np.nan - return X, y - - -@memory.cache -def get_ames_housing() -> Tuple[np.ndarray, np.ndarray]: - """ - Number of samples: 1460 - Number of features: 20 - Number of categorical features: 10 - Number of numerical features: 10 - """ - from sklearn.datasets import fetch_openml - - X, y = fetch_openml(data_id=42165, as_frame=True, return_X_y=True) - - categorical_columns_subset: List[str] = [ - "BldgType", # 5 cats, no nan - "GarageFinish", # 3 cats, nan - "LotConfig", # 5 cats, no nan - "Functional", # 7 cats, no nan - "MasVnrType", # 4 cats, nan - "HouseStyle", # 8 cats, no nan - "FireplaceQu", # 5 cats, nan - "ExterCond", # 5 cats, no nan - "ExterQual", # 4 cats, no nan - "PoolQC", # 3 cats, nan - ] - - numerical_columns_subset: List[str] = [ - "3SsnPorch", - "Fireplaces", - "BsmtHalfBath", - "HalfBath", - "GarageCars", - "TotRmsAbvGrd", - "BsmtFinSF1", - "BsmtFinSF2", - "GrLivArea", - "ScreenPorch", - ] - - X = X[categorical_columns_subset + numerical_columns_subset] - X[categorical_columns_subset] = X[categorical_columns_subset].astype("category") - return X, y - - -@memory.cache -def get_mq2008( - dpath: str, -) -> Tuple[ - sparse.csr_matrix, - np.ndarray, - np.ndarray, - sparse.csr_matrix, - np.ndarray, - np.ndarray, - sparse.csr_matrix, - np.ndarray, - np.ndarray, -]: - from sklearn.datasets import load_svmlight_files - - src = "https://s3-us-west-2.amazonaws.com/xgboost-examples/MQ2008.zip" - target = os.path.join(os.path.expanduser(dpath), "MQ2008.zip") - if not os.path.exists(target): - request.urlretrieve(url=src, filename=target) - - with zipfile.ZipFile(target, "r") as f: - f.extractall(path=dpath) - - ( - x_train, - y_train, - qid_train, - x_test, - y_test, - qid_test, - x_valid, - y_valid, - qid_valid, - ) = load_svmlight_files( - ( - Path(dpath) / "MQ2008" / "Fold1" / "train.txt", - Path(dpath) / "MQ2008" / "Fold1" / "test.txt", - Path(dpath) / "MQ2008" / "Fold1" / "vali.txt", - ), - query_id=True, - zero_based=False, - ) - - return ( - x_train, - y_train, - qid_train, - x_test, - y_test, - qid_test, - x_valid, - y_valid, - qid_valid, - ) - - # pylint: disable=too-many-arguments,too-many-locals @memory.cache def make_categorical( @@ -738,20 +606,7 @@ def random_csc(t_id: int) -> sparse.csc_matrix: TestDataset( "calif_housing-l1", get_california_housing, "reg:absoluteerror", "mae" ), - TestDataset("digits", get_digits, "multi:softmax", "mlogloss"), TestDataset("cancer", get_cancer, "binary:logistic", "logloss"), - TestDataset( - "mtreg", - lambda: datasets.make_regression(n_samples=128, n_features=2, n_targets=3), - "reg:squarederror", - "rmse", - ), - TestDataset( - "mtreg-l1", - lambda: datasets.make_regression(n_samples=128, n_features=2, n_targets=3), - "reg:absoluteerror", - "mae", - ), TestDataset("sparse", get_sparse, "reg:squarederror", "rmse"), TestDataset("sparse-l1", get_sparse, "reg:absoluteerror", "mae"), TestDataset( @@ -764,37 +619,71 @@ def random_csc(t_id: int) -> sparse.csc_matrix: ) -@strategies.composite -def _dataset_weight_margin(draw: Callable) -> TestDataset: - data: TestDataset = draw(_unweighted_datasets_strategy) - if draw(strategies.booleans()): - data.w = draw( - arrays(np.float64, (len(data.y)), elements=strategies.floats(0.1, 2.0)) - ) - if draw(strategies.booleans()): - num_class = 1 - if data.objective == "multi:softmax": - num_class = int(np.max(data.y) + 1) - elif data.name.startswith("mtreg"): - num_class = data.y.shape[1] - - data.margin = draw( - arrays( - np.float64, - (data.y.shape[0] * num_class), - elements=strategies.floats(0.5, 1.0), +def make_datasets_with_margin( + unweighted_strategy: strategies.SearchStrategy, +) -> Callable: + """Factory function for creating strategies that generates datasets with weight and + base margin. + + """ + + @strategies.composite + def weight_margin(draw: Callable) -> TestDataset: + data: TestDataset = draw(unweighted_strategy) + if draw(strategies.booleans()): + data.w = draw( + arrays(np.float64, (len(data.y)), elements=strategies.floats(0.1, 2.0)) ) - ) - assert data.margin is not None - if num_class != 1: - data.margin = data.margin.reshape(data.y.shape[0], num_class) + if draw(strategies.booleans()): + num_class = 1 + if data.objective == "multi:softmax": + num_class = int(np.max(data.y) + 1) + elif data.name.startswith("mtreg"): + num_class = data.y.shape[1] + + data.margin = draw( + arrays( + np.float64, + (data.y.shape[0] * num_class), + elements=strategies.floats(0.5, 1.0), + ) + ) + assert data.margin is not None + if num_class != 1: + data.margin = data.margin.reshape(data.y.shape[0], num_class) + + return data + + return weight_margin - return data +# A strategy for drawing from a set of example datasets. May add random weights to the +# dataset +dataset_strategy = make_datasets_with_margin(_unweighted_datasets_strategy)() + + +_unweighted_multi_datasets_strategy = strategies.sampled_from( + [ + TestDataset("digits", get_digits, "multi:softmax", "mlogloss"), + TestDataset( + "mtreg", + lambda: datasets.make_regression(n_samples=128, n_features=2, n_targets=3), + "reg:squarederror", + "rmse", + ), + TestDataset( + "mtreg-l1", + lambda: datasets.make_regression(n_samples=128, n_features=2, n_targets=3), + "reg:absoluteerror", + "mae", + ), + ] +) -# A strategy for drawing from a set of example datasets -# May add random weights to the dataset -dataset_strategy = _dataset_weight_margin() +# A strategy for drawing from a set of multi-target/multi-class datasets. +multi_dataset_strategy = make_datasets_with_margin( + _unweighted_multi_datasets_strategy +)() def non_increasing(L: Sequence[float], tolerance: float = 1e-4) -> bool: diff --git a/python-package/xgboost/testing/data.py b/python-package/xgboost/testing/data.py index a9ea0019c4b5..477d0cf3d6f0 100644 --- a/python-package/xgboost/testing/data.py +++ b/python-package/xgboost/testing/data.py @@ -1,13 +1,20 @@ """Utilities for data generation.""" -from typing import Any, Generator, Tuple, Union +import os +import zipfile +from typing import Any, Generator, List, Tuple, Union +from urllib import request import numpy as np import pytest from numpy.random import Generator as RNG +from scipy import sparse import xgboost from xgboost.data import pandas_pyarrow_mapper +joblib = pytest.importorskip("joblib") +memory = joblib.Memory("./cachedir", verbose=0) + def np_dtypes( n_samples: int, n_features: int @@ -195,3 +202,141 @@ def check_inf(rng: RNG) -> None: with pytest.raises(ValueError, match="Input data contains `inf`"): xgboost.DMatrix(X, y) + + +@memory.cache +def get_california_housing() -> Tuple[np.ndarray, np.ndarray]: + """Fetch the California housing dataset from sklearn.""" + datasets = pytest.importorskip("sklearn.datasets") + data = datasets.fetch_california_housing() + return data.data, data.target + + +@memory.cache +def get_digits() -> Tuple[np.ndarray, np.ndarray]: + """Fetch the digits dataset from sklearn.""" + datasets = pytest.importorskip("sklearn.datasets") + data = datasets.load_digits() + return data.data, data.target + + +@memory.cache +def get_cancer() -> Tuple[np.ndarray, np.ndarray]: + """Fetch the breast cancer dataset from sklearn.""" + datasets = pytest.importorskip("sklearn.datasets") + return datasets.load_breast_cancer(return_X_y=True) + + +@memory.cache +def get_sparse() -> Tuple[np.ndarray, np.ndarray]: + """Generate a sparse dataset.""" + datasets = pytest.importorskip("sklearn.datasets") + rng = np.random.RandomState(199) + n = 2000 + sparsity = 0.75 + X, y = datasets.make_regression(n, random_state=rng) + flag = rng.binomial(1, sparsity, X.shape) + for i in range(X.shape[0]): + for j in range(X.shape[1]): + if flag[i, j]: + X[i, j] = np.nan + return X, y + + +@memory.cache +def get_ames_housing() -> Tuple[np.ndarray, np.ndarray]: + """ + Number of samples: 1460 + Number of features: 20 + Number of categorical features: 10 + Number of numerical features: 10 + """ + datasets = pytest.importorskip("sklearn.datasets") + X, y = datasets.fetch_openml(data_id=42165, as_frame=True, return_X_y=True) + + categorical_columns_subset: List[str] = [ + "BldgType", # 5 cats, no nan + "GarageFinish", # 3 cats, nan + "LotConfig", # 5 cats, no nan + "Functional", # 7 cats, no nan + "MasVnrType", # 4 cats, nan + "HouseStyle", # 8 cats, no nan + "FireplaceQu", # 5 cats, nan + "ExterCond", # 5 cats, no nan + "ExterQual", # 4 cats, no nan + "PoolQC", # 3 cats, nan + ] + + numerical_columns_subset: List[str] = [ + "3SsnPorch", + "Fireplaces", + "BsmtHalfBath", + "HalfBath", + "GarageCars", + "TotRmsAbvGrd", + "BsmtFinSF1", + "BsmtFinSF2", + "GrLivArea", + "ScreenPorch", + ] + + X = X[categorical_columns_subset + numerical_columns_subset] + X[categorical_columns_subset] = X[categorical_columns_subset].astype("category") + return X, y + + +@memory.cache +def get_mq2008( + dpath: str, +) -> Tuple[ + sparse.csr_matrix, + np.ndarray, + np.ndarray, + sparse.csr_matrix, + np.ndarray, + np.ndarray, + sparse.csr_matrix, + np.ndarray, + np.ndarray, +]: + """Fetch the mq2008 dataset.""" + datasets = pytest.importorskip("sklearn.datasets") + src = "https://s3-us-west-2.amazonaws.com/xgboost-examples/MQ2008.zip" + target = os.path.join(dpath, "MQ2008.zip") + if not os.path.exists(target): + request.urlretrieve(url=src, filename=target) + + with zipfile.ZipFile(target, "r") as f: + f.extractall(path=dpath) + + ( + x_train, + y_train, + qid_train, + x_test, + y_test, + qid_test, + x_valid, + y_valid, + qid_valid, + ) = datasets.load_svmlight_files( + ( + os.path.join(dpath, "MQ2008/Fold1/train.txt"), + os.path.join(dpath, "MQ2008/Fold1/test.txt"), + os.path.join(dpath, "MQ2008/Fold1/vali.txt"), + ), + query_id=True, + zero_based=False, + ) + + return ( + x_train, + y_train, + qid_train, + x_test, + y_test, + qid_test, + x_valid, + y_valid, + qid_valid, + ) diff --git a/python-package/xgboost/testing/params.py b/python-package/xgboost/testing/params.py index 3af3306da40e..e6ba73e1f541 100644 --- a/python-package/xgboost/testing/params.py +++ b/python-package/xgboost/testing/params.py @@ -4,8 +4,8 @@ import pytest -hypothesis = pytest.importorskip("hypothesis") -from hypothesis import strategies # pylint:disable=wrong-import-position +strategies = pytest.importorskip("hypothesis.strategies") + exact_parameter_strategy = strategies.fixed_dictionaries( { @@ -41,6 +41,26 @@ and (cast(int, x["max_depth"]) > 0 or x["grow_policy"] == "lossguide") ) +hist_multi_parameter_strategy = strategies.fixed_dictionaries( + { + "max_depth": strategies.integers(1, 11), + "max_leaves": strategies.integers(0, 1024), + "max_bin": strategies.integers(2, 512), + "multi_strategy": strategies.sampled_from( + ["multi_output_tree", "one_output_per_tree"] + ), + "grow_policy": strategies.sampled_from(["lossguide", "depthwise"]), + "min_child_weight": strategies.floats(0.5, 2.0), + # We cannot enable subsampling as the training loss can increase + # 'subsample': strategies.floats(0.5, 1.0), + "colsample_bytree": strategies.floats(0.5, 1.0), + "colsample_bylevel": strategies.floats(0.5, 1.0), + } +).filter( + lambda x: (cast(int, x["max_depth"]) > 0 or cast(int, x["max_leaves"]) > 0) + and (cast(int, x["max_depth"]) > 0 or x["grow_policy"] == "lossguide") +) + cat_parameter_strategy = strategies.fixed_dictionaries( { "max_cat_to_onehot": strategies.integers(1, 128), diff --git a/src/c_api/c_api_utils.h b/src/c_api/c_api_utils.h index 8908364f2738..1af0206be080 100644 --- a/src/c_api/c_api_utils.h +++ b/src/c_api/c_api_utils.h @@ -55,6 +55,7 @@ inline void CalcPredictShape(bool strict_shape, PredictionType type, size_t rows *out_dim = 2; shape.resize(*out_dim); shape.front() = rows; + // chunksize can be 1 if it's softmax shape.back() = std::min(groups, chunksize); } break; diff --git a/src/common/quantile.cc b/src/common/quantile.cc index 87eb0ec208cd..aaf271934474 100644 --- a/src/common/quantile.cc +++ b/src/common/quantile.cc @@ -359,6 +359,7 @@ void AddCutPoint(typename SketchType::SummaryContainer const &summary, int max_b HistogramCuts *cuts) { size_t required_cuts = std::min(summary.size, static_cast(max_bin)); auto &cut_values = cuts->cut_values_.HostVector(); + // we use the min_value as the first (0th) element, hence starting from 1. for (size_t i = 1; i < required_cuts; ++i) { bst_float cpt = summary.data[i].value; if (i == 1 || cpt > cut_values.back()) { @@ -419,8 +420,8 @@ void SketchContainerImpl::MakeCuts(HistogramCuts* cuts) { } else { AddCutPoint(a, max_num_bins, cuts); // push a value that is greater than anything - const bst_float cpt = (a.size > 0) ? a.data[a.size - 1].value - : cuts->min_vals_.HostVector()[fid]; + const bst_float cpt = + (a.size > 0) ? a.data[a.size - 1].value : cuts->min_vals_.HostVector()[fid]; // this must be bigger than last value in a scale const bst_float last = cpt + (fabs(cpt) + 1e-5f); cuts->cut_values_.HostVector().push_back(last); diff --git a/src/common/quantile.h b/src/common/quantile.h index c8dcf6adad91..a19b4bbb0d01 100644 --- a/src/common/quantile.h +++ b/src/common/quantile.h @@ -352,19 +352,6 @@ struct WQSummary { prev_rmax = data[i].rmax; } } - // check consistency of the summary - inline bool Check(const char *msg) const { - const float tol = 10.0f; - for (size_t i = 0; i < this->size; ++i) { - if (data[i].rmin + data[i].wmin > data[i].rmax + tol || - data[i].rmin < -1e-6f || data[i].rmax < -1e-6f) { - LOG(INFO) << "---------- WQSummary::Check did not pass ----------"; - this->Print(); - return false; - } - } - return true; - } }; /*! \brief try to do efficient pruning */ diff --git a/src/data/iterative_dmatrix.cc b/src/data/iterative_dmatrix.cc index c7ac492c9a14..dc6fb55e8028 100644 --- a/src/data/iterative_dmatrix.cc +++ b/src/data/iterative_dmatrix.cc @@ -257,6 +257,7 @@ void IterativeDMatrix::InitFromCPU(DataIterHandle iter_handle, float missing, } iter.Reset(); CHECK_EQ(rbegin, Info().num_row_); + CHECK_EQ(this->ghist_->Features(), Info().num_col_); /** * Generate column matrix diff --git a/src/gbm/gbtree.cc b/src/gbm/gbtree.cc index 34915d53eaad..a912d6a75d81 100644 --- a/src/gbm/gbtree.cc +++ b/src/gbm/gbtree.cc @@ -10,6 +10,7 @@ #include #include +#include // for uint32_t #include #include #include @@ -27,9 +28,11 @@ #include "xgboost/host_device_vector.h" #include "xgboost/json.h" #include "xgboost/logging.h" +#include "xgboost/model.h" #include "xgboost/objective.h" #include "xgboost/predictor.h" -#include "xgboost/string_view.h" +#include "xgboost/string_view.h" // for StringView +#include "xgboost/tree_model.h" // for RegTree #include "xgboost/tree_updater.h" namespace xgboost::gbm { @@ -131,6 +134,12 @@ void GBTree::PerformTreeMethodHeuristic(DMatrix* fmat) { // set, since only experts are expected to do so. return; } + if (model_.learner_model_param->IsVectorLeaf()) { + CHECK(tparam_.tree_method == TreeMethod::kHist) + << "Only the hist tree method is supported for building multi-target trees with vector " + "leaf."; + } + // tparam_ is set before calling this function. if (tparam_.tree_method != TreeMethod::kAuto) { return; @@ -175,12 +184,12 @@ void GBTree::ConfigureUpdaters() { case TreeMethod::kExact: tparam_.updater_seq = "grow_colmaker,prune"; break; - case TreeMethod::kHist: - LOG(INFO) << - "Tree method is selected to be 'hist', which uses a " - "single updater grow_quantile_histmaker."; + case TreeMethod::kHist: { + LOG(INFO) << "Tree method is selected to be 'hist', which uses a single updater " + "grow_quantile_histmaker."; tparam_.updater_seq = "grow_quantile_histmaker"; break; + } case TreeMethod::kGPUHist: { common::AssertGPUSupport(); tparam_.updater_seq = "grow_gpu_hist"; @@ -209,11 +218,9 @@ void CopyGradient(HostDeviceVector const* in_gpair, int32_t n_thre GPUCopyGradient(in_gpair, n_groups, group_id, out_gpair); } else { std::vector &tmp_h = out_gpair->HostVector(); - auto nsize = static_cast(out_gpair->Size()); - const auto &gpair_h = in_gpair->ConstHostVector(); - common::ParallelFor(nsize, n_threads, [&](bst_omp_uint i) { - tmp_h[i] = gpair_h[i * n_groups + group_id]; - }); + const auto& gpair_h = in_gpair->ConstHostVector(); + common::ParallelFor(out_gpair->Size(), n_threads, + [&](auto i) { tmp_h[i] = gpair_h[i * n_groups + group_id]; }); } } @@ -234,6 +241,7 @@ void GBTree::UpdateTreeLeaf(DMatrix const* p_fmat, HostDeviceVector const CHECK_EQ(model_.param.num_parallel_tree, trees.size()); CHECK_EQ(model_.param.num_parallel_tree, 1) << "Boosting random forest is not supported for current objective."; + CHECK(!trees.front()->IsMultiTarget()) << "Update tree leaf" << MTNotImplemented(); CHECK_EQ(trees.size(), model_.param.num_parallel_tree); for (std::size_t tree_idx = 0; tree_idx < trees.size(); ++tree_idx) { auto const& position = node_position.at(tree_idx); @@ -245,17 +253,18 @@ void GBTree::UpdateTreeLeaf(DMatrix const* p_fmat, HostDeviceVector const void GBTree::DoBoost(DMatrix* p_fmat, HostDeviceVector* in_gpair, PredictionCacheEntry* predt, ObjFunction const* obj) { std::vector>> new_trees; - const int ngroup = model_.learner_model_param->num_output_group; + const int ngroup = model_.learner_model_param->OutputLength(); ConfigureWithKnownData(this->cfg_, p_fmat); monitor_.Start("BoostNewTrees"); + // Weird case that tree method is cpu-based but gpu_id is set. Ideally we should let // `gpu_id` be the single source of determining what algorithms to run, but that will // break a lots of existing code. auto device = tparam_.tree_method != TreeMethod::kGPUHist ? Context::kCpuId : ctx_->gpu_id; - auto out = linalg::TensorView{ + auto out = linalg::MakeTensorView( + device, device == Context::kCpuId ? predt->predictions.HostSpan() : predt->predictions.DeviceSpan(), - {static_cast(p_fmat->Info().num_row_), static_cast(ngroup)}, - device}; + p_fmat->Info().num_row_, model_.learner_model_param->OutputLength()); CHECK_NE(ngroup, 0); if (!p_fmat->SingleColBlock() && obj->Task().UpdateTreeLeaf()) { @@ -266,7 +275,13 @@ void GBTree::DoBoost(DMatrix* p_fmat, HostDeviceVector* in_gpair, // position is negated if the row is sampled out. std::vector> node_position; - if (ngroup == 1) { + if (model_.learner_model_param->IsVectorLeaf()) { + std::vector> ret; + BoostNewTrees(in_gpair, p_fmat, 0, &node_position, &ret); + UpdateTreeLeaf(p_fmat, predt->predictions, obj, 0, node_position, &ret); + // No update prediction cache yet. + new_trees.push_back(std::move(ret)); + } else if (model_.learner_model_param->OutputLength() == 1) { std::vector> ret; BoostNewTrees(in_gpair, p_fmat, 0, &node_position, &ret); UpdateTreeLeaf(p_fmat, predt->predictions, obj, 0, node_position, &ret); @@ -383,11 +398,15 @@ void GBTree::BoostNewTrees(HostDeviceVector* gpair, DMatrix* p_fma } // update the trees - CHECK_EQ(gpair->Size(), p_fmat->Info().num_row_) - << "Mismatching size between number of rows from input data and size of " - "gradient vector."; + auto n_out = model_.learner_model_param->OutputLength() * p_fmat->Info().num_row_; + StringView msg{ + "Mismatching size between number of rows from input data and size of gradient vector."}; + if (!model_.learner_model_param->IsVectorLeaf() && p_fmat->Info().num_row_ != 0) { + CHECK_EQ(n_out % gpair->Size(), 0) << msg; + } else { + CHECK_EQ(gpair->Size(), n_out) << msg; + } - CHECK(out_position); out_position->resize(new_trees.size()); // Rescale learning rate according to the size of trees @@ -402,8 +421,12 @@ void GBTree::BoostNewTrees(HostDeviceVector* gpair, DMatrix* p_fma void GBTree::CommitModel(std::vector>>&& new_trees) { monitor_.Start("CommitModel"); - for (uint32_t gid = 0; gid < model_.learner_model_param->num_output_group; ++gid) { - model_.CommitModel(std::move(new_trees[gid]), gid); + if (this->model_.learner_model_param->IsVectorLeaf()) { + model_.CommitModel(std::move(new_trees[0]), 0); + } else { + for (std::uint32_t gid = 0; gid < model_.learner_model_param->OutputLength(); ++gid) { + model_.CommitModel(std::move(new_trees[gid]), gid); + } } monitor_.Stop("CommitModel"); } @@ -564,11 +587,10 @@ void GBTree::PredictBatch(DMatrix* p_fmat, if (out_preds->version == 0) { // out_preds->Size() can be non-zero as it's initialized here before any // tree is built at the 0^th iterator. - predictor->InitOutPredictions(p_fmat->Info(), &out_preds->predictions, - model_); + predictor->InitOutPredictions(p_fmat->Info(), &out_preds->predictions, model_); } - uint32_t tree_begin, tree_end; + std::uint32_t tree_begin, tree_end; std::tie(tree_begin, tree_end) = detail::LayerToTree(model_, layer_begin, layer_end); CHECK_LE(tree_end, model_.trees.size()) << "Invalid number of trees."; if (tree_end > tree_begin) { @@ -577,7 +599,7 @@ void GBTree::PredictBatch(DMatrix* p_fmat, if (reset) { out_preds->version = 0; } else { - uint32_t delta = layer_end - out_preds->version; + std::uint32_t delta = layer_end - out_preds->version; out_preds->Update(delta); } } @@ -770,6 +792,7 @@ class Dart : public GBTree { void PredictBatchImpl(DMatrix *p_fmat, PredictionCacheEntry *p_out_preds, bool training, unsigned layer_begin, unsigned layer_end) const { + CHECK(!this->model_.learner_model_param->IsVectorLeaf()) << "dart" << MTNotImplemented(); auto &predictor = this->GetPredictor(&p_out_preds->predictions, p_fmat); CHECK(predictor); predictor->InitOutPredictions(p_fmat->Info(), &p_out_preds->predictions, @@ -830,6 +853,7 @@ class Dart : public GBTree { void InplacePredict(std::shared_ptr p_fmat, float missing, PredictionCacheEntry* p_out_preds, uint32_t layer_begin, unsigned layer_end) const override { + CHECK(!this->model_.learner_model_param->IsVectorLeaf()) << "dart" << MTNotImplemented(); uint32_t tree_begin, tree_end; std::tie(tree_begin, tree_end) = detail::LayerToTree(model_, layer_begin, layer_end); auto n_groups = model_.learner_model_param->num_output_group; diff --git a/src/gbm/gbtree.h b/src/gbm/gbtree.h index eb99822f3327..b64532c614e9 100644 --- a/src/gbm/gbtree.h +++ b/src/gbm/gbtree.h @@ -139,14 +139,22 @@ struct DartTrainParam : public XGBoostParameter { namespace detail { // From here on, layer becomes concrete trees. -inline std::pair LayerToTree(gbm::GBTreeModel const &model, - size_t layer_begin, - size_t layer_end) { - bst_group_t groups = model.learner_model_param->num_output_group; - uint32_t tree_begin = layer_begin * groups * model.param.num_parallel_tree; - uint32_t tree_end = layer_end * groups * model.param.num_parallel_tree; +inline std::pair LayerToTree(gbm::GBTreeModel const& model, + std::uint32_t layer_begin, + std::uint32_t layer_end) { + std::uint32_t tree_begin; + std::uint32_t tree_end; + if (model.learner_model_param->IsVectorLeaf()) { + tree_begin = layer_begin * model.param.num_parallel_tree; + tree_end = layer_end * model.param.num_parallel_tree; + } else { + bst_group_t groups = model.learner_model_param->OutputLength(); + tree_begin = layer_begin * groups * model.param.num_parallel_tree; + tree_end = layer_end * groups * model.param.num_parallel_tree; + } + if (tree_end == 0) { - tree_end = static_cast(model.trees.size()); + tree_end = model.trees.size(); } if (model.trees.size() != 0) { CHECK_LE(tree_begin, tree_end); @@ -234,22 +242,25 @@ class GBTree : public GradientBooster { void LoadModel(Json const& in) override; // Number of trees per layer. - auto LayerTrees() const { - auto n_trees = model_.learner_model_param->num_output_group * model_.param.num_parallel_tree; - return n_trees; + [[nodiscard]] std::uint32_t LayerTrees() const { + if (model_.learner_model_param->IsVectorLeaf()) { + return model_.param.num_parallel_tree; + } + return model_.param.num_parallel_tree * model_.learner_model_param->OutputLength(); } // slice the trees, out must be already allocated void Slice(int32_t layer_begin, int32_t layer_end, int32_t step, GradientBooster *out, bool* out_of_bound) const override; - int32_t BoostedRounds() const override { + [[nodiscard]] std::int32_t BoostedRounds() const override { CHECK_NE(model_.param.num_parallel_tree, 0); CHECK_NE(model_.learner_model_param->num_output_group, 0); + return model_.trees.size() / this->LayerTrees(); } - bool ModelFitted() const override { + [[nodiscard]] bool ModelFitted() const override { return !model_.trees.empty() || !model_.trees_to_update.empty(); } diff --git a/src/learner.cc b/src/learner.cc index 14f57a5ba5ba..9b1d65ce6206 100644 --- a/src/learner.cc +++ b/src/learner.cc @@ -326,7 +326,7 @@ struct LearnerTrainParam : public XGBoostParameter { std::string booster; std::string objective; // This is a training parameter and is not saved (nor loaded) in the model. - MultiStrategy multi_strategy{MultiStrategy::kComposite}; + MultiStrategy multi_strategy{MultiStrategy::kOneOutputPerTree}; // declare parameters DMLC_DECLARE_PARAMETER(LearnerTrainParam) { @@ -339,12 +339,12 @@ struct LearnerTrainParam : public XGBoostParameter { .set_default("reg:squarederror") .describe("Objective function used for obtaining gradient."); DMLC_DECLARE_FIELD(multi_strategy) - .add_enum("composite", MultiStrategy::kComposite) - .add_enum("monolithic", MultiStrategy::kMonolithic) - .set_default(MultiStrategy::kComposite) + .add_enum("one_output_per_tree", MultiStrategy::kOneOutputPerTree) + .add_enum("multi_output_tree", MultiStrategy::kMultiOutputTree) + .set_default(MultiStrategy::kOneOutputPerTree) .describe( - "Strategy used for training multi-target models. `monolithic` means building one " - "single tree for all targets."); + "Strategy used for training multi-target models. `multi_output_tree` means building " + "one single tree for all targets."); } }; diff --git a/src/metric/rank_metric.cu b/src/metric/rank_metric.cu index 00116ebdb2ad..386f0d53d6b0 100644 --- a/src/metric/rank_metric.cu +++ b/src/metric/rank_metric.cu @@ -145,7 +145,6 @@ PackedReduceResult NDCGScore(Context const *ctx, MetaInfo const &info, auto d_predt = linalg::MakeTensorView(ctx, predt.ConstDeviceSpan(), predt.Size()); auto d_group_ptr = p_cache->DataGroupPtr(ctx); - auto n_groups = info.group_ptr_.size() - 1; auto d_inv_idcg = p_cache->InvIDCG(ctx); auto d_sorted_idx = p_cache->SortedIdx(ctx, d_predt.Values()); @@ -171,7 +170,6 @@ PackedReduceResult MAPScore(Context const *ctx, MetaInfo const &info, HostDeviceVector const &predt, bool minus, std::shared_ptr p_cache) { auto d_group_ptr = p_cache->DataGroupPtr(ctx); - auto n_groups = info.group_ptr_.size() - 1; auto d_label = info.labels.View(ctx->gpu_id).Slice(linalg::All(), 0); predt.SetDevice(ctx->gpu_id); diff --git a/src/predictor/cpu_predictor.cc b/src/predictor/cpu_predictor.cc index 0c045dda0d01..3d5dfbd674ea 100644 --- a/src/predictor/cpu_predictor.cc +++ b/src/predictor/cpu_predictor.cc @@ -87,30 +87,6 @@ bst_float PredValueByOneTree(const RegTree::FVec &p_feats, RegTree const &tree, : GetLeafIndex(tree, p_feats, cats); return tree[leaf].LeafValue(); } - -void PredictByAllTrees(gbm::GBTreeModel const &model, const size_t tree_begin, - const size_t tree_end, const size_t predict_offset, - const std::vector &thread_temp, const size_t offset, - const size_t block_size, linalg::TensorView out_predt) { - for (size_t tree_id = tree_begin; tree_id < tree_end; ++tree_id) { - const size_t gid = model.tree_info[tree_id]; - auto const &tree = *model.trees[tree_id]; - auto const &cats = tree.GetCategoriesMatrix(); - auto has_categorical = tree.HasCategoricalSplit(); - - if (has_categorical) { - for (std::size_t i = 0; i < block_size; ++i) { - out_predt(predict_offset + i, gid) += - PredValueByOneTree(thread_temp[offset + i], tree, cats); - } - } else { - for (std::size_t i = 0; i < block_size; ++i) { - out_predt(predict_offset + i, gid) += - PredValueByOneTree(thread_temp[offset + i], tree, cats); - } - } - } -} } // namespace scalar namespace multi { @@ -128,7 +104,7 @@ bst_node_t GetLeafIndex(MultiTargetTree const &tree, const RegTree::FVec &feat, } template -void PredValueByOneTree(const RegTree::FVec &p_feats, MultiTargetTree const &tree, +void PredValueByOneTree(RegTree::FVec const &p_feats, MultiTargetTree const &tree, RegTree::CategoricalSplitMatrix const &cats, linalg::VectorView out_predt) { bst_node_t const leaf = p_feats.HasMissing() @@ -140,36 +116,52 @@ void PredValueByOneTree(const RegTree::FVec &p_feats, MultiTargetTree const &tre out_predt(i) += leaf_value(i); } } +} // namespace multi -void PredictByAllTrees(gbm::GBTreeModel const &model, const size_t tree_begin, - const size_t tree_end, const size_t predict_offset, - const std::vector &thread_temp, const size_t offset, - const size_t block_size, linalg::TensorView out_predt) { - for (size_t tree_id = tree_begin; tree_id < tree_end; ++tree_id) { +namespace { +void PredictByAllTrees(gbm::GBTreeModel const &model, std::uint32_t const tree_begin, + std::uint32_t const tree_end, std::size_t const predict_offset, + std::vector const &thread_temp, std::size_t const offset, + std::size_t const block_size, linalg::MatrixView out_predt) { + for (std::uint32_t tree_id = tree_begin; tree_id < tree_end; ++tree_id) { auto const &tree = *model.trees.at(tree_id); - auto cats = tree.GetCategoriesMatrix(); + auto const &cats = tree.GetCategoriesMatrix(); bool has_categorical = tree.HasCategoricalSplit(); - if (has_categorical) { - for (std::size_t i = 0; i < block_size; ++i) { - auto t_predts = out_predt.Slice(predict_offset + i, linalg::All()); - PredValueByOneTree(thread_temp[offset + i], *tree.GetMultiTargetTree(), cats, - t_predts); + if (tree.IsMultiTarget()) { + if (has_categorical) { + for (std::size_t i = 0; i < block_size; ++i) { + auto t_predts = out_predt.Slice(predict_offset + i, linalg::All()); + multi::PredValueByOneTree(thread_temp[offset + i], *tree.GetMultiTargetTree(), cats, + t_predts); + } + } else { + for (std::size_t i = 0; i < block_size; ++i) { + auto t_predts = out_predt.Slice(predict_offset + i, linalg::All()); + multi::PredValueByOneTree(thread_temp[offset + i], *tree.GetMultiTargetTree(), + cats, t_predts); + } } } else { - for (std::size_t i = 0; i < block_size; ++i) { - auto t_predts = out_predt.Slice(predict_offset + i, linalg::All()); - PredValueByOneTree(thread_temp[offset + i], *tree.GetMultiTargetTree(), cats, - t_predts); + auto const gid = model.tree_info[tree_id]; + if (has_categorical) { + for (std::size_t i = 0; i < block_size; ++i) { + out_predt(predict_offset + i, gid) += + scalar::PredValueByOneTree(thread_temp[offset + i], tree, cats); + } + } else { + for (std::size_t i = 0; i < block_size; ++i) { + out_predt(predict_offset + i, gid) += + scalar::PredValueByOneTree(thread_temp[offset + i], tree, cats); + } } } } } -} // namespace multi template void FVecFill(const size_t block_size, const size_t batch_offset, const int num_feature, - DataView* batch, const size_t fvec_offset, std::vector* p_feats) { + DataView *batch, const size_t fvec_offset, std::vector *p_feats) { for (size_t i = 0; i < block_size; ++i) { RegTree::FVec &feats = (*p_feats)[fvec_offset + i]; if (feats.Size() == 0) { @@ -181,8 +173,8 @@ void FVecFill(const size_t block_size, const size_t batch_offset, const int num_ } template -void FVecDrop(const size_t block_size, const size_t batch_offset, DataView* batch, - const size_t fvec_offset, std::vector* p_feats) { +void FVecDrop(const size_t block_size, const size_t batch_offset, DataView *batch, + const size_t fvec_offset, std::vector *p_feats) { for (size_t i = 0; i < block_size; ++i) { RegTree::FVec &feats = (*p_feats)[fvec_offset + i]; const SparsePage::Inst inst = (*batch)[batch_offset + i]; @@ -190,9 +182,7 @@ void FVecDrop(const size_t block_size, const size_t batch_offset, DataView* batc } } -namespace { static std::size_t constexpr kUnroll = 8; -} // anonymous namespace struct SparsePageView { bst_row_t base_rowid; @@ -292,7 +282,7 @@ class AdapterView { template void PredictBatchByBlockOfRowsKernel(DataView batch, gbm::GBTreeModel const &model, - int32_t tree_begin, int32_t tree_end, + std::uint32_t tree_begin, std::uint32_t tree_end, std::vector *p_thread_temp, int32_t n_threads, linalg::TensorView out_predt) { auto &thread_temp = *p_thread_temp; @@ -310,14 +300,8 @@ void PredictBatchByBlockOfRowsKernel(DataView batch, gbm::GBTreeModel const &mod FVecFill(block_size, batch_offset, num_feature, &batch, fvec_offset, p_thread_temp); // process block of rows through all trees to keep cache locality - if (model.learner_model_param->IsVectorLeaf()) { - multi::PredictByAllTrees(model, tree_begin, tree_end, batch_offset + batch.base_rowid, - thread_temp, fvec_offset, block_size, out_predt); - } else { - scalar::PredictByAllTrees(model, tree_begin, tree_end, batch_offset + batch.base_rowid, - thread_temp, fvec_offset, block_size, out_predt); - } - + PredictByAllTrees(model, tree_begin, tree_end, batch_offset + batch.base_rowid, thread_temp, + fvec_offset, block_size, out_predt); FVecDrop(block_size, batch_offset, &batch, fvec_offset, p_thread_temp); }); } @@ -348,7 +332,6 @@ void FillNodeMeanValues(RegTree const* tree, std::vector* mean_values) { FillNodeMeanValues(tree, 0, mean_values); } -namespace { // init thread buffers static void InitThreadTemp(int nthread, std::vector *out) { int prev_thread_temp_size = out->size(); diff --git a/src/predictor/gpu_predictor.cu b/src/predictor/gpu_predictor.cu index ecd399e2244f..4a5c5b104179 100644 --- a/src/predictor/gpu_predictor.cu +++ b/src/predictor/gpu_predictor.cu @@ -411,7 +411,7 @@ class DeviceModel { this->tree_beg_ = tree_begin; this->tree_end_ = tree_end; - this->num_group = model.learner_model_param->num_output_group; + this->num_group = model.learner_model_param->OutputLength(); } }; diff --git a/src/tree/hist/histogram.h b/src/tree/hist/histogram.h index 50b90f244aad..562a0b2d44dc 100644 --- a/src/tree/hist/histogram.h +++ b/src/tree/hist/histogram.h @@ -306,9 +306,9 @@ class HistogramBuilder { // Construct a work space for building histogram. Eventually we should move this // function into histogram builder once hist tree method supports external memory. -template +template common::BlockedSpace2d ConstructHistSpace(Partitioner const &partitioners, - std::vector const &nodes_to_build) { + std::vector const &nodes_to_build) { std::vector partition_size(nodes_to_build.size(), 0); for (auto const &partition : partitioners) { size_t k = 0; diff --git a/src/tree/tree_model.cc b/src/tree/tree_model.cc index 8f297f46d70f..7550904b5753 100644 --- a/src/tree/tree_model.cc +++ b/src/tree/tree_model.cc @@ -889,6 +889,8 @@ void RegTree::Save(dmlc::Stream* fo) const { CHECK_EQ(param_.num_nodes, static_cast(stats_.size())); CHECK_EQ(param_.deprecated_num_roots, 1); CHECK_NE(param_.num_nodes, 0); + CHECK(!IsMultiTarget()) + << "Please use JSON/UBJSON for saving models with multi-target trees."; CHECK(!HasCategoricalSplit()) << "Please use JSON/UBJSON for saving models with categorical splits."; diff --git a/src/tree/updater_quantile_hist.cc b/src/tree/updater_quantile_hist.cc index 7e5955dc81f4..012b8e78179e 100644 --- a/src/tree/updater_quantile_hist.cc +++ b/src/tree/updater_quantile_hist.cc @@ -4,36 +4,39 @@ * \brief use quantized feature values to construct a tree * \author Philip Cho, Tianqi Checn, Egor Smirnov */ -#include // for max +#include // for max, copy, transform #include // for size_t -#include // for uint32_t -#include // for unique_ptr, allocator, make_unique, make_shared -#include // for operator<<, char_traits, basic_ostream -#include // for apply +#include // for uint32_t, int32_t +#include // for unique_ptr, allocator, make_unique, shared_ptr +#include // for accumulate +#include // for basic_ostream, char_traits, operator<< #include // for move, swap #include // for vector #include "../collective/communicator-inl.h" // for Allreduce, IsDistributed #include "../collective/communicator.h" // for Operation #include "../common/hist_util.h" // for HistogramCuts, HistCollection +#include "../common/linalg_op.h" // for begin, cbegin, cend #include "../common/random.h" // for ColumnSampler #include "../common/threading_utils.h" // for ParallelFor #include "../common/timer.h" // for Monitor +#include "../common/transform_iterator.h" // for IndexTransformIter, MakeIndexTransformIter #include "../data/gradient_index.h" // for GHistIndexMatrix #include "common_row_partitioner.h" // for CommonRowPartitioner +#include "dmlc/omp.h" // for omp_get_thread_num #include "dmlc/registry.h" // for DMLC_REGISTRY_FILE_TAG #include "driver.h" // for Driver -#include "hist/evaluate_splits.h" // for HistEvaluator, UpdatePredictionCacheImpl -#include "hist/expand_entry.h" // for CPUExpandEntry +#include "hist/evaluate_splits.h" // for HistEvaluator, HistMultiEvaluator, UpdatePre... +#include "hist/expand_entry.h" // for MultiExpandEntry, CPUExpandEntry #include "hist/histogram.h" // for HistogramBuilder, ConstructHistSpace #include "hist/sampler.h" // for SampleGradient -#include "param.h" // for TrainParam, GradStats -#include "xgboost/base.h" // for GradientPair, GradientPairInternal, bst_node_t +#include "param.h" // for TrainParam, SplitEntryContainer, GradStats +#include "xgboost/base.h" // for GradientPairInternal, GradientPair, bst_targ... #include "xgboost/context.h" // for Context #include "xgboost/data.h" // for BatchIterator, BatchSet, DMatrix, MetaInfo #include "xgboost/host_device_vector.h" // for HostDeviceVector -#include "xgboost/linalg.h" // for TensorView, MatrixView, UnravelIndex, All -#include "xgboost/logging.h" // for LogCheck_EQ, LogCheck_GE, CHECK_EQ, LOG, LOG... +#include "xgboost/linalg.h" // for All, MatrixView, TensorView, Matrix, Empty +#include "xgboost/logging.h" // for LogCheck_EQ, CHECK_EQ, CHECK, LogCheck_GE #include "xgboost/span.h" // for Span, operator!=, SpanIterator #include "xgboost/string_view.h" // for operator<< #include "xgboost/task.h" // for ObjInfo @@ -105,6 +108,212 @@ void UpdateTree(common::Monitor *monitor_, linalg::MatrixViewStop(__func__); } +/** + * \brief Updater for building multi-target trees. The implementation simply iterates over + * each target. + */ +class MultiTargetHistBuilder { + private: + common::Monitor *monitor_{nullptr}; + TrainParam const *param_{nullptr}; + std::shared_ptr col_sampler_; + std::unique_ptr evaluator_; + // Histogram builder for each target. + std::vector> histogram_builder_; + Context const *ctx_{nullptr}; + // Partitioner for each data batch. + std::vector partitioner_; + // Pointer to last updated tree, used for update prediction cache. + RegTree const *p_last_tree_{nullptr}; + + ObjInfo const *task_{nullptr}; + + public: + void UpdatePosition(DMatrix *p_fmat, RegTree const *p_tree, + std::vector const &applied) { + monitor_->Start(__func__); + std::size_t page_id{0}; + for (auto const &page : p_fmat->GetBatches(HistBatch(this->param_))) { + this->partitioner_.at(page_id).UpdatePosition(this->ctx_, page, applied, p_tree); + page_id++; + } + monitor_->Stop(__func__); + } + + void ApplyTreeSplit(MultiExpandEntry const &candidate, RegTree *p_tree) { + this->evaluator_->ApplyTreeSplit(candidate, p_tree); + } + + void InitData(DMatrix *p_fmat, RegTree const *p_tree) { + monitor_->Start(__func__); + + std::size_t page_id = 0; + bst_bin_t n_total_bins = 0; + partitioner_.clear(); + for (auto const &page : p_fmat->GetBatches(HistBatch(param_))) { + if (n_total_bins == 0) { + n_total_bins = page.cut.TotalBins(); + } else { + CHECK_EQ(n_total_bins, page.cut.TotalBins()); + } + partitioner_.emplace_back(ctx_, page.Size(), page.base_rowid, p_fmat->IsColumnSplit()); + page_id++; + } + + bst_target_t n_targets = p_tree->NumTargets(); + histogram_builder_.clear(); + for (std::size_t i = 0; i < n_targets; ++i) { + histogram_builder_.emplace_back(); + histogram_builder_.back().Reset(n_total_bins, HistBatch(param_), ctx_->Threads(), page_id, + collective::IsDistributed(), p_fmat->IsColumnSplit()); + } + + evaluator_ = std::make_unique(ctx_, p_fmat->Info(), param_, col_sampler_); + p_last_tree_ = p_tree; + monitor_->Stop(__func__); + } + + MultiExpandEntry InitRoot(DMatrix *p_fmat, linalg::MatrixView gpair, + RegTree *p_tree) { + monitor_->Start(__func__); + MultiExpandEntry best; + best.nid = RegTree::kRoot; + best.depth = 0; + + auto n_targets = p_tree->NumTargets(); + linalg::Matrix root_sum_tloc = + linalg::Empty(ctx_, ctx_->Threads(), n_targets); + CHECK_EQ(root_sum_tloc.Shape(1), gpair.Shape(1)); + auto h_root_sum_tloc = root_sum_tloc.HostView(); + common::ParallelFor(gpair.Shape(0), ctx_->Threads(), [&](auto i) { + for (bst_target_t t{0}; t < n_targets; ++t) { + h_root_sum_tloc(omp_get_thread_num(), t) += GradientPairPrecise{gpair(i, t)}; + } + }); + // Aggregate to the first row. + auto root_sum = h_root_sum_tloc.Slice(0, linalg::All()); + for (std::int32_t tidx{1}; tidx < ctx_->Threads(); ++tidx) { + for (bst_target_t t{0}; t < n_targets; ++t) { + root_sum(t) += h_root_sum_tloc(tidx, t); + } + } + CHECK(root_sum.CContiguous()); + collective::Allreduce( + reinterpret_cast(root_sum.Values().data()), root_sum.Size() * 2); + + std::vector nodes{best}; + std::size_t i = 0; + auto space = ConstructHistSpace(partitioner_, nodes); + for (auto const &page : p_fmat->GetBatches(HistBatch(param_))) { + for (bst_target_t t{0}; t < n_targets; ++t) { + auto t_gpair = gpair.Slice(linalg::All(), t); + histogram_builder_[t].BuildHist(i, space, page, p_tree, partitioner_.at(i).Partitions(), + nodes, {}, t_gpair.Values()); + } + i++; + } + + auto weight = evaluator_->InitRoot(root_sum); + auto weight_t = weight.HostView(); + std::transform(linalg::cbegin(weight_t), linalg::cend(weight_t), linalg::begin(weight_t), + [&](float w) { return w * param_->learning_rate; }); + + p_tree->SetLeaf(RegTree::kRoot, weight_t); + std::vector hists; + for (bst_target_t t{0}; t < p_tree->NumTargets(); ++t) { + hists.push_back(&histogram_builder_[t].Histogram()); + } + for (auto const &gmat : p_fmat->GetBatches(HistBatch(param_))) { + evaluator_->EvaluateSplits(*p_tree, hists, gmat.cut, &nodes); + break; + } + monitor_->Stop(__func__); + + return nodes.front(); + } + + void BuildHistogram(DMatrix *p_fmat, RegTree const *p_tree, + std::vector const &valid_candidates, + linalg::MatrixView gpair) { + monitor_->Start(__func__); + std::vector nodes_to_build; + std::vector nodes_to_sub; + + for (auto const &c : valid_candidates) { + auto left_nidx = p_tree->LeftChild(c.nid); + auto right_nidx = p_tree->RightChild(c.nid); + + auto build_nidx = left_nidx; + auto subtract_nidx = right_nidx; + auto lit = + common::MakeIndexTransformIter([&](auto i) { return c.split.left_sum[i].GetHess(); }); + auto left_sum = std::accumulate(lit, lit + c.split.left_sum.size(), .0); + auto rit = + common::MakeIndexTransformIter([&](auto i) { return c.split.right_sum[i].GetHess(); }); + auto right_sum = std::accumulate(rit, rit + c.split.right_sum.size(), .0); + auto fewer_right = right_sum < left_sum; + if (fewer_right) { + std::swap(build_nidx, subtract_nidx); + } + nodes_to_build.emplace_back(build_nidx, p_tree->GetDepth(build_nidx)); + nodes_to_sub.emplace_back(subtract_nidx, p_tree->GetDepth(subtract_nidx)); + } + + std::size_t i = 0; + auto space = ConstructHistSpace(partitioner_, nodes_to_build); + for (auto const &page : p_fmat->GetBatches(HistBatch(param_))) { + for (std::size_t t = 0; t < p_tree->NumTargets(); ++t) { + auto t_gpair = gpair.Slice(linalg::All(), t); + // Make sure the gradient matrix is f-order. + CHECK(t_gpair.Contiguous()); + histogram_builder_[t].BuildHist(i, space, page, p_tree, partitioner_.at(i).Partitions(), + nodes_to_build, nodes_to_sub, t_gpair.Values()); + } + i++; + } + monitor_->Stop(__func__); + } + + void EvaluateSplits(DMatrix *p_fmat, RegTree const *p_tree, + std::vector *best_splits) { + monitor_->Start(__func__); + std::vector hists; + for (bst_target_t t{0}; t < p_tree->NumTargets(); ++t) { + hists.push_back(&histogram_builder_[t].Histogram()); + } + for (auto const &gmat : p_fmat->GetBatches(HistBatch(param_))) { + evaluator_->EvaluateSplits(*p_tree, hists, gmat.cut, best_splits); + break; + } + monitor_->Stop(__func__); + } + + void LeafPartition(RegTree const &tree, linalg::MatrixView gpair, + std::vector *p_out_position) { + monitor_->Start(__func__); + if (!task_->UpdateTreeLeaf()) { + return; + } + for (auto const &part : partitioner_) { + part.LeafPartition(ctx_, tree, gpair, p_out_position); + } + monitor_->Stop(__func__); + } + + public: + explicit MultiTargetHistBuilder(Context const *ctx, MetaInfo const &info, TrainParam const *param, + std::shared_ptr column_sampler, + ObjInfo const *task, common::Monitor *monitor) + : monitor_{monitor}, + param_{param}, + col_sampler_{std::move(column_sampler)}, + evaluator_{std::make_unique(ctx, info, param, col_sampler_)}, + ctx_{ctx}, + task_{task} { + monitor_->Init(__func__); + } +}; + class HistBuilder { private: common::Monitor *monitor_; @@ -155,8 +364,7 @@ class HistBuilder { // initialize temp data structure void InitData(DMatrix *fmat, RegTree const *p_tree) { monitor_->Start(__func__); - - size_t page_id{0}; + std::size_t page_id{0}; bst_bin_t n_total_bins{0}; partitioner_.clear(); for (auto const &page : fmat->GetBatches(HistBatch(param_))) { @@ -195,7 +403,7 @@ class HistBuilder { RegTree *p_tree) { CPUExpandEntry node(RegTree::kRoot, p_tree->GetDepth(0)); - size_t page_id = 0; + std::size_t page_id = 0; auto space = ConstructHistSpace(partitioner_, {node}); for (auto const &gidx : p_fmat->GetBatches(HistBatch(param_))) { std::vector nodes_to_build{node}; @@ -214,13 +422,13 @@ class HistBuilder { * of gradient histogram is equal to snode[nid] */ auto const &gmat = *(p_fmat->GetBatches(HistBatch(param_)).begin()); - std::vector const &row_ptr = gmat.cut.Ptrs(); + std::vector const &row_ptr = gmat.cut.Ptrs(); CHECK_GE(row_ptr.size(), 2); - uint32_t const ibegin = row_ptr[0]; - uint32_t const iend = row_ptr[1]; + std::uint32_t const ibegin = row_ptr[0]; + std::uint32_t const iend = row_ptr[1]; auto hist = this->histogram_builder_->Histogram()[RegTree::kRoot]; auto begin = hist.data(); - for (uint32_t i = ibegin; i < iend; ++i) { + for (std::uint32_t i = ibegin; i < iend; ++i) { GradientPairPrecise const &et = begin[i]; grad_stat.Add(et.GetGrad(), et.GetHess()); } @@ -259,7 +467,7 @@ class HistBuilder { std::vector nodes_to_build(valid_candidates.size()); std::vector nodes_to_sub(valid_candidates.size()); - size_t n_idx = 0; + std::size_t n_idx = 0; for (auto const &c : valid_candidates) { auto left_nidx = (*p_tree)[c.nid].LeftChild(); auto right_nidx = (*p_tree)[c.nid].RightChild(); @@ -275,7 +483,7 @@ class HistBuilder { n_idx++; } - size_t page_id{0}; + std::size_t page_id{0}; auto space = ConstructHistSpace(partitioner_, nodes_to_build); for (auto const &gidx : p_fmat->GetBatches(HistBatch(param_))) { histogram_builder_->BuildHist(page_id, space, gidx, p_tree, @@ -311,11 +519,12 @@ class HistBuilder { /*! \brief construct a tree using quantized feature values */ class QuantileHistMaker : public TreeUpdater { - std::unique_ptr p_impl_; + std::unique_ptr p_impl_{nullptr}; + std::unique_ptr p_mtimpl_{nullptr}; std::shared_ptr column_sampler_ = std::make_shared(); common::Monitor monitor_; - ObjInfo const *task_; + ObjInfo const *task_{nullptr}; public: explicit QuantileHistMaker(Context const *ctx, ObjInfo const *task) @@ -332,7 +541,10 @@ class QuantileHistMaker : public TreeUpdater { const std::vector &trees) override { if (trees.front()->IsMultiTarget()) { CHECK(param->monotone_constraints.empty()) << "monotone constraint" << MTNotImplemented(); - LOG(FATAL) << "Not implemented."; + if (!p_mtimpl_) { + this->p_mtimpl_ = std::make_unique( + ctx_, p_fmat->Info(), param, column_sampler_, task_, &monitor_); + } } else { if (!p_impl_) { p_impl_ = @@ -355,13 +567,14 @@ class QuantileHistMaker : public TreeUpdater { for (auto tree_it = trees.begin(); tree_it != trees.end(); ++tree_it) { if (need_copy()) { - // Copy gradient into buffer for sampling. + // Copy gradient into buffer for sampling. This converts C-order to F-order. std::copy(linalg::cbegin(h_gpair), linalg::cend(h_gpair), linalg::begin(h_sample_out)); } SampleGradient(ctx_, *param, h_sample_out); auto *h_out_position = &out_position[tree_it - trees.begin()]; if ((*tree_it)->IsMultiTarget()) { - LOG(FATAL) << "Not implemented."; + UpdateTree(&monitor_, h_sample_out, p_mtimpl_.get(), p_fmat, param, + h_out_position, *tree_it); } else { UpdateTree(&monitor_, h_sample_out, p_impl_.get(), p_fmat, param, h_out_position, *tree_it); @@ -372,6 +585,9 @@ class QuantileHistMaker : public TreeUpdater { bool UpdatePredictionCache(const DMatrix *data, linalg::VectorView out_preds) override { if (p_impl_) { return p_impl_->UpdatePredictionCache(data, out_preds); + } else if (p_mtimpl_) { + // Not yet supported. + return false; } else { return false; } @@ -383,6 +599,6 @@ class QuantileHistMaker : public TreeUpdater { XGBOOST_REGISTER_TREE_UPDATER(QuantileHistMaker, "grow_quantile_histmaker") .describe("Grow tree using quantized histogram.") .set_body([](Context const *ctx, ObjInfo const *task) { - return new QuantileHistMaker(ctx, task); + return new QuantileHistMaker{ctx, task}; }); } // namespace xgboost::tree diff --git a/tests/ci_build/lint_python.py b/tests/ci_build/lint_python.py index 8d601f3550bc..b7864bb5025f 100644 --- a/tests/ci_build/lint_python.py +++ b/tests/ci_build/lint_python.py @@ -3,7 +3,7 @@ import subprocess import sys from multiprocessing import Pool, cpu_count -from typing import Dict, Optional, Tuple +from typing import Dict, Tuple from pylint import epylint from test_utils import PY_PACKAGE, ROOT, cd, print_time, record_time @@ -15,8 +15,11 @@ @record_time -def run_black(rel_path: str) -> bool: - cmd = ["black", "-q", "--check", rel_path] +def run_black(rel_path: str, fix: bool) -> bool: + if fix: + cmd = ["black", "-q", rel_path] + else: + cmd = ["black", "-q", "--check", rel_path] ret = subprocess.run(cmd).returncode if ret != 0: subprocess.run(["black", "--version"]) @@ -31,8 +34,11 @@ def run_black(rel_path: str) -> bool: @record_time -def run_isort(rel_path: str) -> bool: - cmd = ["isort", f"--src={SRCPATH}", "--check", "--profile=black", rel_path] +def run_isort(rel_path: str, fix: bool) -> bool: + if fix: + cmd = ["isort", f"--src={SRCPATH}", "--profile=black", rel_path] + else: + cmd = ["isort", f"--src={SRCPATH}", "--check", "--profile=black", rel_path] ret = subprocess.run(cmd).returncode if ret != 0: subprocess.run(["isort", "--version"]) @@ -132,7 +138,7 @@ def run_pylint() -> bool: def main(args: argparse.Namespace) -> None: if args.format == 1: black_results = [ - run_black(path) + run_black(path, args.fix) for path in [ # core "python-package/", @@ -166,7 +172,7 @@ def main(args: argparse.Namespace) -> None: sys.exit(-1) isort_results = [ - run_isort(path) + run_isort(path, args.fix) for path in [ # core "python-package/", @@ -230,6 +236,11 @@ def main(args: argparse.Namespace) -> None: parser.add_argument("--format", type=int, choices=[0, 1], default=1) parser.add_argument("--type-check", type=int, choices=[0, 1], default=1) parser.add_argument("--pylint", type=int, choices=[0, 1], default=1) + parser.add_argument( + "--fix", + action="store_true", + help="Fix the formatting issues instead of emitting an error.", + ) args = parser.parse_args() try: main(args) diff --git a/tests/cpp/gbm/test_gbtree.cc b/tests/cpp/gbm/test_gbtree.cc index c96b9849775b..270eacf21710 100644 --- a/tests/cpp/gbm/test_gbtree.cc +++ b/tests/cpp/gbm/test_gbtree.cc @@ -412,7 +412,7 @@ std::pair TestModelSlice(std::string booster) { j++; } - // CHECK sliced model doesn't have dependency on old one + // CHECK sliced model doesn't have dependency on the old one learner.reset(); CHECK_EQ(sliced->GetNumFeature(), kCols); diff --git a/tests/cpp/helpers.h b/tests/cpp/helpers.h index a059f0436117..c835444131c5 100644 --- a/tests/cpp/helpers.h +++ b/tests/cpp/helpers.h @@ -473,7 +473,7 @@ inline LearnerModelParam MakeMP(bst_feature_t n_features, float base_score, uint int32_t device = Context::kCpuId) { size_t shape[1]{1}; LearnerModelParam mparam(n_features, linalg::Tensor{{base_score}, shape, device}, - n_groups, 1, MultiStrategy::kComposite); + n_groups, 1, MultiStrategy::kOneOutputPerTree); return mparam; } diff --git a/tests/cpp/predictor/test_predictor.cc b/tests/cpp/predictor/test_predictor.cc index 4570a010df67..d6cf33445893 100644 --- a/tests/cpp/predictor/test_predictor.cc +++ b/tests/cpp/predictor/test_predictor.cc @@ -428,7 +428,7 @@ void TestVectorLeafPrediction(Context const *ctx) { LearnerModelParam mparam{static_cast(kCols), linalg::Vector{{0.5}, {1}, Context::kCpuId}, 1, 3, - MultiStrategy::kMonolithic}; + MultiStrategy::kMultiOutputTree}; std::vector> trees; trees.emplace_back(new RegTree{mparam.LeafLength(), mparam.num_feature}); diff --git a/tests/cpp/test_multi_target.cc b/tests/cpp/test_multi_target.cc index d2e34235c02e..c8d371941255 100644 --- a/tests/cpp/test_multi_target.cc +++ b/tests/cpp/test_multi_target.cc @@ -124,11 +124,11 @@ TEST(MultiStrategy, Configure) { auto p_fmat = RandomDataGenerator{12ul, 3ul, 0.0}.GenerateDMatrix(); p_fmat->Info().labels.Reshape(p_fmat->Info().num_row_, 2); std::unique_ptr learner{Learner::Create({p_fmat})}; - learner->SetParams(Args{{"multi_strategy", "monolithic"}, {"num_target", "2"}}); + learner->SetParams(Args{{"multi_strategy", "multi_output_tree"}, {"num_target", "2"}}); learner->Configure(); ASSERT_EQ(learner->Groups(), 2); - learner->SetParams(Args{{"multi_strategy", "monolithic"}, {"num_target", "0"}}); + learner->SetParams(Args{{"multi_strategy", "multi_output_tree"}, {"num_target", "0"}}); ASSERT_THROW({ learner->Configure(); }, dmlc::Error); } } // namespace xgboost diff --git a/tests/python-gpu/test_gpu_ranking.py b/tests/python-gpu/test_gpu_ranking.py index b8be5dda169a..50bbc3f1c54d 100644 --- a/tests/python-gpu/test_gpu_ranking.py +++ b/tests/python-gpu/test_gpu_ranking.py @@ -116,7 +116,7 @@ def test_with_mq2008(objective, metric) -> None: x_valid, y_valid, qid_valid, - ) = tm.get_mq2008(os.path.join(os.path.join(tm.demo_dir(__file__), "rank"))) + ) = tm.data.get_mq2008(os.path.join(os.path.join(tm.demo_dir(__file__), "rank"))) if metric.find("map") != -1 or objective.find("map") != -1: y_train[y_train <= 1] = 0.0 diff --git a/tests/python-gpu/test_gpu_updaters.py b/tests/python-gpu/test_gpu_updaters.py index 6b28296b258d..ea8d5dcb5389 100644 --- a/tests/python-gpu/test_gpu_updaters.py +++ b/tests/python-gpu/test_gpu_updaters.py @@ -32,6 +32,19 @@ def train_result(param, dmat: xgb.DMatrix, num_rounds: int) -> dict: return result +class TestGPUUpdatersMulti: + @given( + hist_parameter_strategy, strategies.integers(1, 20), tm.multi_dataset_strategy + ) + @settings(deadline=None, max_examples=50, print_blob=True) + def test_hist(self, param, num_rounds, dataset): + param["tree_method"] = "gpu_hist" + param = dataset.set_params(param) + result = train_result(param, dataset.get_dmat(), num_rounds) + note(result) + assert tm.non_increasing(result["train"][dataset.metric]) + + class TestGPUUpdaters: cputest = test_up.TestTreeMethod() @@ -101,7 +114,7 @@ def test_categorical_ames_housing( ) -> None: cat_parameters.update(hist_parameters) dataset = tm.TestDataset( - "ames_housing", tm.get_ames_housing, "reg:squarederror", "rmse" + "ames_housing", tm.data.get_ames_housing, "reg:squarederror", "rmse" ) cat_parameters["tree_method"] = "gpu_hist" results = train_result(cat_parameters, dataset.get_dmat(), 16) diff --git a/tests/python/test_basic_models.py b/tests/python/test_basic_models.py index acacc55f8eed..d03ce142bc83 100644 --- a/tests/python/test_basic_models.py +++ b/tests/python/test_basic_models.py @@ -15,13 +15,17 @@ def json_model(model_path: str, parameters: dict) -> dict: - X = np.random.random((10, 3)) - y = np.random.randint(2, size=(10,)) + datasets = pytest.importorskip("sklearn.datasets") + + X, y = datasets.make_classification(64, n_features=8, n_classes=3, n_informative=6) + if parameters.get("objective", None) == "multi:softmax": + parameters["num_class"] = 3 dm1 = xgb.DMatrix(X, y) bst = xgb.train(parameters, dm1) bst.save_model(model_path) + if model_path.endswith("ubj"): import ubjson with open(model_path, "rb") as ubjfd: @@ -326,24 +330,43 @@ def run_model_json_io(self, parameters: dict, ext: str) -> None: from_ubjraw = xgb.Booster() from_ubjraw.load_model(ubj_raw) - old_from_json = from_jraw.save_raw(raw_format="deprecated") - old_from_ubj = from_ubjraw.save_raw(raw_format="deprecated") + if parameters.get("multi_strategy", None) != "multi_output_tree": + # old binary model is not supported. + old_from_json = from_jraw.save_raw(raw_format="deprecated") + old_from_ubj = from_ubjraw.save_raw(raw_format="deprecated") - assert old_from_json == old_from_ubj + assert old_from_json == old_from_ubj raw_json = bst.save_raw(raw_format="json") pretty = json.dumps(json.loads(raw_json), indent=2) + "\n\n" bst.load_model(bytearray(pretty, encoding="ascii")) - old_from_json = from_jraw.save_raw(raw_format="deprecated") - old_from_ubj = from_ubjraw.save_raw(raw_format="deprecated") + if parameters.get("multi_strategy", None) != "multi_output_tree": + # old binary model is not supported. + old_from_json = from_jraw.save_raw(raw_format="deprecated") + old_from_ubj = from_ubjraw.save_raw(raw_format="deprecated") + + assert old_from_json == old_from_ubj - assert old_from_json == old_from_ubj + rng = np.random.default_rng() + X = rng.random(size=from_jraw.num_features() * 10).reshape( + (10, from_jraw.num_features()) + ) + predt_from_jraw = from_jraw.predict(xgb.DMatrix(X)) + predt_from_bst = bst.predict(xgb.DMatrix(X)) + np.testing.assert_allclose(predt_from_jraw, predt_from_bst) @pytest.mark.parametrize("ext", ["json", "ubj"]) def test_model_json_io(self, ext: str) -> None: parameters = {"booster": "gbtree", "tree_method": "hist"} self.run_model_json_io(parameters, ext) + parameters = { + "booster": "gbtree", + "tree_method": "hist", + "multi_strategy": "multi_output_tree", + "objective": "multi:softmax", + } + self.run_model_json_io(parameters, ext) parameters = {"booster": "gblinear"} self.run_model_json_io(parameters, ext) parameters = {"booster": "dart", "tree_method": "hist"} diff --git a/tests/python/test_callback.py b/tests/python/test_callback.py index fabf8672eb5b..e8375aa5e8be 100644 --- a/tests/python/test_callback.py +++ b/tests/python/test_callback.py @@ -465,7 +465,7 @@ def test_check_point(self): assert os.path.exists(os.path.join(tmpdir, "model_" + str(i) + ".pkl")) def test_callback_list(self): - X, y = tm.get_california_housing() + X, y = tm.data.get_california_housing() m = xgb.DMatrix(X, y) callbacks = [xgb.callback.EarlyStopping(rounds=10)] for i in range(4): diff --git a/tests/python/test_ranking.py b/tests/python/test_ranking.py index 239271ec71bc..30de920f783a 100644 --- a/tests/python/test_ranking.py +++ b/tests/python/test_ranking.py @@ -82,7 +82,7 @@ def setup_class(cls): """ cls.dpath = 'demo/rank/' (x_train, y_train, qid_train, x_test, y_test, qid_test, - x_valid, y_valid, qid_valid) = tm.get_mq2008(cls.dpath) + x_valid, y_valid, qid_valid) = tm.data.get_mq2008(cls.dpath) # instantiate the matrices cls.dtrain = xgboost.DMatrix(x_train, y_train) diff --git a/tests/python/test_updaters.py b/tests/python/test_updaters.py index be72793e7a82..dd710f6a46bf 100644 --- a/tests/python/test_updaters.py +++ b/tests/python/test_updaters.py @@ -11,6 +11,7 @@ from xgboost.testing.params import ( cat_parameter_strategy, exact_parameter_strategy, + hist_multi_parameter_strategy, hist_parameter_strategy, ) from xgboost.testing.updater import check_init_estimation, check_quantile_loss @@ -18,11 +19,70 @@ def train_result(param, dmat, num_rounds): result = {} - xgb.train(param, dmat, num_rounds, [(dmat, 'train')], verbose_eval=False, - evals_result=result) + booster = xgb.train( + param, + dmat, + num_rounds, + [(dmat, "train")], + verbose_eval=False, + evals_result=result, + ) + assert booster.num_features() == dmat.num_col() + assert booster.num_boosted_rounds() == num_rounds + assert booster.feature_names == dmat.feature_names + assert booster.feature_types == dmat.feature_types + return result +class TestTreeMethodMulti: + @given( + exact_parameter_strategy, strategies.integers(1, 20), tm.multi_dataset_strategy + ) + @settings(deadline=None, print_blob=True) + def test_exact(self, param: dict, num_rounds: int, dataset: tm.TestDataset) -> None: + if dataset.name.endswith("-l1"): + return + param["tree_method"] = "exact" + param = dataset.set_params(param) + result = train_result(param, dataset.get_dmat(), num_rounds) + assert tm.non_increasing(result["train"][dataset.metric]) + + @given( + exact_parameter_strategy, + hist_parameter_strategy, + strategies.integers(1, 20), + tm.multi_dataset_strategy, + ) + @settings(deadline=None, print_blob=True) + def test_approx(self, param, hist_param, num_rounds, dataset): + param["tree_method"] = "approx" + param = dataset.set_params(param) + param.update(hist_param) + result = train_result(param, dataset.get_dmat(), num_rounds) + note(result) + assert tm.non_increasing(result["train"][dataset.metric]) + + @given( + exact_parameter_strategy, + hist_multi_parameter_strategy, + strategies.integers(1, 20), + tm.multi_dataset_strategy, + ) + @settings(deadline=None, print_blob=True) + def test_hist( + self, param: dict, hist_param: dict, num_rounds: int, dataset: tm.TestDataset + ) -> None: + if dataset.name.endswith("-l1"): + return + param["tree_method"] = "hist" + param = dataset.set_params(param) + param.update(hist_param) + result = train_result(param, dataset.get_dmat(), num_rounds) + note(result) + assert tm.non_increasing(result["train"][dataset.metric]) + + class TestTreeMethod: USE_ONEHOT = np.iinfo(np.int32).max USE_PART = 1 @@ -77,10 +137,14 @@ def test_pruner(self): # Second prune should not change the tree assert after_prune == second_prune - @given(exact_parameter_strategy, hist_parameter_strategy, strategies.integers(1, 20), - tm.dataset_strategy) + @given( + exact_parameter_strategy, + hist_parameter_strategy, + strategies.integers(1, 20), + tm.dataset_strategy + ) @settings(deadline=None, print_blob=True) - def test_hist(self, param, hist_param, num_rounds, dataset): + def test_hist(self, param: dict, hist_param: dict, num_rounds: int, dataset: tm.TestDataset) -> None: param['tree_method'] = 'hist' param = dataset.set_params(param) param.update(hist_param) @@ -88,23 +152,6 @@ def test_hist(self, param, hist_param, num_rounds, dataset): note(result) assert tm.non_increasing(result['train'][dataset.metric]) - @given(tm.sparse_datasets_strategy) - @settings(deadline=None, print_blob=True) - def test_sparse(self, dataset): - param = {"tree_method": "hist", "max_bin": 64} - hist_result = train_result(param, dataset.get_dmat(), 16) - note(hist_result) - assert tm.non_increasing(hist_result['train'][dataset.metric]) - - param = {"tree_method": "approx", "max_bin": 64} - approx_result = train_result(param, dataset.get_dmat(), 16) - note(approx_result) - assert tm.non_increasing(approx_result['train'][dataset.metric]) - - np.testing.assert_allclose( - hist_result["train"]["rmse"], approx_result["train"]["rmse"] - ) - def test_hist_categorical(self): # hist must be same as exact on all-categorial data dpath = 'demo/data/' @@ -143,6 +190,23 @@ def test_hist_degenerate_case(self): w = [0, 0, 1, 0] model.fit(X, y, sample_weight=w) + @given(tm.sparse_datasets_strategy) + @settings(deadline=None, print_blob=True) + def test_sparse(self, dataset): + param = {"tree_method": "hist", "max_bin": 64} + hist_result = train_result(param, dataset.get_dmat(), 16) + note(hist_result) + assert tm.non_increasing(hist_result['train'][dataset.metric]) + + param = {"tree_method": "approx", "max_bin": 64} + approx_result = train_result(param, dataset.get_dmat(), 16) + note(approx_result) + assert tm.non_increasing(approx_result['train'][dataset.metric]) + + np.testing.assert_allclose( + hist_result["train"]["rmse"], approx_result["train"]["rmse"] + ) + def run_invalid_category(self, tree_method: str) -> None: rng = np.random.default_rng() # too large @@ -365,7 +429,7 @@ def test_categorical_ames_housing( ) -> None: cat_parameters.update(hist_parameters) dataset = tm.TestDataset( - "ames_housing", tm.get_ames_housing, "reg:squarederror", "rmse" + "ames_housing", tm.data.get_ames_housing, "reg:squarederror", "rmse" ) cat_parameters["tree_method"] = tree_method results = train_result(cat_parameters, dataset.get_dmat(), 16) diff --git a/tests/test_distributed/test_with_dask/test_with_dask.py b/tests/test_distributed/test_with_dask/test_with_dask.py index 369dcd421757..0bf952025c0b 100644 --- a/tests/test_distributed/test_with_dask/test_with_dask.py +++ b/tests/test_distributed/test_with_dask/test_with_dask.py @@ -1168,7 +1168,7 @@ def test_dask_aft_survival() -> None: def test_dask_ranking(client: "Client") -> None: dpath = "demo/rank/" - mq2008 = tm.get_mq2008(dpath) + mq2008 = tm.data.get_mq2008(dpath) data = [] for d in mq2008: if isinstance(d, scipy.sparse.csr_matrix):