From 0317aca1d11c2c9de65c45f2ccd3b9159b18ea4f Mon Sep 17 00:00:00 2001 From: Cody Yu Date: Fri, 24 Jul 2020 21:09:06 -0700 Subject: [PATCH] [TOPI] Fix CUDA Library Tuning (#6132) --- python/tvm/autotvm/task/space.py | 7 +++++-- topi/python/topi/cuda/conv2d.py | 7 ++++++- 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/python/tvm/autotvm/task/space.py b/python/tvm/autotvm/task/space.py index fbf474fc4df7..49376613992f 100644 --- a/python/tvm/autotvm/task/space.py +++ b/python/tvm/autotvm/task/space.py @@ -33,6 +33,7 @@ import numpy as np from tvm.te import schedule, thread_axis +from tvm.tir import expr from tvm.autotvm.util import get_const_int Axis = namedtuple('Axis', ['space', 'index']) @@ -733,10 +734,12 @@ def add_flop(self, flop): Parameters --------- - flop: int or float + flop: int or float or IntImm or FloatImm number of float operations """ - self.flop += flop + if isinstance(flop, (expr.IntImm, expr.FloatImm)): + flop = flop.value + self.flop += float(flop) def raise_error(self, msg): """register error in config diff --git a/topi/python/topi/cuda/conv2d.py b/topi/python/topi/cuda/conv2d.py index d98d630d6415..973c21642757 100644 --- a/topi/python/topi/cuda/conv2d.py +++ b/topi/python/topi/cuda/conv2d.py @@ -18,6 +18,7 @@ """Compute definition for conv2d with cuda backend""" from tvm import te from tvm import autotvm +from tvm.autotvm.task.space import OtherOptionEntity from tvm.contrib import cudnn from .. import nn, generic @@ -99,6 +100,10 @@ def conv2d_cudnn(cfg, data, kernel, strides, padding, dilation, groups=1, else: dtype = data.dtype + cfg.define_knob('algo', range(8)) + if cfg.is_fallback: # Let CUDNN choose the best algo + cfg['algo'] = OtherOptionEntity(-1) + return cudnn.conv_forward(data, kernel, [pt, pl], # cudnn padding pt, pl on both sides of input @@ -106,7 +111,7 @@ def conv2d_cudnn(cfg, data, kernel, strides, padding, dilation, groups=1, [dilation_h, dilation_w], conv_mode=1, tensor_format=tensor_format, - algo=-1, # let CUDNN choose the best algo + algo=cfg['algo'].val, conv_dtype=dtype, groups=groups)