diff --git a/python/tvm/meta_schedule/testing/__init__.py b/python/tvm/meta_schedule/testing/__init__.py index bafdd521bffb..24e57928778d 100644 --- a/python/tvm/meta_schedule/testing/__init__.py +++ b/python/tvm/meta_schedule/testing/__init__.py @@ -15,4 +15,11 @@ # specific language governing permissions and limitations # under the License. """Testing utilities in meta schedule""" -from .utils import DummyDatabase, DummyBuilder, DummyRunner, DummyRunnerFuture, DummyMutator +from .utils import ( + DummyDatabase, + DummyBuilder, + DummyRunner, + DummyRunnerFuture, + DummyMutator, + apply_fixed_schedules, +) diff --git a/python/tvm/meta_schedule/testing/utils.py b/python/tvm/meta_schedule/testing/utils.py index b7ef34914089..e22677a3b918 100644 --- a/python/tvm/meta_schedule/testing/utils.py +++ b/python/tvm/meta_schedule/testing/utils.py @@ -15,11 +15,14 @@ # specific language governing permissions and limitations # under the License. """Testing utilitiy functions in meta schedule""" -from typing import List, Optional import random +from typing import List, Optional, Callable, Dict, Union import tvm - +from tvm.relay import Function as RelayFunc +from tvm.tir import Schedule +from tvm.target import Target +from tvm.runtime import NDArray from tvm.meta_schedule import TuneContext # pylint: disable=unused-import from tvm.meta_schedule.utils import derived_object from tvm.meta_schedule.mutator.mutator import PyMutator @@ -32,6 +35,9 @@ PyRunnerFuture, PyRunner, ) +from tvm.meta_schedule.tune import Parse, extract_task_from_relay +from tvm.meta_schedule.integration import ExtractedTask + from tvm.ir import IRModule from tvm.tir.schedule import Trace @@ -110,3 +116,46 @@ def initialize_with_tune_context(self, context: "TuneContext") -> None: def apply(self, trace: Trace, _) -> Optional[Trace]: return Trace(trace.insts, {}) + + +def apply_fixed_schedules( + relay_mod: Union[RelayFunc, IRModule], + target: Union[str, Target], + params: Optional[Dict[str, NDArray]], + schedule_fn: Callable[[ExtractedTask, Schedule], bool], +): + """Apply fixed schedules (manually written, without any tunable knobs) as specified by + schedule_fn to extracted tasks, and return a database that can be passed to ApplyHistoryBest. + + Parameters + ---------- + mod : Union[RelayFunc, IRModule] + The Relay module to apply fixed schedules. + target : Union[str, Target] + The target used to extract tasks. + params : Optional[Dict[str, tvm.runtime.NDArray]] + The associated parameters of the module. + schedule_fn : Callable[[ExtractedTask, Schedule], bool] + A callable that is applied for each extracted task and the corresponding default schedule. + Returns True if the given schedule should be committed to the database, False otherwise. + + Returns + ------- + database : Database + The database containing dummy tuning records for manually scheduled traces. + """ + target = Target(target) if isinstance(target, str) else target + extracted_tasks = extract_task_from_relay(relay_mod, target, params) + + database = DummyDatabase() + + for task in extracted_tasks: + mod = Parse._mod(task.dispatched[0]) + sch = Schedule(mod) + + if schedule_fn(task, sch): + workload = database.commit_workload(mod) + tune_rec = TuningRecord(sch.trace, [0.0], workload, target, []) + database.commit_tuning_record(tune_rec) + + return database diff --git a/tests/python/unittest/test_meta_schedule_multi_anchor.py b/tests/python/unittest/test_meta_schedule_multi_anchor.py index e59639170d0f..78d0ddeda32f 100644 --- a/tests/python/unittest/test_meta_schedule_multi_anchor.py +++ b/tests/python/unittest/test_meta_schedule_multi_anchor.py @@ -14,16 +14,12 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -import os -import tempfile - import numpy as np import tvm import tvm.testing from tvm import relay -from tvm.meta_schedule.tune import Parse, extract_task_from_relay -from tvm.meta_schedule.database import TuningRecord, JSONDatabase +from tvm.meta_schedule.testing import apply_fixed_schedules from tvm.meta_schedule.integration import ApplyHistoryBest @@ -72,39 +68,20 @@ def test_dense_dense(): # print(relay.transform.InferType()(relay_mod)) - target = "llvm" - data_np = np.random.randn(*data_shape).astype("float32") weight1_np = np.random.randn(*weight_shape).astype("float32") weight2_np = np.random.randn(*weight_shape).astype("float32") + target = "llvm" params = {"weight1": weight1_np, "weight2": weight2_np} - extracted_tasks = extract_task_from_relay(relay_mod, target, params) - - assert len(extracted_tasks) == 1 - - task = extracted_tasks[0] - - mod = Parse._mod(task.dispatched[0]) - - with tempfile.TemporaryDirectory() as work_dir: - database = JSONDatabase( - path_workload=os.path.join(work_dir, "database_workload.json"), - path_tuning_record=os.path.join(work_dir, "database_tuning_record.json"), - ) - - workload = database.commit_workload(mod) - - sch = tvm.tir.Schedule(mod) - - schedule_dense_dense(sch) - - # print(sch.mod.script()) - - tune_rec = TuningRecord(sch.trace, [0.0], workload, tvm.target.Target(target), []) + def schedule_fn(task, sch): + if "nn_dense_nn_dense" in task.task_name: + schedule_dense_dense(sch) + return True + return False - database.commit_tuning_record(tune_rec) + database = apply_fixed_schedules(relay_mod, target, params, schedule_fn) with ApplyHistoryBest(database): with tvm.transform.PassContext( diff --git a/tests/python/unittest/test_meta_schedule_tune_relay.py b/tests/python/unittest/test_meta_schedule_tune_relay.py index 63fbac6748c7..76cd82920c35 100644 --- a/tests/python/unittest/test_meta_schedule_tune_relay.py +++ b/tests/python/unittest/test_meta_schedule_tune_relay.py @@ -31,8 +31,8 @@ from tvm.meta_schedule.database import JSONDatabase, PyDatabase, TuningRecord, Workload from tvm.meta_schedule.integration import ApplyHistoryBest from tvm.meta_schedule.testing.relay_workload import get_network +from tvm.meta_schedule.testing import apply_fixed_schedules from tvm.meta_schedule.tune import ( - Parse, extract_task_from_relay, tune_extracted_tasks, tune_relay, @@ -480,52 +480,46 @@ def manual_tir_common(do_tune=False): params = {"weight": weight_np, "bias": bias_np} - extracted_tasks = extract_task_from_relay(relay_mod, target, params) + if do_tune: + extracted_tasks = extract_task_from_relay(relay_mod, target, params) - # Filter out tasks that we don't intend to schedule / tune with TIR. - tune_tasks = list( - filter( - lambda task: "dense" in task.task_name, - extracted_tasks, + # Filter out tasks that we don't intend to schedule / tune with TIR. + tune_tasks = list( + filter( + lambda task: "dense" in task.task_name, + extracted_tasks, + ) + ) + config = ReplayTraceConfig( + num_trials_per_iter=64, + max_trials_per_task=64, + max_trials_global=20000, ) - ) - with tempfile.TemporaryDirectory() as work_dir: - if do_tune: - config = ReplayTraceConfig( - num_trials_per_iter=64, - max_trials_per_task=64, - max_trials_global=20000, - ) + with tempfile.TemporaryDirectory() as work_dir: # postprocs=lambda: [] is important to prevent default post processors from # tampering with the manual schedule. database = tune_extracted_tasks( tune_tasks, target, config, work_dir=work_dir, postprocs=lambda: [] ) - else: - database = JSONDatabase( - path_workload=osp.join(work_dir, "database_workload.json"), - path_tuning_record=osp.join(work_dir, "database_tuning_record.json"), - ) + else: + + def schedule_fn(task, sch): + if "dense" not in task.task_name: + return False - for task in tune_tasks: - mod = Parse._mod(task.dispatched[0]) - workload = database.commit_workload(mod) + block = sch.get_block("compute") - sch = tvm.tir.Schedule(mod) - block = sch.get_block("compute") + # Looks up schedule_rule annotation. See the comment in test_tune_relay_manual_tir_vnni(). + schedule_rule = sch.get(block).annotations["schedule_rule"] - # Looks up schedule_rule annotation. See the comment in test_tune_relay_manual_tir_vnni(). - schedule_rule = sch.get(block).annotations["schedule_rule"] + assert "dense_vnni" in schedule_rule - if "dense_vnni" in schedule_rule: - schedule_dense(block, M, False, sch) + schedule_dense(block, M, False, sch) - # [0.0] is for dummy measurement. There is only one tuning record so ApplyHistoryBest - # will always have only one option. - tune_rec = TuningRecord(sch.trace, [0.0], workload, tvm.target.Target(target), []) + return True - database.commit_tuning_record(tune_rec) + database = apply_fixed_schedules(relay_mod, target, params, schedule_fn) with ApplyHistoryBest(database): with tvm.transform.PassContext( diff --git a/tests/python/unittest/test_meta_schedule_tune_tir.py b/tests/python/unittest/test_meta_schedule_tune_tir.py index 5ac6a24a423a..ebce33965914 100644 --- a/tests/python/unittest/test_meta_schedule_tune_tir.py +++ b/tests/python/unittest/test_meta_schedule_tune_tir.py @@ -91,127 +91,6 @@ def test_tune_matmul_cuda(): print(sch.trace) -@pytest.mark.skip("Integeration test") -def test_tune_matmul_cuda_tensor_core(): - n = 512 - mod = create_prim_func(te_workload.matmul_fp16(n, n, n)) - target = Target("nvidia/geforce-rtx-3070") - config = ReplayTraceConfig( - num_trials_per_iter=32, - max_trials_per_task=320, - max_trials_global=320, - ) - - class DefaultTensorCore: - @staticmethod - def _sch_rules(): - from tvm.meta_schedule import ( - schedule_rule as M, # pylint: disable=import-outside-toplevel - ) - - return [ - M.AutoInline( - into_producer=False, - into_consumer=True, - inline_const_tensor=True, - disallow_if_then_else=False, - require_injective=False, - require_ordered=False, - disallow_op=None, - ), - M.MultiLevelTiling( - structure="SSSRRSRS", - tile_binds=["blockIdx.x", "blockIdx.y", "threadIdx.y"], - # use_tensor_core=True, - max_innermost_factor=64, - vector_load_lens=[1, 2, 3, 4], - reuse_read=schedule_rule.ReuseType( - req="must", - levels=[4], - scope="shared", - ), - reuse_write=schedule_rule.ReuseType( - req="no", - levels=[], - scope="", - ), - ), - M.AutoInline( - into_producer=True, - into_consumer=True, - inline_const_tensor=True, - disallow_if_then_else=False, - require_injective=False, - require_ordered=False, - disallow_op=None, - ), - M.ParallelizeVectorizeUnroll( - max_jobs_per_core=-1, # disable parallelize - max_vectorize_extent=-1, # disable vectorize - unroll_max_steps=[0, 16, 64, 512, 1024], - unroll_explicit=True, - ), - ] - - @staticmethod - def _postproc(): - from tvm.meta_schedule import ( - postproc as M, # pylint: disable=import-outside-toplevel - ) - - return [ - M.RewriteCooperativeFetch(), - M.RewriteParallelVectorizeUnroll(), - M.RewriteReductionBlock(), - M.RewriteTensorCore(), - M.VerifyGPUCode(), - ] - - with tempfile.TemporaryDirectory() as work_dir: - sch: Schedule = tune_tir( - mod=mod, - target=target, - config=config, - work_dir=work_dir, - space=PostOrderApply(), - sch_rules=DefaultTensorCore._sch_rules, - postprocs=DefaultTensorCore._postproc, - num_threads=None, - ) - if sch is None: - print("No valid schedule found!") - else: - print(sch.mod.script()) - print(sch.trace) - - import numpy as np - from tvm.contrib import nvcc - - ctx = tvm.gpu(0) - if nvcc.have_tensorcore(ctx.compute_version): - with tvm.transform.PassContext(): - func = tvm.build(sch.mod["main"], [], "cuda") - print(sch.mod.script()) - print(func.imported_modules[0].get_source()) - a_np = np.random.uniform(size=(n, n)).astype("float16") - b_np = np.random.uniform(size=(n, n)).astype("float16") - a = tvm.nd.array(a_np, ctx) - b = tvm.nd.array(b_np, ctx) - c = tvm.nd.array(np.zeros((n, n), dtype="float32"), ctx) - evaluator = func.time_evaluator( - func.entry_name, ctx, number=3, repeat=1, min_repeat_ms=40 - ) - print("matmul with tensor core: %f ms" % (evaluator(a, b, c).mean * 1e3)) - - np.testing.assert_allclose( - c.asnumpy(), - np.matmul(a_np.astype("float32"), b_np.astype("float32")), - rtol=1e-4, - atol=1e-4, - ) - - if __name__ == """__main__""": test_tune_matmul_cpu() test_tune_matmul_cuda() - test_tune_matmul_cuda_tensor_core()