Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AutoScheduler] Add sampling to dispatcher #7376

Merged
merged 3 commits into from
Feb 9, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion python/tvm/auto_scheduler/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
# Shortcut
from .compute_dag import ComputeDAG, LayoutRewriteOption, get_shape_from_rewritten_layout
from .cost_model import RandomModel, XGBModel
from .dispatcher import DispatchContext, ApplyHistoryBest
from .dispatcher import DispatchContext, ApplyHistoryBest, ApplyHistoryBestOrSample
from .measure import (
MeasureInput,
MeasureResult,
Expand Down
93 changes: 92 additions & 1 deletion python/tvm/auto_scheduler/dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,13 @@

import numpy as np

from tvm.contrib.utils import tempdir
from tvm.tir.expr import FloatImm
from .measure_record import load_records
from .cost_model import RandomModel, XGBModel
from .measure import LocalRPCMeasureContext
from .measure_record import RecordToFile, load_records
from .search_policy import PreloadMeasuredStates, SketchPolicy
from .search_task import SearchTask, TuningOptions
from .utils import calc_workload_dis_factor, decode_workload_key

logger = logging.getLogger("auto_scheduler")
Expand Down Expand Up @@ -301,6 +306,92 @@ def update(self, target, workload_key, state):
entry[workload_args] = (state, 1)


class ApplyHistoryBestOrSample(ApplyHistoryBest):
"""
Apply the history best config, or sample a valid schedule if no config is found.

Parameters
----------
records : str or iterator of (auto_scheduler.measure.MeasureInput,\
auto_scheduler.measure.MeasureResult)
Collection of tuning records.
If is str, then it should be the filename of a records log file.
Each row of this file is an encoded record pair. Otherwise, it is an iterator.
sample_simple_workloads: bool
When False, sampling will not apply to simple workloads (w/o reduction).
cost_model_file: str
The filename of the pre-trained XGBoost cost model. If not present, then random
model will be used.
num_measure: int
Meausre the top-N rank of sampled schedules on the device. The default -1 means
no measurement and simply return the top-1 schedule ranked by the cost model.
"""

def __init__(
self, records, sample_simple_workloads=False, cost_model_file=None, num_measure=-1
):
self.sample_simple_workloads = sample_simple_workloads
self.num_measure = num_measure
self.log_dir = tempdir()
if cost_model_file is None:
self.cost_model = RandomModel()
else:
self.cost_model = XGBModel()
self.cost_model.load(cost_model_file)

super(ApplyHistoryBestOrSample, self).__init__(
records, n_lines=None, include_compatible=True
)

def query(self, target, workload_key, has_complex_op, dag):
if has_complex_op or self.sample_simple_workloads:
ret = self._query_inside(target, workload_key)
else:
ret = super(ApplyHistoryBestOrSample, self)._query_inside(target, workload_key)

if ret is None:
ret = self._old_ctx.query(target, workload_key, has_complex_op, dag)
return ret

def _query_inside(self, target, workload_key):
ret = super(ApplyHistoryBestOrSample, self)._query_inside(target, workload_key)
if ret is not None:
return ret

# Sampling valid schedules when no existing records can be used.
task = SearchTask(workload_key=workload_key, target=target)
measure_ctx = LocalRPCMeasureContext(min_repeat_ms=300)

log_file = self.log_dir.relpath("%s.log" % decode_workload_key(workload_key)[0])

while ret is None:
tune_option = TuningOptions(
num_measure_trials=self.num_measure,
runner=measure_ctx.runner,
measure_callbacks=[RecordToFile(log_file)],
verbose=0,
)
search_policy = SketchPolicy(
task,
self.cost_model,
params={
"eps_greedy": 0.01,
"sample_init_min_population": 64,
"evolutionary_search_num_iters": 0,
},
init_search_callbacks=[PreloadMeasuredStates(log_file)],
verbose=0,
)
task.tune(tune_option, search_policy)

# Load the sampled records and query again.
self.load(log_file)
ret = super(ApplyHistoryBestOrSample, self)._query_inside(target, workload_key)

del measure_ctx
return ret


class FallbackContext(DispatchContext):
"""
A fallback dispatch context.
Expand Down
17 changes: 13 additions & 4 deletions tests/python/relay/test_auto_scheduler_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,16 @@ def tune_network(network, target):
):
lib = relay.build(mod, target=target, params=params)

# Sample a schedule when missing
with auto_scheduler.ApplyHistoryBestOrSample(None, num_measure=2):
with tvm.transform.PassContext(
opt_level=3, config={"relay.backend.use_auto_scheduler": True}
):
lib2 = relay.build(mod, target=target, params=params)

# Compile without auto-scheduler and any other optimization for correctness check
with tvm.transform.PassContext(opt_level=0):
lib2 = relay.build(mod, target=target, params=params)
ref_lib = relay.build(mod, target=target, params=params)

# Check the correctness
def get_output(data, lib):
Expand All @@ -76,10 +83,12 @@ def get_output(data, lib):
else:
raise ValueError("Unknown network: " + network)

actual_output = get_output(data, lib)
expected_output = get_output(data, lib2)
actual_output1 = get_output(data, lib)
actual_output2 = get_output(data, lib2)
expected_output = get_output(data, ref_lib)

tvm.testing.assert_allclose(actual_output, expected_output, rtol=1e-4, atol=1e-4)
tvm.testing.assert_allclose(actual_output1, expected_output, rtol=1e-4, atol=1e-4)
tvm.testing.assert_allclose(actual_output2, expected_output, rtol=1e-4, atol=1e-4)


@tvm.testing.requires_cuda
Expand Down