Skip to content

Commit

Permalink
[Meta Schedule] Add customizable search space to PostOrderApply. (#16)
Browse files Browse the repository at this point in the history
* Add customizable search space to PostOrderApply.

* Minor modification for None rule.

* Change rule to schedule rule.

* Finish relay tuning with ApplyHistoryBest.

* Fix tune relay.
  • Loading branch information
zxybazh authored Jan 22, 2022
1 parent e62d1f2 commit e41b5b2
Show file tree
Hide file tree
Showing 7 changed files with 210 additions and 57 deletions.
5 changes: 2 additions & 3 deletions python/tvm/meta_schedule/testing/relay_workload.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ def get_torch_model(
import os # type: ignore # pylint: disable=import-error,import-outside-toplevel

def do_trace(model, inp):
model.eval()
model_trace = torch.jit.trace(model, inp)
model_trace.eval()
return model_trace
Expand Down Expand Up @@ -178,14 +179,12 @@ def do_trace(model, inp):
}
configuration = config_dict[model_name]
model = transformers.BertModel(configuration)
input_name = "input_ids"
A = torch.randint(10000, input_shape)

model.eval()
scripted_model = torch.jit.trace(model, [A], strict=False)

input_name = "input_ids"
shape_list = [(input_name, input_shape)]
shape_list = [("input_ids", input_shape)]
mod, params = relay.frontend.from_pytorch(scripted_model, shape_list)
return mod, params
else:
Expand Down
44 changes: 23 additions & 21 deletions python/tvm/meta_schedule/tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,20 @@
import logging
import os.path
from typing import Callable, Dict, List, Optional, Union
from tvm.ir.base import structural_equal, structural_hash

import tvm
from tvm import relay
from tvm._ffi.registry import register_func
from tvm.relay import Function as RelayFunc
from tvm.relay.backend.executor_factory import ExecutorFactoryModule
from tvm.ir.base import structural_equal, structural_hash
from tvm.ir.module import IRModule
from tvm.runtime import NDArray
from tvm.meta_schedule.integration import extract_task_from_relay
from tvm.target.target import Target
from tvm.te import Tensor, create_prim_func
from tvm.tir import PrimFunc, Schedule
from tvm.relay import Function as RelayFunc

from .integration import extract_task_from_relay, ApplyHistoryBest
from .builder import Builder, LocalBuilder
from .cost_model import CostModel, XGBModel
from .database import Database, JSONDatabase, TuningRecord
Expand Down Expand Up @@ -216,13 +220,19 @@ class Parse:
"""Parse tuning configuration from user inputs."""

@staticmethod
@register_func("tvm.meta_schedule.tune.parse_mod") # for use in ApplyHistoryBest
def _mod(mod: Union[PrimFunc, IRModule]) -> IRModule:
if isinstance(mod, PrimFunc):
mod = mod.with_attr("global_symbol", "main")
mod = mod.with_attr("tir.noalias", True)
mod = IRModule({"main": mod})
if not isinstance(mod, IRModule):
raise TypeError(f"Expected `mod` to be PrimFunc or IRModule, but gets: {mod}")
# in order to make sure the mod can be found in ApplyHistoryBest
# different func name can cause structural unequal
if "main" not in mod.global_var_map_:
(func_name,) = [global_var for global_var in mod.global_var_map_]
mod = IRModule({"main": mod[func_name]})
return mod

@staticmethod
Expand Down Expand Up @@ -615,7 +625,7 @@ def tune_relay(
postprocs: Optional[TypePostproc] = None,
mutator_probs: Optional[TypeMutatorProb] = None,
num_threads: Optional[int] = None,
) -> List[Optional[Schedule]]:
) -> ExecutorFactoryModule:
"""Tune a TIR IRModule with a given target.
Parameters
Expand Down Expand Up @@ -647,8 +657,8 @@ def tune_relay(
Returns
-------
schs : List[Optional[Schedule]]
The tuned schedules.
lib : ExecutorFactoryModule
The built runtime module for the given relay workload.
"""

logger.info("Working directory: %s", work_dir)
Expand All @@ -660,11 +670,10 @@ def tune_relay(
# parse the tuning contexts
for task in extracted_tasks:
assert len(task.dispatched) == 1, "Only size 1 dispatched task list is supported for now"
mod = Parse._mod(task.dispatched[0])
tune_contexts.append(
Parse._tune_context(
tune_context=None,
mod=mod,
mod=Parse._mod(task.dispatched[0]),
target=target,
config=config,
task_name=task.task_name,
Expand Down Expand Up @@ -704,16 +713,9 @@ def tune_relay(
)
# pylint: enable=protected-access
task_scheduler.tune()
schs: List[Schedule] = []
for task in tasks:
mod = task.mod
workload = database.commit_workload(mod)
bests: List[TuningRecord] = database.get_top_k(workload, top_k=1)
if not bests:
schs.append(None)
else:
assert len(bests) == 1
sch = Schedule(mod)
bests[0].trace.apply_to_schedule(sch, remove_postproc=False)
schs.append(sch)
return schs
with ApplyHistoryBest(database):
with tvm.transform.PassContext(
opt_level=3,
config={"relay.backend.use_meta_schedule": True},
):
return relay.build(mod, target=target, params=params)
13 changes: 9 additions & 4 deletions src/meta_schedule/integration.cc
Original file line number Diff line number Diff line change
Expand Up @@ -117,12 +117,17 @@ Optional<ObjectRef> ApplyHistoryBestNode::Query(runtime::String task_name, IRMod
IRModule prim_mod = dispatched.value()[0];
ICHECK(HasOnlyOneFunction<tir::PrimFunc>(prim_mod)) << prim_mod;
ICHECK(HasOnlyOneFunction<relay::Function>(mod)) << mod;
const auto* parse_mod_func = runtime::Registry::Get("tvm.meta_schedule.tune.parse_mod");
prim_mod = (*parse_mod_func)(prim_mod);
if (database->HasWorkload(prim_mod)) {
Array<TuningRecord> records = database->GetTopK(database->CommitWorkload(prim_mod), 1);
ICHECK(records.size() == 1) << "No records was found for given workload" << prim_mod;
return records[0]->workload->mod;
} else
return NullOpt;
// todo(@zxybazh): check if records always exists when the database has the workload
if (records.size() == 1) {
LOG(INFO) << "Applied history best for " << task_name << "!";
return records[0]->workload->mod;
}
}
return NullOpt;
}

/**************** FFI ****************/
Expand Down
13 changes: 12 additions & 1 deletion src/meta_schedule/space_generator/post_order_apply.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,18 @@ class BlockCollector : public tir::StmtVisitor {
blocks_to_collect_.clear();
VisitStmt(func->body);
for (const String& block_name : blocks_to_collect_) {
results_.push_back(sch_->GetBlock(block_name, func_name_));
tir::BlockRV block_rv = sch_->GetBlock(block_name, func_name_);
// pick out the blocks with annotation for customized search space
if (Optional<ObjectRef> custom_sch_rule_name_opt =
tir::GetAnn<String>(sch_->GetSRef(block_rv), "schedule_rule")) {
String custom_sch_rule_name = Downcast<String>(custom_sch_rule_name_opt.value());
if (custom_sch_rule_name != "None") {
const auto* custom_sch_rule_func = runtime::Registry::Get(custom_sch_rule_name);
(*custom_sch_rule_func)(sch_, block_rv);
}
} else {
results_.push_back(block_rv);
}
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion src/meta_schedule/task_scheduler/task_scheduler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ void TaskSchedulerNode::Tune() {
} else {
SetTaskStopped(task_id);
--running_tasks;
LOG(INFO) << "Task #" << task_id << " has finished. Remaining task(s): " << running_tasks;
LOG(INFO) << "Task #" << task_id + 1 << " has finished. Remaining task(s): " << running_tasks;
}
}
ICHECK_EQ(running_tasks, 0) << "Not all tasks are finished";
Expand Down
76 changes: 76 additions & 0 deletions tests/python/unittest/test_meta_schedule_post_order_apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from tvm.script import tir as T
from tvm.target import Target
from tvm.tir.schedule import BlockRV, Schedule
from tvm import register_func


# pylint: disable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument,
Expand All @@ -50,6 +51,42 @@ def main(a: T.handle, b: T.handle, c: T.handle) -> None:
C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj]


@tvm.script.ir_module
class MatmulCustomized:
@T.prim_func
def main(a: T.handle, b: T.handle, c: T.handle) -> None:
T.func_attr({"global_symbol": "main"})
A = T.match_buffer(a, (1024, 1024), "float32")
B = T.match_buffer(b, (1024, 1024), "float32")
C = T.match_buffer(c, (1024, 1024), "float32")
with T.block("root"):
for i, j, k in T.grid(1024, 1024, 1024):
with T.block("matmul"):
T.block_attr({"schedule_rule": "tvm.meta_schedule.test.custom_search_space"})
vi, vj, vk = T.axis.remap("SSR", [i, j, k])
with T.init():
C[vi, vj] = 0.0
C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj]


@tvm.script.ir_module
class MatmulCustomizedNoneRule:
@T.prim_func
def main(a: T.handle, b: T.handle, c: T.handle) -> None:
T.func_attr({"global_symbol": "main"})
A = T.match_buffer(a, (1024, 1024), "float32")
B = T.match_buffer(b, (1024, 1024), "float32")
C = T.match_buffer(c, (1024, 1024), "float32")
with T.block("root"):
T.block_attr({"schedule_rule": "None"})
for i, j, k in T.grid(1024, 1024, 1024):
with T.block("matmul"):
T.block_attr({"schedule_rule": "None"})
vi, vj, vk = T.axis.remap("SSR", [i, j, k])
with T.init():
C[vi, vj] = 0.0
C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj]

@tvm.script.ir_module
class DuplicateMatmul:
@T.prim_func
Expand Down Expand Up @@ -338,5 +375,44 @@ def correct_trace(a, b, c, d):
)


def test_meta_schedule_post_order_apply_custom_search_space():
@register_func("tvm.meta_schedule.test.custom_search_space")
def custom_search_space_func(sch: Schedule, block: BlockRV):
raise ValueError("Customized search space triggered!")

mod = MatmulCustomized
context = TuneContext(
mod=mod,
target=Target("llvm"),
task_name="Custom Search Space Task",
sch_rules=[],
)
function_called = False
post_order_apply = PostOrderApply()
post_order_apply.initialize_with_tune_context(context)
with pytest.raises(ValueError, match="Customized search space triggered!"):
_ = post_order_apply.generate_design_space(mod)


def test_meta_schedule_post_order_apply_custom_search_space_none_rule():
class DontCallThisRule(PyScheduleRule):
def initialize_with_tune_context(self, tune_context: "TuneContext") -> None:
pass

def apply(self, sch: Schedule, block: BlockRV) -> List[Schedule]:
raise RuntimeError("This schedule rule should not be called!")

mod = MatmulCustomizedNoneRule
context = TuneContext(
mod=mod,
target=Target("llvm"),
task_name="Custom Search Space Task",
sch_rules=[DontCallThisRule()],
)
post_order_apply = PostOrderApply()
post_order_apply.initialize_with_tune_context(context)
_ = post_order_apply.generate_design_space(mod)


if __name__ == "__main__":
sys.exit(pytest.main([__file__] + sys.argv[1:]))
Loading

0 comments on commit e41b5b2

Please sign in to comment.