Skip to content

Commit

Permalink
[AutoTVM] Fix None feature in AutoTVM tuning (#12760)
Browse files Browse the repository at this point in the history
This PR introduces a couple of fixes to make AutoTVM working more
robustly:
- Fixed a very rarecase that `None` could pop up in AutoTVM features;
- Fixed a misuse of `ARGS` in the testing script;
- Fixed the filename for caching.
  • Loading branch information
junrushao authored Sep 12, 2022
1 parent 9671aee commit 4d27664
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 11 deletions.
13 changes: 7 additions & 6 deletions python/tvm/autotvm/testing/tune_relay.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,19 +139,20 @@ def _parse_args():
tracker_key=parsed.rpc_key,
session_timeout_sec=600,
)
if ARGS.target.kind.name != "llvm" and ARGS.graph_tuner:
raise ValueError("GraphTuner only supports llvm target")
if ARGS.target.kind.name != "llvm" and ARGS.cpu_flush:
raise ValueError("cpu_flush only supports llvm target")
if ARGS.target.kind.name == "llvm" and not ARGS.cpu_flush:
warnings.warn("cpu_flush is not enabled for llvm target")
return parsed


ARGS = _parse_args()


def main():
if ARGS.target.kind.name != "llvm" and ARGS.graph_tuner:
raise ValueError("GraphTuner only supports llvm target")
if ARGS.target.kind.name != "llvm" and ARGS.cpu_flush:
raise ValueError("cpu_flush only supports llvm target")
if ARGS.target.kind.name == "llvm" and not ARGS.cpu_flush:
warnings.warn("cpu_flush is not enabled for llvm target")

log_file = os.path.join(ARGS.work_dir, f"{ARGS.workload}.json")
graph_opt_sch_file = os.path.join(ARGS.work_dir, f"{ARGS.workload}_graph_opt.log")
measure_option = autotvm.measure_option(
Expand Down
7 changes: 3 additions & 4 deletions python/tvm/autotvm/tuner/xgboost_cost_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,11 @@
import time

import numpy as np

from tvm.contrib.popen_pool import PopenPoolExecutor, StatusKind

from .. import feature
from ..utils import get_rank
from .metric import max_curve, recall_curve, cover_curve
from .metric import cover_curve, max_curve, recall_curve
from .model_based_tuner import CostModel, FeatureCache

xgb = None
Expand Down Expand Up @@ -346,7 +345,7 @@ def _get_feature(self, indexes):
ret = np.empty((len(indexes), feature_len), dtype=np.float32)
for i, ii in enumerate(indexes):
t = fea_cache[ii]
if t.shape[0] < feature_len:
if t is not None and t.shape[0] < feature_len:
t = np.pad(t, (0, feature_len - t.shape[0]))
ret[i, :] = t if t is not None else 0
return ret
Expand Down Expand Up @@ -449,8 +448,8 @@ def custom_callback(
):
"""callback function for xgboost to support multiple custom evaluation functions"""
# pylint: disable=import-outside-toplevel
from xgboost.core import EarlyStopException
from xgboost.callback import _fmt_metric
from xgboost.core import EarlyStopException

try:
from xgboost.training import aggcv
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/meta_schedule/testing/relay_workload.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ def get_network(
inputs: Tuple[str, List[int], str]
params_bytearray: bytearray

filename = f'relay-{name}-{",".join(str(i) for i in input_shape)}.json'
filename = f'relay-{name}-{layout}-{",".join(str(i) for i in input_shape)}.json'
cached = _load_cache(cache_dir, filename)
if cached is None:
with multiprocessing.Pool(processes=1) as pool:
Expand Down

0 comments on commit 4d27664

Please sign in to comment.