Skip to content

Commit

Permalink
[AutoScheduler] Improve test cases (#6657)
Browse files Browse the repository at this point in the history
* Improve test cases

* update

* fix lint

* fix lint

* trigger CI

* address comments

* trigger CI
  • Loading branch information
merrymercy authored Oct 11, 2020
1 parent f6657a6 commit dd60d24
Show file tree
Hide file tree
Showing 14 changed files with 202 additions and 103 deletions.
9 changes: 6 additions & 3 deletions python/tvm/auto_scheduler/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,17 @@
""" Namespace for TVM Auto-scheduler. """

from . import compute_dag
from . import feature
from . import loop_state
from . import measure
from . import measure_record
from . import loop_state
from . import search_policy
from . import search_task
from . import utils
from . import workload_registry
from . import feature

# Shortcut
from .auto_schedule import SearchTask, TuningOptions, HardwareParams, create_task, auto_schedule
from .auto_schedule import TuningOptions, HardwareParams, create_task, auto_schedule
from .compute_dag import ComputeDAG
from .cost_model import RandomModel, XGBModel
from .measure import (
Expand All @@ -38,5 +40,6 @@
LocalRPCMeasureContext,
)
from .measure_record import RecordToFile, RecordReader, load_best, load_records, save_records
from .search_task import SearchTask
from .search_policy import EmptyPolicy, SketchPolicy, PreloadMeasuredStates
from .workload_registry import register_workload, make_workload_key
34 changes: 8 additions & 26 deletions python/tvm/auto_scheduler/auto_schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,13 @@

import tvm._ffi
from tvm.runtime import Object
from tvm.target import Target
from .measure import LocalBuilder, LocalRunner
from .workload_registry import make_workload_key
from .compute_dag import ComputeDAG
from .cost_model import XGBModel
from .search_policy import SketchPolicy
from .search_task import SearchTask
from . import _ffi_api


Expand All @@ -61,30 +63,6 @@ def __init__(self, num_cores, vector_unit_bytes, cache_line_bytes):
)


@tvm._ffi.register_object("auto_scheduler.SearchTask")
class SearchTask(Object):
"""The computation information and hardware parameters for a schedule search task.
Parameters
----------
dag : ComputeDAG
The ComputeDAG for the corresponding compute declaration.
workload_key : str
The workload key for the corresponding compute declaration.
target : tvm.target.Target
The target device of this search task.
target_host : Optional[tvm.target.Target]
The target host device of this search task.
hardware_params : Optional[HardwareParams]
Hardware parameters used in this search task.
"""

def __init__(self, dag, workload_key, target, target_host=None, hardware_params=None):
self.__init_handle_by_constructor__(
_ffi_api.SearchTask, dag, workload_key, target, target_host, hardware_params
)


@tvm._ffi.register_object("auto_scheduler.TuningOptions")
class TuningOptions(Object):
"""This controls the options of performance tuning.
Expand Down Expand Up @@ -169,9 +147,9 @@ def create_task(func, args, target, target_host=None, hardware_params=None):
Can be the a function or the function name.
args : Union[Tuple[Any, ...], List[Any]]
The args of the function.
target : tvm.target.Target
target : Union[tvm.target.Target, str]
The target device of this search task.
target_host : Optional[tvm.target.Target]
target_host : Optional[Union[tvm.target.Target, str]]
The target host device of this search task.
hardware_params : Optional[HardwareParams]
Hardware parameters used in this search task.
Expand All @@ -182,6 +160,10 @@ def create_task(func, args, target, target_host=None, hardware_params=None):
"""
workload_key = make_workload_key(func, args)
dag = ComputeDAG(workload_key)
if isinstance(target, str):
target = Target(target)
if isinstance(target_host, str):
target_host = Target(target_host)
return SearchTask(dag, workload_key, target, target_host, hardware_params)


Expand Down
2 changes: 1 addition & 1 deletion python/tvm/auto_scheduler/measure.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
A builder builds the executable binary files and a runner runs the binary files to
get the measurement results. The flow of data structures is
. `ProgramBuilder` `ProgramRunner`
. `ProgramBuilder` `ProgramRunner`
`MeasureInput` -----------------> `BuildResult` ----------------> `MeasureResult`
We implement these in python to utilize python's multiprocessing and error handling.
Expand Down
53 changes: 52 additions & 1 deletion python/tvm/auto_scheduler/measure_record.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@

import tvm._ffi
from tvm.runtime import Object
from .measure import MeasureCallback, MeasureErrorNo
from .compute_dag import ComputeDAG
from .measure import MeasureErrorNo, MeasureInput, MeasureCallback
from .search_task import SearchTask
from . import _ffi_api


Expand Down Expand Up @@ -70,6 +72,13 @@ def read_lines(self, max_lines=None, skip_lines=0):
The MeasureInputs loaded from the log file.
results : List[auto_scheduler.measure.MeasureResult]
The MeasureResults loaded from the log file.
Notes
-----
Some unimportant and expensive fields in the returned MeasureInput are not deserialized
for faster read speed (e.g. input.task.compute_dag, input.state.stages).
If you want to use them, you can call the :code:`recover_measure_input` below
to rebuild these fields.
"""
inputs, results = _ffi_api.RecordReaderReadLines(
self, max_lines if max_lines else -1, skip_lines
Expand All @@ -96,6 +105,13 @@ def load_records(filename):
Returns
-------
logs : List[auto_scheduler.measure.MeasureInput, auto_scheduler.measure.MeasureResult]
Notes
-----
Some unimportant and expensive fields in the returned MeasureInput are not deserialized
for faster read speed (e.g., input.task.compute_dag, input.state.stages).
If you want to use them, you can call the :code:`recover_measure_input` below
to rebuild these fields.
"""
return zip(*RecordReader(filename).read_lines())

Expand Down Expand Up @@ -159,3 +175,38 @@ def load_best(filename, workload_key=None, target=None):
best_res = res

return best_inp, best_res


def recover_measure_input(inp, rebuild_state=False):
"""
Recover a deserialized MeasureInput by rebuilding the missing fields.
1. Rebuid the compute_dag in inp.task
2. (Optional) Rebuild the stages in inp.state
Parameters
----------
inp: MeasureInput
The deserialized MeasureInput
rebuild_state: bool = False
Whether rebuild the stages in MeasureInput.State
Returns
-------
new_input: MeasureInput
The fully recovered MeasureInput with all fields rebuilt.
"""
task = inp.task
new_task = SearchTask(
ComputeDAG(task.workload_key),
task.workload_key,
task.target,
task.target_host,
task.hardware_params,
)

if rebuild_state:
new_state = new_task.compute_dag.infer_bound_from_state(inp.state)
else:
new_state = inp.state

return MeasureInput(new_task, new_state)
2 changes: 1 addition & 1 deletion python/tvm/auto_scheduler/search_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ class SketchPolicy(SearchPolicy):
"gpu_multi_level_tiling_structure": "SSSRRSRS",
# Notice: the default thread bind policy of GPU assumes the tiling structure to have at
# least 3 spatial tiling levels in outermost
"max_innermost_split_factor": 16,
"max_innermost_split_factor": 64,
"max_vectorize_size": 16,
"disable_change_compute_location": 0,
}
Expand Down
47 changes: 47 additions & 0 deletions python/tvm/auto_scheduler/search_task.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

""" The definiton of SearchTask """

import tvm._ffi
from tvm.runtime import Object

from . import _ffi_api


@tvm._ffi.register_object("auto_scheduler.SearchTask")
class SearchTask(Object):
"""The computation information and hardware parameters for a schedule search task.
Parameters
----------
dag : ComputeDAG
The ComputeDAG for the corresponding compute declaration.
workload_key : str
The workload key for the corresponding compute declaration.
target : tvm.target.Target
The target device of this search task.
target_host : Optional[tvm.target.Target]
The target host device of this search task.
hardware_params : Optional[HardwareParams]
Hardware parameters used in this search task.
"""

def __init__(self, dag, workload_key, target, target_host=None, hardware_params=None):
self.__init_handle_by_constructor__(
_ffi_api.SearchTask, dag, workload_key, target, target_host, hardware_params
)
6 changes: 4 additions & 2 deletions src/auto_scheduler/search_policy/sketch_policy_rules.cc
Original file line number Diff line number Diff line change
Expand Up @@ -441,6 +441,9 @@ std::vector<std::pair<State, int>> RuleSpecialComputeLocationGPU::Apply(

PopulationGenerationRule::ResultKind InitFillTileSize::Apply(SketchPolicyNode* policy, State* state,
std::mt19937* rand_gen) const {
int max_innermost_split_factor =
GetIntParam(policy->params, SketchParamKey::max_innermost_split_factor);

StateNode* pstate = state->CopyOnWrite();
// Scan the transformation history and randomly fill tiles size for all SplitStep
for (size_t step_id = 0; step_id < (*state)->transform_steps.size(); ++step_id) {
Expand All @@ -459,8 +462,7 @@ PopulationGenerationRule::ResultKind InitFillTileSize::Apply(SketchPolicyNode* p
CHECK(ps->extent);
int extent = GetIntImm(ps->extent.value());
const auto& candidate_lens = policy->split_memo.GetFactorizationSchemes(
extent, ps->lengths.size(),
GetIntParam(policy->params, SketchParamKey::max_innermost_split_factor));
extent, ps->lengths.size(), max_innermost_split_factor);
const auto& candidate_lengths = candidate_lens[(*rand_gen)() % candidate_lens.size()];

pstate->transform_steps.Set(
Expand Down
1 change: 1 addition & 0 deletions tests/python/unittest/test_auto_scheduler_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,7 @@ def conv2d_winograd_nhwc_auto_scheduler_test(


def get_tiled_matmul():
"""Get a compute dag and a state for tiled matmul"""
A, B, C = matmul_auto_scheduler_test(512, 512, 512)
dag = auto_scheduler.ComputeDAG([A, B, C])

Expand Down
16 changes: 8 additions & 8 deletions tests/python/unittest/test_auto_scheduler_cost_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,9 @@


def get_sample_records(number):
"""Generate random a list of random MeasureInput and MeasureResult pairs"""
"""Generate a list of random MeasureInput and MeasureResult pairs"""
N = 128
workload_key = auto_scheduler.make_workload_key(matmul_auto_scheduler_test, (N, N, N))
dag = auto_scheduler.ComputeDAG(workload_key)
target = tvm.target.Target("llvm")
task = auto_scheduler.SearchTask(dag, workload_key, target)
task = auto_scheduler.create_task(matmul_auto_scheduler_test, (N, N, N), "llvm")
policy = auto_scheduler.SketchPolicy(task, verbose=0)
states = policy.sample_initial_population(number)

Expand All @@ -43,11 +40,11 @@ def get_sample_records(number):
for _ in range(len(inputs))
]

return task, dag, inputs, results
return task, inputs, results


def test_random_model():
task, dag, inputs, results = get_sample_records(50)
task, inputs, results = get_sample_records(50)

model = auto_scheduler.RandomModel()
model.update(inputs, results)
Expand All @@ -56,7 +53,7 @@ def test_random_model():


def test_xgb_model():
task, dag, inputs, results = get_sample_records(50)
task, inputs, results = get_sample_records(50)

model = auto_scheduler.XGBModel(num_warmup_sample=-1)
model.update(inputs, results)
Expand All @@ -66,13 +63,16 @@ def test_xgb_model():
costs = [np.mean([x.value for x in res.costs]) for res in results]
throughputs = np.min(costs) / costs

# test regression quality
rmse = np.sqrt(np.mean([np.square(pred - label) for pred, label in zip(preds, throughputs)]))
assert rmse <= 0.3

# test loading a record file
with tempfile.NamedTemporaryFile() as fp:
auto_scheduler.save_records(fp.name, inputs, results)
model.update_from_file(fp.name)

# test model serialization
with tempfile.NamedTemporaryFile() as fp:
model.save(fp.name)
model.load(fp.name)
Expand Down
19 changes: 8 additions & 11 deletions tests/python/unittest/test_auto_scheduler_evolutionary_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,6 @@ def predict(self, task, states):
assert found


@pytest.mark.skip(reason="flaky")
def test_mutate_parallel():
"""
The test case initializes evo search with a batch of "bad" states and check whether
Expand All @@ -95,20 +94,18 @@ def predict(self, task, states):
scores.append(1 if self.is_good_state(state) else 0)
return scores

workload_key = auto_scheduler.make_workload_key(matmul_auto_scheduler_test, (1024, 1024, 1024))
dag = auto_scheduler.ComputeDAG(workload_key)
task = auto_scheduler.SearchTask(dag, workload_key, tvm.target.Target("llvm"))
task = auto_scheduler.create_task(matmul_auto_scheduler_test, (1024, 1024, 1024), "llvm")
policy = auto_scheduler.SketchPolicy(task, program_cost_model=MockCostModel(), verbose=0)
states = policy.sample_initial_population(100)

bad_states = []
for state in states:
if not MockCostModel.is_good_state(state):
bad_states.append(state)

found = False
retry_ct = 0
while retry_ct < 5 and not found:
while retry_ct < 10 and not found:
states = policy.sample_initial_population(100)
bad_states = []
for state in states:
if not MockCostModel.is_good_state(state):
bad_states.append(state)

new_states = policy.evolutionary_search(bad_states, 50)
for state in new_states:
if MockCostModel.is_good_state(state):
Expand Down
17 changes: 7 additions & 10 deletions tests/python/unittest/test_auto_scheduler_layout_rewrite.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,9 @@ def test_apply_steps_with_layout_rewrite():

def test_layout_rewrite_correctness():
N = 128
target = "llvm"
workload = matmul_auto_scheduler_test
workload_key = auto_scheduler.make_workload_key(workload, (N, N, N))
dag = auto_scheduler.ComputeDAG(workload_key)
target = tvm.target.Target(target)
task = auto_scheduler.SearchTask(dag, workload_key, target)
target = tvm.target.Target("llvm")
task = auto_scheduler.create_task(matmul_auto_scheduler_test, (N, N, N), target)
dag = task.compute_dag

with tempfile.NamedTemporaryFile() as fp:
log_file = fp.name
Expand All @@ -60,7 +57,7 @@ def test_layout_rewrite_correctness():
measure_callbacks=[auto_scheduler.RecordToFile(log_file)],
)
auto_scheduler.auto_schedule(task, search_policy, tuning_options)
inp, _ = auto_scheduler.load_best(log_file, workload_key, target)
inp, _ = auto_scheduler.load_best(log_file, task.workload_key, target)
s, bufs = dag.apply_steps_from_state(inp.state, layout_rewrite=True)
s_ref, bufs_ref = dag.apply_steps_from_state(inp.state, layout_rewrite=False)
np_args = [np.random.randn(*topi.get_const_tuple(x.shape)).astype(x.dtype) for x in bufs]
Expand Down Expand Up @@ -89,10 +86,10 @@ def test_layout_rewrite_correctness():
np_args_ref[1] = np_args_ref[1].transpose(new_order)
np_args_ref[1] = np_args_ref[1].reshape((red_dim, out_dim))

func = tvm.build(s, bufs, target=inp.task.target, target_host=inp.task.target_host)
func_ref = tvm.build(s_ref, bufs_ref, target="llvm")
func = tvm.build(s, bufs, target=target)
func_ref = tvm.build(s_ref, bufs_ref, target=target)

ctx = tvm.context(str(inp.task.target))
ctx = tvm.context(str(target))
ctx_ref = tvm.cpu()

args = [tvm.nd.array(x, ctx=ctx) for x in np_args]
Expand Down
Loading

0 comments on commit dd60d24

Please sign in to comment.