Skip to content

Commit

Permalink
[ROCm][Auto scheduler] Support Auto scheduler and NHWC convolution on…
Browse files Browse the repository at this point in the history
… ROCm (apache#7038)

* add nhwc + winograd support to rocm strategy

* support rocm hw parameters in search task

* run analysis pass for rocm too

* run black

* pylint fix

* use IsGPUTask function

Co-authored-by: Masahiro Masuda <masahi@129@gmail.com>
  • Loading branch information
2 people authored and electriclilies committed Feb 18, 2021
1 parent d1640a9 commit 6540b6b
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 13 deletions.
52 changes: 45 additions & 7 deletions python/tvm/relay/op/strategy/rocm.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,10 @@
"""Definition of ROCm operator strategy."""
# pylint: disable=invalid-name,unused-argument,unused-wildcard-import,wildcard-import
from tvm import topi
from tvm.auto_scheduler import is_auto_scheduler_enabled
from .generic import *
from .. import op as _op
from .cuda import judge_winograd, naive_schedule


@schedule_lrn.register("rocm")
Expand Down Expand Up @@ -67,20 +69,56 @@ def conv2d_strategy_rocm(attrs, inputs, out_type, target):
name="conv2d_nchw_winograd.cuda",
plevel=5,
)
elif layout == "NHWC":
assert kernel_layout == "HWIO"
strategy.add_implementation(
wrap_compute_conv2d(topi.cuda.conv2d_nhwc),
wrap_topi_schedule(topi.cuda.schedule_conv2d_nhwc),
name="conv2d_nhwc.cuda",
)
N, H, W, _ = get_const_tuple(data.shape)
KH, KW, CI, CO = get_const_tuple(kernel.shape)

(_, judge_winograd_autotvm, judge_winograd_auto_scheduler,) = judge_winograd(
N,
H,
W,
KH,
KW,
CI,
CO,
padding,
stride_h,
stride_w,
dilation_h,
dilation_w,
data.dtype,
kernel.dtype,
pre_flag=False,
)

if judge_winograd_autotvm:
strategy.add_implementation(
wrap_compute_conv2d(topi.cuda.conv2d_nhwc_winograd_direct),
wrap_topi_schedule(topi.cuda.schedule_conv2d_nhwc_winograd_direct),
name="conv2d_nhwc_winograd_direct.cuda",
plevel=5,
)

if is_auto_scheduler_enabled() and judge_winograd_auto_scheduler:
strategy.add_implementation(
wrap_compute_conv2d(topi.nn.conv2d_winograd_nhwc),
naive_schedule, # this implementation should never be picked by autotvm
name="conv2d_nhwc.winograd",
plevel=15,
)
elif layout == "HWCN":
assert kernel_layout == "HWIO"
strategy.add_implementation(
wrap_compute_conv2d(topi.cuda.conv2d_hwcn),
wrap_topi_schedule(topi.cuda.schedule_conv2d_hwcn),
name="conv2d_hwcn.cuda",
)
# TODO(@alexgl-github): Re-enable this after fix the conv2d_nhwc for cuda
# elif layout == "NHWC":
# assert kernel_layout == "HWIO"
# strategy.add_implementation(
# wrap_compute_conv2d(topi.cuda.conv2d_nhwc),
# wrap_topi_schedule(topi.cuda.schedule_conv2d_nhwc),
# name="conv2d_nhwc.cuda")
elif layout == "NCHW4c" and data.dtype in ["int8", "uint8"]:
assert kernel_layout == "OIHW4o4i"
strategy.add_implementation(
Expand Down
3 changes: 2 additions & 1 deletion src/auto_scheduler/feature.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
#include <unordered_map>
#include <vector>

#include "search_policy/utils.h"
#include "utils.h"

namespace tvm {
Expand Down Expand Up @@ -1296,7 +1297,7 @@ void GetPerStoreFeaturesWorkerFunc(const SearchTask& task, const State& state, i
}
auto mod = IRModule(Map<GlobalVar, BaseFunc>({{global_var, f}}));

if (task->target->kind->device_type == kDLGPU) {
if (IsGPUTask(task)) {
auto pass_list = Array<tvm::transform::Pass>();
// Phase 0
pass_list.push_back(tir::transform::InjectPrefetch());
Expand Down
13 changes: 8 additions & 5 deletions src/auto_scheduler/search_task.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
* \brief Meta information and hardware parameters for a search task.
*/

#include <dlpack/dlpack.h>
#include <tvm/auto_scheduler/search_task.h>
#include <tvm/runtime/device_api.h>
#include <tvm/runtime/registry.h>
Expand Down Expand Up @@ -52,11 +53,13 @@ HardwareParams::HardwareParams(int num_cores, int vector_unit_bytes, int cache_l

HardwareParams HardwareParamsNode::GetDefaultHardwareParams(const Target& target,
const Target& target_host) {
if (target->kind->device_type == kDLCPU) {
const auto device_type = target->kind->device_type;
if (device_type == kDLCPU) {
return HardwareParams(tvm::runtime::threading::MaxConcurrency(), 64, 64, 0, 0, 0, 0, 0);
} else if (target->kind->device_type == kDLGPU) {
auto ctx = TVMContext{kDLGPU, 0};
auto func = tvm::runtime::Registry::Get("device_api.gpu");
} else if (device_type == kDLGPU || device_type == kDLROCM) {
auto ctx = TVMContext{static_cast<DLDeviceType>(device_type), 0};
auto device_name = device_type == kDLGPU ? "device_api.gpu" : "device_api.rocm";
auto func = tvm::runtime::Registry::Get(device_name);
ICHECK(func != nullptr) << "Cannot find GPU device_api in registry";
auto device_api = static_cast<tvm::runtime::DeviceAPI*>(((*func)()).operator void*());

Expand All @@ -77,7 +80,7 @@ HardwareParams HardwareParamsNode::GetDefaultHardwareParams(const Target& target
int max_vthread_extent = warp_size / 4;
return HardwareParams(-1, 16, 64, max_shared_memory_per_block, max_local_memory_per_block,
max_threads_per_block, max_vthread_extent, warp_size);
} else if (target->kind->device_type == kDLMetal) {
} else if (device_type == kDLMetal) {
// Reference: https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf
// This setting looks working for Metal GPUs later than A10
int max_shared_memory_per_block = 32 * 1024;
Expand Down

0 comments on commit 6540b6b

Please sign in to comment.