Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions python/tvm/topi/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,7 +450,6 @@ def round(x):
return te.compute(x.shape, lambda *i: te.round(x(*i)))


@tvm.te.tag_scope(tag=tag.ELEMWISE)
def log(x):
"""Take logarithm of input x.

Expand All @@ -464,10 +463,11 @@ def log(x):
y : tvm.te.Tensor
The result.
"""
return te.compute(x.shape, lambda *i: te.log(x(*i)))
if x.dtype.startswith("int"):
x = te.compute(x.shape, lambda *i: x(*i).astype("float32"))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The intermediate te.compute for casting should be tagged as elemwise to allow for operator fusion. Without this tag, the cast operation might not be inlined, potentially leading to suboptimal performance.

Suggested change
x = te.compute(x.shape, lambda *i: x(*i).astype("float32"))
x = te.compute(x.shape, lambda *i: x(*i).astype("float32"), tag=tag.ELEMWISE)

return te.compute(x.shape, lambda *i: te.log(x(*i)), tag=tag.ELEMWISE)


@tvm.te.tag_scope(tag=tag.ELEMWISE)
def log2(x):
"""Take logarithm to the base 2 of input x.

Expand All @@ -481,7 +481,9 @@ def log2(x):
y : tvm.te.Tensor
The result.
"""
return te.compute(x.shape, lambda *i: te.log2(x(*i)))
if x.dtype.startswith("int"):
x = te.compute(x.shape, lambda *i: x(*i).astype("float32"))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Similar to the log function, the intermediate te.compute for casting here should be tagged as elemwise to enable operator fusion. This ensures that the cast can be inlined by the scheduler.

Suggested change
x = te.compute(x.shape, lambda *i: x(*i).astype("float32"))
x = te.compute(x.shape, lambda *i: x(*i).astype("float32"), tag=tag.ELEMWISE)

return te.compute(x.shape, lambda *i: te.log2(x(*i)), tag=tag.ELEMWISE)


def log10(x):
Expand Down
Loading