Skip to content

Commit

Permalink
[TIR] Expose shift_left and shift_right to Python (#12584)
Browse files Browse the repository at this point in the history
This PR exposes the following TIR operation in python:

- `shift_left`: tested [here](https://github.com/apache/tvm/blob/1afd0593956066635ee49297b731726c9218c91c/tests/python/unittest/test_tir_transform_simplify.py#L487)
- `shift_right`: add new unittest

Co-authored-by: yongwww <yongcale@gmail.com>
  • Loading branch information
cyx-6 and yongwww authored Aug 25, 2022
1 parent b8fbfe2 commit cd8fd91
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 1 deletion.
2 changes: 1 addition & 1 deletion python/tvm/tir/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@
from .op import likely, isnan, isnullptr, isfinite, isinf, copysign
from .op import div, indexdiv, indexmod, truncdiv, truncmod, floordiv, floormod, ceildiv
from .op import comm_reducer, min, max, sum
from .op import q_multiply_shift
from .op import q_multiply_shift, shift_left, shift_right
from .op import TVMBackendAllocWorkspace, TVMBackendFreeWorkspace

from .schedule import StmtSRef, BlockScope, ScheduleState, Schedule, ScheduleError
Expand Down
38 changes: 38 additions & 0 deletions python/tvm/tir/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -1604,6 +1604,44 @@ def q_multiply_shift(x, y, q, s):
return call_intrin("int32", "tir.q_multiply_shift", x, y, q, s)


def shift_left(x, y, span=None):
"""Return the result of x left shifted by y bits.
Parameters
----------
x : PrimExpr
Input argument.
y : PrimExpr
Input argument.
Returns
-------
z : PrimExpr
The result.
"""
return _ffi_api.left_shift(x, y, span)


def shift_right(x, y, span=None):
"""Return the result of x right shifted by y bits.
Parameters
----------
x : PrimExpr
Input argument.
y : PrimExpr
Input argument.
Returns
-------
z : PrimExpr
The result.
"""
return _ffi_api.right_shift(x, y, span)


def fmod(x, y):
"""Return the remainder of x divided by y with the same sign as x.
Expand Down
16 changes: 16 additions & 0 deletions tests/python/unittest/test_tir_op_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,20 @@ def test_tir_op_vectorcombine():
assert expr.op.name == "tir.vectorcombine"


def test_tir_op_shift_left():
x = tir.Var("x", dtype="int32")
y = tir.Var("x", dtype="int32")
expr = tir.shift_left(x, y)
assert expr.op.name == "tir.shift_left"


def test_tir_op_shift_right():
x = tir.Var("x", dtype="int32")
y = tir.Var("x", dtype="int32")
expr = tir.shift_right(x, y)
assert expr.op.name == "tir.shift_right"


def test_tir_op_TVMBackendAllocWorkspace():
expr = tir.TVMBackendAllocWorkspace(0, 1, 2, 3, 4)
assert expr.op.name == "tir.TVMBackendAllocWorkspace"
Expand Down Expand Up @@ -154,5 +168,7 @@ def test_tir_op_TVMBackendFreeWorkspace():
test_tir_op_vectorlow()
test_tir_op_vectorhigh()
test_tir_op_vectorcombine()
test_tir_op_shift_left()
test_tir_op_shift_right()
test_tir_op_TVMBackendAllocWorkspace()
test_tir_op_TVMBackendFreeWorkspace()

0 comments on commit cd8fd91

Please sign in to comment.