From 87374d11aa159514ae916d9dda8c25725262bd2e Mon Sep 17 00:00:00 2001 From: masahi Date: Sun, 5 May 2019 21:17:29 +0900 Subject: [PATCH] [ROCm] Fix dense autotvm template registration (#3136) * Fix rocm dense autotvm template * suppres lint warning --- topi/python/topi/cuda/__init__.py | 1 + topi/python/topi/rocm/dense.py | 13 +++++++------ 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/topi/python/topi/cuda/__init__.py b/topi/python/topi/cuda/__init__.py index aca410b93276..65ed0ff10dad 100644 --- a/topi/python/topi/cuda/__init__.py +++ b/topi/python/topi/cuda/__init__.py @@ -11,6 +11,7 @@ from .reduction import schedule_reduce from .softmax import schedule_softmax from .injective import schedule_injective, schedule_elemwise, schedule_broadcast +from .dense import schedule_dense from .pooling import schedule_pool, schedule_global_pool from .extern import schedule_extern from .nn import schedule_lrn, schedule_l2_normalize diff --git a/topi/python/topi/rocm/dense.py b/topi/python/topi/rocm/dense.py index a8c033f0bd73..6fca7cd79656 100644 --- a/topi/python/topi/rocm/dense.py +++ b/topi/python/topi/rocm/dense.py @@ -14,18 +14,19 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# pylint: disable=invalid-name, unused-variable +# pylint: disable=invalid-name, unused-variable, unused-argument """Schedule for dense operator""" from __future__ import absolute_import as _abs import tvm +from tvm import autotvm from tvm.contrib import rocblas import topi from ..nn.dense import dense, dense_default from .. import tag from .. import generic -@dense.register("rocm") -def dense_rocm(data, weight, bias=None, out_dtype=None): +@autotvm.register_topi_compute(dense, "rocm", "direct") +def dense_rocm(cfg, data, weight, bias=None, out_dtype=None): """Dense operator for rocm backend. Parameters @@ -67,8 +68,8 @@ def dense_rocm(data, weight, bias=None, out_dtype=None): return dense_default(data, weight, bias, out_dtype) -@generic.schedule_dense.register(["rocm"]) -def schedule_dense(outs): +@autotvm.register_topi_schedule(generic.schedule_dense, "rocm", "direct") +def schedule_dense(cfg, outs): """Schedule for dense operator. Parameters @@ -85,4 +86,4 @@ def schedule_dense(outs): target = tvm.target.current_target() if target.target_name == "rocm" and "rocblas" in target.libs: return generic.schedule_extern(outs) - return topi.cuda.schedule_dense(outs) + return topi.cuda.schedule_dense(cfg, outs)