diff --git a/python/tvm/meta_schedule/testing/relay_workload.py b/python/tvm/meta_schedule/testing/relay_workload.py index 2ad448808bdb..bf9287a8eb18 100644 --- a/python/tvm/meta_schedule/testing/relay_workload.py +++ b/python/tvm/meta_schedule/testing/relay_workload.py @@ -131,6 +131,7 @@ def get_torch_model( import os # type: ignore # pylint: disable=import-error,import-outside-toplevel def do_trace(model, inp): + model.eval() model_trace = torch.jit.trace(model, inp) model_trace.eval() return model_trace @@ -178,14 +179,12 @@ def do_trace(model, inp): } configuration = config_dict[model_name] model = transformers.BertModel(configuration) - input_name = "input_ids" A = torch.randint(10000, input_shape) model.eval() scripted_model = torch.jit.trace(model, [A], strict=False) - input_name = "input_ids" - shape_list = [(input_name, input_shape)] + shape_list = [("input_ids", input_shape)] mod, params = relay.frontend.from_pytorch(scripted_model, shape_list) return mod, params else: diff --git a/python/tvm/meta_schedule/tune.py b/python/tvm/meta_schedule/tune.py index 16f4104ad8ec..ee9198eb1def 100644 --- a/python/tvm/meta_schedule/tune.py +++ b/python/tvm/meta_schedule/tune.py @@ -19,16 +19,20 @@ import logging import os.path from typing import Callable, Dict, List, Optional, Union -from tvm.ir.base import structural_equal, structural_hash +import tvm +from tvm import relay +from tvm._ffi.registry import register_func +from tvm.relay import Function as RelayFunc +from tvm.relay.backend.executor_factory import ExecutorFactoryModule +from tvm.ir.base import structural_equal, structural_hash from tvm.ir.module import IRModule from tvm.runtime import NDArray -from tvm.meta_schedule.integration import extract_task_from_relay from tvm.target.target import Target from tvm.te import Tensor, create_prim_func from tvm.tir import PrimFunc, Schedule -from tvm.relay import Function as RelayFunc +from .integration import extract_task_from_relay, ApplyHistoryBest from .builder import Builder, LocalBuilder from .cost_model import CostModel, XGBModel from .database import Database, JSONDatabase, TuningRecord @@ -216,6 +220,7 @@ class Parse: """Parse tuning configuration from user inputs.""" @staticmethod + @register_func("tvm.meta_schedule.tune.parse_mod") # for use in ApplyHistoryBest def _mod(mod: Union[PrimFunc, IRModule]) -> IRModule: if isinstance(mod, PrimFunc): mod = mod.with_attr("global_symbol", "main") @@ -223,6 +228,11 @@ def _mod(mod: Union[PrimFunc, IRModule]) -> IRModule: mod = IRModule({"main": mod}) if not isinstance(mod, IRModule): raise TypeError(f"Expected `mod` to be PrimFunc or IRModule, but gets: {mod}") + # in order to make sure the mod can be found in ApplyHistoryBest + # different func name can cause structural unequal + if "main" not in mod.global_var_map_: + (func_name,) = [global_var for global_var in mod.global_var_map_] + mod = IRModule({"main": mod[func_name]}) return mod @staticmethod @@ -615,7 +625,7 @@ def tune_relay( postprocs: Optional[TypePostproc] = None, mutator_probs: Optional[TypeMutatorProb] = None, num_threads: Optional[int] = None, -) -> List[Optional[Schedule]]: +) -> ExecutorFactoryModule: """Tune a TIR IRModule with a given target. Parameters @@ -647,8 +657,8 @@ def tune_relay( Returns ------- - schs : List[Optional[Schedule]] - The tuned schedules. + lib : ExecutorFactoryModule + The built runtime module for the given relay workload. """ logger.info("Working directory: %s", work_dir) @@ -660,11 +670,10 @@ def tune_relay( # parse the tuning contexts for task in extracted_tasks: assert len(task.dispatched) == 1, "Only size 1 dispatched task list is supported for now" - mod = Parse._mod(task.dispatched[0]) tune_contexts.append( Parse._tune_context( tune_context=None, - mod=mod, + mod=Parse._mod(task.dispatched[0]), target=target, config=config, task_name=task.task_name, @@ -704,16 +713,9 @@ def tune_relay( ) # pylint: enable=protected-access task_scheduler.tune() - schs: List[Schedule] = [] - for task in tasks: - mod = task.mod - workload = database.commit_workload(mod) - bests: List[TuningRecord] = database.get_top_k(workload, top_k=1) - if not bests: - schs.append(None) - else: - assert len(bests) == 1 - sch = Schedule(mod) - bests[0].trace.apply_to_schedule(sch, remove_postproc=False) - schs.append(sch) - return schs + with ApplyHistoryBest(database): + with tvm.transform.PassContext( + opt_level=3, + config={"relay.backend.use_meta_schedule": True}, + ): + return relay.build(mod, target=target, params=params) diff --git a/src/meta_schedule/integration.cc b/src/meta_schedule/integration.cc index 130b3a534b70..f62f5a91d394 100644 --- a/src/meta_schedule/integration.cc +++ b/src/meta_schedule/integration.cc @@ -117,12 +117,17 @@ Optional ApplyHistoryBestNode::Query(runtime::String task_name, IRMod IRModule prim_mod = dispatched.value()[0]; ICHECK(HasOnlyOneFunction(prim_mod)) << prim_mod; ICHECK(HasOnlyOneFunction(mod)) << mod; + const auto* parse_mod_func = runtime::Registry::Get("tvm.meta_schedule.tune.parse_mod"); + prim_mod = (*parse_mod_func)(prim_mod); if (database->HasWorkload(prim_mod)) { Array records = database->GetTopK(database->CommitWorkload(prim_mod), 1); - ICHECK(records.size() == 1) << "No records was found for given workload" << prim_mod; - return records[0]->workload->mod; - } else - return NullOpt; + // todo(@zxybazh): check if records always exists when the database has the workload + if (records.size() == 1) { + LOG(INFO) << "Applied history best for " << task_name << "!"; + return records[0]->workload->mod; + } + } + return NullOpt; } /**************** FFI ****************/ diff --git a/src/meta_schedule/space_generator/post_order_apply.cc b/src/meta_schedule/space_generator/post_order_apply.cc index 3f685407817b..fff7c2711218 100644 --- a/src/meta_schedule/space_generator/post_order_apply.cc +++ b/src/meta_schedule/space_generator/post_order_apply.cc @@ -40,7 +40,18 @@ class BlockCollector : public tir::StmtVisitor { blocks_to_collect_.clear(); VisitStmt(func->body); for (const String& block_name : blocks_to_collect_) { - results_.push_back(sch_->GetBlock(block_name, func_name_)); + tir::BlockRV block_rv = sch_->GetBlock(block_name, func_name_); + // pick out the blocks with annotation for customized search space + if (Optional custom_sch_rule_name_opt = + tir::GetAnn(sch_->GetSRef(block_rv), "schedule_rule")) { + String custom_sch_rule_name = Downcast(custom_sch_rule_name_opt.value()); + if (custom_sch_rule_name != "None") { + const auto* custom_sch_rule_func = runtime::Registry::Get(custom_sch_rule_name); + (*custom_sch_rule_func)(sch_, block_rv); + } + } else { + results_.push_back(block_rv); + } } } } diff --git a/src/meta_schedule/task_scheduler/task_scheduler.cc b/src/meta_schedule/task_scheduler/task_scheduler.cc index e3cc51a479ef..28f95b2dc0dd 100644 --- a/src/meta_schedule/task_scheduler/task_scheduler.cc +++ b/src/meta_schedule/task_scheduler/task_scheduler.cc @@ -138,7 +138,7 @@ void TaskSchedulerNode::Tune() { } else { SetTaskStopped(task_id); --running_tasks; - LOG(INFO) << "Task #" << task_id << " has finished. Remaining task(s): " << running_tasks; + LOG(INFO) << "Task #" << task_id + 1 << " has finished. Remaining task(s): " << running_tasks; } } ICHECK_EQ(running_tasks, 0) << "Not all tasks are finished"; diff --git a/tests/python/unittest/test_meta_schedule_post_order_apply.py b/tests/python/unittest/test_meta_schedule_post_order_apply.py index b78e67817ebf..8e13b31c3c53 100644 --- a/tests/python/unittest/test_meta_schedule_post_order_apply.py +++ b/tests/python/unittest/test_meta_schedule_post_order_apply.py @@ -29,6 +29,7 @@ from tvm.script import tir as T from tvm.target import Target from tvm.tir.schedule import BlockRV, Schedule +from tvm import register_func # pylint: disable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument, @@ -50,6 +51,42 @@ def main(a: T.handle, b: T.handle, c: T.handle) -> None: C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] +@tvm.script.ir_module +class MatmulCustomized: + @T.prim_func + def main(a: T.handle, b: T.handle, c: T.handle) -> None: + T.func_attr({"global_symbol": "main"}) + A = T.match_buffer(a, (1024, 1024), "float32") + B = T.match_buffer(b, (1024, 1024), "float32") + C = T.match_buffer(c, (1024, 1024), "float32") + with T.block("root"): + for i, j, k in T.grid(1024, 1024, 1024): + with T.block("matmul"): + T.block_attr({"schedule_rule": "tvm.meta_schedule.test.custom_search_space"}) + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + C[vi, vj] = 0.0 + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] + + +@tvm.script.ir_module +class MatmulCustomizedNoneRule: + @T.prim_func + def main(a: T.handle, b: T.handle, c: T.handle) -> None: + T.func_attr({"global_symbol": "main"}) + A = T.match_buffer(a, (1024, 1024), "float32") + B = T.match_buffer(b, (1024, 1024), "float32") + C = T.match_buffer(c, (1024, 1024), "float32") + with T.block("root"): + T.block_attr({"schedule_rule": "None"}) + for i, j, k in T.grid(1024, 1024, 1024): + with T.block("matmul"): + T.block_attr({"schedule_rule": "None"}) + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + C[vi, vj] = 0.0 + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] + @tvm.script.ir_module class DuplicateMatmul: @T.prim_func @@ -338,5 +375,44 @@ def correct_trace(a, b, c, d): ) +def test_meta_schedule_post_order_apply_custom_search_space(): + @register_func("tvm.meta_schedule.test.custom_search_space") + def custom_search_space_func(sch: Schedule, block: BlockRV): + raise ValueError("Customized search space triggered!") + + mod = MatmulCustomized + context = TuneContext( + mod=mod, + target=Target("llvm"), + task_name="Custom Search Space Task", + sch_rules=[], + ) + function_called = False + post_order_apply = PostOrderApply() + post_order_apply.initialize_with_tune_context(context) + with pytest.raises(ValueError, match="Customized search space triggered!"): + _ = post_order_apply.generate_design_space(mod) + + +def test_meta_schedule_post_order_apply_custom_search_space_none_rule(): + class DontCallThisRule(PyScheduleRule): + def initialize_with_tune_context(self, tune_context: "TuneContext") -> None: + pass + + def apply(self, sch: Schedule, block: BlockRV) -> List[Schedule]: + raise RuntimeError("This schedule rule should not be called!") + + mod = MatmulCustomizedNoneRule + context = TuneContext( + mod=mod, + target=Target("llvm"), + task_name="Custom Search Space Task", + sch_rules=[DontCallThisRule()], + ) + post_order_apply = PostOrderApply() + post_order_apply.initialize_with_tune_context(context) + _ = post_order_apply.generate_design_space(mod) + + if __name__ == "__main__": sys.exit(pytest.main([__file__] + sys.argv[1:])) diff --git a/tests/python/unittest/test_meta_schedule_tune_relay.py b/tests/python/unittest/test_meta_schedule_tune_relay.py index 02264f797127..09aaa08d5185 100644 --- a/tests/python/unittest/test_meta_schedule_tune_relay.py +++ b/tests/python/unittest/test_meta_schedule_tune_relay.py @@ -17,19 +17,63 @@ # pylint: disable=missing-docstring import logging import tempfile -from typing import List, Tuple - import pytest +import numpy as np +from typing import Tuple, List + +import tvm +from tvm import relay +from tvm.ir import IRModule +from tvm.runtime.ndarray import cpu, cuda +from tvm.target.target import Target +from tvm.contrib import graph_executor from tvm.meta_schedule import ReplayTraceConfig +from tvm.meta_schedule.database import PyDatabase, Workload, TuningRecord from tvm.meta_schedule.testing import MODEL_TYPE, MODEL_TYPES, get_torch_model from tvm.meta_schedule.tune import tune_relay -from tvm.target.target import Target -from tvm.tir import Schedule logging.basicConfig() logging.getLogger("tvm.meta_schedule").setLevel(logging.DEBUG) +class DummyDatabase(PyDatabase): + def __init__(self): + super().__init__() + self.records = [] + self.workload_reg = [] + + def has_workload(self, mod: IRModule) -> Workload: + for workload in self.workload_reg: + if tvm.ir.structural_equal(workload.mod, mod): + return True + return False + + def commit_tuning_record(self, record: TuningRecord) -> None: + self.records.append(record) + + def commit_workload(self, mod: IRModule) -> Workload: + for workload in self.workload_reg: + if tvm.ir.structural_equal(workload.mod, mod): + return workload + workload = Workload(mod) + self.workload_reg.append(workload) + return workload + + def get_top_k(self, workload: Workload, top_k: int) -> List[TuningRecord]: + return list( + filter( + lambda x: x.workload == workload, + sorted(self.records, key=lambda x: sum(x.run_secs) / len(x.run_secs)), + ) + )[: int(top_k)] + + def __len__(self) -> int: + return len(self.records) + + def print_results(self) -> None: + print("\n".join([str(r) for r in self.records])) + + @pytest.mark.skip("Integration test") @pytest.mark.parametrize("model_name", ["resnet18", "mobilenet_v2", "bert_base"]) @pytest.mark.parametrize("batch_size", [1]) @@ -39,19 +83,26 @@ def test_meta_schedule_tune_relay(model_name: str, batch_size: int, target: str) pytest.skip("inception_v3 does not handle batch_size of 1") input_shape: Tuple[int, ...] - if MODEL_TYPES[model_name] == MODEL_TYPE.IMAGE_CLASSIFICATION: - input_shape = (batch_size, 3, 299, 299) - elif MODEL_TYPES[model_name] == MODEL_TYPE.SEGMENTATION: - input_shape = (batch_size, 3, 299, 299) - elif MODEL_TYPES[model_name] == MODEL_TYPE.OBJECT_DETECTION: - input_shape = (1, 3, 300, 300) - elif MODEL_TYPES[model_name] == MODEL_TYPE.VIDEO_CLASSIFICATION: - input_shape = (batch_size, 3, 3, 299, 299) - elif MODEL_TYPES[model_name] == MODEL_TYPE.TEXT_CLASSIFICATION: + input_name = "input0" + dev = tvm.cpu() if str(target).startswith("llvm") else cuda() + if MODEL_TYPES[model_name] == MODEL_TYPE.TEXT_CLASSIFICATION: seq_length = 128 + input_name = "input_ids" input_shape = (batch_size, seq_length) + data = tvm.nd.array(np.random.randint(0, 30521, size=input_shape), dev) # embedding size else: - raise ValueError("Unsupported model: " + model_name) + if MODEL_TYPES[model_name] == MODEL_TYPE.IMAGE_CLASSIFICATION: + input_shape = (batch_size, 3, 299, 299) + elif MODEL_TYPES[model_name] == MODEL_TYPE.SEGMENTATION: + input_shape = (batch_size, 3, 299, 299) + elif MODEL_TYPES[model_name] == MODEL_TYPE.OBJECT_DETECTION: + input_shape = (1, 3, 300, 300) + elif MODEL_TYPES[model_name] == MODEL_TYPE.VIDEO_CLASSIFICATION: + input_shape = (batch_size, 3, 3, 299, 299) + else: + raise ValueError("Unsupported model: " + model_name) + data = tvm.nd.array(np.random.randn(*input_shape).astype("float32"), dev) + output_shape: Tuple[int, int] = (batch_size, 1000) mod, params = get_torch_model( @@ -63,7 +114,8 @@ def test_meta_schedule_tune_relay(model_name: str, batch_size: int, target: str) with tempfile.TemporaryDirectory() as work_dir: target = Target(target) - schs: List[Schedule] = tune_relay( + database = DummyDatabase() + rt_mod: tvm.module = tune_relay( mod=mod, params=params, target=target, @@ -72,20 +124,28 @@ def test_meta_schedule_tune_relay(model_name: str, batch_size: int, target: str) num_trials_total=32, ), work_dir=work_dir, + database=database, ) - for i, sch in enumerate(schs): - print("-" * 10 + f" Part {i+1}/{len(schs)} " + "-" * 10) - if sch is None: - print("No valid schedule found!") - else: - print(sch.mod.script()) - print(sch.trace) + # Compile without meta-scheduler for correctness check + with tvm.transform.PassContext(opt_level=0): + rt_mod2 = relay.build(mod, target=target, params=params) + + def get_output(data, lib): + module = graph_executor.GraphModule(lib["default"](dev)) + module.set_input(input_name, data) + module.run() + return module.get_output(0).numpy() + + # Check correctness + actual_output = get_output(data, rt_mod) + expected_output = get_output(data, rt_mod2) + assert np.allclose(actual_output, expected_output, rtol=1e-4, atol=2e-4) if __name__ == """__main__""": - test_meta_schedule_tune_relay("resnet18", 1, "llvm --num-cores=16") + # test_meta_schedule_tune_relay("resnet18", 1, "llvm --num-cores=16") test_meta_schedule_tune_relay("resnet18", 1, "nvidia/geforce-rtx-3070") - test_meta_schedule_tune_relay("mobilenet_v2", 1, "llvm --num-cores=16") - test_meta_schedule_tune_relay("mobilenet_v2", 1, "nvidia/geforce-rtx-3070") - test_meta_schedule_tune_relay("bert_base", 1, "llvm --num-cores=16") - test_meta_schedule_tune_relay("bert_base", 1, "nvidia/geforce-rtx-3070") + # test_meta_schedule_tune_relay("mobilenet_v2", 1, "llvm --num-cores=16") + # test_meta_schedule_tune_relay("mobilenet_v2", 1, "nvidia/geforce-rtx-3070") + # test_meta_schedule_tune_relay("bert_base", 1, "llvm --num-cores=16") + # test_meta_schedule_tune_relay("bert_base", 1, "nvidia/geforce-rtx-3070")