Skip to content

Commit

Permalink
[Auto Scheduler] Mali Support (apache#7132)
Browse files Browse the repository at this point in the history
* [Auto Scheduler] Mali Support

* Fix doc

* fix lint

* address comments

* fix doc
  • Loading branch information
FrozenGene authored and Tushar Dey committed Jan 20, 2021
1 parent 5746b1a commit 3ce88a6
Show file tree
Hide file tree
Showing 8 changed files with 538 additions and 46 deletions.
8 changes: 6 additions & 2 deletions python/tvm/auto_scheduler/relay_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,9 +259,13 @@ def auto_schedule_topi(outs, has_complex_op):

key = register_workload_tensors(dag.hash_key(), io_tensors)

# only enable layout rewrite for cpu backend
# only enable layout rewrite for cpu / mali backend
target = tvm.target.Target.current()
enable_layout_rewrite = "cpu" in target.keys
enable_layout_rewrite_targets = ["cpu", "mali"]
enable_layout_rewrite = any(
enable_layout_rewrite_target in target.keys
for enable_layout_rewrite_target in enable_layout_rewrite_targets
)

env = TracingEnvironment.current
if env is None:
Expand Down
55 changes: 55 additions & 0 deletions python/tvm/relay/op/strategy/mali.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
# pylint: disable=invalid-name,unused-argument,wildcard-import,unused-wildcard-import
import re
from tvm import topi
from tvm.auto_scheduler import is_auto_scheduler_enabled
from .generic import *
from .. import op as _op

Expand Down Expand Up @@ -69,6 +70,38 @@ def conv2d_strategy_mali(attrs, inputs, out_type, target):
raise RuntimeError(
"Unsupported weight layout {} for conv2d NCHW".format(kernel_layout)
)
elif layout == "NHWC":
assert kernel_layout == "HWIO"
if not is_auto_scheduler_enabled():
raise RuntimeError(
"conv2d NHWC layout is not enabled for mali without auto_scheduler."
)
strategy.add_implementation(
wrap_compute_conv2d(topi.nn.conv2d_nhwc, need_auto_scheduler_layout=True),
naive_schedule,
name="conv2d_nhwc.mali",
)
is_winograd_applicable = False
if len(kernel.shape) == 4:
kernel_h, kernel_w, _, _ = get_const_tuple(kernel.shape)
is_winograd_applicable = (
"float" in data.dtype
and "float" in kernel.dtype
and kernel_h == 3
and kernel_w == 3
and stride_h == 1
and stride_w == 1
and dilation_h == 1
and dilation_w == 1
)
if is_winograd_applicable:
strategy.add_implementation(
wrap_compute_conv2d(topi.nn.conv2d_winograd_nhwc),
naive_schedule, # this implementation should never be picked by autotvm
name="conv2d_nhwc.winograd",
plevel=15,
)

else:
raise RuntimeError("Unsupported conv2d layout {} for mali".format(layout))
elif is_depthwise_conv2d(data.shape, layout, kernel.shape, kernel_layout, groups):
Expand All @@ -79,6 +112,17 @@ def conv2d_strategy_mali(attrs, inputs, out_type, target):
wrap_topi_schedule(topi.mali.schedule_depthwise_conv2d_nchw),
name="depthwise_conv2d_nchw.mali",
)
elif layout == "NHWC":
assert kernel_layout == "HWOI"
if not is_auto_scheduler_enabled():
raise RuntimeError(
"depthwise_conv2d NHWC layout is not enabled for mali without auto_scheduler."
)
strategy.add_implementation(
wrap_compute_conv2d(topi.nn.depthwise_conv2d_nhwc),
naive_schedule,
name="depthwise_conv2d_nhwc.mali",
)
else:
raise RuntimeError("Unsupported depthwise_conv2d layout {} for mali".format(layout))
else: # group_conv2d
Expand All @@ -105,6 +149,17 @@ def conv2d_winograd_without_weight_transfrom_strategy_mali(attrs, inputs, out_ty
wrap_topi_schedule(topi.mali.schedule_conv2d_nchw_winograd),
name="conv2d_nchw_winograd.mali",
)
elif layout == "NHWC":
if not is_auto_scheduler_enabled():
raise RuntimeError(
"Winograd conv2d NHWC is not enabled for mali without auto_scheduler."
)
strategy.add_implementation(
wrap_compute_conv2d(topi.nn.conv2d_winograd_nhwc_without_weight_transform),
naive_schedule, # this implementation should never be picked by autotvm
name="conv2d_nhwc_winograd_without_weight_transform",
plevel=15,
)
else:
raise RuntimeError(
"Unsupported conv2d_winograd_without_weight_transfrom layout {}".format(layout)
Expand Down
75 changes: 60 additions & 15 deletions python/tvm/topi/mali/conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# under the License.
# pylint: disable=invalid-name,unused-variable,unused-argument,no-else-return
"""conv2d schedule on ARM Mali GPU"""
import logging
import tvm
from tvm import te
from tvm import relay
Expand All @@ -25,10 +26,13 @@
from ..utils import traverse_inline, get_const_int, get_const_tuple
from .. import nn
from ..nn.winograd_util import winograd_transform_matrices
from ..nn.conv2d import conv2d_winograd_nhwc, _conv2d_winograd_nhwc_impl

# reuse some compute declarations from ARM CPU
from ..arm_cpu.conv2d_spatial_pack import conv2d_spatial_pack_nchw

logger = logging.getLogger("topi")


@autotvm.register_topi_compute("conv2d_nchw_spatial_pack.mali")
def conv2d_nchw_spatial_pack(cfg, data, kernel, strides, padding, dilation, out_dtype):
Expand Down Expand Up @@ -188,8 +192,12 @@ def _schedule_spatial_pack(cfg, s, output, conv, data_vec, kernel_vec):


##### WINOGRAD TEMPLATE #####
def _pick_tile_size(data, kernel):
N, CI, H, W = get_const_tuple(data.shape)
def _pick_tile_size(data, kernel, layout="NCHW"):
if layout == "NCHW":
N, CI, H, W = get_const_tuple(data.shape)
else:
assert layout == "NHWC"
N, H, W, CI = get_const_tuple(data.shape)

if H % 4 == 0:
return 4
Expand Down Expand Up @@ -467,30 +475,54 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, out_type):
target = tvm.target.Target.current(allow_none=False)
dispatch_ctx = autotvm.task.DispatchContext.current

_, outs = relay.backend.compile_engine.select_implementation(
new_attrs = {k: attrs[k] for k in attrs.keys()}

strides = attrs.get_int_tuple("strides")
padding = attrs.get_int_tuple("padding")
dilation = attrs.get_int_tuple("dilation")
data_layout = attrs["data_layout"]
kernel_layout = attrs["kernel_layout"]
data, kernel = tinfos
out_dtype = out_type.dtype

impl, outs = relay.backend.compile_engine.select_implementation(
relay.op.get("nn.conv2d"), attrs, tinfos, out_type, target
)
workload = autotvm.task.get_workload(outs)
if workload is None:
# The best implementation is not an AutoTVM template,
# we then assume it's not necessary to alter this op.
# The best implementation is not an AutoTVM template.
# It may be from the auto-scheduler
if impl.name.find("winograd") != -1:
if dilation != (1, 1):
logger.warning("Does not support weight pre-transform for dilated convolution.")
return None

assert data_layout == "NHWC" and kernel_layout == "HWIO"
N, H, W, CI = get_const_tuple(data.shape)
KH, KW, _, CO = get_const_tuple(kernel.shape)

# Pre-compute weight transformation in winograd
tile_size = _pick_tile_size(tinfos[0], tinfos[1], layout="NHWC")

# HWIO -> OIHW
kernel_transform = relay.transpose(inputs[1], axes=[3, 2, 0, 1])
# alpha, alpha, CO, CI
weight = relay.nn.contrib_conv2d_winograd_weight_transform(
kernel_transform, tile_size=tile_size
)
new_attrs["tile_size"] = tile_size
new_attrs["channels"] = CO
return relay.nn.contrib_conv2d_winograd_without_weight_transform(
inputs[0], weight, **new_attrs
)

return None
cfg = dispatch_ctx.query(target, workload)
if cfg.is_fallback: # if is fallback, clear query cache and return None
autotvm.task.clear_fallback_cache(target, workload)
return None

topi_tmpl = workload[0]
new_attrs = {k: attrs[k] for k in attrs.keys()}

strides = attrs.get_int_tuple("strides")
padding = attrs.get_int_tuple("padding")
dilation = attrs.get_int_tuple("dilation")
data_layout = attrs["data_layout"]
kernel_layout = attrs["kernel_layout"]
data, kernel = tinfos
out_dtype = out_type.dtype

idxd = tvm.tir.indexdiv

if topi_tmpl == "conv2d_nchw_spatial_pack.mali":
Expand Down Expand Up @@ -545,6 +577,19 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, out_type):
return None


@conv2d_winograd_nhwc.register(["mali"])
def conv2d_winograd_nhwc_mali(
data, weight, strides, padding, dilation, out_dtype, pre_computed=False
):
"""Conv2D Winograd in NHWC layout.
This is a clean version to be used by the auto-scheduler for mali.
"""
tile_size = _pick_tile_size(data, weight, layout="NHWC")
return _conv2d_winograd_nhwc_impl(
data, weight, strides, padding, dilation, out_dtype, tile_size, pre_computed
)


##### SCHECULE UTILITIES #####
def tile_and_bind(s, tensor, y, x, y_factor, x_factor=None):
""" tile and bind to GPU threads """
Expand Down
13 changes: 12 additions & 1 deletion python/tvm/topi/nn/conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,7 +382,18 @@ def conv2d_nhwc(
if auto_scheduler_rewritten_layout:
# Infer shape for the rewritten layout
# todo(merrymercy): wrap this with a more general interface.
if len(Filter.shape) >= 10:
if len(Filter.shape) == 17:
# For mali.
# GPU tile structure is SSSRRSRS
# You could refer function comment of DoMultiLevelTiling
# in the utils.h to see more detail explanation.
kernel_h = Filter.shape[6] * Filter.shape[9] * Filter.shape[13]
kernel_w = Filter.shape[7] * Filter.shape[10] * Filter.shape[14]
channel = Filter.shape[8] * Filter.shape[11] * Filter.shape[15]
num_filter = Filter.shape[12] * Filter.shape[16]
for i in range(6):
num_filter *= Filter.shape[i]
elif len(Filter.shape) >= 10:
# For cpu tile structure SSRSRS
base = len(Filter.shape) - 10
kernel_h = Filter.shape[2 + base] * Filter.shape[6 + base]
Expand Down
65 changes: 39 additions & 26 deletions src/auto_scheduler/search_policy/sketch_policy.cc
Original file line number Diff line number Diff line change
Expand Up @@ -115,21 +115,35 @@ SketchPolicy::SketchPolicy(SearchTask task, CostModel program_cost_model,
node->mutation_rules.push_back(std::make_shared<MutateParallel>(0.01));
} else if (IsGPUTask(node->search_task)) {
// Sketch Generation Rules
node->sketch_rules.push_back(&rule_add_cache_read_stage);
node->sketch_rules.push_back(&rule_special_compute_location_gpu);
node->sketch_rules.push_back(&rule_always_inline);
node->sketch_rules.push_back(&rule_simplify_compute_with_const_tensor);
node->sketch_rules.push_back(&rule_cross_thread_reduction);
node->sketch_rules.push_back(&rule_add_cache_write_stage);
node->sketch_rules.push_back(&rule_multi_level_tiling_with_fusion);
node->sketch_rules.push_back(&rule_multi_level_tiling);
node->sketch_rules.push_back(&rule_skip_stage);
if (node->search_task->target->GetAttr<String>("device", "") == "mali") {
node->sketch_rules.push_back(&rule_always_inline);
node->sketch_rules.push_back(&rule_simplify_compute_with_const_tensor);
node->sketch_rules.push_back(&rule_add_rfactor);
node->sketch_rules.push_back(&rule_add_cache_write_stage);
node->sketch_rules.push_back(&rule_multi_level_tiling_with_fusion);
node->sketch_rules.push_back(&rule_multi_level_tiling);
node->sketch_rules.push_back(&rule_skip_stage);
} else {
node->sketch_rules.push_back(&rule_add_cache_read_stage);
node->sketch_rules.push_back(&rule_special_compute_location_gpu);
node->sketch_rules.push_back(&rule_always_inline);
node->sketch_rules.push_back(&rule_simplify_compute_with_const_tensor);
node->sketch_rules.push_back(&rule_cross_thread_reduction);
node->sketch_rules.push_back(&rule_add_cache_write_stage);
node->sketch_rules.push_back(&rule_multi_level_tiling_with_fusion);
node->sketch_rules.push_back(&rule_multi_level_tiling);
node->sketch_rules.push_back(&rule_skip_stage);
}

// Initial Population Generation Rules
node->init_rules.push_back(&init_fill_tile_size);
node->init_rules.push_back(&init_thread_bind);
node->init_rules.push_back(&init_unroll);

if (node->search_task->target->GetAttr<String>("device", "") == "mali") {
node->init_rules.push_back(&init_vectorization);
}

// Mutation Rules for Evolutionary Search
node->mutation_rules.push_back(std::make_shared<MutateTileSize>(0.90));
node->mutation_rules.push_back(std::make_shared<MutateAutoUnroll>(0.10));
Expand Down Expand Up @@ -389,23 +403,22 @@ Array<State> SketchPolicyNode::SampleInitPopulation(const Array<State>& sketches
std::vector<State> temp_states(population);

// Sample a batch of states randomly
support::parallel_for(0, population,
[this, &temp_states, &sketches, &rand_gens](int index) {
// Randomly choose a sketch
State tmp_s = sketches[(rand_gens[index])() % sketches.size()];
// Apply random annotation rules one by one
bool valid = true;
for (const auto& rule : init_rules) {
if (rule->Apply(this, &tmp_s, &rand_gens[index]) ==
PopulationGenerationRule::ResultKind::kInvalid) {
valid = false;
break;
}
}
if (valid) {
temp_states[index] = std::move(tmp_s);
}
});
support::parallel_for(0, population, [this, &temp_states, &sketches, &rand_gens](int index) {
// Randomly choose a sketch
State tmp_s = sketches[(rand_gens[index])() % sketches.size()];
// Apply random annotation rules one by one
bool valid = true;
for (const auto& rule : init_rules) {
if (rule->Apply(this, &tmp_s, &rand_gens[index]) ==
PopulationGenerationRule::ResultKind::kInvalid) {
valid = false;
break;
}
}
if (valid) {
temp_states[index] = std::move(tmp_s);
}
});

// Filter out the states that were failed to apply initial rules
Array<State> cand_states;
Expand Down
16 changes: 16 additions & 0 deletions src/auto_scheduler/search_task.cc
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,22 @@ HardwareParams HardwareParamsNode::GetDefaultHardwareParams(const Target& target
int max_vthread_extent = warp_size / 4;
return HardwareParams(-1, 16, 64, max_shared_memory_per_block, max_local_memory_per_block,
max_threads_per_block, max_vthread_extent, warp_size);
} else if (target->kind->device_type == kDLOpenCL) {
if (target->GetAttr<String>("device", "") == "mali") {
// We cannot use device API to get hardware attributes like CUDA,
// because like Mali target is normally on the remote machine.
int max_shared_memory_per_block = 32768;
int max_local_memory_per_block = INT32_MAX; // skip the check on local memory
int max_threads_per_block = 256;
int warp_size = 1;
int max_vthread_extent = 1;
return HardwareParams(-1, 16, 64, max_shared_memory_per_block, max_local_memory_per_block,
max_threads_per_block, max_vthread_extent, warp_size);
} else {
// add other opencl target
auto target_device = target->GetAttr<String>("device", "");
LOG(FATAL) << "No default hardware parameters for opencl target device: " << target_device;
}
} else {
LOG(FATAL) << "No default hardware parameters for target: " << target;
}
Expand Down
5 changes: 3 additions & 2 deletions src/relay/backend/build_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -343,8 +343,9 @@ class RelayBuildModule : public runtime::ModuleNode {
if (backend::IsAutoSchedulerEnabled() && targets.size() == 1) {
const auto& target = (*targets.begin()).second;
Pass major_pass = transform::AutoSchedulerLayoutRewrite();

if (target->kind->device_type == kDLCPU && pass_ctx.PassEnabled(major_pass->Info())) {
bool enable_layout_rewrite_targets =
target->kind->device_type == kDLCPU || target->GetAttr<String>("device", "") == "mali";
if (enable_layout_rewrite_targets && pass_ctx.PassEnabled(major_pass->Info())) {
With<Target> tctx(target);
relay_module = major_pass(relay_module);
// Defuse ops to fold constants, then fuse them again
Expand Down
Loading

0 comments on commit 3ce88a6

Please sign in to comment.