diff --git a/python/tvm/topi/math.py b/python/tvm/topi/math.py index fb306f9e599b..61b39aad9114 100644 --- a/python/tvm/topi/math.py +++ b/python/tvm/topi/math.py @@ -450,7 +450,6 @@ def round(x): return te.compute(x.shape, lambda *i: te.round(x(*i))) -@tvm.te.tag_scope(tag=tag.ELEMWISE) def log(x): """Take logarithm of input x. @@ -464,10 +463,11 @@ def log(x): y : tvm.te.Tensor The result. """ - return te.compute(x.shape, lambda *i: te.log(x(*i))) + if x.dtype.startswith("int"): + x = te.compute(x.shape, lambda *i: x(*i).astype("float32")) + return te.compute(x.shape, lambda *i: te.log(x(*i)), tag=tag.ELEMWISE) -@tvm.te.tag_scope(tag=tag.ELEMWISE) def log2(x): """Take logarithm to the base 2 of input x. @@ -481,7 +481,9 @@ def log2(x): y : tvm.te.Tensor The result. """ - return te.compute(x.shape, lambda *i: te.log2(x(*i))) + if x.dtype.startswith("int"): + x = te.compute(x.shape, lambda *i: x(*i).astype("float32")) + return te.compute(x.shape, lambda *i: te.log2(x(*i)), tag=tag.ELEMWISE) def log10(x):