diff --git a/demo/guide-python/callbacks.py b/demo/guide-python/callbacks.py index 42fe397dba73..be03b1693d99 100644 --- a/demo/guide-python/callbacks.py +++ b/demo/guide-python/callbacks.py @@ -1,9 +1,9 @@ -''' +""" Demo for using and defining callback functions ============================================== .. versionadded:: 1.3.0 -''' +""" import argparse import os import tempfile @@ -17,10 +17,11 @@ class Plotting(xgb.callback.TrainingCallback): - '''Plot evaluation result during training. Only for demonstration purpose as it's quite + """Plot evaluation result during training. Only for demonstration purpose as it's quite slow to draw. - ''' + """ + def __init__(self, rounds): self.fig = plt.figure() self.ax = self.fig.add_subplot(111) @@ -31,16 +32,16 @@ def __init__(self, rounds): plt.ion() def _get_key(self, data, metric): - return f'{data}-{metric}' + return f"{data}-{metric}" def after_iteration(self, model, epoch, evals_log): - '''Update the plot.''' + """Update the plot.""" if not self.lines: for data, metric in evals_log.items(): for metric_name, log in metric.items(): key = self._get_key(data, metric_name) expanded = log + [0] * (self.rounds - len(log)) - self.lines[key], = self.ax.plot(self.x, expanded, label=key) + (self.lines[key],) = self.ax.plot(self.x, expanded, label=key) self.ax.legend() else: # https://pythonspot.com/matplotlib-update-plot/ @@ -55,8 +56,8 @@ def after_iteration(self, model, epoch, evals_log): def custom_callback(): - '''Demo for defining a custom callback function that plots evaluation result during - training.''' + """Demo for defining a custom callback function that plots evaluation result during + training.""" X, y = load_breast_cancer(return_X_y=True) X_train, X_valid, y_train, y_valid = train_test_split(X, y, random_state=0) @@ -69,15 +70,16 @@ def custom_callback(): # Pass it to the `callbacks` parameter as a list. xgb.train( { - 'objective': 'binary:logistic', - 'eval_metric': ['error', 'rmse'], - 'tree_method': 'hist', + "objective": "binary:logistic", + "eval_metric": ["error", "rmse"], + "tree_method": "hist", "device": "cuda", }, D_train, - evals=[(D_train, 'Train'), (D_valid, 'Valid')], + evals=[(D_train, "Train"), (D_valid, "Valid")], num_boost_round=num_boost_round, - callbacks=[plotting]) + callbacks=[plotting], + ) def check_point_callback(): @@ -90,10 +92,10 @@ def check(as_pickle): if i == 0: continue if as_pickle: - path = os.path.join(tmpdir, 'model_' + str(i) + '.pkl') + path = os.path.join(tmpdir, "model_" + str(i) + ".pkl") else: - path = os.path.join(tmpdir, 'model_' + str(i) + '.json') - assert(os.path.exists(path)) + path = os.path.join(tmpdir, "model_" + str(i) + ".json") + assert os.path.exists(path) X, y = load_breast_cancer(return_X_y=True) m = xgb.DMatrix(X, y) @@ -101,31 +103,36 @@ def check(as_pickle): with tempfile.TemporaryDirectory() as tmpdir: # Use callback class from xgboost.callback # Feel free to subclass/customize it to suit your need. - check_point = xgb.callback.TrainingCheckPoint(directory=tmpdir, - iterations=rounds, - name='model') - xgb.train({'objective': 'binary:logistic'}, m, - num_boost_round=10, - verbose_eval=False, - callbacks=[check_point]) + check_point = xgb.callback.TrainingCheckPoint( + directory=tmpdir, iterations=rounds, name="model" + ) + xgb.train( + {"objective": "binary:logistic"}, + m, + num_boost_round=10, + verbose_eval=False, + callbacks=[check_point], + ) check(False) # This version of checkpoint saves everything including parameters and # model. See: doc/tutorials/saving_model.rst - check_point = xgb.callback.TrainingCheckPoint(directory=tmpdir, - iterations=rounds, - as_pickle=True, - name='model') - xgb.train({'objective': 'binary:logistic'}, m, - num_boost_round=10, - verbose_eval=False, - callbacks=[check_point]) + check_point = xgb.callback.TrainingCheckPoint( + directory=tmpdir, iterations=rounds, as_pickle=True, name="model" + ) + xgb.train( + {"objective": "binary:logistic"}, + m, + num_boost_round=10, + verbose_eval=False, + callbacks=[check_point], + ) check(True) -if __name__ == '__main__': +if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--plot', default=1, type=int) + parser.add_argument("--plot", default=1, type=int) args = parser.parse_args() check_point_callback() diff --git a/doc/python/model.rst b/doc/python/model.rst index c854043b36f1..5ea38164ae15 100644 --- a/doc/python/model.rst +++ b/doc/python/model.rst @@ -37,3 +37,7 @@ The sliced model is a copy of selected trees, that means the model itself is imm during slicing. This feature is the basis of `save_best` option in early stopping callback. See :ref:`sphx_glr_python_examples_individual_trees.py` for a worked example on how to combine prediction with sliced trees. + +.. note:: + + The returned model slice doesn't contain attributes like :py:class:`~xgboost.Booster.best_iteration` and :py:class:`~xgboost.Booster.best_score`. diff --git a/python-package/xgboost/callback.py b/python-package/xgboost/callback.py index 88e34073711c..6077aa1e3188 100644 --- a/python-package/xgboost/callback.py +++ b/python-package/xgboost/callback.py @@ -134,13 +134,17 @@ def __init__( is_cv: bool = False, ) -> None: self.callbacks = set(callbacks) - if metric is not None: - msg = ( - "metric must be callable object for monitoring. For " - + "builtin metrics, passing them in training parameter" - + " will invoke monitor automatically." - ) - assert callable(metric), msg + for cb in callbacks: + if not isinstance(cb, TrainingCallback): + raise TypeError("callback must be an instance of `TrainingCallback`.") + + msg = ( + "metric must be callable object for monitoring. For builtin metrics" + ", passing them in training parameter invokes monitor automatically." + ) + if metric is not None and not callable(metric): + raise TypeError(msg) + self.metric = metric self.history: TrainingCallback.EvalsLog = collections.OrderedDict() self._output_margin = output_margin @@ -170,16 +174,6 @@ def after_training(self, model: _Model) -> _Model: else: assert isinstance(model, Booster), msg - if not self.is_cv: - if model.attr("best_score") is not None: - model.best_score = float(cast(str, model.attr("best_score"))) - model.best_iteration = int(cast(str, model.attr("best_iteration"))) - else: - # Due to compatibility with version older than 1.4, these attributes are - # added to Python object even if early stopping is not used. - model.best_iteration = model.num_boosted_rounds() - 1 - model.set_attr(best_iteration=str(model.best_iteration)) - return model def before_iteration( @@ -267,9 +261,14 @@ class LearningRateScheduler(TrainingCallback): def __init__( self, learning_rates: Union[Callable[[int], float], Sequence[float]] ) -> None: - assert callable(learning_rates) or isinstance( + if not callable(learning_rates) and not isinstance( learning_rates, collections.abc.Sequence - ) + ): + raise TypeError( + "Invalid learning rates, expecting callable or sequence, got: " + f"{type(learning_rates)}" + ) + if callable(learning_rates): self.learning_rates = learning_rates else: @@ -302,24 +301,28 @@ class EarlyStopping(TrainingCallback): save_best : Whether training should return the best model or the last model. min_delta : - Minimum absolute change in score to be qualified as an improvement. .. versionadded:: 1.5.0 - .. code-block:: python + Minimum absolute change in score to be qualified as an improvement. - es = xgboost.callback.EarlyStopping( - rounds=2, - min_delta=1e-3, - save_best=True, - maximize=False, - data_name="validation_0", - metric_name="mlogloss", - ) - clf = xgboost.XGBClassifier(tree_method="gpu_hist", callbacks=[es]) + Examples + -------- + + .. code-block:: python - X, y = load_digits(return_X_y=True) - clf.fit(X, y, eval_set=[(X, y)]) + es = xgboost.callback.EarlyStopping( + rounds=2, + min_delta=1e-3, + save_best=True, + maximize=False, + data_name="validation_0", + metric_name="mlogloss", + ) + clf = xgboost.XGBClassifier(tree_method="hist", device="cuda", callbacks=[es]) + + X, y = load_digits(return_X_y=True) + clf.fit(X, y, eval_set=[(X, y)]) """ # pylint: disable=too-many-arguments @@ -363,7 +366,7 @@ def maximize(new: _Score, best: _Score) -> bool: return numpy.greater(get_s(new) - self._min_delta, get_s(best)) def minimize(new: _Score, best: _Score) -> bool: - """New score should be smaller than the old one.""" + """New score should be lesser than the old one.""" return numpy.greater(get_s(best) - self._min_delta, get_s(new)) if self.maximize is None: @@ -419,38 +422,53 @@ def after_iteration( ) -> bool: epoch += self.starting_round # training continuation msg = "Must have at least 1 validation dataset for early stopping." - assert len(evals_log.keys()) >= 1, msg - data_name = "" + if len(evals_log.keys()) < 1: + raise ValueError(msg) + + # Get data name if self.data: - for d, _ in evals_log.items(): - if d == self.data: - data_name = d - if not data_name: - raise ValueError("No dataset named:", self.data) + data_name = self.data else: # Use the last one as default. data_name = list(evals_log.keys())[-1] - assert isinstance(data_name, str) and data_name + if data_name not in evals_log: + raise ValueError(f"No dataset named: {data_name}") + + if not isinstance(data_name, str): + raise TypeError( + f"The name of the dataset should be a string. Got: {type(data_name)}" + ) data_log = evals_log[data_name] - # Filter out scores that can not be used for early stopping. + # Get metric name if self.metric_name: metric_name = self.metric_name else: # Use last metric by default. - assert isinstance(data_log, collections.OrderedDict) metric_name = list(data_log.keys())[-1] + if metric_name not in data_log: + raise ValueError(f"No metric named: {metric_name}") + + # The latest score score = data_log[metric_name][-1] return self._update_rounds(score, data_name, metric_name, model, epoch) def after_training(self, model: _Model) -> _Model: + if not self.save_best: + return model + try: - if self.save_best: - model = model[: int(model.attr("best_iteration")) + 1] + best_iteration = model.best_iteration + best_score = model.best_score + assert best_iteration is not None and best_score is not None + model = model[: best_iteration + 1] + model.best_iteration = best_iteration + model.best_score = best_score except XGBoostError as e: raise XGBoostError( - "`save_best` is not applicable to current booster" + "`save_best` is not applicable to the current booster" ) from e + return model @@ -462,8 +480,6 @@ class EvaluationMonitor(TrainingCallback): Parameters ---------- - metric : - Extra user defined metric. rank : Which worker should be used for printing the result. period : diff --git a/python-package/xgboost/core.py b/python-package/xgboost/core.py index 4cacd61f3bb9..febcf43d7b5d 100644 --- a/python-package/xgboost/core.py +++ b/python-package/xgboost/core.py @@ -1905,7 +1905,7 @@ def attributes(self) -> Dict[str, Optional[str]]: attr_names = from_cstr_to_pystr(sarr, length) return {n: self.attr(n) for n in attr_names} - def set_attr(self, **kwargs: Optional[str]) -> None: + def set_attr(self, **kwargs: Optional[Any]) -> None: """Set the attribute of the Booster. Parameters @@ -2574,10 +2574,35 @@ def load_model(self, fname: ModelIn) -> None: else: raise TypeError("Unknown file type: ", fname) - if self.attr("best_iteration") is not None: - self.best_iteration = int(cast(int, self.attr("best_iteration"))) - if self.attr("best_score") is not None: - self.best_score = float(cast(float, self.attr("best_score"))) + @property + def best_iteration(self) -> int: + """The best iteration during training.""" + best = self.attr("best_iteration") + if best is not None: + return int(best) + + raise AttributeError( + "`best_iteration` is only defined when early stopping is used." + ) + + @best_iteration.setter + def best_iteration(self, iteration: int) -> None: + self.set_attr(best_iteration=iteration) + + @property + def best_score(self) -> float: + """The best evaluation score during training.""" + best = self.attr("best_score") + if best is not None: + return float(best) + + raise AttributeError( + "`best_score` is only defined when early stopping is used." + ) + + @best_score.setter + def best_score(self, score: int) -> None: + self.set_attr(best_score=score) def num_boosted_rounds(self) -> int: """Get number of boosted rounds. For gblinear this is reset to 0 after diff --git a/python-package/xgboost/sklearn.py b/python-package/xgboost/sklearn.py index 46a3ffa4aec1..a1a2519a37a5 100644 --- a/python-package/xgboost/sklearn.py +++ b/python-package/xgboost/sklearn.py @@ -230,10 +230,10 @@ def task(i: int) -> float: subsample : Optional[float] Subsample ratio of the training instance. sampling_method : - Sampling method. Used only by `gpu_hist` tree method. - - `uniform`: select random training instances uniformly. - - `gradient_based` select random training instances with higher probability when - the gradient and hessian are larger. (cf. CatBoost) + Sampling method. Used only by the GPU version of ``hist`` tree method. + - ``uniform``: select random training instances uniformly. + - ``gradient_based`` select random training instances with higher probability + when the gradient and hessian are larger. (cf. CatBoost) colsample_bytree : Optional[float] Subsample ratio of columns when constructing each tree. colsample_bylevel : Optional[float] @@ -986,12 +986,12 @@ def fit( X : Feature matrix. See :ref:`py-data` for a list of supported types. - When the ``tree_method`` is set to ``hist`` or ``gpu_hist``, internally, the + When the ``tree_method`` is set to ``hist``, internally, the :py:class:`QuantileDMatrix` will be used instead of the :py:class:`DMatrix` for conserving memory. However, this has performance implications when the device of input data is not matched with algorithm. For instance, if the - input is a numpy array on CPU but ``gpu_hist`` is used for training, then - the data is first processed on CPU then transferred to GPU. + input is a numpy array on CPU but ``cuda`` is used for training, then the + data is first processed on CPU then transferred to GPU. y : Labels sample_weight : @@ -1273,19 +1273,10 @@ def feature_names_in_(self) -> np.ndarray: ) return np.array(feature_names) - def _early_stopping_attr(self, attr: str) -> Union[float, int]: - booster = self.get_booster() - try: - return getattr(booster, attr) - except AttributeError as e: - raise AttributeError( - f"`{attr}` in only defined when early stopping is used." - ) from e - @property def best_score(self) -> float: """The best score obtained by early stopping.""" - return float(self._early_stopping_attr("best_score")) + return self.get_booster().best_score @property def best_iteration(self) -> int: @@ -1293,7 +1284,7 @@ def best_iteration(self) -> int: for instance if the best iteration is the first round, then best_iteration is 0. """ - return int(self._early_stopping_attr("best_iteration")) + return self.get_booster().best_iteration @property def feature_importances_(self) -> np.ndarray: @@ -1920,12 +1911,12 @@ def fit( | 1 | :math:`x_{20}` | :math:`x_{21}` | +-----+----------------+----------------+ - When the ``tree_method`` is set to ``hist`` or ``gpu_hist``, internally, the + When the ``tree_method`` is set to ``hist``, internally, the :py:class:`QuantileDMatrix` will be used instead of the :py:class:`DMatrix` for conserving memory. However, this has performance implications when the device of input data is not matched with algorithm. For instance, if the - input is a numpy array on CPU but ``gpu_hist`` is used for training, then - the data is first processed on CPU then transferred to GPU. + input is a numpy array on CPU but ``cuda`` is used for training, then the + data is first processed on CPU then transferred to GPU. y : Labels group : diff --git a/python-package/xgboost/training.py b/python-package/xgboost/training.py index a238e73c8501..aa3c18a01e8b 100644 --- a/python-package/xgboost/training.py +++ b/python-package/xgboost/training.py @@ -28,17 +28,6 @@ _CVFolds = Sequence["CVPack"] -def _assert_new_callback(callbacks: Optional[Sequence[TrainingCallback]]) -> None: - is_new_callback: bool = not callbacks or all( - isinstance(c, TrainingCallback) for c in callbacks - ) - if not is_new_callback: - link = "https://xgboost.readthedocs.io/en/latest/python/callbacks.html" - raise ValueError( - f"Old style callback was removed in version 1.6. See: {link}." - ) - - def _configure_custom_metric( feval: Optional[Metric], custom_metric: Optional[Metric] ) -> Optional[Metric]: @@ -170,7 +159,6 @@ def train( bst = Booster(params, [dtrain] + [d[0] for d in evals], model_file=xgb_model) start_iteration = 0 - _assert_new_callback(callbacks) if verbose_eval: verbose_eval = 1 if verbose_eval is True else verbose_eval callbacks.append(EvaluationMonitor(period=verbose_eval)) @@ -247,7 +235,7 @@ def eval( result = [f.eval(iteration, feval, output_margin) for f in self.cvfolds] return result - def set_attr(self, **kwargs: Optional[str]) -> Any: + def set_attr(self, **kwargs: Optional[Any]) -> Any: """Iterate through folds for setting attributes""" for f in self.cvfolds: f.bst.set_attr(**kwargs) @@ -274,11 +262,20 @@ def best_iteration(self) -> int: """Get best_iteration""" return int(cast(int, self.cvfolds[0].bst.attr("best_iteration"))) + @best_iteration.setter + def best_iteration(self, iteration: int) -> None: + """Get best_iteration""" + self.set_attr(best_iteration=iteration) + @property def best_score(self) -> float: """Get best_score.""" return float(cast(float, self.cvfolds[0].bst.attr("best_score"))) + @best_score.setter + def best_score(self, score: float) -> None: + self.set_attr(best_score=score) + def groups_to_rows(groups: List[np.ndarray], boundaries: np.ndarray) -> np.ndarray: """ @@ -551,7 +548,6 @@ def cv( # setup callbacks callbacks = [] if callbacks is None else copy.copy(list(callbacks)) - _assert_new_callback(callbacks) if verbose_eval: verbose_eval = 1 if verbose_eval is True else verbose_eval diff --git a/tests/ci_build/lint_python.py b/tests/ci_build/lint_python.py index ca5d56e4c39b..e4eb72df53d1 100644 --- a/tests/ci_build/lint_python.py +++ b/tests/ci_build/lint_python.py @@ -35,6 +35,7 @@ class LintersPaths: "demo/dask/", "demo/json-model/json_parser.py", "demo/guide-python/cat_in_the_dat.py", + "demo/guide-python/callbacks.py", "demo/guide-python/categorical.py", "demo/guide-python/feature_weights.py", "demo/guide-python/sklearn_parallel.py", diff --git a/tests/python/test_callback.py b/tests/python/test_callback.py index d3ec05e6eb54..56c9fdabdde3 100644 --- a/tests/python/test_callback.py +++ b/tests/python/test_callback.py @@ -1,7 +1,6 @@ import json import os import tempfile -from contextlib import nullcontext from typing import Union import pytest @@ -104,15 +103,6 @@ def test_early_stopping(self): dump = booster.get_dump(dump_format='json') assert len(dump) - booster.best_iteration == early_stopping_rounds + 1 - # No early stopping, best_iteration should be set to last epoch - booster = xgb.train({'objective': 'binary:logistic', - 'eval_metric': 'error'}, D_train, - evals=[(D_train, 'Train'), (D_valid, 'Valid')], - num_boost_round=10, - evals_result=evals_result, - verbose_eval=True) - assert booster.num_boosted_rounds() - 1 == booster.best_iteration - def test_early_stopping_custom_eval(self): D_train = xgb.DMatrix(self.X_train, self.y_train) D_valid = xgb.DMatrix(self.X_valid, self.y_valid) @@ -204,8 +194,9 @@ def test_early_stopping_save_best_model(self): X, y = load_breast_cancer(return_X_y=True) n_estimators = 100 early_stopping_rounds = 5 - early_stop = xgb.callback.EarlyStopping(rounds=early_stopping_rounds, - save_best=True) + early_stop = xgb.callback.EarlyStopping( + rounds=early_stopping_rounds, save_best=True + ) cls = xgb.XGBClassifier( n_estimators=n_estimators, eval_metric=tm.eval_error_metric_skl, @@ -216,20 +207,27 @@ def test_early_stopping_save_best_model(self): dump = booster.get_dump(dump_format='json') assert len(dump) == booster.best_iteration + 1 - early_stop = xgb.callback.EarlyStopping(rounds=early_stopping_rounds, - save_best=True) + early_stop = xgb.callback.EarlyStopping( + rounds=early_stopping_rounds, save_best=True + ) cls = xgb.XGBClassifier( - booster='gblinear', n_estimators=10, eval_metric=tm.eval_error_metric_skl + booster="gblinear", + n_estimators=10, + eval_metric=tm.eval_error_metric_skl, + callbacks=[early_stop], ) with pytest.raises(ValueError): - cls.fit(X, y, eval_set=[(X, y)], callbacks=[early_stop]) + cls.fit(X, y, eval_set=[(X, y)]) # No error early_stop = xgb.callback.EarlyStopping(rounds=early_stopping_rounds, save_best=False) xgb.XGBClassifier( - booster='gblinear', n_estimators=10, eval_metric=tm.eval_error_metric_skl - ).fit(X, y, eval_set=[(X, y)], callbacks=[early_stop]) + booster="gblinear", + n_estimators=10, + eval_metric=tm.eval_error_metric_skl, + callbacks=[early_stop], + ).fit(X, y, eval_set=[(X, y)]) def test_early_stopping_continuation(self): from sklearn.datasets import load_breast_cancer @@ -252,8 +250,11 @@ def test_early_stopping_continuation(self): cls.load_model(path) assert cls._Booster is not None early_stopping_rounds = 3 - cls.set_params(eval_metric=tm.eval_error_metric_skl) - cls.fit(X, y, eval_set=[(X, y)], early_stopping_rounds=early_stopping_rounds) + cls.set_params( + eval_metric=tm.eval_error_metric_skl, + early_stopping_rounds=early_stopping_rounds, + ) + cls.fit(X, y, eval_set=[(X, y)]) booster = cls.get_booster() assert booster.num_boosted_rounds() == \ booster.best_iteration + early_stopping_rounds + 1 @@ -280,20 +281,20 @@ def run_eta_decay(self, tree_method): watchlist = [(dtest, 'eval'), (dtrain, 'train')] num_round = 4 - warning_check = nullcontext() - # learning_rates as a list # init eta with 0 to check whether learning_rates work param = {'max_depth': 2, 'eta': 0, 'verbosity': 0, 'objective': 'binary:logistic', 'eval_metric': 'error', 'tree_method': tree_method} evals_result = {} - with warning_check: - bst = xgb.train(param, dtrain, num_round, watchlist, - callbacks=[scheduler([ - 0.8, 0.7, 0.6, 0.5 - ])], - evals_result=evals_result) + bst = xgb.train( + param, + dtrain, + num_round, + evals=watchlist, + callbacks=[scheduler([0.8, 0.7, 0.6, 0.5])], + evals_result=evals_result, + ) eval_errors_0 = list(map(float, evals_result['eval']['error'])) assert isinstance(bst, xgb.core.Booster) # validation error should decrease, if eta > 0 @@ -304,11 +305,15 @@ def run_eta_decay(self, tree_method): 'objective': 'binary:logistic', 'eval_metric': 'error', 'tree_method': tree_method} evals_result = {} - with warning_check: - bst = xgb.train(param, dtrain, num_round, watchlist, - callbacks=[scheduler( - [0.8, 0.7, 0.6, 0.5])], - evals_result=evals_result) + + bst = xgb.train( + param, + dtrain, + num_round, + evals=watchlist, + callbacks=[scheduler([0.8, 0.7, 0.6, 0.5])], + evals_result=evals_result, + ) eval_errors_1 = list(map(float, evals_result['eval']['error'])) assert isinstance(bst, xgb.core.Booster) # validation error should decrease, if learning_rate > 0 @@ -320,12 +325,14 @@ def run_eta_decay(self, tree_method): 'eval_metric': 'error', 'tree_method': tree_method } evals_result = {} - with warning_check: - bst = xgb.train(param, dtrain, num_round, watchlist, - callbacks=[scheduler( - [0, 0, 0, 0] - )], - evals_result=evals_result) + bst = xgb.train( + param, + dtrain, + num_round, + evals=watchlist, + callbacks=[scheduler([0, 0, 0, 0])], + evals_result=evals_result, + ) eval_errors_2 = list(map(float, evals_result['eval']['error'])) assert isinstance(bst, xgb.core.Booster) # validation error should not decrease, if eta/learning_rate = 0 @@ -336,12 +343,14 @@ def eta_decay(ithround, num_boost_round=num_round): return num_boost_round / (ithround + 1) evals_result = {} - with warning_check: - bst = xgb.train(param, dtrain, num_round, watchlist, - callbacks=[ - scheduler(eta_decay) - ], - evals_result=evals_result) + bst = xgb.train( + param, + dtrain, + num_round, + evals=watchlist, + callbacks=[scheduler(eta_decay)], + evals_result=evals_result, + ) eval_errors_3 = list(map(float, evals_result['eval']['error'])) assert isinstance(bst, xgb.core.Booster) @@ -351,8 +360,7 @@ def eta_decay(ithround, num_boost_round=num_round): for i in range(1, len(eval_errors_0)): assert eval_errors_3[i] != eval_errors_2[i] - with warning_check: - xgb.cv(param, dtrain, num_round, callbacks=[scheduler(eta_decay)]) + xgb.cv(param, dtrain, num_round, callbacks=[scheduler(eta_decay)]) def run_eta_decay_leaf_output(self, tree_method: str, objective: str) -> None: # check decay has effect on leaf output. @@ -378,7 +386,7 @@ def eta_decay_0(i): param, dtrain, num_round, - watchlist, + evals=watchlist, callbacks=[scheduler(eta_decay_0)], ) @@ -391,7 +399,7 @@ def eta_decay_1(i: int) -> float: param, dtrain, num_round, - watchlist, + evals=watchlist, callbacks=[scheduler(eta_decay_1)], ) bst_json0 = bst0.save_raw(raw_format="json") @@ -474,3 +482,24 @@ def test_callback_list(self): callbacks=callbacks, ) assert len(callbacks) == 1 + + def test_attribute_error(self) -> None: + from sklearn.datasets import load_breast_cancer + + X, y = load_breast_cancer(return_X_y=True) + + clf = xgb.XGBClassifier(n_estimators=8) + clf.fit(X, y, eval_set=[(X, y)]) + + with pytest.raises(AttributeError, match="early stopping is used"): + clf.best_iteration + + with pytest.raises(AttributeError, match="early stopping is used"): + clf.best_score + + booster = clf.get_booster() + with pytest.raises(AttributeError, match="early stopping is used"): + booster.best_iteration + + with pytest.raises(AttributeError, match="early stopping is used"): + booster.best_score diff --git a/tests/python/test_predict.py b/tests/python/test_predict.py index 04a7d70cb948..6ed9c39f7e7e 100644 --- a/tests/python/test_predict.py +++ b/tests/python/test_predict.py @@ -173,7 +173,7 @@ def test_predict(self): np.testing.assert_allclose(predt_from_dmatrix, predt_from_array) with pytest.raises(ValueError): - booster.predict(test, iteration_range=(0, booster.best_iteration + 2)) + booster.predict(test, iteration_range=(0, booster.num_boosted_rounds() + 2)) default = booster.predict(test) @@ -181,7 +181,7 @@ def test_predict(self): np.testing.assert_allclose(range_full, default) range_full = booster.predict( - test, iteration_range=(0, booster.best_iteration + 1) + test, iteration_range=(0, booster.num_boosted_rounds()) ) np.testing.assert_allclose(range_full, default) diff --git a/tests/python/test_training_continuation.py b/tests/python/test_training_continuation.py index 3cbe6a4216b2..6b2f9630136d 100644 --- a/tests/python/test_training_continuation.py +++ b/tests/python/test_training_continuation.py @@ -100,8 +100,8 @@ def run_training_continuation(self, xgb_params_01, xgb_params_02, res2 = mean_squared_error( y_2class, gbdt_04.predict( - dtrain_2class, iteration_range=(0, gbdt_04.best_iteration + 1) - ) + dtrain_2class, iteration_range=(0, gbdt_04.num_boosted_rounds()) + ), ) assert res1 == res2 @@ -112,7 +112,7 @@ def run_training_continuation(self, xgb_params_01, xgb_params_02, res2 = mean_squared_error( y_2class, gbdt_04.predict( - dtrain_2class, iteration_range=(0, gbdt_04.best_iteration + 1) + dtrain_2class, iteration_range=(0, gbdt_04.num_boosted_rounds()) ) ) assert res1 == res2 @@ -126,7 +126,7 @@ def run_training_continuation(self, xgb_params_01, xgb_params_02, res1 = gbdt_05.predict(dtrain_5class) res2 = gbdt_05.predict( - dtrain_5class, iteration_range=(0, gbdt_05.best_iteration + 1) + dtrain_5class, iteration_range=(0, gbdt_05.num_boosted_rounds()) ) np.testing.assert_almost_equal(res1, res2) @@ -138,15 +138,16 @@ def test_training_continuation_json(self): @pytest.mark.skipif(**tm.no_sklearn()) def test_training_continuation_updaters_json(self): # Picked up from R tests. - updaters = 'grow_colmaker,prune,refresh' + updaters = "grow_colmaker,prune,refresh" params = self.generate_parameters() for p in params: - p['updater'] = updaters + p["updater"] = updaters self.run_training_continuation(params[0], params[1], params[2]) @pytest.mark.skipif(**tm.no_sklearn()) def test_changed_parameter(self): from sklearn.datasets import load_breast_cancer + X, y = load_breast_cancer(return_X_y=True) clf = xgb.XGBClassifier(n_estimators=2) clf.fit(X, y, eval_set=[(X, y)], eval_metric="logloss")