Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Meta Schedule] Add customizable search space to PostOrderApply. #16

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions python/tvm/meta_schedule/testing/relay_workload.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ def get_torch_model(
import os # type: ignore # pylint: disable=import-error,import-outside-toplevel

def do_trace(model, inp):
model.eval()
model_trace = torch.jit.trace(model, inp)
model_trace.eval()
return model_trace
Expand Down Expand Up @@ -178,14 +179,12 @@ def do_trace(model, inp):
}
configuration = config_dict[model_name]
model = transformers.BertModel(configuration)
input_name = "input_ids"
A = torch.randint(10000, input_shape)

model.eval()
scripted_model = torch.jit.trace(model, [A], strict=False)

input_name = "input_ids"
shape_list = [(input_name, input_shape)]
shape_list = [("input_ids", input_shape)]
mod, params = relay.frontend.from_pytorch(scripted_model, shape_list)
return mod, params
else:
Expand Down
44 changes: 23 additions & 21 deletions python/tvm/meta_schedule/tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,20 @@
import logging
import os.path
from typing import Callable, Dict, List, Optional, Union
from tvm.ir.base import structural_equal, structural_hash

import tvm
from tvm import relay
from tvm._ffi.registry import register_func
from tvm.relay import Function as RelayFunc
from tvm.relay.backend.executor_factory import ExecutorFactoryModule
from tvm.ir.base import structural_equal, structural_hash
from tvm.ir.module import IRModule
from tvm.runtime import NDArray
from tvm.meta_schedule.integration import extract_task_from_relay
from tvm.target.target import Target
from tvm.te import Tensor, create_prim_func
from tvm.tir import PrimFunc, Schedule
from tvm.relay import Function as RelayFunc

from .integration import extract_task_from_relay, ApplyHistoryBest
from .builder import Builder, LocalBuilder
from .cost_model import CostModel, XGBModel
from .database import Database, JSONDatabase, TuningRecord
Expand Down Expand Up @@ -216,13 +220,19 @@ class Parse:
"""Parse tuning configuration from user inputs."""

@staticmethod
@register_func("tvm.meta_schedule.tune.parse_mod") # for use in ApplyHistoryBest
def _mod(mod: Union[PrimFunc, IRModule]) -> IRModule:
if isinstance(mod, PrimFunc):
mod = mod.with_attr("global_symbol", "main")
mod = mod.with_attr("tir.noalias", True)
mod = IRModule({"main": mod})
if not isinstance(mod, IRModule):
raise TypeError(f"Expected `mod` to be PrimFunc or IRModule, but gets: {mod}")
# in order to make sure the mod can be found in ApplyHistoryBest
# different func name can cause structural unequal
if "main" not in mod.global_var_map_:
(func_name,) = [global_var for global_var in mod.global_var_map_]
mod = IRModule({"main": mod[func_name]})
return mod

@staticmethod
Expand Down Expand Up @@ -615,7 +625,7 @@ def tune_relay(
postprocs: Optional[TypePostproc] = None,
mutator_probs: Optional[TypeMutatorProb] = None,
num_threads: Optional[int] = None,
) -> List[Optional[Schedule]]:
) -> ExecutorFactoryModule:
"""Tune a TIR IRModule with a given target.

Parameters
Expand Down Expand Up @@ -647,8 +657,8 @@ def tune_relay(

Returns
-------
schs : List[Optional[Schedule]]
The tuned schedules.
lib : ExecutorFactoryModule
The built runtime module for the given relay workload.
"""

logger.info("Working directory: %s", work_dir)
Expand All @@ -660,11 +670,10 @@ def tune_relay(
# parse the tuning contexts
for task in extracted_tasks:
assert len(task.dispatched) == 1, "Only size 1 dispatched task list is supported for now"
mod = Parse._mod(task.dispatched[0])
tune_contexts.append(
Parse._tune_context(
tune_context=None,
mod=mod,
mod=Parse._mod(task.dispatched[0]),
target=target,
config=config,
task_name=task.task_name,
Expand Down Expand Up @@ -704,16 +713,9 @@ def tune_relay(
)
# pylint: enable=protected-access
task_scheduler.tune()
schs: List[Schedule] = []
for task in tasks:
mod = task.mod
workload = database.commit_workload(mod)
bests: List[TuningRecord] = database.get_top_k(workload, top_k=1)
if not bests:
schs.append(None)
else:
assert len(bests) == 1
sch = Schedule(mod)
bests[0].trace.apply_to_schedule(sch, remove_postproc=False)
schs.append(sch)
return schs
with ApplyHistoryBest(database):
with tvm.transform.PassContext(
opt_level=3,
config={"relay.backend.use_meta_schedule": True},
):
return relay.build(mod, target=target, params=params)
13 changes: 9 additions & 4 deletions src/meta_schedule/integration.cc
Original file line number Diff line number Diff line change
Expand Up @@ -117,12 +117,17 @@ Optional<ObjectRef> ApplyHistoryBestNode::Query(runtime::String task_name, IRMod
IRModule prim_mod = dispatched.value()[0];
ICHECK(HasOnlyOneFunction<tir::PrimFunc>(prim_mod)) << prim_mod;
ICHECK(HasOnlyOneFunction<relay::Function>(mod)) << mod;
const auto* parse_mod_func = runtime::Registry::Get("tvm.meta_schedule.tune.parse_mod");
prim_mod = (*parse_mod_func)(prim_mod);
if (database->HasWorkload(prim_mod)) {
Array<TuningRecord> records = database->GetTopK(database->CommitWorkload(prim_mod), 1);
ICHECK(records.size() == 1) << "No records was found for given workload" << prim_mod;
return records[0]->workload->mod;
} else
return NullOpt;
// todo(@zxybazh): check if records always exists when the database has the workload
if (records.size() == 1) {
LOG(INFO) << "Applied history best for " << task_name << "!";
return records[0]->workload->mod;
}
}
return NullOpt;
}

/**************** FFI ****************/
Expand Down
13 changes: 12 additions & 1 deletion src/meta_schedule/space_generator/post_order_apply.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,18 @@ class BlockCollector : public tir::StmtVisitor {
blocks_to_collect_.clear();
VisitStmt(func->body);
for (const String& block_name : blocks_to_collect_) {
results_.push_back(sch_->GetBlock(block_name, func_name_));
tir::BlockRV block_rv = sch_->GetBlock(block_name, func_name_);
// pick out the blocks with annotation for customized search space
if (Optional<ObjectRef> custom_sch_rule_name_opt =
tir::GetAnn<String>(sch_->GetSRef(block_rv), "schedule_rule")) {
String custom_sch_rule_name = Downcast<String>(custom_sch_rule_name_opt.value());
if (custom_sch_rule_name != "None") {
const auto* custom_sch_rule_func = runtime::Registry::Get(custom_sch_rule_name);
(*custom_sch_rule_func)(sch_, block_rv);
}
} else {
results_.push_back(block_rv);
}
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion src/meta_schedule/task_scheduler/task_scheduler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ void TaskSchedulerNode::Tune() {
} else {
SetTaskStopped(task_id);
--running_tasks;
LOG(INFO) << "Task #" << task_id << " has finished. Remaining task(s): " << running_tasks;
LOG(INFO) << "Task #" << task_id + 1 << " has finished. Remaining task(s): " << running_tasks;
}
}
ICHECK_EQ(running_tasks, 0) << "Not all tasks are finished";
Expand Down
76 changes: 76 additions & 0 deletions tests/python/unittest/test_meta_schedule_post_order_apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from tvm.script import tir as T
from tvm.target import Target
from tvm.tir.schedule import BlockRV, Schedule
from tvm import register_func


# pylint: disable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument,
Expand All @@ -50,6 +51,42 @@ def main(a: T.handle, b: T.handle, c: T.handle) -> None:
C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj]


@tvm.script.ir_module
class MatmulCustomized:
@T.prim_func
def main(a: T.handle, b: T.handle, c: T.handle) -> None:
T.func_attr({"global_symbol": "main"})
A = T.match_buffer(a, (1024, 1024), "float32")
B = T.match_buffer(b, (1024, 1024), "float32")
C = T.match_buffer(c, (1024, 1024), "float32")
with T.block("root"):
for i, j, k in T.grid(1024, 1024, 1024):
with T.block("matmul"):
T.block_attr({"schedule_rule": "tvm.meta_schedule.test.custom_search_space"})
vi, vj, vk = T.axis.remap("SSR", [i, j, k])
with T.init():
C[vi, vj] = 0.0
C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj]


@tvm.script.ir_module
class MatmulCustomizedNoneRule:
@T.prim_func
def main(a: T.handle, b: T.handle, c: T.handle) -> None:
T.func_attr({"global_symbol": "main"})
A = T.match_buffer(a, (1024, 1024), "float32")
B = T.match_buffer(b, (1024, 1024), "float32")
C = T.match_buffer(c, (1024, 1024), "float32")
with T.block("root"):
T.block_attr({"schedule_rule": "None"})
for i, j, k in T.grid(1024, 1024, 1024):
with T.block("matmul"):
T.block_attr({"schedule_rule": "None"})
vi, vj, vk = T.axis.remap("SSR", [i, j, k])
with T.init():
C[vi, vj] = 0.0
C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj]

@tvm.script.ir_module
class DuplicateMatmul:
@T.prim_func
Expand Down Expand Up @@ -338,5 +375,44 @@ def correct_trace(a, b, c, d):
)


def test_meta_schedule_post_order_apply_custom_search_space():
@register_func("tvm.meta_schedule.test.custom_search_space")
def custom_search_space_func(sch: Schedule, block: BlockRV):
raise ValueError("Customized search space triggered!")

mod = MatmulCustomized
context = TuneContext(
mod=mod,
target=Target("llvm"),
task_name="Custom Search Space Task",
sch_rules=[],
)
function_called = False
post_order_apply = PostOrderApply()
post_order_apply.initialize_with_tune_context(context)
with pytest.raises(ValueError, match="Customized search space triggered!"):
_ = post_order_apply.generate_design_space(mod)


def test_meta_schedule_post_order_apply_custom_search_space_none_rule():
class DontCallThisRule(PyScheduleRule):
def initialize_with_tune_context(self, tune_context: "TuneContext") -> None:
pass

def apply(self, sch: Schedule, block: BlockRV) -> List[Schedule]:
raise RuntimeError("This schedule rule should not be called!")

mod = MatmulCustomizedNoneRule
context = TuneContext(
mod=mod,
target=Target("llvm"),
task_name="Custom Search Space Task",
sch_rules=[DontCallThisRule()],
)
post_order_apply = PostOrderApply()
post_order_apply.initialize_with_tune_context(context)
_ = post_order_apply.generate_design_space(mod)


if __name__ == "__main__":
sys.exit(pytest.main([__file__] + sys.argv[1:]))
Loading