Skip to content

Commit

Permalink
{relay,topi}.reinterpret support
Browse files Browse the repository at this point in the history
= Motivation

It's useful to expose the tvm::reinterpret functionality to Relay/TOPI users, as
this allows them to build (fused) operators leveraging the bitwise
reinterpretation of an operator. An example is approximate transcendental
functions, which can be implemented similar to:

```.py
    def C(x):
        return relay.expr.const(x, "float32")

    def approx_exp(x):
        x = relay.minimum(relay.maximum(x, C(-88.0)), C(88.0))
        x = C(127.0) + x * C(1.44269504)
        xf = relay.floor(x)
        i = relay.cast(xf, "int32")
        x = x - xf
        Y = C(0.99992522) + x * (C(0.69583354) + x * (C(0.22606716) + x * C(0.078024523)))
        exponent = relay.left_shift(i, relay.expr.const(23, "int32"))
        exponent = relay.reinterpret(exponent, "float32")
        return exponent * Y

    def approx_sigmoid(x):
        # <2.0e-5 absolute error over [-5, 5]
        y = approx_exp(x)
        return y / (y + C(1.0))

    def approx_tanh(x):
        # <4.0e-5 absolute error over [-5, 5]
        x = x * C(2.0)
        y = approx_exp(x)
        return (y - C(1.0)) / (y + C(1.0))
```

See unit tests for implementations of these approximate transendentals.
  • Loading branch information
ajtulloch committed Jul 23, 2019
1 parent 19eb829 commit f0470d8
Show file tree
Hide file tree
Showing 9 changed files with 219 additions and 8 deletions.
2 changes: 2 additions & 0 deletions docs/api/python/topi.rst
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ List of operators
topi.sigmoid
topi.clip
topi.cast
topi.reinterpret
topi.transpose
topi.flip
topi.strided_slice
Expand Down Expand Up @@ -133,6 +134,7 @@ topi
.. autofunction:: topi.sigmoid
.. autofunction:: topi.clip
.. autofunction:: topi.cast
.. autofunction:: topi.reinterpret
.. autofunction:: topi.transpose
.. autofunction:: topi.flip
.. autofunction:: topi.strided_slice
Expand Down
1 change: 1 addition & 0 deletions python/tvm/relay/op/_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
_reg.register_schedule("repeat", schedule_broadcast)
_reg.register_schedule("tile", schedule_broadcast)
_reg.register_schedule("cast", schedule_injective)
_reg.register_schedule("reinterpret", schedule_injective)
_reg.register_schedule("strided_slice", schedule_injective)
_reg.register_schedule("slice_like", schedule_injective)
_reg.register_schedule("split", schedule_injective)
Expand Down
20 changes: 20 additions & 0 deletions python/tvm/relay/op/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,26 @@ def cast(data, dtype):
return _relay_make.cast(data, dtype)


def reinterpret(data, dtype):
"""Reinterpret input tensor to data type.
Parameters
----------
data : relay.Expr
The input data to the operator.
dtype: str
The target data type
Returns
-------
result : relay.Expr
The reinterpreted result.
"""
from .. import _make as _relay_make
return _relay_make.reinterpret(data, dtype)


def expand_dims(data, axis, num_newaxis=1):
"""Insert `num_newaxis` axises at the position given by `axis`.
Expand Down
46 changes: 46 additions & 0 deletions src/relay/op/tensor/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,52 @@ RELAY_REGISTER_OP("cast")
.set_attr<TOpPattern>("TOpPattern", kElemWise)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", ElemwiseArbitraryLayout);

// relay.reinterpret
bool ReinterpretRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 2);
const auto* data = types[0].as<TensorTypeNode>();
if (data == nullptr) {
CHECK(types[0].as<IncompleteTypeNode>())
<< "Reinterpret: expect input type to be TensorType but get " << types[0];
return false;
}
const auto* param = attrs.as<CastAttrs>();
reporter->Assign(types[1], TensorTypeNode::make(data->shape, param->dtype));
return true;
}

Array<Tensor> ReinterpretCompute(const Attrs& attrs, const Array<Tensor>& inputs,
const Type& out_type, const Target& target) {
const CastAttrs* param = attrs.as<CastAttrs>();
CHECK(param != nullptr);
DataType dtype = param->dtype;
return {topi::reinterpret(inputs[0], dtype)};
}

Expr MakeReinterpret(Expr data, DataType dtype) {
auto attrs = make_node<CastAttrs>();
attrs->dtype = dtype;
static const Op& op = Op::Get("reinterpret");
return CallNode::make(op, {data}, Attrs(attrs), {});
}

TVM_REGISTER_API("relay._make.reinterpret").set_body([](const TVMArgs& args, TVMRetValue* rv) {
runtime::detail::unpack_call<Expr, 2>(MakeReinterpret, args, rv);
});

RELAY_REGISTER_OP("reinterpret")
.describe(R"code(Reinterpret the data into a new data type.
)code" TVM_ADD_FILELINE)
.set_num_inputs(1)
.set_attrs_type_key("relay.attrs.CastAttrs")
.add_argument("data", "Tensor", "The input tensor.")
.set_support_level(3)
.add_type_rel("Reinterpret", CastRel)
.set_attr<FTVMCompute>("FTVMCompute", ReinterpretCompute)
.set_attr<TOpPattern>("TOpPattern", kElemWise)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", ElemwiseArbitraryLayout);

// relay.expand_dims
TVM_REGISTER_NODE_TYPE(ExpandDimsAttrs);

Expand Down
64 changes: 64 additions & 0 deletions tests/python/relay/test_op_level3.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ def test_cast():
assert "dtype=" in yy.astext()
assert yy.checked_type == relay.TensorType((8, 9, 4), "int32")


def test_clip():
a = relay.var("a", relay.TensorType((10, 4), "float32"))
y = relay.clip(a, 1., 4.)
Expand All @@ -88,6 +89,69 @@ def test_clip():
np.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=0.01)


def test_reinterpret():
a = relay.var("a", relay.TensorType((1000, 4), "float32"))
y = relay.reinterpret(a, "int32")
yy = run_infer_type(y)
assert yy.checked_type == relay.TensorType((1000, 4), "int32")

data = np.random.randn(1000, 4).astype('float32') * 1000
intrp = create_executor()
op_res = intrp.evaluate(y, {a: relay.const(data)})
ref_res = data.view("int32")
np.testing.assert_equal(op_res.asnumpy(), ref_res)


def test_approximate_transcendental():
def C(x):
return relay.expr.const(x, "float32")

def approx_exp(x):
# An approximation derived from Opus,
# https://github.com/xiph/opus/blob/c1c247/celt/mathops.h#L147-L165
x = relay.minimum(relay.maximum(x, C(-88.0)), C(88.0))
x = C(127.0) + x * C(1.44269504)
xf = relay.floor(x)
i = relay.cast(xf, "int32")
x = x - xf
Y = C(0.99992522) + x * (C(0.69583354) + x * (C(0.22606716) + x * C(0.078024523)))
exponent = relay.left_shift(i, relay.expr.const(23, "int32"))
exponent = relay.reinterpret(exponent, "float32")
return exponent * Y

def approximate_sigmoid(x):
y = approx_exp(x)
return y / (y + C(1.0))

def approximate_tanh(x):
x = x * C(2.0)
y = approx_exp(x)
return (y - C(1.0)) / (y + C(1.0))

a = relay.var("a", relay.TensorType((1000,), "float32"))
y = approximate_sigmoid(a)
yy = run_infer_type(y)
assert yy.checked_type == relay.TensorType((1000,), "float32")
data = np.linspace(-5, 5, 1000).astype("float32")
intrp = create_executor()
op_res = intrp.evaluate(y, {a: relay.const(data)})

def reference_sigmoid(x):
return np.exp(-np.logaddexp(0, -x))
np.testing.assert_allclose(op_res.asnumpy(), reference_sigmoid(data), atol=2e-5, rtol=1e-9)

y = approximate_tanh(a)
yy = run_infer_type(y)
assert yy.checked_type == relay.TensorType((1000,), "float32")
data = np.linspace(-5, 5, 1000).astype("float32")
intrp = create_executor()
op_res = intrp.evaluate(y, {a: relay.const(data)})

def reference_tanh(x):
return np.tanh(x)
np.testing.assert_allclose(op_res.asnumpy(), reference_tanh(data), atol=4e-5, rtol=1e-9)


def test_squeeze():
def verify_squeeze(shape, dtype, axis):
x = relay.var("x", relay.TensorType(shape, dtype))
Expand Down
36 changes: 28 additions & 8 deletions topi/include/topi/elemwise.h
Original file line number Diff line number Diff line change
Expand Up @@ -269,14 +269,34 @@ inline Tensor cast(const Tensor& x,
}

/*!
* \brief Creates an operation that sum each element of a tensor
*
* \param xs The input tensor array
* \param name The name of the operation
* \param tag The tag to mark the operation
*
* \return A Tensor whose op member is the sum operation
*/
* \brief Reinterpret each element of x to the given type.
* \param x The input tensor
* \param type The type to cast to
* \param name The name of the operation
* \param tag The tag to mark the operation
*
* \return A Tensor whose op member is the reinterpret operation
*/
inline Tensor reinterpret(const Tensor& x, Type type, std::string name = "tensor",
std::string tag = kElementWise) {
return compute(x->shape,
[&](const Array<Var>& i) {
return tvm::ir::Call::make(type, "reinterpret", {x(i)},
tvm::ir::Call::PureIntrinsic);
},
name, tag);
}

/*!
* \brief Creates an operation that sum each element of a tensor
*
* \param xs The input tensor array
* \param name The name of the operation
* \param tag The tag to mark the operation
*
* \return A Tensor whose op member is the sum operation
*/
inline Tensor elemwise_sum(const Array<Tensor>& xs,
std::string name = "T_elemwise_sum",
std::string tag = kElementWise) {
Expand Down
18 changes: 18 additions & 0 deletions topi/python/topi/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,3 +343,21 @@ def cast(x, dtype):
return tvm.compute(
x.shape, lambda *i: x(*i).astype(dtype), tag=tag.ELEMWISE)
return tvm.make._cast(dtype, x)

def reinterpret(x, dtype):
"""Reinterpret input to specified data type.
Parameters
----------
x : tvm.Tensor
Input argument.
dtype : str
Data type.
Returns
-------
y : tvm.Tensor
The result.
"""
return cpp.reinterpret(x, dtype)
6 changes: 6 additions & 0 deletions topi/src/topi.cc
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,12 @@ TVM_REGISTER_GLOBAL("topi.cast")
*rv = cast(args[0], args[1]);
});


TVM_REGISTER_GLOBAL("topi.reinterpret")
.set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = reinterpret(args[0], args[1]);
});

TVM_REGISTER_GLOBAL("topi.elemwise_sum")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = elemwise_sum(args[0]);
Expand Down
34 changes: 34 additions & 0 deletions topi/tests/python/test_topi_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,29 @@ def check_device(device):
check_device(device)


def verify_reinterpret(in_shape, in_dtype, out_dtype, generator):
A = tvm.placeholder(shape=in_shape, name="A", dtype=in_dtype)
B = topi.reinterpret(A, out_dtype)
def check_device(device):
ctx = tvm.context(device, 0)
if not ctx.exist:
print("Skip because %s is not enabled" % device)
return
print("Running on target: %s" % device)
with tvm.target.create(device):
s = topi.generic.schedule_elemwise(B)
foo = tvm.build(s, [A, B], device, name="reinterpret")
data_npy = generator(in_shape).astype(in_dtype)
out_npy = data_npy.view(B.dtype)
data_nd = tvm.nd.array(data_npy, ctx)
out_nd = tvm.nd.array(np.empty(in_shape).astype(B.dtype), ctx)
foo(data_nd, out_nd)
np.testing.assert_equal(out_nd.asnumpy(), out_npy)

for device in get_all_backend():
check_device(device)


def verify_transpose(in_shape, axes):
A = tvm.placeholder(shape=in_shape, name="A")
B = topi.transpose(A, axes)
Expand Down Expand Up @@ -434,6 +457,17 @@ def test_expand_dims():
verify_expand_dims((3, 10), (1, 3, 10), -3, 1)


def test_reinterpret():
verify_reinterpret((1000,), "float32", "int32",
lambda shape: np.random.randn(*shape) * 1000)
verify_reinterpret((1000,), "float16", "int16",
lambda shape: np.random.randn(*shape) * 100)
verify_reinterpret((1000,), "int16", "uint16",
lambda shape: np.random.randint(-1000, 1000, size=shape))
verify_reinterpret((1000,), "uint32", "int32",
lambda shape: np.random.randint(0, 2 ** 32 - 1, size=shape))


def test_transpose():
verify_transpose((3, 10, 2), (1, 0, 2))
verify_transpose((3, 10, 5), (2, 0, 1))
Expand Down

0 comments on commit f0470d8

Please sign in to comment.