Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove deprecated feval. #11051

Merged
merged 9 commits into from
Dec 6, 2024
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion demo/guide-python/cross_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
Demo for using cross validation
===============================
"""

import os

import numpy as np
Expand Down Expand Up @@ -83,9 +84,12 @@ def logregobj(preds, dtrain):

def evalerror(preds, dtrain):
labels = dtrain.get_label()
preds = 1.0 / (1.0 + np.exp(-preds))
return "error", float(sum(labels != (preds > 0.0))) / len(labels)


param = {"max_depth": 2, "eta": 1}
# train with customized objective
xgb.cv(param, dtrain, num_round, nfold=5, seed=0, obj=logregobj, feval=evalerror)
xgb.cv(
param, dtrain, num_round, nfold=5, seed=0, obj=logregobj, custom_metric=evalerror
)
1 change: 1 addition & 0 deletions doc/python/python_api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ Core Data Structure
.. autoclass:: xgboost.Booster
:members:
:show-inheritance:
:special-members: __getitem__

.. autoclass:: xgboost.DataIter
:members:
Expand Down
6 changes: 0 additions & 6 deletions python-package/xgboost/dask/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -766,7 +766,6 @@ async def _train_async(
num_boost_round: int,
evals: Optional[Sequence[Tuple[DaskDMatrix, str]]],
obj: Optional[Objective],
feval: Optional[Metric],
early_stopping_rounds: Optional[int],
verbose_eval: Union[int, bool],
xgb_model: Optional[Booster],
Expand Down Expand Up @@ -816,7 +815,6 @@ def do_train( # pylint: disable=too-many-positional-arguments
evals_result=local_history,
evals=evals if len(evals) != 0 else None,
obj=obj,
feval=feval,
custom_metric=custom_metric,
early_stopping_rounds=early_stopping_rounds,
verbose_eval=verbose_eval,
Expand Down Expand Up @@ -870,7 +868,6 @@ def train( # pylint: disable=unused-argument
*,
evals: Optional[Sequence[Tuple[DaskDMatrix, str]]] = None,
obj: Optional[Objective] = None,
feval: Optional[Metric] = None,
early_stopping_rounds: Optional[int] = None,
xgb_model: Optional[Booster] = None,
verbose_eval: Union[int, bool] = True,
Expand Down Expand Up @@ -1675,7 +1672,6 @@ async def _fit_async(
num_boost_round=self.get_num_boosting_rounds(),
evals=evals,
obj=obj,
feval=None,
custom_metric=metric,
verbose_eval=verbose,
early_stopping_rounds=self.early_stopping_rounds,
Expand Down Expand Up @@ -1784,7 +1780,6 @@ async def _fit_async(
num_boost_round=self.get_num_boosting_rounds(),
evals=evals,
obj=obj,
feval=None,
custom_metric=metric,
verbose_eval=verbose,
early_stopping_rounds=self.early_stopping_rounds,
Expand Down Expand Up @@ -1986,7 +1981,6 @@ async def _fit_async(
num_boost_round=self.get_num_boosting_rounds(),
evals=evals,
obj=None,
feval=None,
custom_metric=metric,
verbose_eval=verbose,
early_stopping_rounds=self.early_stopping_rounds,
Expand Down
2 changes: 1 addition & 1 deletion python-package/xgboost/sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,7 +426,7 @@ def task(i: int) -> float:

Metric used for monitoring the training result and early stopping. It can be a
string or list of strings as names of predefined metric in XGBoost (See
doc/parameter.rst), one of the metrics in :py:mod:`sklearn.metrics`, or any
:doc:`/parameter`), one of the metrics in :py:mod:`sklearn.metrics`, or any
other user defined metric that looks like `sklearn.metrics`.

If custom objective is also provided, then custom metric should implement the
Expand Down
33 changes: 30 additions & 3 deletions python-package/xgboost/testing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -662,9 +662,29 @@ def predictor_equal(lhs: xgb.DMatrix, rhs: xgb.DMatrix) -> bool:
M = TypeVar("M", xgb.Booster, xgb.XGBModel)


def eval_error_metric(predt: np.ndarray, dtrain: xgb.DMatrix) -> Tuple[str, np.float64]:
"""Evaluation metric for xgb.train"""
def logregobj(preds: np.ndarray, dtrain: xgb.DMatrix) -> Tuple[np.ndarray, np.ndarray]:
"""Binary regression custom objective."""
labels = dtrain.get_label()
preds = 1.0 / (1.0 + np.exp(-preds))
grad = preds - labels
hess = preds * (1.0 - preds)
return grad, hess


def eval_error_metric(
predt: np.ndarray, dtrain: xgb.DMatrix, rev_link: bool
) -> Tuple[str, np.float64]:
"""Evaluation metric for xgb.train.

Parameters
----------
rev_link : Whether the metric needs to apply the reverse link function (activation).

"""
label = dtrain.get_label()
if rev_link:
predt = 1.0 / (1.0 + np.exp(-predt))
assert (0.0 <= predt).all() and (predt <= 1.0).all()
r = np.zeros(predt.shape)
gt = predt > 0.5
if predt.size == 0:
Expand All @@ -675,8 +695,15 @@ def eval_error_metric(predt: np.ndarray, dtrain: xgb.DMatrix) -> Tuple[str, np.f
return "CustomErr", np.sum(r)


def eval_error_metric_skl(y_true: np.ndarray, y_score: np.ndarray) -> np.float64:
def eval_error_metric_skl(
y_true: np.ndarray, y_score: np.ndarray, rev_link: bool = False
) -> np.float64:
"""Evaluation metric that looks like metrics provided by sklearn."""

if rev_link:
y_score = 1.0 / (1.0 + np.exp(-y_score))
assert (0.0 <= y_score).all() and (y_score <= 1.0).all()

r = np.zeros(y_score.shape)
gt = y_score > 0.5
r[gt] = 1 - y_true[gt]
Expand Down
84 changes: 31 additions & 53 deletions python-package/xgboost/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
"""Training Library containing training routines."""
import copy
import os
import warnings
from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Union, cast

import numpy as np
Expand All @@ -28,26 +27,6 @@
_CVFolds = Sequence["CVPack"]


def _configure_custom_metric(
feval: Optional[Metric], custom_metric: Optional[Metric]
) -> Optional[Metric]:
if feval is not None:
link = (
"https://xgboost.readthedocs.io/en/latest/tutorials/custom_metric_obj.html"
)
warnings.warn(
"`feval` is deprecated, use `custom_metric` instead. They have "
"different behavior when custom objective is also used."
f"See {link} for details on the `custom_metric`."
)
if feval is not None and custom_metric is not None:
raise ValueError(
"Both `feval` and `custom_metric` are supplied. Use `custom_metric` instead."
)
eval_metric = custom_metric if custom_metric is not None else feval
return eval_metric


@_deprecate_positional_args
def train(
params: Dict[str, Any],
Expand All @@ -56,7 +35,6 @@ def train(
*,
evals: Optional[Sequence[Tuple[DMatrix, str]]] = None,
obj: Optional[Objective] = None,
feval: Optional[Metric] = None,
maximize: Optional[bool] = None,
early_stopping_rounds: Optional[int] = None,
evals_result: Optional[TrainingCallback.EvalsLog] = None,
Expand All @@ -81,23 +59,27 @@ def train(
obj
Custom objective function. See :doc:`Custom Objective
</tutorials/custom_metric_obj>` for details.
feval :
.. deprecated:: 1.6.0
Use `custom_metric` instead.
maximize :
Whether to maximize feval.
Whether to maximize custom_metric.

early_stopping_rounds :

Activates early stopping. Validation metric needs to improve at least once in
every **early_stopping_rounds** round(s) to continue training.

Requires at least one item in **evals**.

The method returns the model from the last iteration (not the best one). Use
custom callback or model slicing if the best model is desired.
If there's more than one item in **evals**, the last entry will be used for early
stopping.
custom callback :py:class:`~xgboost.callback.EarlyStopping` or :py:meth:`model
slicing <xgboost.Booster.__getitem__>` if the best model is desired. If there's
more than one item in **evals**, the last entry will be used for early stopping.

If there's more than one metric in the **eval_metric** parameter given in
**params**, the last metric will be used for early stopping.

If early stopping occurs, the model will have two additional fields:
``bst.best_score``, ``bst.best_iteration``.

evals_result :
This dictionary stores the evaluation results of all the items in watchlist.

Expand All @@ -113,15 +95,22 @@ def train(

verbose_eval :
Requires at least one item in **evals**.

If **verbose_eval** is True then the evaluation metric on the validation set is
printed at each boosting stage.
If **verbose_eval** is an integer then the evaluation metric on the validation set
is printed at every given **verbose_eval** boosting stage. The last boosting stage
/ the boosting stage found by using **early_stopping_rounds** is also printed.
Example: with ``verbose_eval=4`` and at least one item in **evals**, an evaluation metric
is printed every 4 boosting stages, instead of every boosting stage.

If **verbose_eval** is an integer then the evaluation metric on the validation
set is printed at every given **verbose_eval** boosting stage. The last boosting
stage / the boosting stage found by using **early_stopping_rounds** is also
printed.

Example: with ``verbose_eval=4`` and at least one item in **evals**, an
evaluation metric is printed every 4 boosting stages, instead of every boosting
stage.

xgb_model :
Xgb model to be loaded before training (allows training continuation).

callbacks :
List of callback functions that are applied at end of each iteration.
It is possible to use predefined callbacks by using
Expand All @@ -145,15 +134,17 @@ def train(
.. versionadded 1.6.0

Custom metric function. See :doc:`Custom Metric </tutorials/custom_metric_obj>`
for details.
for details. The metric receives transformed predictioin (after applying the
trivialfis marked this conversation as resolved.
Show resolved Hide resolved
reverse link function) when using a builtin objective, and raw output when using
a custom objective.

Returns
-------
Booster : a trained booster model

"""

callbacks = [] if callbacks is None else copy.copy(list(callbacks))
metric_fn = _configure_custom_metric(feval, custom_metric)
evals = list(evals) if evals else []

bst = Booster(params, [dtrain] + [d[0] for d in evals], model_file=xgb_model)
Expand All @@ -165,12 +156,7 @@ def train(
if early_stopping_rounds:
callbacks.append(EarlyStopping(rounds=early_stopping_rounds, maximize=maximize))
cb_container = CallbackContainer(
callbacks,
metric=metric_fn,
# For old `feval` parameter, the behavior is unchanged. For the new
# `custom_metric`, it will receive proper prediction result when custom objective
# is not used.
output_margin=callable(obj) or metric_fn is feval,
callbacks, metric=custom_metric, output_margin=callable(obj)
)

bst = cb_container.before_training(bst)
Expand Down Expand Up @@ -423,7 +409,6 @@ def cv(
folds: XGBStratifiedKFold = None,
metrics: Sequence[str] = (),
obj: Optional[Objective] = None,
feval: Optional[Metric] = None,
maximize: Optional[bool] = None,
early_stopping_rounds: Optional[int] = None,
fpreproc: Optional[FPreProcCallable] = None,
Expand Down Expand Up @@ -464,11 +449,9 @@ def cv(
Custom objective function. See :doc:`Custom Objective
</tutorials/custom_metric_obj>` for details.

feval : function
.. deprecated:: 1.6.0
Use `custom_metric` instead.
maximize : bool
Whether to maximize feval.
Whether to maximize the evaluataion metric (score or error).

early_stopping_rounds: int
Activates early stopping. Cross-Validation metric (average of validation
metric computed over CV folds) needs to improve at least once in
Expand Down Expand Up @@ -559,8 +542,6 @@ def cv(
shuffle=shuffle,
)

metric_fn = _configure_custom_metric(feval, custom_metric)

# setup callbacks
callbacks = [] if callbacks is None else copy.copy(list(callbacks))

Expand All @@ -570,10 +551,7 @@ def cv(
if early_stopping_rounds:
callbacks.append(EarlyStopping(rounds=early_stopping_rounds, maximize=maximize))
callbacks_container = CallbackContainer(
callbacks,
metric=metric_fn,
is_cv=True,
output_margin=callable(obj) or metric_fn is feval,
callbacks, metric=custom_metric, is_cv=True, output_margin=callable(obj)
)

booster = _PackedBooster(cvfolds)
Expand Down
2 changes: 2 additions & 0 deletions tests/ci_build/lint_python.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ class LintersPaths:
"tests/python/test_dt.py",
"tests/python/test_demos.py",
"tests/python/test_eval_metrics.py",
"tests/python/test_early_stopping.py",
"tests/python/test_multi_target.py",
"tests/python/test_objectives.py",
"tests/python/test_predict.py",
Expand Down Expand Up @@ -54,6 +55,7 @@ class LintersPaths:
"demo/guide-python/callbacks.py",
"demo/guide-python/categorical.py",
"demo/guide-python/cat_pipeline.py",
"demo/guide-python/cross_validation.py",
"demo/guide-python/feature_weights.py",
"demo/guide-python/model_parser.py",
"demo/guide-python/sklearn_parallel.py",
Expand Down
Loading
Loading