diff --git a/optgbm/sklearn.py b/optgbm/sklearn.py index 6e9c852..4052c06 100644 --- a/optgbm/sklearn.py +++ b/optgbm/sklearn.py @@ -152,10 +152,17 @@ def __call__(self, trial: trial_module.Trial) -> float: num_boost_round=self.n_estimators, ) # Dict[str, List[float]] best_iteration = callbacks[0].best_iteration_ # type: ignore + values = eval_hist[ + "{}-mean".format(self.eval_name) + ] # type: List[float] + evals_result = { + "cv_agg": {self.eval_name: values} + } # type: Dict[str, Dict[str, List[float]]] trial.set_user_attr("best_iteration", best_iteration) + trial.set_user_attr("evals_result", evals_result) - value = eval_hist["{}-mean".format(self.eval_name)][-1] # type: float + value = values[-1] # type: float is_best_trial = True # type: bool try: @@ -650,6 +657,7 @@ def fit( None if early_stopping_rounds is None else best_iteration ) self._best_score = self.study_.best_value + self._evals_result = self.study_.best_trial.user_attrs["evals_result"] self._objective = params["objective"] self.best_params_ = {**params, **self.study_.best_params} self.n_splits_ = cv.get_n_splits(X, y, groups=groups) diff --git a/tests/test_sklearn.py b/tests/test_sklearn.py index 22fe1c6..1b5ea56 100644 --- a/tests/test_sklearn.py +++ b/tests/test_sklearn.py @@ -374,3 +374,15 @@ def test_plot_importance(n_jobs: int) -> None: clf.fit(X, y) lgb.plot_importance(clf) + + +def test_plot_metric() -> None: + X, y = load_breast_cancer(return_X_y=True) + + clf = OGBMClassifier( + n_estimators=n_estimators, n_trials=n_trials, refit=False + ) + + clf.fit(X, y) + + lgb.plot_metric(clf)