Skip to content

Commit

Permalink
fix lint
Browse files Browse the repository at this point in the history
  • Loading branch information
YJ Shi committed Jul 21, 2022
1 parent 6e031bf commit 2fda92d
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions python/tvm/meta_schedule/cost_model/xgb_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -774,7 +774,7 @@ class XGBoostCallback(TrainingCallback):
"""Base class for XGBoost callbacks."""

def __call__(self, env: "xgb.core.CallbackEnv"):
"""Compatibility with xgboost<1.3"""
# Compatibility with xgboost < 1.3
return self.after_iteration(env.model, env.iteration, env.evaluation_result_list)

def after_iteration(self, model: "xgb.Booster", epoch: int, evals_log: Dict):
Expand Down Expand Up @@ -805,6 +805,7 @@ def __init__(
self.aggregated_cv = None

def init(self, model: "xgb.Booster"):
"""Internal function for intialization"""
booster: "xgb.Booster" = model
self.state["best_iteration"] = 0
self.state["best_score"] = float("inf")
Expand All @@ -820,10 +821,12 @@ def init(self, model: "xgb.Booster"):
booster.set_attr(best_score=str(self.state["best_score"]))

def after_iteration(self, model: "xgb.Booster", epoch: int, evals_log: Dict):
"""Internal function for after_iteration"""
# pylint:disable = import-outside-toplevel
try:
from xgboost.callback import _fmt_metric # type: ignore
except ImportError:
"""Compatibility with xgboost>=1.6"""
# Compatibility with xgboost >= 1.6

def _fmt_metric(value, show_stdv=True):
if len(value) == 2:
Expand All @@ -834,6 +837,7 @@ def _fmt_metric(value, show_stdv=True):
return f"{value[0]}:{value[1]:.5f}"
raise ValueError("wrong metric value", value)

import xgboost as xgb
from xgboost import rabit # type: ignore

try:
Expand Down

0 comments on commit 2fda92d

Please sign in to comment.