Skip to content

Commit

Permalink
[Meta Schedule] Add ApplyHisotryBest Meta Schedule Context (#10049)
Browse files Browse the repository at this point in the history
* Add ApplyHisotryBest.

Co-authored-by: Junru Shao <junrushao1994@gmail.com>
Co-authored-by: Bohan Hou <32121147+spectrometerHBH@users.noreply.github.com>
Co-authored-by: Ruihang Lai <lairuihangdongdong@qq.com>
Co-authored-by: Hongyi Jin <3231950289@qq.com>
Co-authored-by: Wuwei Lin <wuwei@apache.org>
Co-authored-by: Siyuan Feng <Hzfengsy@sjtu.edu.cn>

* Retrigger CI.

* Update integration.py

Co-authored-by: Junru Shao <junrushao1994@gmail.com>
Co-authored-by: Bohan Hou <32121147+spectrometerHBH@users.noreply.github.com>
Co-authored-by: Ruihang Lai <lairuihangdongdong@qq.com>
Co-authored-by: Hongyi Jin <3231950289@qq.com>
Co-authored-by: Wuwei Lin <wuwei@apache.org>
Co-authored-by: Siyuan Feng <Hzfengsy@sjtu.edu.cn>
  • Loading branch information
7 people authored Jan 26, 2022
1 parent 1b9b05e commit 94c4e0e
Show file tree
Hide file tree
Showing 4 changed files with 103 additions and 2 deletions.
9 changes: 8 additions & 1 deletion python/tvm/meta_schedule/integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from tvm.target import Target
from tvm.tir import PrimFunc

from .database import Database
from . import _ffi_api


Expand Down Expand Up @@ -174,7 +175,13 @@ def __init__(self) -> None:

@register_object("meta_schedule.ApplyHistoryBest")
class ApplyHistoryBest(MetaScheduleContext):
pass
"""An integration context that allows application of historically best record from database"""

database: Database
""" The database to be queried from"""

def __init__(self, database) -> None:
self.__init_handle_by_constructor__(_ffi_api.ApplyHistoryBest, database) # type: ignore # pylint: disable=no-member


def extract_task(
Expand Down
22 changes: 21 additions & 1 deletion src/meta_schedule/integration.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
#include <tvm/relay/function.h>
#include <tvm/tir/function.h>

#include "./utils.h"

namespace tvm {
namespace meta_schedule {

Expand Down Expand Up @@ -112,7 +114,21 @@ ApplyHistoryBest::ApplyHistoryBest(Database database) {

Optional<ObjectRef> ApplyHistoryBestNode::Query(runtime::String task_name, IRModule mod,
Optional<Array<IRModule>> dispatched) {
throw;
ICHECK(dispatched.defined());
ICHECK_EQ(dispatched.value().size(), 1);
ICHECK(HasOnlyOneFunction<relay::Function>(mod)) << mod;
IRModule prim_mod = dispatched.value()[0];
ICHECK(HasOnlyOneFunction<tir::PrimFunc>(prim_mod)) << prim_mod;
// Unify func name to make sure it can be found in database
prim_mod = UnifyFuncName(prim_mod);
if (database->HasWorkload(prim_mod)) {
Array<TuningRecord> records = database->GetTopK(database->CommitWorkload(prim_mod), 1);
if (records.size() == 1) {
LOG(INFO) << "Applied history best for " << task_name << ".";
return records[0]->workload->mod;
}
}
return NullOpt;
}

/**************** FFI ****************/
Expand Down Expand Up @@ -146,6 +162,10 @@ TVM_REGISTER_GLOBAL("meta_schedule.MetaScheduleContextQuery")
TVM_REGISTER_GLOBAL("meta_schedule.TaskExtraction").set_body_typed([]() -> TaskExtraction {
return TaskExtraction();
});
TVM_REGISTER_GLOBAL("meta_schedule.ApplyHistoryBest")
.set_body_typed([](Database database) -> ApplyHistoryBest {
return ApplyHistoryBest(database);
});

} // namespace meta_schedule
} // namespace tvm
16 changes: 16 additions & 0 deletions src/meta_schedule/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,22 @@ inline int GetTargetNumCores(const Target& target) {
return num_cores;
}

/*!
* \brief Unify the function name in workload to "main".
* \param mod The workload.
* \return The new workload with unified function name.
* \note If the name is not unified, the workload may not be found in database.
*/
inline IRModule UnifyFuncName(const IRModule& mod) {
if (!mod->ContainGlobalVar("main") && mod->GetGlobalTypeVars().size() == 1) {
IRModule new_mod = IRModule(
Map<GlobalVar, BaseFunc>({{GlobalVar("main"), mod->functions[mod->GetGlobalVars()[0]]}}));
return new_mod;
} else {
return mod;
}
}

} // namespace meta_schedule
} // namespace tvm

Expand Down
58 changes: 58 additions & 0 deletions tests/python/unittest/test_meta_schedule_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,14 @@
import tvm
from tvm import meta_schedule as ms
from tvm.ir.module import IRModule
from tvm.tir import Schedule
from tvm.target import Target
from tvm.meta_schedule.database import PyDatabase, Workload, TuningRecord
from tvm.meta_schedule.integration import (
ExtractedTask,
MetaScheduleContext,
TaskExtraction,
ApplyHistoryBest,
)
from tvm.meta_schedule.testing import get_network
from tvm.script import tir as T
Expand Down Expand Up @@ -116,5 +120,59 @@ def test_meta_schedule_integration_extract_from_resnet():
assert len(extracted_tasks) == 30


def test_meta_schedule_integration_apply_history_best():
class DummyDatabase(PyDatabase):
def __init__(self):
super().__init__()
self.records = []
self.workload_reg = []

def has_workload(self, mod: IRModule) -> Workload:
for workload in self.workload_reg:
if tvm.ir.structural_equal(workload.mod, mod):
return True
return False

def commit_tuning_record(self, record: TuningRecord) -> None:
self.records.append(record)

def commit_workload(self, mod: IRModule) -> Workload:
for workload in self.workload_reg:
if tvm.ir.structural_equal(workload.mod, mod):
return workload
workload = Workload(mod)
self.workload_reg.append(workload)
return workload

def get_top_k(self, workload: Workload, top_k: int) -> List[TuningRecord]:
return list(
filter(
lambda x: x.workload == workload,
sorted(self.records, key=lambda x: sum(x.run_secs) / len(x.run_secs)),
)
)[: int(top_k)]

def __len__(self) -> int:
return len(self.records)

def print_results(self) -> None:
print("\n".join([str(r) for r in self.records]))

mod, _, _, _ = get_network(
name="resnet-18",
batch_size=1,
layout="NHWC",
dtype="float32",
)
database = DummyDatabase()
env = ApplyHistoryBest(database)
workload = database.commit_workload(MockModule)
database.commit_tuning_record(
TuningRecord(Schedule(MockModule).trace, [1.0], workload, Target("llvm"), [])
)
mod = env.query(task_name="mock-task", mod=mod, dispatched=[MockModule])
assert tvm.ir.structural_equal(mod, workload.mod)


if __name__ == "__main__":
sys.exit(pytest.main([__file__] + sys.argv[1:]))

0 comments on commit 94c4e0e

Please sign in to comment.