Skip to content

Commit

Permalink
[MetaSchedule] Introduce MergedDatabase
Browse files Browse the repository at this point in the history
Following up #12520 and #12626, this PR introduces `MergedDatabase`,
which allow users to compose multiple databases so that the high-level
IR could select the best tuning records among them.

The `MergedDatabase` also comes with an extra field `preferred` to allow
users to override tuning records from other databases. A classic usecase
of the `preferred` parameter is through handcrafted schedule functions:

```python
def schedule_fn(sch: tir.Schedule) -> bool:
  if "nn_conv2d" in sch.mod.attrs["task_name"]:
    handcrafted_scheduling(sch)
    return True
  return False

with ms.database.MergedDatabase(
  preferred=ms.database.ScheduleFn(schedule_fn),
  # ^^^^ override scheduling decisions
  databases=[database],
  fallback=libtorch_database,
  # ^^^^ fallback to libtorch
):
  lib = relay.build(...)
```
  • Loading branch information
junrushao committed Aug 30, 2022
1 parent 9e88723 commit aec10ac
Show file tree
Hide file tree
Showing 8 changed files with 287 additions and 29 deletions.
12 changes: 12 additions & 0 deletions include/tvm/meta_schedule/database.h
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,18 @@ class Database : public runtime::ObjectRef {
*/
TVM_DLL static Database JSONDatabase(String path_workload, String path_tuning_record,
bool allow_missing);
/*!
* \brief Create a database merged from multiple databases.
* \param preferred The preferred databases. If one of the preferred database responses to a
* query, all other databases will be ignored.
* \param databases The databases to be merged.
* \param fallback The fallback databases. If all the databases didn't answer a query,
* the response from the first fallback database that responds will be used.
* \return The merged database.
*/
TVM_DLL static Database MergedDatabase(Array<Database, void> preferred,
Array<Database, void> databases,
Array<Database, void> fallback);
/*!
* \brief Create a database with customized methods on the python-side.
* \param f_has_workload The packed function of `HasWorkload`.
Expand Down
1 change: 1 addition & 0 deletions python/tvm/meta_schedule/database/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,5 @@
from .database import Database, PyDatabase, TuningRecord, Workload
from .json_database import JSONDatabase
from .memory_database import MemoryDatabase
from .merged_database import MergedDatabase
from .schedule_fn_database import ScheduleFnDatabase
94 changes: 94 additions & 0 deletions python/tvm/meta_schedule/database/merged_database.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""A database consists of multiple databases."""
from typing import List, Union

from tvm._ffi import register_object

from .. import _ffi_api
from .database import Database


@register_object("meta_schedule.MergedDatabase")
class MergedDatabase(Database):
"""A database composed of multiple databases, allowing users to guide IR rewriting using
combined knowledge of those databases.
Besides querying from all databases and picking the best running time, this database also
comes with two extra sets of databases:
- preferred: If the preferred database responds to a query, all responses of other databases
will be overridden and ignored.
- fallback: If all databases don't respond to a query, the fallback databases will be used.
Examples
--------
An example of using the merged database:
.. code-block:: python
def schedule_conv2d(sch: tir.Schedule) -> bool:
if "nn_conv2d" in sch.mod.attrs["task_name"]:
handcrafted_scheduling(sch)
return True
return False
with ms.database.MergedDatabase(
preferred=ScheduleFnDatabase(schedule_conv2d), # override schedule for conv2d
databases=[existing_db0, existing_db1, existing_db2], # use existing databases
fallback=libtorch, # fallback to libtorch
):
lib = relay.build(...)
"""

def __init__(
self,
*,
preferred: Union[None, Database, List[Database]] = None,
databases: Union[None, Database, List[Database]] = None,
fallback: Union[None, Database, List[Database]] = None,
) -> None:
"""Construct a merged database from multiple databases.
Parameters
----------
preferred : Union[None, Database, List[Database]] = None
The preferred databases. If one of the preferred database responses to a
query, all other databases will be ignored.
databases : Union[None, Database, List[Database]] = None
The list of databases to merge.
fallback : Union[None, Database, List[Database]] = None
The fallback databases. If all the databases didn't answer a query,
the response from the first fallback database that responds will be used.
"""
if preferred is None:
preferred = []
elif isinstance(preferred, Database):
preferred = [preferred]
if databases is None:
databases = []
elif isinstance(databases, Database):
databases = [databases]
if fallback is None:
fallback = []
elif isinstance(fallback, Database):
fallback = [fallback]
self.__init_handle_by_constructor__(
_ffi_api.DatabaseMergedDatabase, # type: ignore # pylint: disable=no-member
preferred,
databases,
fallback,
)
22 changes: 0 additions & 22 deletions src/meta_schedule/database/json_database.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,28 +25,6 @@
namespace tvm {
namespace meta_schedule {

/*! \brief The struct defining comparison function of sorting by mean run seconds. */
struct SortTuningRecordByMeanRunSecs {
static const constexpr double kMaxMeanTime = 1e10;

static double Mean(const Array<FloatImm>& a) {
if (a.empty()) {
return kMaxMeanTime;
}
double sum = 0.0;
for (const FloatImm& i : a) {
sum += i->value;
}
return sum / a.size();
}

bool operator()(const TuningRecord& a, const TuningRecord& b) const {
double a_time = Mean(a->run_secs.value_or({}));
double b_time = Mean(b->run_secs.value_or({}));
return a_time < b_time;
}
};

/*!
* \brief Read lines from a json file.
* \param path The path to the json file.
Expand Down
112 changes: 112 additions & 0 deletions src/meta_schedule/database/merged_database.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#include "../utils.h"

namespace tvm {
namespace meta_schedule {

class MergedDatabaseNode : public DatabaseNode {
public:
Array<Database> preferred;
Array<Database> databases;
Array<Database> fallback;

void VisitAttrs(AttrVisitor* v) {
v->Visit("preferred", &preferred);
v->Visit("databases", &databases);
v->Visit("fallback", &fallback);
}

static constexpr const char* _type_key = "meta_schedule.MergedDatabase";
TVM_DECLARE_FINAL_OBJECT_INFO(MergedDatabaseNode, DatabaseNode);

public:
Optional<TuningRecord> QueryTuningRecord(const IRModule& mod, const Target& target,
const String& task_name) final {
for (const Database& db : preferred) {
if (Optional<TuningRecord> record = db->QueryTuningRecord(mod, target, task_name)) {
return record;
}
}
std::vector<TuningRecord> results;
results.reserve(databases.size());
for (const Database& db : databases) {
if (Optional<TuningRecord> record = db->QueryTuningRecord(mod, target, task_name)) {
ICHECK(record.value()->run_secs.defined());
results.push_back(record.value());
}
}
std::sort(results.begin(), results.end(), SortTuningRecordByMeanRunSecs());
if (!results.empty()) {
return results[0];
}
for (const Database& db : fallback) {
if (Optional<TuningRecord> record = db->QueryTuningRecord(mod, target, task_name)) {
return record;
}
}
return NullOpt;
}

bool HasWorkload(const IRModule& mod) final {
LOG(FATAL) << "NotImplementedError: MergedDatabase.HasWorkload";
throw;
}

Workload CommitWorkload(const IRModule& mod) final {
LOG(FATAL) << "NotImplementedError: MergedDatabase.CommitWorkload";
throw;
}

void CommitTuningRecord(const TuningRecord& record) final {
LOG(FATAL) << "NotImplementedError: MergedDatabase.CommitTuningRecord";
throw;
}

Array<TuningRecord> GetTopK(const Workload& workload, int top_k) final {
LOG(FATAL) << "NotImplementedError: MergedDatabase.GetTopK";
throw;
}

Array<TuningRecord> GetAllTuningRecords() final {
LOG(FATAL) << "NotImplementedError: MergedDatabase.GetAllTuningRecords";
throw;
}

int64_t Size() final {
LOG(FATAL) << "NotImplementedError: MergedDatabase.size";
throw;
}
};

Database Database::MergedDatabase(Array<Database> preferred, Array<Database> databases,
Array<Database> fallback) {
ObjectPtr<MergedDatabaseNode> n = make_object<MergedDatabaseNode>();
n->preferred = std::move(preferred);
n->databases = std::move(databases);
n->fallback = std::move(fallback);
return Database(n);
}

TVM_REGISTER_NODE_TYPE(MergedDatabaseNode);
TVM_REGISTER_GLOBAL("meta_schedule.DatabaseMergedDatabase")
.set_body_typed(Database::MergedDatabase);

} // namespace meta_schedule
} // namespace tvm
22 changes: 22 additions & 0 deletions src/meta_schedule/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -404,6 +404,28 @@ inline Array<Integer> AsIntArray(const ObjectRef& obj) {
return results;
}

/*! \brief The struct defining comparison function of sorting by mean run seconds. */
struct SortTuningRecordByMeanRunSecs {
static const constexpr double kMaxMeanTime = 1e10;

static double Mean(const Array<FloatImm>& a) {
if (a.empty()) {
return kMaxMeanTime;
}
double sum = 0.0;
for (const FloatImm& i : a) {
sum += i->value;
}
return sum / a.size();
}

bool operator()(const TuningRecord& a, const TuningRecord& b) const {
double a_time = Mean(a->run_secs.value_or({}));
double b_time = Mean(b->run_secs.value_or({}));
return a_time < b_time;
}
};

} // namespace meta_schedule
} // namespace tvm

Expand Down
9 changes: 2 additions & 7 deletions tests/python/unittest/test_link_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,17 +412,12 @@ def schedule_fn(sch):
return True
return False

link_params = True

with StringIO() as stderr_buf, redirect_stderr(stderr_buf):
with ms.database.ScheduleFnDatabase(schedule_fn), tvm.transform.PassContext(
opt_level=3,
config={
"relay.backend.use_meta_schedule": True,
"relay.FuseOps.link_params": link_params,
},
config={"relay.backend.use_meta_schedule": True},
):
executor = Executor("graph", {"link-params": link_params})
executor = Executor("graph")
lib = relay.build(relay_mod, target=target, executor=executor)

# Workload look up should succeed. This does not work when the test is invoked from pytest.
Expand Down
44 changes: 44 additions & 0 deletions tests/python/unittest/test_meta_schedule_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,5 +294,49 @@ def test_meta_schedule_database_reload():
_equal_record(ret[1], records[2])


def test_meta_schedule_database_merge():
mod: IRModule = Matmul
target = tvm.target.Target("llvm")
arg_info = ms.arg_info.ArgInfo.from_prim_func(func=mod["main"])
db_1 = ms.database.MemoryDatabase()
db_2 = ms.database.MemoryDatabase()
db_preferred = ms.database.MemoryDatabase()
db_fallback = ms.database.MemoryDatabase()
trace = _create_schedule(mod, _schedule_matmul).trace

def query(db):
return db.query_tuning_record(mod=mod, target=target, workload_name="main").run_secs

def commit_record(db, run_sec):
db.commit_tuning_record(
ms.database.TuningRecord(
trace,
workload=db.commit_workload(mod),
run_secs=[run_sec],
target=target,
args_info=arg_info,
)
)

commit_record(db_1, 1.0)
(run_sec,) = query(db_1)
assert run_sec.value == 1.0

commit_record(db_2, 0.5)
(run_sec,) = query(db_2)
assert run_sec.value == 0.5

(run_secs,) = query(ms.database.MergedDatabase(databases=[db_1, db_2]))
assert run_secs.value == 0.5

commit_record(db_preferred, 10.0)
(run_secs,) = query(ms.database.MergedDatabase(preferred=db_preferred, databases=[db_1, db_2]))
assert run_secs.value == 10.0

commit_record(db_fallback, 10.0)
(run_secs,) = query(ms.database.MergedDatabase(fallback=db_fallback))
assert run_secs.value == 10.0


if __name__ == "__main__":
tvm.testing.main()

0 comments on commit aec10ac

Please sign in to comment.