From e51fc042fadb88e2574774083dec578d4935d4f4 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Tue, 18 Aug 2020 10:48:10 -0700 Subject: [PATCH] address comments --- python/tvm/auto_scheduler/auto_schedule.py | 8 +++++-- .../auto_scheduler/cost_model/xgb_model.py | 23 ++++++++++++------- python/tvm/auto_scheduler/measure.py | 6 +++-- 3 files changed, 25 insertions(+), 12 deletions(-) diff --git a/python/tvm/auto_scheduler/auto_schedule.py b/python/tvm/auto_scheduler/auto_schedule.py index a81b2bb7ede5..eb5a3fb49934 100644 --- a/python/tvm/auto_scheduler/auto_schedule.py +++ b/python/tvm/auto_scheduler/auto_schedule.py @@ -161,7 +161,9 @@ def __init__(self, task, schedule_cost_model=RandomModel(), params=None, seed=No seed or random.randint(1, 1 << 30), verbose, init_search_callbacks) def generate_sketches(self, print_for_debug=False): - """ Generate the sketches. This is mainly used for debugging and testing. + """ Generate the sketches. + This python interface is mainly used for debugging and testing. + The actual search is all doen in c++. Parameters ---------- @@ -181,7 +183,9 @@ def generate_sketches(self, print_for_debug=False): return sketches def sample_initial_population(self, pop_size): - """Sample initial population. This is mainly used for debugging and testing. + """Sample initial population. + This python interface is mainly used for debugging and testing. + The actual search is all doen in c++. Parameters ---------- diff --git a/python/tvm/auto_scheduler/cost_model/xgb_model.py b/python/tvm/auto_scheduler/cost_model/xgb_model.py index fac4056b24f4..043a3f474bf2 100644 --- a/python/tvm/auto_scheduler/cost_model/xgb_model.py +++ b/python/tvm/auto_scheduler/cost_model/xgb_model.py @@ -49,7 +49,7 @@ def get(self, key, matrix, default=None): The name of the attribute matrix: xgb.DMatrix The matrix - default: Optional + default: Optional[Any] The default value if the item does not exist """ return self.context_dict[key].get(matrix.handle.value, default) @@ -64,7 +64,7 @@ def set(self, key, matrix, value): The name of the attribute matrix: xgb.DMatrix The matrix - value: Optional + value: Optional[Any] The new value """ self.context_dict[key][matrix.handle.value] = value @@ -76,7 +76,7 @@ class XGBModel(PythonBasedModel): """Train a XGBoost model to predict the normalized throughputs of programs. Let the normalized throughput be the score of a program (higher is better). We predict - (approximiate) the score of a program = the sum of the scores of all stages in this program. + the (approximiate) score of a program = the sum of the scores of all stages in this program. i.e. score(P) = score_s0 + score_s1 + ... + score_sn, where score_si is the score of Stage i in Program P. @@ -129,6 +129,7 @@ def update(self, inputs, results): """ if len(inputs) <= 0: return + assert len(inputs) == len(results) self.inputs.extend(inputs) self.results.extend(results) @@ -254,8 +255,8 @@ def load_log_file(self, file_name, n_lines=None): ---------- file_name: str The filename - n_lines: int - Only + n_lines: Optional[int] + Only load first n lines of the log file """ inputs, results = RecordReader(file_name).read_lines(n_lines) logger.info("XGBModel: Loaded %s measurement records from %s", len(inputs), file_name) @@ -399,7 +400,9 @@ def pack_sum_square_error(preds, dtrain): Returns ------- - gradient and hessian + gradient: np.ndarray + hessian: np.ndarray + gradient and hessian according to the xgboost format """ pack_ids = dmatrix_context.get("pack_ids", dtrain) weight = dtrain.get_weight() @@ -427,7 +430,9 @@ def pack_sum_rmse(raw_preds, labels): Returns ------- - The name and value of the metric + name: str + score: float + The name and score of this metric """ pack_ids = dmatrix_context.get("pack_ids", labels) preds = predict_throughput_pack_sum(raw_preds, pack_ids)[pack_ids] @@ -458,7 +463,9 @@ def feval(preds, labels): Returns ------- - The name and value of the metric + name: str + score: float + The name and score of this metric """ group_sizes = dmatrix_context.get('group_sizes', labels, [len(preds)]) pack_ids = dmatrix_context.get("pack_ids", labels) diff --git a/python/tvm/auto_scheduler/measure.py b/python/tvm/auto_scheduler/measure.py index 7d97ade1ca84..925de2f871e6 100644 --- a/python/tvm/auto_scheduler/measure.py +++ b/python/tvm/auto_scheduler/measure.py @@ -48,6 +48,7 @@ from tvm.contrib import tar, ndk from . import _ffi_api +from .loop_state import StateObject from .utils import get_const_tuple, NoDaemonPool, call_func_with_timeout, request_remote, \ check_remote @@ -71,11 +72,12 @@ class MeasureInput(Object): Parameters ---------- task : SearchTask - The SearchTask of this measure. - state : State + The SearchTask of this measurement. + state : Union[State, StateObject] The State to be measured. """ def __init__(self, task, state): + state = state if isinstance(state, StateObject) else state.state_object self.__init_handle_by_constructor__(_ffi_api.MeasureInput, task, state)