Skip to content

Commit

Permalink
[MetaSchedule][UX] Make Database with-able
Browse files Browse the repository at this point in the history
`ApplyHistoryBest` right now plays a role as the database adaptor to query inside the database.
In fact, the logic could be simplified and users only have to deal with `Database` instead of this
extra object.

- [x] Add `EnterWithScope`/`ExitWithScope`/`Current` to Database
- [x] Migrate `te_filter_func` => "tir_filter" in Relay's pass context
- [x] Migrate `f_take_tuning_record` => "Database.query_tuning_record"
- [x] Migrate `TECompiler` to use `Database`
- [x] Remove apply-history-best

Next PR:
- Migrate `f_direct_dispatch` (potentially unify with `apply_fixed_schedule`?)
  • Loading branch information
junrushao committed Aug 26, 2022
1 parent 49b3c72 commit 2221492
Show file tree
Hide file tree
Showing 27 changed files with 511 additions and 764 deletions.
115 changes: 0 additions & 115 deletions include/tvm/meta_schedule/apply_history_best.h

This file was deleted.

28 changes: 28 additions & 0 deletions include/tvm/meta_schedule/database.h
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,27 @@ class DatabaseNode : public runtime::Object {
* \return The size of the database.
*/
virtual int64_t Size() = 0;
/*!
* \brief Query the best record of the given workload from the database.
* \param mod The IRModule to be searched for.
* \param target The target to be searched for.
* \return The best record of the given workload; NullOpt if not found.
*/
virtual Optional<TuningRecord> QueryTuningRecord(IRModule mod, Target target);
/*!
* \brief Query the best schedule of the given workload from the database.
* \param mod The IRModule to be searched for.
* \param target The target to be searched for.
* \return The schedule in the best schedule of the given workload; NullOpt if not found.
*/
virtual Optional<tir::Schedule> QuerySchedule(IRModule mod, Target target);
/*!
* \brief Query the best IRModule of the given workload from the database.
* \param mod The IRModule to be searched for.
* \param target The target to be searched for.
* \return The IRModule in the best IRModule of the given workload; NullOpt if not found.
*/
virtual Optional<IRModule> QueryIRModule(IRModule mod, Target target);

static constexpr const char* _type_key = "meta_schedule.Database";
TVM_DECLARE_BASE_OBJECT_INFO(DatabaseNode, runtime::Object);
Expand Down Expand Up @@ -339,6 +360,13 @@ class Database : public runtime::ObjectRef {
PyDatabaseNode::FGetTopK f_get_top_k,
PyDatabaseNode::FGetAllTuningRecords f_get_all_tuning_records,
PyDatabaseNode::FSize f_size);
/*! \return The current Database in the scope. */
static Optional<Database> Current();
/*! \brief Entering the scope of the context manager */
void EnterWithScope();
/*! \brief Exiting the scope of the context manager */
void ExitWithScope();

TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(Database, runtime::ObjectRef, DatabaseNode);
};

Expand Down
20 changes: 0 additions & 20 deletions include/tvm/meta_schedule/extracted_task.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,26 +76,6 @@ class ExtractedTask : public runtime::ObjectRef {
ExtractedTaskNode);
};

/*!
* \brief The default TE task filter
* \param args The input/output arguments of the TE compute graph
* \param constants Raw data for constant tensors in args. If the size of this array is N, the last
* N tensors in args will be treated as constant tensors.
* \return NullOpt if the task is filtered out, otherwise the task in PrimFunc
*/
Optional<tvm::tir::PrimFunc> DefaultTaskFilter(const Array<tvm::te::Tensor, void>& args,
const Array<runtime::NDArray>& constants);

/*!
* \brief The default TE task filter, with `te.extern` allowed
* \param args The input/output arguments of the TE compute graph
* \param constants Raw data for constant tensors in args. If the size of this array is N, the last
* N tensors in args will be treated as constant tensors.
* \return NullOpt if the task is filtered out, otherwise the task in PrimFunc
*/
Optional<tir::PrimFunc> DefaultTaskFilterAllowExtern(const Array<tvm::te::Tensor, void>& args,
const Array<runtime::NDArray>& constants);

} // namespace meta_schedule
} // namespace tvm

Expand Down
93 changes: 51 additions & 42 deletions python/tvm/auto_scheduler/testing/tune_relay.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,18 @@
# specific language governing permissions and limitations
# under the License.
# pylint: disable=missing-docstring
from distutils.util import strtobool
import argparse
import json
import os
from distutils.util import strtobool

import tvm
from tvm import auto_scheduler
from tvm import meta_schedule as ms
from tvm import relay
from tvm.meta_schedule.testing.custom_builder_runner import run_module_via_rpc
from tvm.meta_schedule.testing.relay_workload import get_network
from tvm.meta_schedule.testing.tune_utils import generate_input_data, create_timer
from tvm.meta_schedule.testing.tune_utils import create_timer, generate_input_data
from tvm.meta_schedule.utils import cpu_count
from tvm.support import describe

Expand Down Expand Up @@ -170,53 +170,62 @@ def main():
ARGS.input_shape,
cache_dir=ARGS.cache_dir,
)
input_info = {input_name: input_shape}
input_info = [
{
"name": input_name,
"shape": input_shape,
"dtype": input_dtype,
},
]
input_data = {
item["name"]: generate_input_data(item["shape"], item["dtype"]) for item in ARGS.input_shape
item["name"]: generate_input_data(item["shape"], item["dtype"]) for item in input_info
}
for input_name, input_shape in input_info.items():
print(f" input_name : {input_name}")
print(f" input_shape: {input_shape}")
print(f" input_dtype: {input_dtype}")
for item in input_info:
print(f" input_name : {item['name']}")
print(f" input_shape: {item['shape']}")
print(f" input_dtype: {item['dtype']}")

with ms.Profiler() as profiler:
tasks, task_weights = auto_scheduler.extract_tasks(
mod["main"],
params,
target=ARGS.target,
hardware_params=hardware_params,
)
for idx, (task, task_weight) in enumerate(zip(tasks, task_weights)):
print(
f"==== Task {idx}: {task.desc} "
f"(weight {task_weight} key: {task.workload_key}) ====="
)
print(task.compute_dag)

if ARGS.num_trials > 0:
tuner = auto_scheduler.TaskScheduler(tasks, task_weights)
tuner.tune(
auto_scheduler.TuningOptions(
num_measure_trials=ARGS.num_trials,
runner=runner,
measure_callbacks=[
auto_scheduler.RecordToFile(log_file),
],
),
adaptive_training=ARGS.adaptive_training,
with ms.Profiler.timeit("TaskExtraction"):
tasks, task_weights = auto_scheduler.extract_tasks(
mod["main"],
params,
target=ARGS.target,
hardware_params=hardware_params,
)
for idx, (task, task_weight) in enumerate(zip(tasks, task_weights)):
print(
f"==== Task {idx}: {task.desc} "
f"(weight {task_weight} key: {task.workload_key}) ====="
)
print(task.compute_dag)

with ms.Profiler.timeit("Tuning"):
if ARGS.num_trials > 0:
tuner = auto_scheduler.TaskScheduler(tasks, task_weights)
tuner.tune(
auto_scheduler.TuningOptions(
num_measure_trials=ARGS.num_trials,
runner=runner,
measure_callbacks=[
auto_scheduler.RecordToFile(log_file),
],
),
adaptive_training=ARGS.adaptive_training,
)

relay_build = {"graph": relay.build, "vm": relay.vm.compile}[ARGS.backend]
with auto_scheduler.ApplyHistoryBest(log_file):
with tvm.transform.PassContext(
opt_level=3,
config={"relay.backend.use_auto_scheduler": True},
):
lib = relay_build(
mod,
target=ARGS.target,
params=params,
)
with ms.Profiler.timeit("PostTuningCompilation"):
with auto_scheduler.ApplyHistoryBest(log_file):
with tvm.transform.PassContext(
opt_level=3,
config={"relay.backend.use_auto_scheduler": True},
):
lib = relay_build(
mod,
target=ARGS.target,
params=params,
)
print("Tuning Time:")
print(profiler.table())

Expand Down
1 change: 0 additions & 1 deletion python/tvm/meta_schedule/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
search_strategy,
space_generator,
)
from .apply_history_best import ApplyHistoryBest
from .extracted_task import ExtractedTask
from .profiler import Profiler
from .relay_integration import (
Expand Down
Loading

0 comments on commit 2221492

Please sign in to comment.