Skip to content

Commit

Permalink
[TOPI] Add left_shift, right_shift, clip, cast (#504)
Browse files Browse the repository at this point in the history
* [TOPI] Add left_shift, right_shift, clip, cast

* [TOPI] Add test

* [TOPI] Fix
  • Loading branch information
ZihengJiang authored and tqchen committed Oct 3, 2017
1 parent 4fdef3a commit af8cbdd
Show file tree
Hide file tree
Showing 2 changed files with 127 additions and 0 deletions.
84 changes: 84 additions & 0 deletions topi/python/topi/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,3 +122,87 @@ def sigmoid(x):
The result.
"""
return tvm.compute(x.shape, lambda *i: tvm.sigmoid(x(*i)))


@tvm.tag_scope(tag=tag.ELEMWISE)
def left_shift(x, n):
"""Take n bits left shift of input x.
Parameters
----------
x : tvm.Tensor
Input argument.
n : int
Number of bits.
Returns
-------
y : tvm.Tensor
The result.
"""
return tvm.compute(x.shape, lambda *i: x(*i) << n)


@tvm.tag_scope(tag=tag.ELEMWISE)
def right_shift(x, n):
"""Take n bits right shift of input x.
Parameters
----------
x : tvm.Tensor
Input argument.
n : int
Number of bits.
Returns
-------
y : tvm.Tensor
The result.
"""
return tvm.compute(x.shape, lambda *i: x(*i) >> n)


@tvm.tag_scope(tag=tag.ELEMWISE)
def clip(x, a_min, a_max):
"""Clip (limit) the values in an array. Given an interval, values
outside the interval are clipped to the interval edges.
Parameters
----------
x : tvm.Tensor
Input argument.
a_min : int or float
Minimum value.
a_max : int or float
Maximum value.
Returns
-------
y : tvm.Tensor
The result.
"""
def _compute(*indices):
value = x(*indices)
const_min = tvm.const(a_min, value.dtype)
const_max = tvm.const(a_max, value.dtype)
return tvm.max(tvm.min(value, const_max), const_min)
return tvm.compute(x.shape, _compute)


@tvm.tag_scope(tag=tag.ELEMWISE)
def cast(x, dtype):
"""Cast input to specified data type.
Parameters
----------
x : tvm.Tensor
Input argument.
dtype : str
Data type.
Returns
-------
y : tvm.Tensor
The result.
"""
return tvm.compute(x.shape, lambda *i: x(*i).astype(dtype))
43 changes: 43 additions & 0 deletions topi/tests/python/test_topi_clip.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
"""Test code for clip operator"""
import numpy as np
import tvm
import topi
from topi.util import get_const_tuple
from tvm.contrib.pickle_memoize import memoize


def verify_clip(N, a_min, a_max, dtype):
A = tvm.placeholder((N, N), dtype=dtype, name='A')
B = topi.clip(A, a_min, a_max)
s = tvm.create_schedule([B.op])

# use memoize to pickle the test data for next time use
@memoize("topi.tests.test_topi_clip")
def get_ref_data():
a_np = np.random.uniform(a_min*2, a_max*2, size=(N, N)).astype(dtype)
b_np = np.clip(a_np, a_min, a_max)
return a_np, b_np
a_np, b_np = get_ref_data()

def check_device(device):
if not tvm.module.enabled(device):
print("Skip because %s is not enabled" % device)
return
ctx = tvm.cpu(0) if device == "llvm" else tvm.gpu(0)
a = tvm.nd.array(a_np, ctx)
b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=dtype), ctx)
f = tvm.build(s, [A, B], device, name="clip")
f(a, b)
np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5)

for device in ['llvm']:
check_device(device)

def test_clip():
verify_clip(1024, -127, 127, 'int8')
verify_clip(1024, -127, 127, 'int16')
verify_clip(1024, -127, 127, 'float32')


if __name__ == "__main__":
test_clip()

0 comments on commit af8cbdd

Please sign in to comment.