Skip to content

Commit

Permalink
Calling XGBModel.fit() should clear the Booster by default (#6562)
Browse files Browse the repository at this point in the history
* Calling XGBModel.fit() should clear the Booster by default

* Document the behavior of fit()

* Allow sklearn object to be passed in directly via xgb_model argument

* Fix lint
  • Loading branch information
hcho3 authored Dec 31, 2020
1 parent 5e9e525 commit fa13992
Showing 1 changed file with 16 additions and 3 deletions.
19 changes: 16 additions & 3 deletions python-package/xgboost/sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -501,8 +501,10 @@ def _configure_fit(
eval_metric: Optional[Union[Callable, str, List[str]]],
params: Dict[str, Any],
) -> Tuple[Booster, Optional[Union[Callable, str, List[str]]], Dict[str, Any]]:
model = self._Booster if hasattr(self, "_Booster") else None
model = booster if booster is not None else model
# pylint: disable=protected-access, no-self-use
model = booster
if hasattr(model, '_Booster'):
model = model._Booster # Handle the case when xgb_model is a sklearn model object
feval = eval_metric if callable(eval_metric) else None
if eval_metric is not None:
if callable(eval_metric):
Expand All @@ -518,7 +520,11 @@ def fit(self, X, y, *, sample_weight=None, base_margin=None,
feature_weights=None,
callbacks=None):
# pylint: disable=invalid-name,attribute-defined-outside-init
"""Fit gradient boosting model
"""Fit gradient boosting model.
Note that calling ``fit()`` multiple times will cause the model object to be re-fit from
scratch. To resume training from a previous checkpoint, explicitly pass ``xgb_model``
argument.
Parameters
----------
Expand Down Expand Up @@ -1212,6 +1218,10 @@ def fit(self, X, y, *, group, sample_weight=None, base_margin=None,
# pylint: disable = attribute-defined-outside-init,arguments-differ
"""Fit gradient boosting ranker
Note that calling ``fit()`` multiple times will cause the model object to be re-fit from
scratch. To resume training from a previous checkpoint, explicitly pass ``xgb_model``
argument.
Parameters
----------
X : array_like
Expand Down Expand Up @@ -1322,6 +1332,9 @@ def fit(self, X, y, *, group, sample_weight=None, base_margin=None,
raise ValueError(
'Custom evaluation metric is not yet supported for XGBRanker.')
params.update({'eval_metric': eval_metric})
if hasattr(xgb_model, '_Booster'):
# Handle the case when xgb_model is a sklearn model object
xgb_model = xgb_model._Booster # pylint: disable=protected-access

self._Booster = train(params, train_dmatrix,
self.n_estimators,
Expand Down

0 comments on commit fa13992

Please sign in to comment.