Skip to content

Commit

Permalink
[ONNX]MaxRoiPool, Mod & Xor op support added (apache#5729)
Browse files Browse the repository at this point in the history
  • Loading branch information
siju-samuel authored and trevor-m committed Jun 18, 2020
1 parent b4175cd commit 98147a7
Show file tree
Hide file tree
Showing 2 changed files with 165 additions and 1 deletion.
34 changes: 33 additions & 1 deletion python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -501,6 +501,20 @@ def _impl_v1(cls, inputs, attr, params):
return _op.nn.dense(inputs[0], input_1_t)


class Mod(OnnxOpConverter):
""" Operator converter for Mod.
"""

@classmethod
def _impl_v1(cls, inputs, attr, params):
assert len(inputs) == 2, "Mod op take 2 inputs, {} given".format(len(inputs))
if attr['fmod'] == 1:
op_name = "floor_mod"
else:
op_name = "mod"
return AttrCvt(op_name)(inputs, {}, params)


class MaxPool(Pool):
""" Operator converter for MaxPool
"""
Expand Down Expand Up @@ -1660,8 +1674,23 @@ def _impl_v1(cls, inputs, attr, params):
return _op.topk(inputs[0], k=K, axis=axis)


class MaxRoiPool(OnnxOpConverter):
"""Operator converter for MaxRoiPool.
"""
@classmethod
def _impl_v1(cls, inputs, attr, params):
assert len(inputs) == 2, "MMaxRoiPool op take 2 inputs, {} given".format(len(inputs))

data = inputs[0]
rois = inputs[1]
pooled_shape = attr.get("pooled_shape")
spatial_scale = attr.get("spatial_scale", 1.0)

return _vision.roi_pool(data, rois, pooled_shape, spatial_scale)


class RoiAlign(OnnxOpConverter):
"""Operator converter for TopK
"""Operator converter for RoiAlign.
"""
@classmethod
def _impl_v1(cls, inputs, attr, params):
Expand Down Expand Up @@ -1778,6 +1807,8 @@ def _get_convert_map(opset):
'SoftPlus': SoftPlus.get_converter(opset),
'Gemm': Gemm.get_converter(opset),
'MatMul': MatMul.get_converter(opset),
'Mod': Mod.get_converter(opset),
'Xor': Renamer('logical_xor'),

# defs/nn
'AveragePool': AveragePool.get_converter(opset),
Expand All @@ -1797,6 +1828,7 @@ def _get_convert_map(opset):
'LSTM': LSTM.get_converter(opset),

# defs/vision
'MaxRoiPool': MaxRoiPool.get_converter(opset),
'RoiAlign': RoiAlign.get_converter(opset),

# defs/reduction
Expand Down
132 changes: 132 additions & 0 deletions tests/python/frontend/onnx/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -2250,6 +2250,135 @@ def test_pooling():
auto_pad='SAME_UPPER')


def verify_mod(x_shape, y_shape, fmod, dtype='float32'):
x_np = np.random.uniform(size=x_shape).astype(dtype)
y_np = np.random.uniform(size=y_shape).astype(dtype)
y_np = np.where(y_np==0, 1, y_np) #remove 0's to avoid division by zero error

if fmod:
np_out = np.fmod(x_np, y_np)
else:
np_out = np.mod(x_np, y_np)

out_shape = np_out.shape
mod_node = helper.make_node("Mod",
inputs=["x", "y"],
outputs=["z"],
fmod=fmod)

onnx_dtype = TensorProto.FLOAT if dtype == "float32" else TensorProto.INT32
graph = helper.make_graph([mod_node],
"mod_test",
inputs=[helper.make_tensor_value_info("x",
onnx_dtype, list(x_shape)),
helper.make_tensor_value_info("y",
onnx_dtype, list(y_shape))],
outputs=[helper.make_tensor_value_info("z",
onnx_dtype, list(out_shape))])
model = helper.make_model(graph, producer_name='mod_test')

for target, ctx in ctx_list():
tvm_out = get_tvm_output(
model, [x_np, y_np], target, ctx, out_shape)
tvm.testing.assert_allclose(np_out, tvm_out, rtol=1e-5, atol=1e-5)


def test_mod():
# Mod
verify_mod(x_shape=[1, 32, 32], y_shape=[1, 32, 32], fmod=0)

verify_mod(x_shape=[1, 32, 32], y_shape=[1, 1, 32], fmod=0, dtype="int32")

# fmod
verify_mod(x_shape=[1, 1, 32], y_shape=[1, 32, 32], fmod=1)

verify_mod(x_shape=[1, 32, 32], y_shape=[1, 32, 32], fmod=1, dtype="int32")


def verify_xor(x_shape, y_shape):
x_np = np.random.choice(a=[False, True], size=x_shape).astype("bool")
y_np = np.random.choice(a=[False, True], size=y_shape).astype("bool")

np_out = np.logical_xor(x_np, y_np)
out_shape = np_out.shape

xor_node = helper.make_node("Xor",
inputs=["x", "y"],
outputs=["z"])

onnx_dtype = TensorProto.BOOL
graph = helper.make_graph([xor_node],
"xor_test",
inputs=[helper.make_tensor_value_info("x",
onnx_dtype, list(x_shape)),
helper.make_tensor_value_info("y",
onnx_dtype, list(y_shape))],
outputs=[helper.make_tensor_value_info("z",
onnx_dtype, list(out_shape))])
model = helper.make_model(graph, producer_name='xor_test')

for target, ctx in ctx_list():
tvm_out = get_tvm_output(
model, [x_np, y_np], target, ctx, out_shape)
tvm.testing.assert_allclose(np_out, tvm_out, rtol=1e-5, atol=1e-5)


def test_xor():
# XOR
verify_xor(x_shape=[1, 32, 32], y_shape=[1, 32, 32])

# Xor broadcast
verify_xor(x_shape=[1, 32, 32], y_shape=[1, 1, 32])


def verify_max_roi_pool(x_shape, rois_shape, pooled_shape, spatial_scale, out_shape):
x_np = np.random.uniform(size=x_shape).astype('float32')
rois_np = np.random.uniform(size=rois_shape).astype('float32')

if spatial_scale is None:
pool_node = helper.make_node("MaxRoiPool",
inputs=["x", "rois"],
outputs=["y"],
pooled_shape=pooled_shape)
else:
pool_node = helper.make_node("MaxRoiPool",
inputs=["x", "rois"],
outputs=["y"],
pooled_shape=pooled_shape,
spatial_scale=spatial_scale)

graph = helper.make_graph([pool_node],
"pool_test",
inputs=[helper.make_tensor_value_info("x",
TensorProto.FLOAT, list(x_shape)),
helper.make_tensor_value_info("rois",
TensorProto.FLOAT, list(rois_shape))],
outputs=[helper.make_tensor_value_info("y",
TensorProto.FLOAT, list(out_shape))])

model = helper.make_model(graph, producer_name='pool_test')

onnx_out = get_onnxruntime_output(model, [x_np, rois_np], 'float32')[0]
for target, ctx in ctx_list():
tvm_out = get_tvm_output(
model, [x_np, rois_np], target, ctx, out_shape)
tvm.testing.assert_allclose(onnx_out, tvm_out, rtol=1e-5, atol=1e-5)


def test_max_roi_pool():
verify_max_roi_pool(x_shape=[1, 3, 6, 6],
rois_shape=[3, 5],
pooled_shape=[1, 1],
spatial_scale=None,
out_shape=[3, 3, 1, 1])

verify_max_roi_pool(x_shape=[1, 3, 10, 10],
rois_shape=[4, 5],
pooled_shape=[2, 2],
spatial_scale=2.0,
out_shape=[4, 3, 2, 2])


def verify_lppool(x_shape, kernel_shape, p, strides, pads, out_shape, auto_pad="NOTSET"):
x_np = np.random.uniform(size=x_shape).astype('float32')

Expand Down Expand Up @@ -2739,4 +2868,7 @@ def verify_roi_align(input_dims, num_roi, output_height, output_width, sampling_
test_resize()
test_nonzero()
test_topk()
test_mod()
test_xor()
test_max_roi_pool()
test_roialign()

0 comments on commit 98147a7

Please sign in to comment.