From 25b06ce57039f4d841300a499e480f177038935e Mon Sep 17 00:00:00 2001 From: Josh Fromm Date: Thu, 16 Feb 2023 20:26:34 -0800 Subject: [PATCH] [TUZ-152] Implement a pass that applies scheduling for a target backend (#20) * Enable minimal tuning. * Implement wrapper around minimalist pass that schedules a module for a target device. * Add option to change number of trials per task. * Typing and lint clean up. * Format long lines. * Change to trials and disable logging. * Fixed number of trials and removed logging. * Remove commented out code. * Fix test. --------- Co-authored-by: Xiyou Zhou --- python/tvm/meta_schedule/relax_integration.py | 11 +- python/tvm/meta_schedule/tir_integration.py | 5 +- python/tvm/relax/transform/__init__.py | 1 + python/tvm/relax/transform/schedule.py | 108 ++++++++++++++++++ python/tvm/relax/transform/transform.py | 20 +++- .../relax/transform/tuning_api/primitives.py | 2 +- src/meta_schedule/utils.h | 4 +- src/relax/transform/meta_schedule.cc | 30 +++-- .../test_transform_meta_schedule_tuning.py | 34 +++++- .../test_transform_schedule_for_target.py | 91 +++++++++++++++ 10 files changed, 289 insertions(+), 17 deletions(-) create mode 100644 python/tvm/relax/transform/schedule.py create mode 100644 tests/python/relax/test_transform_schedule_for_target.py diff --git a/python/tvm/meta_schedule/relax_integration.py b/python/tvm/meta_schedule/relax_integration.py index db22214b768f..6abb9d261655 100644 --- a/python/tvm/meta_schedule/relax_integration.py +++ b/python/tvm/meta_schedule/relax_integration.py @@ -146,11 +146,11 @@ def tune_relax( target: Union[str, Target], work_dir: str, max_trials_global: int, - *, max_trials_per_task: Optional[int] = None, num_trials_per_iter: int = 64, builder: Builder.BuilderType = "local", runner: Runner.RunnerType = "local", + *, database: Database.DatabaseType = "json", cost_model: CostModel.CostModelType = "xgb", measure_callbacks: MeasureCallback.CallbackListType = "default", @@ -231,11 +231,11 @@ def _tune_relax( target: Union[str, Target], work_dir: str, max_trials_global: int, - *, max_trials_per_task: Optional[int] = None, num_trials_per_iter: int = 64, builder: Builder.BuilderType = "local", runner: Runner.RunnerType = "local", + *, database: Database.DatabaseType = "json", cost_model: CostModel.CostModelType = "xgb", measure_callbacks: MeasureCallback.CallbackListType = "default", @@ -286,6 +286,13 @@ def _tune_relax( ret_mod : IRModule IRModule """ + + if isinstance(num_trials_per_iter, IntImm): + num_trials_per_iter = int(num_trials_per_iter) + + if isinstance(max_trials_per_task, IntImm): + max_trials_per_task = int(max_trials_per_task) + if isinstance(max_trials_global, IntImm): max_trials_global = int(max_trials_global) diff --git a/python/tvm/meta_schedule/tir_integration.py b/python/tvm/meta_schedule/tir_integration.py index d5f5ee86e0b8..717e7d1747a7 100644 --- a/python/tvm/meta_schedule/tir_integration.py +++ b/python/tvm/meta_schedule/tir_integration.py @@ -45,10 +45,10 @@ def tune_tir( target: Union[str, Target], work_dir: str, max_trials_global: int, - *, num_trials_per_iter: int = 64, builder: Builder.BuilderType = "local", runner: Runner.RunnerType = "local", + *, database: Database.DatabaseType = "json", cost_model: CostModel.CostModelType = "xgb", measure_callbacks: MeasureCallback.CallbackListType = "default", @@ -103,6 +103,7 @@ def tune_tir( """ (logger,) = get_loggers_from_work_dir(work_dir, [task_name]) (seed,) = fork_seed(seed, n=1) + num_trials_per_iter = int(num_trials_per_iter) return tune_tasks( tasks=[ TuneContext( @@ -136,10 +137,10 @@ def _tune_tir( target: Union[str, Target], work_dir: str, max_trials_global: int, - *, num_trials_per_iter: int = 64, builder: Builder.BuilderType = "local", runner: Runner.RunnerType = "local", + *, database: Database.DatabaseType = "json", cost_model: CostModel.CostModelType = "xgb", measure_callbacks: MeasureCallback.CallbackListType = "default", diff --git a/python/tvm/relax/transform/__init__.py b/python/tvm/relax/transform/__init__.py index 78f450b25ce2..70939891b61c 100644 --- a/python/tvm/relax/transform/__init__.py +++ b/python/tvm/relax/transform/__init__.py @@ -21,3 +21,4 @@ # Import to register the legalization functions. from . import legalize_ops +from .schedule import ScheduleForTarget diff --git a/python/tvm/relax/transform/schedule.py b/python/tvm/relax/transform/schedule.py new file mode 100644 index 000000000000..dfeaa29a548b --- /dev/null +++ b/python/tvm/relax/transform/schedule.py @@ -0,0 +1,108 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name, unused-argument +"""Relax passes related to scheduling functions for target hardware.""" +import logging +import tempfile +from typing import Union, List + +from tvm import relax +from tvm.ir import transform, IRModule +from tvm.target import Target +from tvm import meta_schedule as ms +from .tuning_api import Trace + + +@transform.module_pass(opt_level=2, name="schedule_for_target") +class ScheduleForTarget: + """Apply a minimal set of transformations to enable running on a specific target.""" + + def __init__(self, target: Union[Target, str]): + """ + This function returns a pass which applies basic schedule transformations to each + primitive function in the input module for the specified target. This is useful + when a hardware target requires certain intrinsics for kernels to be valid. For + example, on GPUs, each kernel must have loops bound to a thread and block index. + By default, primitive functions do not contain this binding. Applying a single + step of Metaschedule's transform rules inserts bindings that enable the functions + to be run. + + Thus, this pass is a convenience wrapper around the simplist possible invocation + of Metaschedule tuning. It performs only a few schedules per task, skips benchmarking, + and verifies that they can be built. + + Parameters + ---------- + target : Union[Target, str] + The tvm target that fucntions should be scheduled for. + """ + if isinstance(target, str): + target = Target(target) + self.target = target + # Create a fake runner function that does not perform benchmarking. This + # allows us to save time when transforming primitive functions in the module. + @ms.derived_object + class FakeRunner(ms.runner.PyRunner): + def run( + self, runner_inputs: List[ms.runner.RunnerInput] + ) -> List[ms.runner.RunnerFuture]: + return [ms.runner.LocalRunnerFuture([0.0], None)] + + self.runner = FakeRunner() + + def transform_module(self, mod: IRModule, ctx: transform.PassContext) -> IRModule: + """Apply a minimal set of tuning to transform the input module. + + Parameters + ---------- + mod : IRModule + The input module to schedule. + ctx : transform.PassContext + Information about the current pass, not currently used. + + Returns + ------- + scheduled_mod : IRModule + The input module with hardware specific transformations applied. + """ + # Extract the number of tasks in the input module so that we can + # determine the minimal number of transformations to try. + num_tasks = len(ms.relax_integration.extract_tasks(mod, self.target)) + # Disable logging for this pass. + logging_level = logging.getLogger("tvm.meta_schedule").level + logging.getLogger("tvm.meta_schedule").setLevel(logging.CRITICAL) + # Perform a minimal set of metaschedule tuning on the input module. + with tempfile.TemporaryDirectory() as work_dir: + with self.target, transform.PassContext(trace=Trace(mod), opt_level=0): + # Create a pass that finds one valid schedule per task in the module. + tuning_pass = relax.transform.MetaScheduleTuneIRMod( + params={}, + work_dir=work_dir, + max_trials_global=num_tasks, + max_trials_per_task=1, + runner=self.runner, + ) + + # Apply the pass on our module. + mod = tuning_pass(mod) + + # Use the results of tuning to schedule the module. + application_pass = relax.transform.MetaScheduleApplyDatabase(work_dir) + mod = application_pass(mod) + # Re-enable normal logging. + logging.getLogger("tvm.meta_schedule").setLevel(logging_level) + return mod diff --git a/python/tvm/relax/transform/transform.py b/python/tvm/relax/transform/transform.py index 97daae49412a..56db7ffe60f8 100644 --- a/python/tvm/relax/transform/transform.py +++ b/python/tvm/relax/transform/transform.py @@ -23,6 +23,7 @@ import numpy as np # type: ignore import tvm.ir from tvm.runtime import NDArray +from tvm import meta_schedule as ms from . import _ffi_api from .legalize_ops.common import LegalizeFunc @@ -489,6 +490,7 @@ def MetaScheduleApplyDatabase( def MetaScheduleTuneTIR( work_dir: str, max_trials_global: int, + runner: Optional[ms.runner.Runner] = None, ) -> tvm.ir.transform.Pass: """Tune TIR with MetaSchedule. Parameters @@ -497,17 +499,23 @@ def MetaScheduleTuneTIR( work directory max_trials_gloabl: int maximum number of total trials allowed for tuning + runner: Optional[ms.runner.Runner] + runner for tuning Returns ------- ret: tvm.ir.transform.Pass """ - return _ffi_api.MetaScheduleTuneTIR(work_dir, max_trials_global) # type: ignore + if runner is None: + runner = ms.runner.LocalRunner() + return _ffi_api.MetaScheduleTuneTIR(work_dir, max_trials_global, runner) # type: ignore def MetaScheduleTuneIRMod( params: Dict[str, NDArray], work_dir: str, max_trials_global: int, + max_trials_per_task: Optional[int] = None, + runner: Optional[ms.runner.Runner] = None, ) -> tvm.ir.transform.Pass: """Tune Relax IRModule with MetaSchedule. Parameters @@ -518,11 +526,19 @@ def MetaScheduleTuneIRMod( work directory max_trials_gloabl: int maximum number of total trials allowed for tuning + max_trials_per_task: Optional[int] + maximum number of trials allowed for each task + runner: Optional[ms.runner.Runner] + runner for the tuning pass Returns ------- ret: tvm.ir.transform.Pass """ - return _ffi_api.MetaScheduleTuneIRMod(params, work_dir, max_trials_global) # type: ignore + if runner is None: + runner = ms.runner.LocalRunner() + return _ffi_api.MetaScheduleTuneIRMod( # type: ignore + params, work_dir, max_trials_global, max_trials_per_task, runner + ) def _wrap_class_function_pass(pass_cls, pass_info): diff --git a/python/tvm/relax/transform/tuning_api/primitives.py b/python/tvm/relax/transform/tuning_api/primitives.py index 67b81ba7e99c..a9c734fd1354 100644 --- a/python/tvm/relax/transform/tuning_api/primitives.py +++ b/python/tvm/relax/transform/tuning_api/primitives.py @@ -21,10 +21,10 @@ import tvm from tvm.runtime import Object from tvm.ir.module import IRModule -from tvm.relax import Expr from tvm.tir.schedule.trace import JSON_TYPE, _json_from_tvm from tvm._ffi import register_object from . import _ffi_api +from ...expr import Expr logger = logging.getLogger("TuningAPI") # pylint: disable=invalid-name diff --git a/src/meta_schedule/utils.h b/src/meta_schedule/utils.h index 9a372dde8f6d..8f42928b72f1 100644 --- a/src/meta_schedule/utils.h +++ b/src/meta_schedule/utils.h @@ -154,8 +154,8 @@ inline void clear_logging(const char* file, int lineno, PackedFunc logging_func) logging_func(static_cast(PyLogMessage::Level::CLEAR), file, lineno, ""); } else { // this would clear all logging output in the console - runtime::detail::LogMessage(file, lineno, TVM_LOG_LEVEL_INFO).stream() - << "\033c\033[3J\033[2J\033[0m\033[H"; + logging_func(static_cast(PyLogMessage::Level::INFO), file, lineno, + "\033c\033[3J\033[2J\033[0m\033[H"); } } diff --git a/src/relax/transform/meta_schedule.cc b/src/relax/transform/meta_schedule.cc index d444ba16654f..337a63b22f51 100644 --- a/src/relax/transform/meta_schedule.cc +++ b/src/relax/transform/meta_schedule.cc @@ -22,6 +22,7 @@ * \brief Pass for meta_schedule tuning */ #include +#include #include #include #include @@ -33,10 +34,13 @@ namespace transform { class MetaScheduleTuner { public: explicit MetaScheduleTuner(Target target, String work_dir, Integer max_trials_global, + Integer max_trials_per_task, meta_schedule::Runner runner, Map params = {}) : target_(target), work_dir_(work_dir), max_trials_global_(max_trials_global), + max_trials_per_task_(max_trials_per_task), + runner_(runner), params_(params) { candgen_func_ = runtime::Registry::Get("relax.tuning_api.default_generate_candidate"); ICHECK(candgen_func_) << "Default candidate generation function is not found."; @@ -48,7 +52,11 @@ class MetaScheduleTuner { IRModule TuneIRMod(IRModule mod, transform::PassContext ctx) { Trace trace = Downcast(ctx->GetCurrentTrace()); ctx->PopTrace(); - Choice choice("tvm.meta_schedule.tune_relax", {params_, target_, work_dir_, max_trials_global_}, + String builder = "local"; + Integer num_trials_per_iter = 64; + Choice choice("tvm.meta_schedule.tune_relax", + {params_, target_, work_dir_, max_trials_global_, max_trials_per_task_, + num_trials_per_iter, builder, runner_}, "relax.tuning_api.Choice.default_constr_func", {}); Knob knob("meta_schedule.tune_irmod", {{"0", choice}}); Array candidates = (*candgen_func_)(Array({knob}), trace); @@ -64,8 +72,10 @@ class MetaScheduleTuner { // TODO(@sunggg): Whenever we tune tir, assume we start a new trace w/o pushing to the trace // stack. Revisit later when we collect more usecases. Trace trace = Trace((*normalize_mod_func_)(f), {}, {}); - - Choice choice("tvm.meta_schedule.tune_tir", {target_, work_dir_, max_trials_global_}, + String builder = "local"; + Integer num_trials_per_iter = 64; + Choice choice("tvm.meta_schedule.tune_tir", + {target_, work_dir_, max_trials_global_, num_trials_per_iter, builder, runner_}, "relax.tuning_api.Choice.default_constr_func", {}); Knob knob("meta_schedule.tune_primfunc", {{"0", choice}}); Array candidates = (*candgen_func_)(Array({knob}), trace); @@ -78,6 +88,8 @@ class MetaScheduleTuner { Target target_; String work_dir_; Integer max_trials_global_; + Integer max_trials_per_task_; + meta_schedule::Runner runner_; Map params_; const runtime::PackedFunc* candgen_func_; const runtime::PackedFunc* normalize_mod_func_; @@ -138,11 +150,14 @@ Pass MetaScheduleApplyDatabase(Optional work_dir) { } Pass MetaScheduleTuneIRMod(Map params, String work_dir, - Integer max_trials_global) { + Integer max_trials_global, Integer max_trials_per_task, + meta_schedule::Runner runner) { Target target = Target::Current(false); runtime::TypedPackedFunc pass_func = [=](IRModule m, PassContext ctx) { - return MetaScheduleTuner(target, work_dir, max_trials_global, params).TuneIRMod(m, ctx); + return MetaScheduleTuner(target, work_dir, max_trials_global, max_trials_per_task, runner, + params) + .TuneIRMod(m, ctx); }; return CreateModulePass(/*pass function*/ pass_func, /*opt level*/ 0, /*pass name*/ "MetaScheduleTuneIRModule", @@ -150,11 +165,12 @@ Pass MetaScheduleTuneIRMod(Map params, String work_dir /*traceable*/ true); } -Pass MetaScheduleTuneTIR(String work_dir, Integer max_trials_global) { +Pass MetaScheduleTuneTIR(String work_dir, Integer max_trials_global, meta_schedule::Runner runner) { Target target = Target::Current(false); runtime::TypedPackedFunc pass_func = [=](tir::PrimFunc f, IRModule mod, PassContext ctx) { - return MetaScheduleTuner(target, work_dir, max_trials_global).TuneTIR(f, ctx); + return MetaScheduleTuner(target, work_dir, max_trials_global, max_trials_global, runner) + .TuneTIR(f, ctx); }; return tir::transform::CreatePrimFuncPass(/*pass function*/ pass_func, /*opt level*/ 0, /*pass name*/ "MetaScheduleTuneTIR", diff --git a/tests/python/relax/test_transform_meta_schedule_tuning.py b/tests/python/relax/test_transform_meta_schedule_tuning.py index d87ea5cec728..cb110536f40a 100644 --- a/tests/python/relax/test_transform_meta_schedule_tuning.py +++ b/tests/python/relax/test_transform_meta_schedule_tuning.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. +from typing import List import tempfile import tvm @@ -79,7 +80,9 @@ def test_ms_tuning_irmodule(): with tempfile.TemporaryDirectory() as work_dir: with target, transform.PassContext(trace=Trace(mod), opt_level=0): tuning_pass = relax.transform.MetaScheduleTuneIRMod( - params={}, work_dir=work_dir, max_trials_global=4 + params={}, + work_dir=work_dir, + max_trials_global=4, ) out_mod = tuning_pass(mod) assert PassContext.current().get_trace_stack_size() == 1 @@ -111,5 +114,34 @@ def test_ms_tuning_primfunc(): assert not tvm.ir.structural_equal(mod, out_mod) +def test_ms_tuning_minimal(): + @ms.derived_object + class FakeRunner(ms.runner.PyRunner): + def run(self, runner_inputs: List[ms.runner.RunnerInput]) -> List[ms.runner.RunnerFuture]: + return [ms.runner.LocalRunnerFuture([0.0], None)] + + mod = InputModule + assert isinstance(mod, IRModule) + num_tasks = len(ms.relax_integration.extract_tasks(mod, target)) + with tempfile.TemporaryDirectory() as work_dir: + with target, transform.PassContext(trace=Trace(mod), opt_level=0): + tuning_pass = relax.transform.MetaScheduleTuneIRMod( + params={}, + work_dir=work_dir, + max_trials_global=4 * num_tasks, + max_trials_per_task=1, + runner=FakeRunner(), + ) + out_mod = tuning_pass(mod) + assert PassContext.current().get_trace_stack_size() == 1 + assert PassContext.current().get_current_trace().size == 1 + tvm.ir.assert_structural_equal(mod, out_mod) + + application_pass = relax.transform.MetaScheduleApplyDatabase(work_dir) + + out_mod = application_pass(mod) + assert not tvm.ir.structural_equal(mod, out_mod) + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/relax/test_transform_schedule_for_target.py b/tests/python/relax/test_transform_schedule_for_target.py new file mode 100644 index 000000000000..57295a95d1fe --- /dev/null +++ b/tests/python/relax/test_transform_schedule_for_target.py @@ -0,0 +1,91 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import List +import tempfile + +import tvm +import tvm.testing +import tvm.meta_schedule as ms +from tvm import relax +from tvm.ir import transform +from tvm.ir.module import IRModule +from tvm.ir.transform import PassContext +from tvm.relax.transform.tuning_api import Trace +from tvm.script import relax as R +from tvm.script import tir as T + + +@tvm.script.ir_module +class InputModule: + @T.prim_func + def tir_matmul(x: T.handle, y: T.handle, z: T.handle) -> None: + T.func_attr({"global_symbol": "tir_matmul"}) + k = T.var("int32") + A = T.match_buffer(x, (32, 32)) + B = T.match_buffer(y, (32, 32)) + C = T.match_buffer(z, (32, 32)) + + for (i0, j0, k0) in T.grid(32, 32, 32): + with T.block(): + i, j, k = T.axis.remap("SSR", [i0, j0, k0]) + with T.init(): + C[i, j] = 0.0 + C[i, j] += A[i, k] * B[j, k] + + @T.prim_func + def tir_relu(x: T.handle, y: T.handle): + T.func_attr({"global_symbol": "tir_relu"}) + A = T.match_buffer(x, (32, 32)) + B = T.match_buffer(y, (32, 32)) + for (i, j) in T.grid(32, 32): + with T.block(): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = T.max(A[vi, vj], 0.0) + + @R.function + def main(x: R.Tensor((32, 32), "float32"), w: R.Tensor((32, 32), "float32")) -> R.Tensor: + with R.dataflow(): + lv0 = R.call_tir(tir_matmul, (x, w), R.Tensor((32, 32), dtype="float32")) + lv1 = R.call_tir(tir_relu, (lv0), R.Tensor((32, 32), dtype="float32")) + R.output(lv1) + return lv1 + + +@tvm.testing.parametrize_targets( + "llvm --num-cores=2", + "cuda -max_shared_memory_per_block=49152 -max_threads_per_block=1024 -registers_per_block=65536 -thread_warp_size=32", +) +def test_schedule_for_target(target, dev): + # Check that input device is valid on current machine. + if dev in [enabled_dev for _, enabled_dev in tvm.testing.enabled_targets()]: + mod = InputModule + assert isinstance(mod, IRModule) + out_mod = relax.transform.ScheduleForTarget(target)(mod) + assert not tvm.ir.structural_equal(mod, out_mod) + + # Perform an additional check depending on the target. + if dev == tvm.cpu(): + # On CPU, parallelization should have been inserted. + assert "parallel" in out_mod.script() + if dev == tvm.gpu(): + # On GPU, thread bindings should have been inserted. + assert "thread_binding" in out_mod.script() + + +if __name__ == "__main__": + tvm.testing.main()