Skip to content

Commit

Permalink
Implement multi-target for hist.
Browse files Browse the repository at this point in the history
- Add new hist tree builder.
- Move data fetchers for tests.
- Dispatch function calls in gbm base on the tree type.
  • Loading branch information
trivialfis committed Mar 22, 2023
1 parent 9b6cc0e commit bc56079
Show file tree
Hide file tree
Showing 28 changed files with 836 additions and 371 deletions.
35 changes: 29 additions & 6 deletions demo/guide-python/multioutput_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,11 @@
https://scikit-learn.org/stable/auto_examples/ensemble/plot_random_forest_regression_multioutput.html#sphx-glr-auto-examples-ensemble-plot-random-forest-regression-multioutput-py
See :doc:`/tutorials/multioutput` for more information.
.. note::
The feature is experimental. For the monolithic strategy, many features are missing.
"""

import argparse
Expand Down Expand Up @@ -40,19 +45,26 @@ def gen_circle() -> Tuple[np.ndarray, np.ndarray]:
return X, y


def rmse_model(plot_result: bool):
def rmse_model(plot_result: bool, strategy: str):
"""Draw a circle with 2-dim coordinate as target variables."""
X, y = gen_circle()
# Train a regressor on it
reg = xgb.XGBRegressor(tree_method="hist", n_estimators=64)
reg = xgb.XGBRegressor(
tree_method="hist",
n_estimators=128,
n_jobs=16,
max_depth=8,
multi_strategy=strategy,
subsample=0.6,
)
reg.fit(X, y, eval_set=[(X, y)])

y_predt = reg.predict(X)
if plot_result:
plot_predt(y, y_predt, "multi")


def custom_rmse_model(plot_result: bool) -> None:
def custom_rmse_model(plot_result: bool, strategy: str) -> None:
"""Train using Python implementation of Squared Error."""

# As the experimental support status, custom objective doesn't support matrix as
Expand Down Expand Up @@ -88,9 +100,10 @@ def rmse(predt: np.ndarray, dtrain: xgb.DMatrix) -> Tuple[str, float]:
{
"tree_method": "hist",
"num_target": y.shape[1],
"multi_strategy": strategy,
},
dtrain=Xy,
num_boost_round=100,
num_boost_round=128,
obj=squared_log,
evals=[(Xy, "Train")],
evals_result=results,
Expand All @@ -107,6 +120,16 @@ def rmse(predt: np.ndarray, dtrain: xgb.DMatrix) -> Tuple[str, float]:
parser.add_argument("--plot", choices=[0, 1], type=int, default=1)
args = parser.parse_args()
# Train with builtin RMSE objective
rmse_model(args.plot == 1)
# - One model per output.
rmse_model(args.plot == 1, "composite")

# - One model for all outputs, this is still working in progress, many features are
# missing.
rmse_model(args.plot == 1, "monolithic")

# Train with custom objective.
custom_rmse_model(args.plot == 1)
# - One model per output.
custom_rmse_model(args.plot == 1, "composite")
# - One model for all outputs, this is still working in progress, many features are
# missing.
custom_rmse_model(args.plot == 1, "monolithic")
12 changes: 12 additions & 0 deletions doc/parameter.rst
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,18 @@ Parameters for Tree Booster
list is a group of indices of features that are allowed to interact with each other.
See :doc:`/tutorials/feature_interaction_constraint` for more information.

* ``multi_strategy``, [default = ``composite``]

.. versionadded:: 2.0.0

.. note:: This parameter is working-in-progress.

- The strategy used for training multi-target models, including multi-target regression
and multi-class classification. See :doc:`/tutorials/multioutput` for more information.

- ``composite``: One model for each target.
- ``monolithic``: Use multi-target trees.

.. _cat-param:

Parameters for Categorical Feature
Expand Down
28 changes: 27 additions & 1 deletion doc/tutorials/multioutput.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,11 @@ can be simultaneously classified as both sci-fi and comedy. For detailed explan
terminologies related to different multi-output models please refer to the
:doc:`scikit-learn user guide <sklearn:modules/multiclass>`.

Internally, XGBoost builds one model for each target similar to sklearn meta estimators,
**********************************
Training with One-Model-Per-Target
**********************************

By default, XGBoost builds one model for each target similar to sklearn meta estimators,
with the added benefit of reusing data and other integrated features like SHAP. For a
worked example of regression, see
:ref:`sphx_glr_python_examples_multioutput_regression.py`. For multi-label classification,
Expand All @@ -36,3 +40,25 @@ dense matrix for labels.
The feature is still under development with limited support from objectives and metrics.

*************************
Training with Vector Leaf
*************************

.. versionadded:: 2.0

.. note::

This is still working-in-progress, and many features are missing.

XGBoost can optionally build multi-output trees with the size of leaf equals to the number
of targets when the tree method `hist` is used. The behavior can be controlled by the
``multi_strategy`` training parameter, which can take the value `composite` (the default)
for building one model per-target or `monolithic` for building multi-output trees.

.. code-block:: python
clf = xgb.XGBClassifier(tree_method="hist", multi_strategy="monolithic")
See :ref:`sphx_glr_python_examples_multioutput_regression.py` for a worked example with
regression.
8 changes: 4 additions & 4 deletions include/xgboost/linalg.h
Original file line number Diff line number Diff line change
Expand Up @@ -530,17 +530,17 @@ class TensorView {
/**
* \brief Number of items in the tensor.
*/
LINALG_HD [[nodiscard]] std::size_t Size() const { return size_; }
[[nodiscard]] LINALG_HD std::size_t Size() const { return size_; }
/**
* \brief Whether this is a contiguous array, both C and F contiguous returns true.
*/
LINALG_HD [[nodiscard]] bool Contiguous() const {
[[nodiscard]] LINALG_HD bool Contiguous() const {
return data_.size() == this->Size() || this->CContiguous() || this->FContiguous();
}
/**
* \brief Whether it's a c-contiguous array.
*/
LINALG_HD [[nodiscard]] bool CContiguous() const {
[[nodiscard]] LINALG_HD bool CContiguous() const {
StrideT stride;
static_assert(std::is_same<decltype(stride), decltype(stride_)>::value);
// It's contiguous if the stride can be calculated from shape.
Expand All @@ -550,7 +550,7 @@ class TensorView {
/**
* \brief Whether it's a f-contiguous array.
*/
LINALG_HD [[nodiscard]] bool FContiguous() const {
[[nodiscard]] LINALG_HD bool FContiguous() const {
StrideT stride;
static_assert(std::is_same<decltype(stride), decltype(stride_)>::value);
// It's contiguous if the stride can be calculated from shape.
Expand Down
15 changes: 15 additions & 0 deletions python-package/xgboost/sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,19 @@ def task(i: int) -> float:
needs to be set to have categorical feature support. See :doc:`Categorical Data
</tutorials/categorical>` and :ref:`cat-param` for details.
multi_strategy : Optional[str]
.. versionadded:: 2.0.0
.. note:: This parameter is working-in-progress.
The strategy used for training multi-target models, including multi-target
regression and multi-class classification. See :doc:`/tutorials/multioutput` for
more information.
- ``composite``: One model for each target.
- ``monolithic``: Use multi-target trees.
eval_metric : Optional[Union[str, List[str], Callable]]
.. versionadded:: 1.6.0
Expand Down Expand Up @@ -624,6 +637,7 @@ def __init__(
feature_types: Optional[FeatureTypes] = None,
max_cat_to_onehot: Optional[int] = None,
max_cat_threshold: Optional[int] = None,
multi_strategy: Optional[str] = None,
eval_metric: Optional[Union[str, List[str], Callable]] = None,
early_stopping_rounds: Optional[int] = None,
callbacks: Optional[List[TrainingCallback]] = None,
Expand Down Expand Up @@ -670,6 +684,7 @@ def __init__(
self.feature_types = feature_types
self.max_cat_to_onehot = max_cat_to_onehot
self.max_cat_threshold = max_cat_threshold
self.multi_strategy = multi_strategy
self.eval_metric = eval_metric
self.early_stopping_rounds = early_stopping_rounds
self.callbacks = callbacks
Expand Down
Loading

0 comments on commit bc56079

Please sign in to comment.