Skip to content

Commit

Permalink
[TUZ-152] Implement a pass that applies scheduling for a target backe…
Browse files Browse the repository at this point in the history
…nd (apache#20)

* Enable minimal tuning.

* Implement wrapper around minimalist pass that schedules a module for a target device.

* Add option to change number of trials per task.

* Typing and lint clean up.

* Format long lines.

* Change to trials and disable logging.

* Fixed number of trials and removed logging.

* Remove commented out code.

* Fix test.

---------

Co-authored-by: Xiyou Zhou <xiyou@octoml.ai>
  • Loading branch information
2 people authored and vinx13 committed Mar 1, 2023
1 parent f351a04 commit 83effce
Show file tree
Hide file tree
Showing 10 changed files with 289 additions and 17 deletions.
11 changes: 9 additions & 2 deletions python/tvm/meta_schedule/relax_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,11 +146,11 @@ def tune_relax(
target: Union[str, Target],
work_dir: str,
max_trials_global: int,
*,
max_trials_per_task: Optional[int] = None,
num_trials_per_iter: int = 64,
builder: Builder.BuilderType = "local",
runner: Runner.RunnerType = "local",
*,
database: Database.DatabaseType = "json",
cost_model: CostModel.CostModelType = "xgb",
measure_callbacks: MeasureCallback.CallbackListType = "default",
Expand Down Expand Up @@ -231,11 +231,11 @@ def _tune_relax(
target: Union[str, Target],
work_dir: str,
max_trials_global: int,
*,
max_trials_per_task: Optional[int] = None,
num_trials_per_iter: int = 64,
builder: Builder.BuilderType = "local",
runner: Runner.RunnerType = "local",
*,
database: Database.DatabaseType = "json",
cost_model: CostModel.CostModelType = "xgb",
measure_callbacks: MeasureCallback.CallbackListType = "default",
Expand Down Expand Up @@ -286,6 +286,13 @@ def _tune_relax(
ret_mod : IRModule
IRModule
"""

if isinstance(num_trials_per_iter, IntImm):
num_trials_per_iter = int(num_trials_per_iter)

if isinstance(max_trials_per_task, IntImm):
max_trials_per_task = int(max_trials_per_task)

if isinstance(max_trials_global, IntImm):
max_trials_global = int(max_trials_global)

Expand Down
5 changes: 3 additions & 2 deletions python/tvm/meta_schedule/tir_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,10 @@ def tune_tir(
target: Union[str, Target],
work_dir: str,
max_trials_global: int,
*,
num_trials_per_iter: int = 64,
builder: Builder.BuilderType = "local",
runner: Runner.RunnerType = "local",
*,
database: Database.DatabaseType = "json",
cost_model: CostModel.CostModelType = "xgb",
measure_callbacks: MeasureCallback.CallbackListType = "default",
Expand Down Expand Up @@ -103,6 +103,7 @@ def tune_tir(
"""
(logger,) = get_loggers_from_work_dir(work_dir, [task_name])
(seed,) = fork_seed(seed, n=1)
num_trials_per_iter = int(num_trials_per_iter)
return tune_tasks(
tasks=[
TuneContext(
Expand Down Expand Up @@ -136,10 +137,10 @@ def _tune_tir(
target: Union[str, Target],
work_dir: str,
max_trials_global: int,
*,
num_trials_per_iter: int = 64,
builder: Builder.BuilderType = "local",
runner: Runner.RunnerType = "local",
*,
database: Database.DatabaseType = "json",
cost_model: CostModel.CostModelType = "xgb",
measure_callbacks: MeasureCallback.CallbackListType = "default",
Expand Down
1 change: 1 addition & 0 deletions python/tvm/relax/transform/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,4 @@

# Import to register the legalization functions.
from . import legalize_ops
from .schedule import ScheduleForTarget
108 changes: 108 additions & 0 deletions python/tvm/relax/transform/schedule.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name, unused-argument
"""Relax passes related to scheduling functions for target hardware."""
import logging
import tempfile
from typing import Union, List

from tvm import relax
from tvm.ir import transform, IRModule
from tvm.target import Target
from tvm import meta_schedule as ms
from .tuning_api import Trace


@transform.module_pass(opt_level=2, name="schedule_for_target")
class ScheduleForTarget:
"""Apply a minimal set of transformations to enable running on a specific target."""

def __init__(self, target: Union[Target, str]):
"""
This function returns a pass which applies basic schedule transformations to each
primitive function in the input module for the specified target. This is useful
when a hardware target requires certain intrinsics for kernels to be valid. For
example, on GPUs, each kernel must have loops bound to a thread and block index.
By default, primitive functions do not contain this binding. Applying a single
step of Metaschedule's transform rules inserts bindings that enable the functions
to be run.
Thus, this pass is a convenience wrapper around the simplist possible invocation
of Metaschedule tuning. It performs only a few schedules per task, skips benchmarking,
and verifies that they can be built.
Parameters
----------
target : Union[Target, str]
The tvm target that fucntions should be scheduled for.
"""
if isinstance(target, str):
target = Target(target)
self.target = target
# Create a fake runner function that does not perform benchmarking. This
# allows us to save time when transforming primitive functions in the module.
@ms.derived_object
class FakeRunner(ms.runner.PyRunner):
def run(
self, runner_inputs: List[ms.runner.RunnerInput]
) -> List[ms.runner.RunnerFuture]:
return [ms.runner.LocalRunnerFuture([0.0], None)]

self.runner = FakeRunner()

def transform_module(self, mod: IRModule, ctx: transform.PassContext) -> IRModule:
"""Apply a minimal set of tuning to transform the input module.
Parameters
----------
mod : IRModule
The input module to schedule.
ctx : transform.PassContext
Information about the current pass, not currently used.
Returns
-------
scheduled_mod : IRModule
The input module with hardware specific transformations applied.
"""
# Extract the number of tasks in the input module so that we can
# determine the minimal number of transformations to try.
num_tasks = len(ms.relax_integration.extract_tasks(mod, self.target))
# Disable logging for this pass.
logging_level = logging.getLogger("tvm.meta_schedule").level
logging.getLogger("tvm.meta_schedule").setLevel(logging.CRITICAL)
# Perform a minimal set of metaschedule tuning on the input module.
with tempfile.TemporaryDirectory() as work_dir:
with self.target, transform.PassContext(trace=Trace(mod), opt_level=0):
# Create a pass that finds one valid schedule per task in the module.
tuning_pass = relax.transform.MetaScheduleTuneIRMod(
params={},
work_dir=work_dir,
max_trials_global=num_tasks,
max_trials_per_task=1,
runner=self.runner,
)

# Apply the pass on our module.
mod = tuning_pass(mod)

# Use the results of tuning to schedule the module.
application_pass = relax.transform.MetaScheduleApplyDatabase(work_dir)
mod = application_pass(mod)
# Re-enable normal logging.
logging.getLogger("tvm.meta_schedule").setLevel(logging_level)
return mod
20 changes: 18 additions & 2 deletions python/tvm/relax/transform/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import numpy as np # type: ignore
import tvm.ir
from tvm.runtime import NDArray
from tvm import meta_schedule as ms
from . import _ffi_api
from .legalize_ops.common import LegalizeFunc

Expand Down Expand Up @@ -489,6 +490,7 @@ def MetaScheduleApplyDatabase(
def MetaScheduleTuneTIR(
work_dir: str,
max_trials_global: int,
runner: Optional[ms.runner.Runner] = None,
) -> tvm.ir.transform.Pass:
"""Tune TIR with MetaSchedule.
Parameters
Expand All @@ -497,17 +499,23 @@ def MetaScheduleTuneTIR(
work directory
max_trials_gloabl: int
maximum number of total trials allowed for tuning
runner: Optional[ms.runner.Runner]
runner for tuning
Returns
-------
ret: tvm.ir.transform.Pass
"""
return _ffi_api.MetaScheduleTuneTIR(work_dir, max_trials_global) # type: ignore
if runner is None:
runner = ms.runner.LocalRunner()
return _ffi_api.MetaScheduleTuneTIR(work_dir, max_trials_global, runner) # type: ignore


def MetaScheduleTuneIRMod(
params: Dict[str, NDArray],
work_dir: str,
max_trials_global: int,
max_trials_per_task: Optional[int] = None,
runner: Optional[ms.runner.Runner] = None,
) -> tvm.ir.transform.Pass:
"""Tune Relax IRModule with MetaSchedule.
Parameters
Expand All @@ -518,11 +526,19 @@ def MetaScheduleTuneIRMod(
work directory
max_trials_gloabl: int
maximum number of total trials allowed for tuning
max_trials_per_task: Optional[int]
maximum number of trials allowed for each task
runner: Optional[ms.runner.Runner]
runner for the tuning pass
Returns
-------
ret: tvm.ir.transform.Pass
"""
return _ffi_api.MetaScheduleTuneIRMod(params, work_dir, max_trials_global) # type: ignore
if runner is None:
runner = ms.runner.LocalRunner()
return _ffi_api.MetaScheduleTuneIRMod( # type: ignore
params, work_dir, max_trials_global, max_trials_per_task, runner
)


def _wrap_class_function_pass(pass_cls, pass_info):
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relax/transform/tuning_api/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@
import tvm
from tvm.runtime import Object
from tvm.ir.module import IRModule
from tvm.relax import Expr
from tvm.tir.schedule.trace import JSON_TYPE, _json_from_tvm
from tvm._ffi import register_object
from . import _ffi_api
from ...expr import Expr

logger = logging.getLogger("TuningAPI") # pylint: disable=invalid-name

Expand Down
4 changes: 2 additions & 2 deletions src/meta_schedule/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -154,8 +154,8 @@ inline void clear_logging(const char* file, int lineno, PackedFunc logging_func)
logging_func(static_cast<int>(PyLogMessage::Level::CLEAR), file, lineno, "");
} else {
// this would clear all logging output in the console
runtime::detail::LogMessage(file, lineno, TVM_LOG_LEVEL_INFO).stream()
<< "\033c\033[3J\033[2J\033[0m\033[H";
logging_func(static_cast<int>(PyLogMessage::Level::INFO), file, lineno,
"\033c\033[3J\033[2J\033[0m\033[H");
}
}

Expand Down
30 changes: 23 additions & 7 deletions src/relax/transform/meta_schedule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
* \brief Pass for meta_schedule tuning
*/
#include <tvm/meta_schedule/database.h>
#include <tvm/meta_schedule/runner.h>
#include <tvm/relax/transform.h>
#include <tvm/relax/tuning_api.h>
#include <tvm/tir/transform.h>
Expand All @@ -33,10 +34,13 @@ namespace transform {
class MetaScheduleTuner {
public:
explicit MetaScheduleTuner(Target target, String work_dir, Integer max_trials_global,
Integer max_trials_per_task, meta_schedule::Runner runner,
Map<String, runtime::NDArray> params = {})
: target_(target),
work_dir_(work_dir),
max_trials_global_(max_trials_global),
max_trials_per_task_(max_trials_per_task),
runner_(runner),
params_(params) {
candgen_func_ = runtime::Registry::Get("relax.tuning_api.default_generate_candidate");
ICHECK(candgen_func_) << "Default candidate generation function is not found.";
Expand All @@ -48,7 +52,11 @@ class MetaScheduleTuner {
IRModule TuneIRMod(IRModule mod, transform::PassContext ctx) {
Trace trace = Downcast<Trace>(ctx->GetCurrentTrace());
ctx->PopTrace();
Choice choice("tvm.meta_schedule.tune_relax", {params_, target_, work_dir_, max_trials_global_},
String builder = "local";
Integer num_trials_per_iter = 64;
Choice choice("tvm.meta_schedule.tune_relax",
{params_, target_, work_dir_, max_trials_global_, max_trials_per_task_,
num_trials_per_iter, builder, runner_},
"relax.tuning_api.Choice.default_constr_func", {});
Knob knob("meta_schedule.tune_irmod", {{"0", choice}});
Array<Trace> candidates = (*candgen_func_)(Array<Knob>({knob}), trace);
Expand All @@ -64,8 +72,10 @@ class MetaScheduleTuner {
// TODO(@sunggg): Whenever we tune tir, assume we start a new trace w/o pushing to the trace
// stack. Revisit later when we collect more usecases.
Trace trace = Trace((*normalize_mod_func_)(f), {}, {});

Choice choice("tvm.meta_schedule.tune_tir", {target_, work_dir_, max_trials_global_},
String builder = "local";
Integer num_trials_per_iter = 64;
Choice choice("tvm.meta_schedule.tune_tir",
{target_, work_dir_, max_trials_global_, num_trials_per_iter, builder, runner_},
"relax.tuning_api.Choice.default_constr_func", {});
Knob knob("meta_schedule.tune_primfunc", {{"0", choice}});
Array<Trace> candidates = (*candgen_func_)(Array<Knob>({knob}), trace);
Expand All @@ -78,6 +88,8 @@ class MetaScheduleTuner {
Target target_;
String work_dir_;
Integer max_trials_global_;
Integer max_trials_per_task_;
meta_schedule::Runner runner_;
Map<String, runtime::NDArray> params_;
const runtime::PackedFunc* candgen_func_;
const runtime::PackedFunc* normalize_mod_func_;
Expand Down Expand Up @@ -138,23 +150,27 @@ Pass MetaScheduleApplyDatabase(Optional<String> work_dir) {
}

Pass MetaScheduleTuneIRMod(Map<String, runtime::NDArray> params, String work_dir,
Integer max_trials_global) {
Integer max_trials_global, Integer max_trials_per_task,
meta_schedule::Runner runner) {
Target target = Target::Current(false);
runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func = [=](IRModule m,
PassContext ctx) {
return MetaScheduleTuner(target, work_dir, max_trials_global, params).TuneIRMod(m, ctx);
return MetaScheduleTuner(target, work_dir, max_trials_global, max_trials_per_task, runner,
params)
.TuneIRMod(m, ctx);
};
return CreateModulePass(/*pass function*/ pass_func, /*opt level*/ 0,
/*pass name*/ "MetaScheduleTuneIRModule",
/*required*/ {},
/*traceable*/ true);
}

Pass MetaScheduleTuneTIR(String work_dir, Integer max_trials_global) {
Pass MetaScheduleTuneTIR(String work_dir, Integer max_trials_global, meta_schedule::Runner runner) {
Target target = Target::Current(false);
runtime::TypedPackedFunc<tir::PrimFunc(tir::PrimFunc, IRModule, PassContext)> pass_func =
[=](tir::PrimFunc f, IRModule mod, PassContext ctx) {
return MetaScheduleTuner(target, work_dir, max_trials_global).TuneTIR(f, ctx);
return MetaScheduleTuner(target, work_dir, max_trials_global, max_trials_global, runner)
.TuneTIR(f, ctx);
};
return tir::transform::CreatePrimFuncPass(/*pass function*/ pass_func, /*opt level*/ 0,
/*pass name*/ "MetaScheduleTuneTIR",
Expand Down
Loading

0 comments on commit 83effce

Please sign in to comment.