Skip to content

Commit

Permalink
[AutoScheduler] Add sampling to dispatcher (#7376)
Browse files Browse the repository at this point in the history
* [AutoScheduler] Add sampling to dispatcher

* address comment

* make measurment configurable
  • Loading branch information
comaniac authored Feb 9, 2021
1 parent 2b8d113 commit 0716c2a
Show file tree
Hide file tree
Showing 3 changed files with 106 additions and 6 deletions.
2 changes: 1 addition & 1 deletion python/tvm/auto_scheduler/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
# Shortcut
from .compute_dag import ComputeDAG, LayoutRewriteOption, get_shape_from_rewritten_layout
from .cost_model import RandomModel, XGBModel
from .dispatcher import DispatchContext, ApplyHistoryBest
from .dispatcher import DispatchContext, ApplyHistoryBest, ApplyHistoryBestOrSample
from .measure import (
MeasureInput,
MeasureResult,
Expand Down
93 changes: 92 additions & 1 deletion python/tvm/auto_scheduler/dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,13 @@

import numpy as np

from tvm.contrib.utils import tempdir
from tvm.tir.expr import FloatImm
from .measure_record import load_records
from .cost_model import RandomModel, XGBModel
from .measure import LocalRPCMeasureContext
from .measure_record import RecordToFile, load_records
from .search_policy import PreloadMeasuredStates, SketchPolicy
from .search_task import SearchTask, TuningOptions
from .utils import calc_workload_dis_factor, decode_workload_key

logger = logging.getLogger("auto_scheduler")
Expand Down Expand Up @@ -301,6 +306,92 @@ def update(self, target, workload_key, state):
entry[workload_args] = (state, 1)


class ApplyHistoryBestOrSample(ApplyHistoryBest):
"""
Apply the history best config, or sample a valid schedule if no config is found.
Parameters
----------
records : str or iterator of (auto_scheduler.measure.MeasureInput,\
auto_scheduler.measure.MeasureResult)
Collection of tuning records.
If is str, then it should be the filename of a records log file.
Each row of this file is an encoded record pair. Otherwise, it is an iterator.
sample_simple_workloads: bool
When False, sampling will not apply to simple workloads (w/o reduction).
cost_model_file: str
The filename of the pre-trained XGBoost cost model. If not present, then random
model will be used.
num_measure: int
Meausre the top-N rank of sampled schedules on the device. The default -1 means
no measurement and simply return the top-1 schedule ranked by the cost model.
"""

def __init__(
self, records, sample_simple_workloads=False, cost_model_file=None, num_measure=-1
):
self.sample_simple_workloads = sample_simple_workloads
self.num_measure = num_measure
self.log_dir = tempdir()
if cost_model_file is None:
self.cost_model = RandomModel()
else:
self.cost_model = XGBModel()
self.cost_model.load(cost_model_file)

super(ApplyHistoryBestOrSample, self).__init__(
records, n_lines=None, include_compatible=True
)

def query(self, target, workload_key, has_complex_op, dag):
if has_complex_op or self.sample_simple_workloads:
ret = self._query_inside(target, workload_key)
else:
ret = super(ApplyHistoryBestOrSample, self)._query_inside(target, workload_key)

if ret is None:
ret = self._old_ctx.query(target, workload_key, has_complex_op, dag)
return ret

def _query_inside(self, target, workload_key):
ret = super(ApplyHistoryBestOrSample, self)._query_inside(target, workload_key)
if ret is not None:
return ret

# Sampling valid schedules when no existing records can be used.
task = SearchTask(workload_key=workload_key, target=target)
measure_ctx = LocalRPCMeasureContext(min_repeat_ms=300)

log_file = self.log_dir.relpath("%s.log" % decode_workload_key(workload_key)[0])

while ret is None:
tune_option = TuningOptions(
num_measure_trials=self.num_measure,
runner=measure_ctx.runner,
measure_callbacks=[RecordToFile(log_file)],
verbose=0,
)
search_policy = SketchPolicy(
task,
self.cost_model,
params={
"eps_greedy": 0.01,
"sample_init_min_population": 64,
"evolutionary_search_num_iters": 0,
},
init_search_callbacks=[PreloadMeasuredStates(log_file)],
verbose=0,
)
task.tune(tune_option, search_policy)

# Load the sampled records and query again.
self.load(log_file)
ret = super(ApplyHistoryBestOrSample, self)._query_inside(target, workload_key)

del measure_ctx
return ret


class FallbackContext(DispatchContext):
"""
A fallback dispatch context.
Expand Down
17 changes: 13 additions & 4 deletions tests/python/relay/test_auto_scheduler_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,16 @@ def tune_network(network, target):
):
lib = relay.build(mod, target=target, params=params)

# Sample a schedule when missing
with auto_scheduler.ApplyHistoryBestOrSample(None, num_measure=2):
with tvm.transform.PassContext(
opt_level=3, config={"relay.backend.use_auto_scheduler": True}
):
lib2 = relay.build(mod, target=target, params=params)

# Compile without auto-scheduler and any other optimization for correctness check
with tvm.transform.PassContext(opt_level=0):
lib2 = relay.build(mod, target=target, params=params)
ref_lib = relay.build(mod, target=target, params=params)

# Check the correctness
def get_output(data, lib):
Expand All @@ -76,10 +83,12 @@ def get_output(data, lib):
else:
raise ValueError("Unknown network: " + network)

actual_output = get_output(data, lib)
expected_output = get_output(data, lib2)
actual_output1 = get_output(data, lib)
actual_output2 = get_output(data, lib2)
expected_output = get_output(data, ref_lib)

tvm.testing.assert_allclose(actual_output, expected_output, rtol=1e-4, atol=1e-4)
tvm.testing.assert_allclose(actual_output1, expected_output, rtol=1e-4, atol=1e-4)
tvm.testing.assert_allclose(actual_output2, expected_output, rtol=1e-4, atol=1e-4)


@tvm.testing.requires_cuda
Expand Down

0 comments on commit 0716c2a

Please sign in to comment.