Skip to content

Commit

Permalink
add decorator
Browse files Browse the repository at this point in the history
  • Loading branch information
YJ Shi committed Jul 27, 2022
1 parent 701ba3c commit 9444134
Showing 1 changed file with 12 additions and 2 deletions.
14 changes: 12 additions & 2 deletions python/tvm/meta_schedule/cost_model/xgb_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,23 @@
from .metric import max_curve


if TYPE_CHECKING:
def optional_xgboost_callback(XGBoostCustomCallback):
# pylint:disable = import-outside-toplevel
try:
from xgboost.callback import TrainingCallback # type: ignore
except ImportError:

class TrainingCallback: # type: ignore
pass

class OptXGBoostCustomCallback(XGBoostCustomCallback, TrainingCallback):
pass

return OptXGBoostCustomCallback


if TYPE_CHECKING:

import xgboost as xgb # type: ignore

from ..tune_context import TuneContext
Expand Down Expand Up @@ -645,7 +654,8 @@ def average_peak_score(ys_pred: np.ndarray):
return eval_result


class XGBoostCustomCallback(TrainingCallback):
@optional_xgboost_callback
class XGBoostCustomCallback:
"""Custom callback class for xgboost to support multiple custom evaluation functions"""

def __init__(
Expand Down

0 comments on commit 9444134

Please sign in to comment.