diff --git a/python/tvm/autotvm/task/space.py b/python/tvm/autotvm/task/space.py index 378aaa387c43..49376613992f 100644 --- a/python/tvm/autotvm/task/space.py +++ b/python/tvm/autotvm/task/space.py @@ -33,6 +33,7 @@ import numpy as np from tvm.te import schedule, thread_axis +from tvm.tir import expr from tvm.autotvm.util import get_const_int Axis = namedtuple('Axis', ['space', 'index']) @@ -736,9 +737,9 @@ def add_flop(self, flop): flop: int or float or IntImm or FloatImm number of float operations """ - if not isinstance(flop, (int, float)): + if isinstance(flop, (expr.IntImm, expr.FloatImm)): flop = flop.value - self.flop += flop + self.flop += float(flop) def raise_error(self, msg): """register error in config diff --git a/topi/python/topi/cuda/conv2d.py b/topi/python/topi/cuda/conv2d.py index 5b84a424eac0..973c21642757 100644 --- a/topi/python/topi/cuda/conv2d.py +++ b/topi/python/topi/cuda/conv2d.py @@ -111,7 +111,7 @@ def conv2d_cudnn(cfg, data, kernel, strides, padding, dilation, groups=1, [dilation_h, dilation_w], conv_mode=1, tensor_format=tensor_format, - algo=cfg['algo'], + algo=cfg['algo'].val, conv_dtype=dtype, groups=groups)