Skip to content

Commit

Permalink
address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
merrymercy committed Aug 18, 2020
1 parent 90ae47c commit e51fc04
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 12 deletions.
8 changes: 6 additions & 2 deletions python/tvm/auto_scheduler/auto_schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,9 @@ def __init__(self, task, schedule_cost_model=RandomModel(), params=None, seed=No
seed or random.randint(1, 1 << 30), verbose, init_search_callbacks)

def generate_sketches(self, print_for_debug=False):
""" Generate the sketches. This is mainly used for debugging and testing.
""" Generate the sketches.
This python interface is mainly used for debugging and testing.
The actual search is all doen in c++.
Parameters
----------
Expand All @@ -181,7 +183,9 @@ def generate_sketches(self, print_for_debug=False):
return sketches

def sample_initial_population(self, pop_size):
"""Sample initial population. This is mainly used for debugging and testing.
"""Sample initial population.
This python interface is mainly used for debugging and testing.
The actual search is all doen in c++.
Parameters
----------
Expand Down
23 changes: 15 additions & 8 deletions python/tvm/auto_scheduler/cost_model/xgb_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def get(self, key, matrix, default=None):
The name of the attribute
matrix: xgb.DMatrix
The matrix
default: Optional
default: Optional[Any]
The default value if the item does not exist
"""
return self.context_dict[key].get(matrix.handle.value, default)
Expand All @@ -64,7 +64,7 @@ def set(self, key, matrix, value):
The name of the attribute
matrix: xgb.DMatrix
The matrix
value: Optional
value: Optional[Any]
The new value
"""
self.context_dict[key][matrix.handle.value] = value
Expand All @@ -76,7 +76,7 @@ class XGBModel(PythonBasedModel):
"""Train a XGBoost model to predict the normalized throughputs of programs.
Let the normalized throughput be the score of a program (higher is better). We predict
(approximiate) the score of a program = the sum of the scores of all stages in this program.
the (approximiate) score of a program = the sum of the scores of all stages in this program.
i.e. score(P) = score_s0 + score_s1 + ... + score_sn,
where score_si is the score of Stage i in Program P.
Expand Down Expand Up @@ -129,6 +129,7 @@ def update(self, inputs, results):
"""
if len(inputs) <= 0:
return
assert len(inputs) == len(results)

self.inputs.extend(inputs)
self.results.extend(results)
Expand Down Expand Up @@ -254,8 +255,8 @@ def load_log_file(self, file_name, n_lines=None):
----------
file_name: str
The filename
n_lines: int
Only
n_lines: Optional[int]
Only load first n lines of the log file
"""
inputs, results = RecordReader(file_name).read_lines(n_lines)
logger.info("XGBModel: Loaded %s measurement records from %s", len(inputs), file_name)
Expand Down Expand Up @@ -399,7 +400,9 @@ def pack_sum_square_error(preds, dtrain):
Returns
-------
gradient and hessian
gradient: np.ndarray
hessian: np.ndarray
gradient and hessian according to the xgboost format
"""
pack_ids = dmatrix_context.get("pack_ids", dtrain)
weight = dtrain.get_weight()
Expand Down Expand Up @@ -427,7 +430,9 @@ def pack_sum_rmse(raw_preds, labels):
Returns
-------
The name and value of the metric
name: str
score: float
The name and score of this metric
"""
pack_ids = dmatrix_context.get("pack_ids", labels)
preds = predict_throughput_pack_sum(raw_preds, pack_ids)[pack_ids]
Expand Down Expand Up @@ -458,7 +463,9 @@ def feval(preds, labels):
Returns
-------
The name and value of the metric
name: str
score: float
The name and score of this metric
"""
group_sizes = dmatrix_context.get('group_sizes', labels, [len(preds)])
pack_ids = dmatrix_context.get("pack_ids", labels)
Expand Down
6 changes: 4 additions & 2 deletions python/tvm/auto_scheduler/measure.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
from tvm.contrib import tar, ndk

from . import _ffi_api
from .loop_state import StateObject
from .utils import get_const_tuple, NoDaemonPool, call_func_with_timeout, request_remote, \
check_remote

Expand All @@ -71,11 +72,12 @@ class MeasureInput(Object):
Parameters
----------
task : SearchTask
The SearchTask of this measure.
state : State
The SearchTask of this measurement.
state : Union[State, StateObject]
The State to be measured.
"""
def __init__(self, task, state):
state = state if isinstance(state, StateObject) else state.state_object
self.__init_handle_by_constructor__(_ffi_api.MeasureInput, task, state)


Expand Down

0 comments on commit e51fc04

Please sign in to comment.