Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MetaSchedule][UX] Make Database with-able #12520

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 0 additions & 115 deletions include/tvm/meta_schedule/apply_history_best.h

This file was deleted.

28 changes: 28 additions & 0 deletions include/tvm/meta_schedule/database.h
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,27 @@ class DatabaseNode : public runtime::Object {
* \return The size of the database.
*/
virtual int64_t Size() = 0;
/*!
* \brief Query the best record of the given workload from the database.
* \param mod The IRModule to be searched for.
* \param target The target to be searched for.
* \return The best record of the given workload; NullOpt if not found.
*/
virtual Optional<TuningRecord> QueryTuningRecord(IRModule mod, Target target);
/*!
* \brief Query the best schedule of the given workload from the database.
* \param mod The IRModule to be searched for.
* \param target The target to be searched for.
* \return The schedule in the best schedule of the given workload; NullOpt if not found.
*/
virtual Optional<tir::Schedule> QuerySchedule(IRModule mod, Target target);
/*!
* \brief Query the best IRModule of the given workload from the database.
* \param mod The IRModule to be searched for.
* \param target The target to be searched for.
* \return The IRModule in the best IRModule of the given workload; NullOpt if not found.
*/
virtual Optional<IRModule> QueryIRModule(IRModule mod, Target target);

static constexpr const char* _type_key = "meta_schedule.Database";
TVM_DECLARE_BASE_OBJECT_INFO(DatabaseNode, runtime::Object);
Expand Down Expand Up @@ -339,6 +360,13 @@ class Database : public runtime::ObjectRef {
PyDatabaseNode::FGetTopK f_get_top_k,
PyDatabaseNode::FGetAllTuningRecords f_get_all_tuning_records,
PyDatabaseNode::FSize f_size);
/*! \return The current Database in the scope. */
static Optional<Database> Current();
/*! \brief Entering the scope of the context manager */
void EnterWithScope();
/*! \brief Exiting the scope of the context manager */
void ExitWithScope();

TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(Database, runtime::ObjectRef, DatabaseNode);
};

Expand Down
20 changes: 0 additions & 20 deletions include/tvm/meta_schedule/extracted_task.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,26 +76,6 @@ class ExtractedTask : public runtime::ObjectRef {
ExtractedTaskNode);
};

/*!
* \brief The default TE task filter
* \param args The input/output arguments of the TE compute graph
* \param constants Raw data for constant tensors in args. If the size of this array is N, the last
* N tensors in args will be treated as constant tensors.
* \return NullOpt if the task is filtered out, otherwise the task in PrimFunc
*/
Optional<tvm::tir::PrimFunc> DefaultTaskFilter(const Array<tvm::te::Tensor, void>& args,
const Array<runtime::NDArray>& constants);

/*!
* \brief The default TE task filter, with `te.extern` allowed
* \param args The input/output arguments of the TE compute graph
* \param constants Raw data for constant tensors in args. If the size of this array is N, the last
* N tensors in args will be treated as constant tensors.
* \return NullOpt if the task is filtered out, otherwise the task in PrimFunc
*/
Optional<tir::PrimFunc> DefaultTaskFilterAllowExtern(const Array<tvm::te::Tensor, void>& args,
const Array<runtime::NDArray>& constants);

} // namespace meta_schedule
} // namespace tvm

Expand Down
93 changes: 51 additions & 42 deletions python/tvm/auto_scheduler/testing/tune_relay.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,18 @@
# specific language governing permissions and limitations
# under the License.
# pylint: disable=missing-docstring
from distutils.util import strtobool
import argparse
import json
import os
from distutils.util import strtobool

import tvm
from tvm import auto_scheduler
from tvm import meta_schedule as ms
from tvm import relay
from tvm.meta_schedule.testing.custom_builder_runner import run_module_via_rpc
from tvm.meta_schedule.testing.relay_workload import get_network
from tvm.meta_schedule.testing.tune_utils import generate_input_data, create_timer
from tvm.meta_schedule.testing.tune_utils import create_timer, generate_input_data
from tvm.meta_schedule.utils import cpu_count
from tvm.support import describe

Expand Down Expand Up @@ -170,53 +170,62 @@ def main():
ARGS.input_shape,
cache_dir=ARGS.cache_dir,
)
input_info = {input_name: input_shape}
input_info = [
{
"name": input_name,
"shape": input_shape,
"dtype": input_dtype,
},
]
input_data = {
item["name"]: generate_input_data(item["shape"], item["dtype"]) for item in ARGS.input_shape
item["name"]: generate_input_data(item["shape"], item["dtype"]) for item in input_info
}
for input_name, input_shape in input_info.items():
print(f" input_name : {input_name}")
print(f" input_shape: {input_shape}")
print(f" input_dtype: {input_dtype}")
for item in input_info:
print(f" input_name : {item['name']}")
print(f" input_shape: {item['shape']}")
print(f" input_dtype: {item['dtype']}")

with ms.Profiler() as profiler:
tasks, task_weights = auto_scheduler.extract_tasks(
mod["main"],
params,
target=ARGS.target,
hardware_params=hardware_params,
)
for idx, (task, task_weight) in enumerate(zip(tasks, task_weights)):
print(
f"==== Task {idx}: {task.desc} "
f"(weight {task_weight} key: {task.workload_key}) ====="
)
print(task.compute_dag)

if ARGS.num_trials > 0:
tuner = auto_scheduler.TaskScheduler(tasks, task_weights)
tuner.tune(
auto_scheduler.TuningOptions(
num_measure_trials=ARGS.num_trials,
runner=runner,
measure_callbacks=[
auto_scheduler.RecordToFile(log_file),
],
),
adaptive_training=ARGS.adaptive_training,
with ms.Profiler.timeit("TaskExtraction"):
tasks, task_weights = auto_scheduler.extract_tasks(
mod["main"],
params,
target=ARGS.target,
hardware_params=hardware_params,
)
for idx, (task, task_weight) in enumerate(zip(tasks, task_weights)):
print(
f"==== Task {idx}: {task.desc} "
f"(weight {task_weight} key: {task.workload_key}) ====="
)
print(task.compute_dag)

with ms.Profiler.timeit("Tuning"):
if ARGS.num_trials > 0:
tuner = auto_scheduler.TaskScheduler(tasks, task_weights)
tuner.tune(
auto_scheduler.TuningOptions(
num_measure_trials=ARGS.num_trials,
runner=runner,
measure_callbacks=[
auto_scheduler.RecordToFile(log_file),
],
),
adaptive_training=ARGS.adaptive_training,
)

relay_build = {"graph": relay.build, "vm": relay.vm.compile}[ARGS.backend]
with auto_scheduler.ApplyHistoryBest(log_file):
with tvm.transform.PassContext(
opt_level=3,
config={"relay.backend.use_auto_scheduler": True},
):
lib = relay_build(
mod,
target=ARGS.target,
params=params,
)
with ms.Profiler.timeit("PostTuningCompilation"):
with auto_scheduler.ApplyHistoryBest(log_file):
with tvm.transform.PassContext(
opt_level=3,
config={"relay.backend.use_auto_scheduler": True},
):
lib = relay_build(
mod,
target=ARGS.target,
params=params,
)
print("Tuning Time:")
print(profiler.table())

Expand Down
1 change: 0 additions & 1 deletion python/tvm/meta_schedule/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
search_strategy,
space_generator,
)
from .apply_history_best import ApplyHistoryBest
from .extracted_task import ExtractedTask
from .profiler import Profiler
from .relay_integration import (
Expand Down
Loading